티스토리 뷰

Arxiv Link

 

Contributions: propose a novel architecture which possibly be better MLP-alternative, having surely better interpretability and probably better accuracy. the new architecture has opened up the possibility of learning activation function itself instead of the traditional approach of learning weights.

Background: Kolmogorov Arnold Representation theorem

Machine learning에서, MLP는 일반적으로 univarsial approximator로 해석된다 (any function can be approximated with the infinite-width MLP). 다만, MLP (Multi-Layer-Perceptron)는 univariate function을 근사하는 데에 한계를 보일 수 있다. 예를 들어, 아래와 같이 N-dimensional 한 input에 대해서, sin, exp 등의 연산을 가지는 함수는 MLP로 근사하기가 어렵다 (가능하지만, 많은 parameter를 필요로 할 수 있다).

$$f(x_1,\cdots, x_N) = \exp(\frac{1}{N} \sum^N_{i=1} \sin^2(x_i))$$

한편, Kolmogorov-Arnold Representation theorem은 어떠한 함수가 bounded domain (its domain is not infinite) 이기만 하면 어떤 single-variable continuous function들의 composition으로 표현할 수 있음을 이야기한다. Kolmogorov-Arnold representation theorem의 statement는 다음과 같다:

어떤 smooth 한 함수 $f: [0, 1]^n \rightarrow \mathbb{R}$ 이 있을 때:

$$f(\mathbf{x}) = f(x_1,\cdots,x_n)= \sum^{2n+1}_{q=1} \Phi_q (\sum^n_{p=1} \phi_{q,p} (x_p))$$

where $\phi_{q,p}:[0,1] \rightarrow \mathbb{R}$ and $\Phi_{q}:\mathbb{R} \rightarrow \mathbb{R}$. 이는 어떠한 finite 한 함수 $f$는 (theorem에서는 [0,1]로 bounded 되어 있지만 scale and shift를 하면은 되니까)  어떤 univariate 한 함수들의 sum-of-sum으로 표현 (represented) 될 수 있다는 의미이다 (아래 그림의 (b))

Motivation

limitation of the MLP structure: complex function approximation

MLP가 univarsial approximation theorem에 기반하여 어떤 함수를 근사하는 데 사용되었다면, KAN은 위 Kolomogorov-Arnold representation theorem에 근거하여 함수를 근사하는 데에 사용한다. 이를 통해서, 위에서 언급한 exp-sin 예시와 같은 복잡한 univariate function을 포함하는 어떤 함수를 좀 더 효과적으로 근사할 수 있을 것이라 기대할 수 있다.

limitation of the MLP structure 2: interpretability

neural-network를 통해 어떤 function이 근사적으로 "어떤 식을 가지는지"를 알고 싶다고 할 때, MLP는 그 해석이 직관적으로 표현되지 않을 수 있다. 하지만, Kolmogorov-Arnold Representation theorem은 각 input의 각 dimension들이 어떤 univariate function들을 타고 나중에 composition 되는지 확인하기가 쉽다. 일종의 symbolic regression 이랑 비슷한 objective를 가진다고 생각해도 좋을 듯.

limitation of Kolmogorov-Arnold Representation theorem and "physicist mind"

논문에서는 Kolmogorov-arnold representation theorem을 구성할 univariate function이 말 그대로 "any-function"이라서, fractal 구조를 가지거나 smooth 하지 않을 수도 있기 때문에 (극단적으로는, 모든 mapping에 대해서 외워버리는 특정 함수를 만들 수도 있겠다.) 외면당해 왔다고 서술한다. 다만, 물리학에서의 철학처럼 그러한 worst-case보다는 일반적인 경우---실제 물리세계는 smooth 한 function들의 composition일 것이라는 가정---를 생각하는 것을 목표로 한다고 언급한다.

 

Kolmogorov-arnold network

How the hell we can find such univariate functions?

MLP에서는 직관적으로 matrix-multiplication의 대상이 되는 weights들을 학습하지만, KAN에서는 Kolmogorov-Arnold representation theorem의 대상이 되는 univariate function (or, activation function) 들을 "학습"시켜야 한다. 이렇게 학습을 시키는 것이 목적이 되면 하나의 질문이 따라오게 된다: 이러한 activation function들을 어떻게 parameterize 할 수 있을까? 논문에서는 B-spline curve를 그 해답으로 생각한다.

Another background: spline and B-spline

Spline은 graphics, machinical engineering 등에서 종종 다루는 개념으로, 특정 점들 (control-points)을 기준으로 정의되는 어떤 곡선을 뜻한다. 그림을 그리는 소프트웨어 등에서 곡선을 그릴 때, 점들을 움직여 가면서 그리던 그것이다 (아래 그림).

Example of B-spline from Opensourc.ES(https://opensourc.es/blog/b-spline/). control-point의 움직임에 따라 곡선이 변화한다.

 

B-spline에 대한 정의를 모르더라도 KAN을 추상적으로 이해하는 데에 문제가 없으나, 여기서는 어떻게 B-spline이 정의되는지, 조정 가능한 변수는 어떤 것들이 있는지를 살펴보도록 한다.

 

이 중에서도, B-spline은 특정 점 $x$에 대해, recursive 한 형태로 정의된다. 첫 번째 단계 (0'th degree)에서의 B-spline $B^{0}(x)$ 는 다음과 같이 간단하게 정의된다 (어떤 특정 interval 사이에 있으면 1, 아니면 0):

$$B^0_i(x)= \begin{cases} 1 & \text{if } t_i \geq x < t_{i+1} \\ 0 & \text{otherwise} \end{cases}$$

여기서 $t_i$라는 것이 해당 "interval"을 결정하는데, 이 $t_i$를 knot 라고 하며, B-spline을 정의할 수 있는 파라미터이다 (e.g. $\mathbf{t} = [t_0 , ..., t_i, ..., t_n]$. 이들은 단순히 grid가 될 수도 있으며 ($\mathbf{t} = [0 , 1, 2, 3, ..., i, ..., n]$) 심지어 중복된 값을 가지도록 설계될 수 있다 ($\mathbf{t} = [0, 1, 1, 1, 2, 3, 3, 3]$). KAN에서는 이들이 단순한 grid-point로 표기된다. 재미있는 점은, 이러한 knot들은 KAN의 training 혹은 inference 중에 더 빽빽하게 만들 수도, 더 느슨하게 만들수도 있다는 것이다. 이것을 조정함에 따라 제안한 network의 accuracy가 변화할 수 있음을 논문에서 언급한다.

 

위처럼 첫 번째 단계에 해당하는 0'th degree의 B-spline을 만들었다면, p'th degree는 아래와 같이 재귀적으로 정의된다:

$$B^p_i(x) = \frac{x - t_i}{t_{i+p} - t_{i}}B^{p-1}_i(x) + \frac{t_{i+p+1} - x}{t_{i+p+1} - t_{i + 1}} B^{p-1}_{i+1}(x)$$

 

여기서, B-spline이 특정 모델에 사용되어 적절한 parameterization을 가질 수 있는 중요한 property가 있다: degree가 $n$이고, 특정 knot $\mathbf{t}$ 를 가지는 어떤 spline curve는 B-spline의 linear-combination으로 표현될 수 있다는 것이다. (엄밀히 말하면, (1) 주어진 점들을 지나는 곡선 (2) 그 곡선이 점들 사이 구간이 다항식으로 정의되는 모든 곡선을 표현할 수 있다. -> Taylor series를 생각하면 모든 곡선이라고 할 수 있을까? 잘 아시는 분들 댓글부탁..) 이 때문에 아래의 식처럼, B-spline은 어떤 곡선을 만드는 Basis-function (function version of the basis vector, informally speaking)처럼 기능할 수 있다.

$$S_{n,\mathbf{t}}(x) = \sum_{i} \alpha_i B^n_i(x)$$

KAN은 실제로 위 식에서의 $\alpha_i \in \mathbb{R}$ 를 trainable-parameter 로써 설정했다. (이후 아래의 구현예시에서 coefficient가 parameter처럼 동작하는 것을 확인할 수 있다)

 

위에서 정의한 방식에 기반하여, input $x$와 knot들에 대해서 B-spline value를 구하는 코드는 아래처럼 구현될 수 있다. (KAN의 original repository를 간단히 변형시킨 버전). 일반적인 torch code와 달리 batch-dimension이 input (`x`, `x_{eval}`)의 제일 마지막에 위치해 있음에 주의하자. 아래의 `num_spline`이 MLP에서의 hidden dimension처럼 작동한다고 생각하면 된다.

def coef2curve(x_eval, grid, coef, k):
    """
    Converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis).

    Args:
        x_eval (Tensor; [num_splines x num_samples])
        grid (Tensor; [num_splines x num_grid_points])
        coef (Tensor; [num_splines x num_coef_params])
            The B-spline coefficients. num_coef_params = num_grid_intervals + k.
        k (int): The degree of the B-spline.
        
    Returns:
        y_eval (Tensor; [num_splines x num_samples])
            The evaluated B-spline values.
    """
    # weighted sum of the splines (with coef as its weights)
    y_eval = torch.einsum('ij,ijk->ik', coef, B_batch(x_eval, grid, k, device=device))
    return y_eval


def B_batch(x, knot, k=0, device='cpu'):
	"""
    Args:
        x (Tensor; [num_splines x num_samples]
        knot (Tensor; [num_splines x num_grid_points]
    	k (int) degree of the B-spline
        
    Return:
    	spline values : (Tensor; [num_splines, num_B-spline_coefficients, num_samples]) 
        The number of B-spline coefficients = number of grid points + k - 1.
    
    """
    knot = knot.unsqueeze(dim=2).to(device)
    x = x.unsqueeze(dim=1).to(device)

    if k == 0:
        value = (x >= knot[:, :-1]) * (x < knot[:, 1:])
    else:
        B_km1 = B_batch(x[:, 0], knot=knot[:, :, 0], k=k - 1, device=device)
        value = (x - knot[:, :-(k + 1)]) / (knot[:, k:-1] - knot[:, :-(k + 1)]) * B_km1[:, :-1] + (
                    knot[:, k + 1:] - x) / (knot[:, k + 1:] - knot[:, 1:(-k)]) * B_km1[:, 1:]
    return value
    
    
>>> num_spline = 5
>>> num_sample = 100
>>> num_grid_interval = 10
>>> k = 3
>>> x_eval = torch.normal(0,1,size=(num_spline, num_sample))
>>> grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1))
>>> coef = torch.normal(0,1,size=(num_spline, num_grid_interval+k))
>>> coef2curve(x_eval, grids, coef, k=k).shape
torch.Size([5, 100])

 

 

KAN architecture

KAN의 구조는 본 글의 처음에서 언급한 Kolmogorov-arnold Representation Theorem (KRT)의 확장된 버전을 사용한다. 해당 부분을 기술하는 데에는 이전에 썼던 것들과 동일한 notation을 사용하도록 하겠다. 앞서 언급한 B-spline이 적절히 parameterized 된 곡선이라는 것을 알고 나면, KAN을 구성하는 layer와 네트워크를 이해하는 것에 큰 허들은 없다.

 

KAN은 KRT에서 기반한 "KAN layer"(whatever it is)를 계속해서 쌓은 neural-network 구조이다 (마치 MLP가 그런 것처럼). 초창기 MLP의 성공이 단순히 stack-more-layer에 있었던 것처럼, KAN도 KAN-layer를 많이 쌓는것을 목표로 하였다. 구체적으로는, 저자들은 KRT는 2-layer의 규모에 한정되어 있기 때문에, 실제로 이를 그대로 적용하면 real-world의 복잡한 함수를 근사하기에 어렵다고 주장한다. 따라서, 이들은 더 넓은 KAN layer를 더 깊게 쌓는 것을 시도했고 이것을 논문의 main-contribution 중 하나라고 말한다. 아래에서는 KAN의 building-block인 KAN Layer가 무엇인지를 이어서 다룰 것이다.

 

So, what the hell is the KAN Layer anyway?

저자는 input dimension $n_{in}$, output dimension $n_{out}$ 을 가지는 KAN layer $\Phi$ 를 다음과 같이 정의한다

$$\Phi = {\phi_{p,q}}, \qquad p = {1,2,\dots, n_{in}}, \qquad q={1,2, \dots, n_{out}}$$

총 $n_{in}*n_{out}$ 개의 activation function $\phi_{p,q}$ 가 하나의 KAN layer에 존재하는 것인데, matrix form으로 펼쳐보면 아래와 같다 (맨 왼쪽의 $l$은 몇 번째 layer인지를 뜻하는 것이므로, 여기서는 무시해도 좋다)

여기서의 $\phi_{1,1}$ 들은 MLP에서의 weight에 대응된다. KAN의 개념을 간단하게 소개할 때 쓰이는 "weight 대신 activation function을 배우겠다"라는 방식은 MLP에서의 $W={w_11, w_12, ..., w_21, w_22, ..., w_{n_{l}}, n_{l+1}}$ 대신 위 $\Phi$ 를 조정해 나가겠다는 말이다.

 

이제 KAN layer를 구성하는 각각의 activation function $\phi$ 가 어떻게 parameterized 되어 있는지 (즉, 무엇이 학습되는지)를 알아보자. $\phi$ 는 단순히 B-spline으로 대체하는 것이 아니라 MLP-based parameterization + B-spline과 같은 형태를 띤다:

$$\phi(x) = w(silu(x) + spline(x))$$

$$spline(x)=\sum_i c_i B_i(x)$$

위 식에서, trainable 한 것들은 $c_i$ 와 $w$ 뿐이다. 이전의 B-spline section에서, B-spline의 linear combiniation으로 어떤 spline 곡선이든 표현할 수 있다고 언급하였었다. 이것의 coefficient에 해당하는 것이 $c_i$이다. $w$는 식에서 보이는 대로 activation을 전체적으로 scaling 해주는 역할을 한다. 

 

초기의 initialization 단계에서, $c_i$ 들은 0으로 초기화되는데, 이렇게 되면 spline에 대한 term이 무시되어서 (activation first, matrix-multiplication later)를 수행하는 MLP와 비슷하게 동작한다.

 

논문에서 제안하는 KAN 구조란 결국, 이러한 KAN-layer $\Phi$ 를 계속해서 쌓은 구조이다.

$$KAN(x) = \Phi_n \circ \Phi_{n-1} \circ \cdots \circ \Phi_1 (\mathbf{x})$$

 

Theoretical guarantee and robustness to curse-of-dimensionality

저자는 이렇게 제시된 KAN과 KAN이 맞추고자 하는 함수와의 오차 (Error)가 KAN의 Spline function의 number of grid와 degree가 클수록 줄어든다고 이론적으로 보여준다. 논문에서의 full-statement는 아래와 같은데, 결국 위 문장을 formal 하게 풀어쓴 내용이다.

논문에서 하나 흥미롭게 보는 점은, KAN의 Error가 input의 dimension과는 independent 하다는 점이다. 저자는 느낌표를 붙여가면서 까지 이것을 "curse-of-dimensionality"로부터 KAN이 자유롭다고 한다. 이와 관련되어서 아래와 같은 썰들을 논문에서 푼다. 

 

1. 위 theorem에서 error가 input dimension에 대해 independent 한 것에 반해, A neural scaling raw from the dimension of the data manifold 논문에서, piecewise linear function (ReLU-network가 piecewise linear function이다)의 function approximation이 dimension에 dependent 한 것을 확인하였다.

 

2. "사실 MLP의 universal-approximation theorem (UAT) 도 input dimension에 independent 한 것 같은데 어떤 이유에서 저것을 강조하는지는 잘 모르겠다."라고 생각하고 있던 찰나에, 특정 문제 세팅에서 , MLP가 polynomial을 근사한다고 할 때에, neural의 숫자에 대해 aprroximation error가 dependent 하다는 논문을 cite 한 것을 확인했다. 다만... 그렇게 와닿지는 않는다.

 

On the interpretability of the KAN

KAN의 주요한 motvation 중 하나는, 1D activation function을 학습함으로써 가지는 interpretability이다. 논문에서는 이러한 강점을 더 강화시킬 수 있는 sparsification을 제안한다. 이름에서 보이듯, 그리고 다른 deep learning interpretability가 종종 그러하듯, L1-regularization 등을 이용해서 trainable-parameter를 sparse 한 형태로 만들어주는 것을 제안한다. (여기에 왜 L1 reg---or, lasso---가 sparse parameter를 만드는지에 대한 훌륭한 시각자료가 있다) 다만, 이러한 L1 regularization 만으로는 sparsity를 가지기 충분하지 않다고 주장한다. 그래서 KAN-version of entropy 같은 것을 정의하는데, KAN layer $\Phi$ 의 각 element $\phi_{i,j}$가 "probability" 같은 역할을 하는 형태이다. (recap: $\Phi$는 $\phi_{i,j}$ 들로 구성된 input_dim x output_dim의 matrix였으며, 각 $\phi$는 모두 input $x$를 받는 함수이다)

 

먼저 KAN layer의 각 activation function $\phi$의 L1 loss를 아래와 같이 정의한다.

$$| \phi | \equiv \frac{1}{N} \sum^{N}_{s=1} |\phi(x^{(s)})|$$

여기서 $N$은 batch-size라 생각해도 좋다. 이제 KAN layer $\Phi$의 L1 norm은 아래와 같이 sum-of-all-L1-norm으로 정의된다:

$$|\Phi| \equiv \sum^{n_in}_{i=1}\sum^{n_out}_{i=1} |\phi_{i,j}|$$

앞서 말했듯, 이로는 Sparse 하게 만드는데 부족하여, $\Phi$의 entropy를 아래와 같이 정의한다. (sum of values가 1이니, 확률에서의 entropy라 생각해도 무관하고, 이를 최소화시키면 각 element가 sparse해지게 된다)

$$S(\Phi) \equiv - \sum^{n_in}_{i=1}\sum^{n_out}_{i=1} \frac{|\phi_{i,j}|}{|\Phi|} \log (\frac{|\phi_{i,j}|}{|\Phi|})$$

 

위에서 정의된 loss들을 regularization으로 사용하여, KAN을 "sparse" 한 형태로 만들 수 있다. 이후에는, 낮은 norm을 가지는 값들을 pruning 하고, 각 activation 들을 symbolification (cos, sin, exp, log 등의 함수에 대응시키는 것) 하여 symbolic regression을 할 수 있다.

On the evaluation results of the KAN

KAN의 논문에 걸쳐서 제시되는 실험적 결과들은 대부분 PDE-solving, physic equation approximation 등의 scientific equation 근사로, modern deep learning의 시각에서는 단순한 문제들이다. (아래처럼, toy-dataset에 대한 결과가 많다.)

하지만,,, 저렇게 생긴 synthetic 1D regression dataset에 MLP를 많이 해본 사람은 느끼겠지만 저런 형태의 함수를 fitting 시키기가 몹시 어렵다. sin-curve에 대해 fitting 시키고, 조금만 옆으로 넘어가면 extrapolation이 전혀 되지 않는 것을 관찰할 수 있다. KAN에서도 이것이 완벽하게 해소된 것으로 보이지는 않지만. (MLP가 piecewise linear 하다는 한계 때문일까) 

 

Toy-dataset에서 흥미롭게 본 부분은 continual-learning인데, 너무 단순한 setting에서의 예시를 보였지만, 이게 정말 realistic setting에서의 catastrophic forgetting (TL; DR: 새 데이터를 학습할 때 이전 데이터를 까먹는 현상)을 해소시켜 준다면 굉장할 것으로 생각된다.

 

이외에, 많은 Deeplearning4Science에 대한 예시가 나오는데, background가 충분치 않아서 얼마나 유의미한 결과인지는 잘 이해하지 못하겠다. 다만, 위에서 언급했든 자연과학, 공학에서의 equation을 모방하는 데에 MLP가 적절치 못하다는 것은 확실하므로, KAN이 해당 부분에 contribution이 크다는 것을 말하고자 함은 알 수 있었다.

 

여담:

이미 realistic 한 setting에서의 KAN이 활용되고 있는 것으로 들었다만 (e.g. simply replace the MLP layers with KAN for LLMs) 그 gain 자체는 negligible 한 것으로 보인다.

공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/09   »
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
29 30
글 보관함