티스토리 뷰

ICML2024; Oral

Motivation: the limitation of the MHA (Multi-Head-Attention)

  • MHA는 Dimension 이 head의 수로 나누어지는 특성상 low-rank bottleneck이 생길 수 있다.
  • MHA에서는 head-redundancy가 생길 수 있다. (여러 head들끼리 비슷한 동작을 할 수 있다)
  • 결론적으로, 이 문제들을 해소할 vanilla MHA의 훌륭한 대체제를 찾고 싶다.

본 논문에서는, 위와 같은 multi-head-attention의 문제를 "attention-score composition"을 통해 풀고자 한다. 논문에 걸쳐 어떻게 attention-score composition이 위 문제에 도움이 될 수 있는지를 설명하고, 어떤 방식으로 attention-score를 composite (or, mix) 할 것인지에 대해 제안한다. 


Attention matrix composition

Notations and definition of MHA and “composition”

논문의 전반적으로 걸쳐 attention head 혹은 attention score의 composition에 대해서 언급한다. 여기서의 composition은, 각 attention-head 별로 계산된 attention score $A_h$ 들의 linear combination을 의미한다: e.g. $A=c_1 A_1 + c_2A2 + \cdots + c_H A_H$.

 
  • 본문에서는 attention score를 QK 이후 (before-softmax)로 칭하고, attention weight matrix를 after-softmax로 칭한다. 이때, attention matrix A는 둘 다를 칭할 수 있다.

The functionality of matrix composition

이 장에서는 attention matrix를 composition 하는 것이 어떤 시나리오에서 도움이 될 수 있는지를 이야기한다.

 

아래의 Figure 1은 각각 composition matrix C를 나타낸다. 이 예시에서는 number of heads는 8이다 ($H=8$). 아래 그림에서, 각 matrix의 element가 밝으면 밝을수록 값이 높다. 특정 i 번째 row, j 번째 column에 해당하는 element가 값이 높다면, j 번째 attention head를 구성 (composite) 할 때, i 번째의 attention head의 영향을 많이 받음을 의미한다. 아래 Figure의 B를 예로 들자면, $C$를 통한 Attention composition을 할 때, 6번째 head는 4, 7 번째 attention head의 값에 많은 영향을 준다.

아래에서는 위의 Figure 1의 예시와 함께, composition matrix의 구성에 따라, 이러한 attention score composition 이 어떤 상황에서 긍정적인 도움이 될 수 있는지를 high-level에서 기술한다.

(시나리오 1) Figure 1-(a) “mutual excitation and inhibition“

  • (mutual excitation) 3번째, 8번째 두 attention head의 attention matrix가 서로에게 높은 weight를 주고 있다. 만약 두 head가 서로 positively-correlated 되어 있다면, 이런 composition matrix C는 도움이 될 것이다.
  • (mutual inhibition) 반면에, 2번째와 5번째 attention head는 서로에게 낮은 weight (negative value)를 주고 있다. 두 head가 negatively-correlated 되어 있다면, 이는 도움이 된다.

(시나리오 2) Figure 1-(b) “one-to-many“

  • 6번째 head는 4번째, 7번째 head에 많은 영향을 준다.

(시나리오 3) Figure 1-(c) “many-to-one“

  • Head 3과 7 Head 1에 많은 영향 (shares their -attention- weight)을 준다. 
    • Head 1 OV circuit (V, O weight for head 1) 이 noun을 hypernym/superclass로 변환하는 기능이 뛰어나다고 하자 (사과 → 과일의 semantic-transformation을 잘 수행한다고 해보자).
    • 하지만, Head 1의  QK-circuit은 “사과“라는 적절한 noun에 attend 하는 기능이 부족하다고 해보자 (even though its OV circuit has a somewhat superpower)
    • 이때, Head 3과 7의 의 attention weight을 잘 만드는 기능 (QK-circuit; attention score) 을 빌려올 수 있다.
    • (+) 이러한 case가 실제로 발생하는지를 본 블로그 글의 Experiment -> Verification 3 에서 보인다.

(시나리오 4) Figure 1-(d) “self-excitiaton“

  • Head 3,6은 스스로의 attention score에 더 집중하고, Head 4는 오히려 스스로의 attention score와 반대로 업데이트한다. 이러한 self-excitation은 attention score가 beneficial 한지, determental 한지를 구별하는데 유용할 수 있다.

이와 같은 Attention composition은, 구현상으로는 TxSxH의 attention matrix에서 1x1 convolution (H to H) 연산을 하는 것과 유사하게 동작한다.

What is the point of attention matrix composition?

앞서 언급한 attention-matrix composition이 긍정적으로 기능할 수 있는 시나리오들에 더해, 또 다른 장점을 논한다. 여기서 열쇠가 되는 것은 두 가지 직관적으로 받아들일 수 있는 사실이다.

  1. Attention head dimension을 키우는 것으로, attention head의 "low-rank bottleneck (low-rank로 인한 information loss)" 을 해소할 수 있다. (Bhojanapalli 2020; ICLR)
  2. attention matrix composition은 attention head-dimension을 H (number of head) 배 늘린 것과 유사한 효과를 가진다. 즉, low-rank bottleneck 이 완화된다.

논문에서는 조금 더 formal 하게 (사실 충분히 formal 한 지는 잘 모르겠다), 아래와 같이 기술한다:

Attention matrix composition은 각각의 attention head의 head dimension을 head의 숫자만큼 expansion 시킨 것과 동일한 효과를 가진다: Query weight을 weighted-concat (for last dimension; number_of_heads x head_dimension -to-> model_dimension) 시킨 것과 동치이다.

D_m: Model dimension, H: number of heads, D_h: dimension for each head

위에서의 Concat은 일반적인 torch에서의 concat이라 생각해도 좋다. (torch.cat(*, dim=-1)). 또한, 위의 QK circuit case와 비슷하게, OV에 대해서도 동일한 결과를 가진다 (본 글에서는 생략하도록 한다).

 

이처럼, Attention-matrix composition은 attention head들의 dimension을 expansion 시키는 효과가 있고, 심지어 composition matrix $C$를 trainable 하게 적절히 paramterize 하게 디자인할 수 있기 때문에 전체적으로 expressive power가 향상된다고 논문에서는 주장한다.

Dynamically Composable Multi-Head attention

위에서는 Attention의 Composition이 무엇인지, 어떤 시나리오에서 이것들이 도움이 될 수 있는지, Attention-score의 Composition이 Attention head의 dimension을 expand 하는 것과 동일하여 low-rank 문제를 해결한다는 내용을 다루었다. 이 장에서는 좀 더 실제적으로 넘어가서, (1) 기존의 Multi-head attention에서 이러한 composition 연산이 어디에 추가되는지, (2) 해당 composition 연산은 어떻게 구성되며 parameterized 되는지를 다룬다.

Vanilla MHA to DCMHA (Dynamically-Composable MHA)

DCMHA의 layer 구성은 간단하다. Multi-head attention에서 attention score의 softmax 전후에 compose를 추가해 주는 형태이다. 이러한 Compose function은 Head를 나누기 이전의 Query-Key 값들로부터 input을 받아내어, Compose layer 내부의 trainable parameter를 이용하여 다른 head 간의 attention matrix를 mixing 한다.

아래에서는 이러한 compose layer가 어떤 식으로 구성되고 parameterized 되어 있는지를 다룬다.

Compose layer

Compose layer는 위 그림과 같이, 어느 정도 복잡한 구성을 가지지만 이 모든 연산들의 목적은 위에서 언급한 것과 크게 다르지 않다:

attention score value를 여러 head에 걸쳐 mixing 하는 weight ($C$)을 찾는다.

 

그림에서는 5가지의 branch가 있으나, 크게는 3가지로 나누어 볼 수 있다.

Query-key wise dynamic projection

- Query 혹은 key 값을 input으로 받아, 아래와 같은 연산을 통해서 새로운 attention weight을 만든다. 이때, 학습 가능한 parameter는 W_1, W_2이다 (위 그림의 회색 w_q1 w_q2). 위 그림에서 분기 2 4 (B2, B4)가 이에 해당하며, Query Key에 각각 하나씩 수행해 준다.

"""
D: model dimension
T: sequence length for query
S: sequence length for key
H: number of heads
R: head-compression dimension (smaller than H)
"""

def get_dynamic_projection_weight(Q, W_1, W_2):
	"""
    W_1 (nn.Parameters): D x I (where, I=2HR)
    W_2 (nn.Parameters): I x I
    """
    w_q = F.gelu(Q @ W_1) @ W_2  # T x I
    w_q1, w_q2 = torch.chunk(w_q, dim=1) # T x HR
    w_q1 = RMSNorm(w_q1.reshape(T, H, R)) # T x H x R
    w_q2 = w_q1.reshape(T, R, H) # T x R x H
    return w_q1, w_q2

Q = query(emb) # T x D
w_q1, w_q2 = get_dynamic_projection_weight(Q, W_1, W_2)
composition_weight = w_q1 @ w_q2 # H x H

 

(+) 위 코드에서, Composition weight (dim: H)을 만드는 과정 중 H 보다 더 낮은 dimension R으로 projection 시킨다. 이 부분에 대해서 논문에서 언급하는 rationale는 이러하다: 더 낮은 dimension으로 projection 시킴으로써, 가능한 모든 composition weight를 전부 고려하지 않고 소수의 특정 패턴을 찾아내는 것으로 충분하다.

Query-Key wise dynamic gating

위에서 언급한 dynamic projection에 비해, 여기서 다루는 gating은 비교적 단순하다. Query와 Key 값에서 단순히 trainable weight 하나를 matrix-multiplication 해 주는 형식이다. 위 그림에서 분기 3, 5가 이에 해당한다.

"""
D: model dimension
T: sequence length for query
S: sequence length for key
H: number of heads
"""
# W_qg (nn.Parameter): D x H

composition_weight = F.tanh(Q @ W_qg) # T x H

 

Base projection

위 그림의 분기 1 (B1)에 해당하는 단순한 projection layer이다. HxH의 dimension을 가지는 단 하나의 trainable weight W_b로 구성된다. 즉 여기서는 composition weight이 단순히 W_b로 표현된다.

Put in all together

이런 모든 composition weight를 만들어내고 나서, 새로운 attention matrix는 아래와 같이 만들어낸 모든 composition weight 분기들의 합으로써 구할 수 있다. 아래의 새로운 new_A가 새로운 attention score (before-softmax)와 attention weight (after-softmax)가 되어, 기존의 MHA의 계산 방식과 같이 계산된다.

"""
D: model dimension
T: sequence length for query
S: sequence length for key
H: number of heads
"""
# A (torch.Tensor): attention matrix, TxSxH

w_q1, w_q2 = get_query_dynamic_projection(Q, W_q1, W_q2) # H x H
w_k1, w_k2 = get_key_dynamic_projection(K, W_k1, W_k2) # H x H

w_qg = get_query_gating(Q, W_qg) # T x H
w_kg = get_query_gating(Q, W_kg) # S x H

new_A = A @ W_b +\             # B1, base projection
	A @ (w_q1 @ w_q2) +\       # B2, Query-wise dynamic projection
    A @ (w_k1 @ w_k2) +\       # B4, Key-wise dynamic projection
    A.transpose(0,1) * w_qg +\ # B3, Query-wise dynamic gating
    A * w_kg                   # B5, Key-wise dynamic gating

 

Behind the design choice: tensor decomposition perspective

저자는 왜 위와 같이 composition을 디자인하였는가? 이에 대한 답은 효율적인 연산을 위해서이다. 어떤 multi-head attention matrix A (HxTxS)에 대해서 composition-weight (H)를 얻으려면, 특정 4D transformation tensor W (TxSxHxH)를 생각해 볼 수 있다 (A times W will be composition-weight with dimension H). 이러한 4D-tensor W를 decompose 하여 효율적으로 계산하는 과정이 지금까지 언급했던 Composition layer의 구성이다. 아래 그림의 low-rank 가 dynamic projection, diagonal 이 gating, W_b가 Base projection에 해당한다.

Experiments

이 장에서는, 논문에서 제시되었던 실험들을 "Verification N"으로 정리해서, DCMHA 방법론에 대해 가지는 물음들에 어떤 답과 실험을 하였는지를 정리해 보겠다.

Verification 1: DCMHA로 구성된 LLM 모델들은 Scaling-Law를 가지고 있는가?

그렇다. 기존 Transformer (이하 그림의 TFM) 구조의 scaling-law를 고스란히 따라가면서도, 더 낮은 loss를 유지한다.

Verification 2: Large-scale 모델에서 improvement를 보이는가?

그렇다. 2.8B, 6.9B, 12B의 scale을 가지는 Pythia 모델들에 Dynamic Compose를 추가했을 때, 같은 조건하에서 높은 성능 향상을 보인다

여러 downstream task에서는, 6.9B 모델이 12B 모델을 능가하기도 하였다 (아래 붉은 박스들)

 

Verification 3: composition은 KV circuit을 다른 head로부터 잘 빌려오는가?

그렇다고 볼 수 있다. 본 글의 The functionality of the matrix composition의 시나리오들 중에서, MHA는 특정 head가 Query-Key attend를 제대로 하지 못하여  훌륭한 OV circuit을 가지고 있음에도 제대로 동작을 하지 못할 수 있다는 것을 언급했었다. 이를 알아보기 위해, 아래처럼 QK-circuit (getting attention score; QK-pattern below) 이 특별히 중요한 synthetic dataset을 준비하여, 성능을 비교해 보았다.

 

이런 극단적인 패턴의 데이터셋에 evaluation을 할 경우, Dynamic composition의 유무에 따른 성능차는 더욱 도드라졌다.

Verificiation 4: head는 실제로 diverse 해 졌는가?

그렇다. 원래의 의도대로 MHA의 한계 중 하나인 Query, Key의 값들이 head 간에 비슷해지는 현상을 완화하였다. Mean-cumulative captured variance (the lower the diverse)를 통해 측정할 경우, 아래처럼 Dynamic Composition이 더 diverse 한 QK-circuit을 가짐을 알 수 있다.

(필자의 생각) Attention-head들을 "Mixing" 해준다고만 추상적으로 생각하면, head들이 더 Diverse 해지는 것은 와닿지 않는다. Graph-transformer에서 node들 간의 연산 (mixing)을 계속하다 보면 Over-smoothing 이슈가 생겨나 모든 노드들이 비슷한 값을 가지는 것을 생각해 보면, 어떤 이유에서 diverse 해졌는지에 의문을 가지게 된다. Composition weight들은 어째서 sparse 해 졌는가? composition weight들이 layer가 깊어지면 smoothing 될 가능성은 없는가?

-> 논문에서 사실 Diversity가 강조되는 부분은 아니었다. 좀 더 일반적인 범위에서 MHA improvement를 노렸다고 봐야 할듯. 본 실험은 그냥 부수적인 효과로 봐야할 듯하다. 혹은, MHA 상에서의 Performance와 Diversity는 필요충분조건인 걸까?

Verificiation 5: MHA와 DCMHA의 학습 과정은 근본적으로 다른가?

그럴 수 있다. 기존의 MHA에서 DC를 추가하여 continual pretraining을 수행할 경우, 큰 gain을 얻지 못한다. 이는 학습의 초기 단계부터 DCMHA가 MHA와는 다르게 학습된다는 것을 암시한다.

Verification 6: Composition의 추가가 학습과 추론 시간에 얼마나 영향을 미치는가?

최소 5.2%, 최대 25.5%의 throughput 감소가 있다. 다만, 모델이 커질수록 영향은 적어진다. 또한, 특정한 accleration이 전혀 추가되지 않은 버전에서의 drop이므로, 개선될 여지가 크다.

Verification 7: 다른 modality의 transformer에서도 유효한 방법론인가?

그렇다. 약 40% 정도로 작은 parameter 수로도 ViT를 통한 ImageNet classification에서 거의 동등하거나 높은 성능을 보여준다. 특히나, 더 적은 epoch로도 높은 성능에 빨리 도달한다.

공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/09   »
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
29 30
글 보관함