티스토리 뷰

  • 이 포스트는 Academic-reels와 특집 어딘가에 있는... shorts는 아니지만 그렇다고 10분짜리 비디오도 아닌 그런 구성입니다. 정리본 같은 거랄까요
  • 되게 Scientific 하게 잘 쓰인 논문입니다. 점수가 매우 높아요. 흥미로운 가설을 설정했고, 가설을 support 하는 좋은 관측들을 했고, 그에 따른 simple-but-effective 한 방법을 제시합니다. 글도 매우 잘 써진 것 같고요.
  • 별 5.

FAIR, Meta, ICLR 2024 Oral

Objective and motivation

  • (Our objective) Vision transformer의 마지막 attention layer를 visualize 해보면, 위와 같은 “abnormal patch“ 가 보인다 (semantically not important, but strongly attended by other patches). 이런 것들을 오른쪽처럼, “semantically meaningful“ 한 패치에만 attention score가 높도록 만들고 싶다
  • (Doubtful point) 근데, 애초에 왜 abnormal-patch들이 생겨나는 걸까? 이것들의 정체가 뭘까?

Observation: the problem of the abnormal patches

  • (Observation 1) self-supervise-trained 모델의 attention map을 활용해서 object detection, object discovery, segmentation 같은 문제를 풀려고 하면 (LOST method), 위 모델 중 “DINO“ 가 잘한다.
  • (Observation 2) 반면에, DINO 이후에 개발된 DINOv2는 다른 task에서 DINO보다 잘 함에도 불구하고, 위처럼 attention map이 좋지 않아서 attention-map-driven 방법론 (LOST 같은)을 사용할 수 없다.
    • 관측해 보니, 이러한 문제는 다른 많은 ViT 들에서도 나타나는 일반적인 현상이다

왜 이런 일이 일어날까? 그리고 이런 현상을 해소하는 방법은 무엇일까? 해소하게 되면 모델이 더 잘 동작하게 될까? 아래 Obesrvation 3, 4에서는 먼저 왜 이런 일이 일어날까? 에 대해서 설명한다.

  • (Observation 3) 위 attention-map의 heatmap에서 "abnormal patch"들은 다른 pixel에 비해 10배 이상의 높은 norm을 가진다. 그리고 이러한 현상은 ViT의 중간 layer에서, 그리고 충분히 큰 ViT에서 충분히 오랜시간 학습했을때 일어난다.
  • (Observation 4: Strong clue what the abnormal patch is) 해당 patch들에 linear layer를 붙이면, classification task를 아주 잘 수행한다. (다른 patch들에 붙인것에 비해서)
    • → Interpretation: semantically not meaningful token에 대해서, ViT는 해당 image의 “global information“을 해당 토큰에 저장한다. 따라서 attention score도 높으며, 해당 토큰 만을 이용한 downstream task도 잘 하게 된다.

Take a closer look at the problem

"Artifacts" in the local features of DINOv2

Definition of the -Artifacts-: they are high-norm outlier tokens

  • 위 figure-3에서의 결과처럼, high norm and outlier patches이다. outlier 하다는 것은, 특별히 이미지의 sementics에 별 영향을 주지 않는 어떤 patch라는 소리이다 (위 그림의 white-background처럼)

아래에서는, 이러한 outlier-patch 혹은 artifact들이 실제로 global-information을 견인하는가? 에 대한 가설을 위해 여러 evidence를 보여준다.

(Evidence 1) Outliers appear during the training of large models

  • 1/3 이상의 학습이 되어야 outlier 가 등장한다
  • Outlier는 중간 이후의 layer에 등장한다
  • Outlier는 모델 사이즈가 어느 정도 커야 증가한다

작은 모델, 학습이 덜 된 모델은 아예 global-information을 배우지 못한 것이라 해석해 볼 수 있다. 반면에, "정보를 더 많이 가졌다"라고 할 수 있는 큰 모델, 학습이 더 된 모델, 더 많은 정보를 가지고 있는 상위 layer에서 이런 norm이 커지는 token들이 많아지는 것은 global-information과 outlier-patch 간의 연관성을 생각할 수 있게 해 준다.

(Evidence 2) High norm tokens appear where patch information is redundant & High norm tokens hold little local information

  • Outlier patch는 다 또이 또 이하고 의미 없게 생겼다.. background 라던지 그런 것들.
  • 따라서, 걔네들끼리의 cosine similarity도 높다.
  • 또한, 이런 outlier patch들은 local information 대신 global information을 가지고 있어서, local-information이 필요한 어떤 위치에 있는지 맞추는 task와, 해당 픽셀이 무엇인지 reconstruction 하는 task 둘 다 맞추기 어렵다 (high errors and poor accuracy in the above figure-5 (b)).

(Evidence 3) Artifacts can resolve classification problem, maybe due to their global information-carrying

  • Outlier patch 만으로 classification task 등을 풀면, normal patch 만으로 푸는 것보다 잘 풀린다.

Hypothesis and remediation

아래와 같은 가설을 논문에 걸쳐서 세우고 있고, 실제로 이를 위 증거들로 어느 정도 shed-light를 해주었다.

(Hypothesis) 충분히 크고, 오래 학습된 ViT 모델은 어떤 불필요한 토큰을 global information을 저장하고, 처리하고, 탐색하는 데 사용된다

 

그렇다면, 이를 실제로 어떻게 활용해보아야 할까? 원래의 목적이었던, attention-map을 적절하게 유지하면서도, 이러한 global-information을 활용하는 방법은 없을까? 논문에서는 이에 대한 간단한 해답으로 register token을 제시한다.

Remediation: the register tokens

  • 정말 불필요한 token인 register token을 정의하고, 이러한 global information을 대신해서 받아주길 기대한다. (CLS token에 더해)
    • 맨 마지막 output에는 해당 register token을 뺀다
  • memory transformer와 비슷한 구조.. 라는데 memory transformer 잘 모름. translation에 사용되었다고 한다.
    • (Opinion 1) seq2 seq 생성에 쓰였으면… Diffusion도 이런 게 필요하지 않을까 하는 망상 (seq 2 seq이긴 하니까…?)
    • (Opinion 2) T2I diffusion에서, text embedding을 register token에 갖다 박아버리면 global-information에 text-conditioning을 더 잘할 수 있지 않을까?… -> 이미 CLIP의 pooled-embedding을 사용하면 비슷한 기능을 하지만 말이다.

Experiments

아래에서는, register token을 사용했을 때 실제로 기대했던 효과가 생기는지 (outlier가 없어지는지, attention map이 정상화가 되는지 등)에 대해서 입증한다. 

Verification: 진짜 outlier가 없어짐

(maybe) side-effect: down-stream task를 좀 더 잘하게 됨

  • 논문에서 주장하는 게 meaningful attention map without outlier patches leading to the better performance는 아니었기에… 일종의 side-effect로 해석됨 (의외로 이게 major-contribution이 아니었다. 리뷰어가 닦달했나?)
  • 큰 모델 (DINOv2)가 좀 더 향상이 있는 걸로 보임.

Effect of the number of register tokens

  • Global 보다 local이 더 중요할 것 같은 depth-estimation, segmentation이 필요한 reg가 더 적고… global information이 더 중요할것 같은 ImageNet이 필요한 token이 더 많은 게 이상하긴 하다.

Object discovery via attention map is now working.

  • 논문에서 제일 처음 이야기했던, attention-map을 활용한 downstream task의 degradation을 해소하는 파트이다.
  • 특히나 이러한 outlier가 심하게 발생하는 큰 모델인 DINOv2에서 꽤 큰 improvement가 일어난다.

attention map of the register token

흥미롭게도, 개별 register token과 다른 pixel들 간에 걸리는 attention map을 살펴보면 각 register token이 담당하는 어떤 "object"가 있는 것처럼 보인다. reg0 은 전체적인 edge, reg6은 캐러멜의, reg8은 스푼의, reg12는 커피의 texture. 이들이 interpretability 에도 도움이 될 것이라 기대해 볼 수 있다.

Discussion point: How about the registers for DiT?

여기서는 Diffusion-transformer에 이러한 register token을 적용하면 어떻게 될지에 대해서 짧게 이야기해 본다.

  • DINOv2랑 CLIP은 모델 구조는 DiT와 같지만, Objective는 전혀 다르다. 그래도 register token을 CLIP에 적용시켜도 될까?
    • CLIP은 text-image contrastive learning
    • DINOv2는 aligning randomly-cropped image; patch masking; and other regularization terms
    • 하지만, 적어도 위 실험에서는 abnormal-patch가 CLIP (OpenCLIP) 에도 존재하며, register token이 abnormal-patch를 없애주는 걸 관측할 수 있다.
      • 하지만, 그것이 downstream task에 미치는 영향은 어떨지 모른다. 즉, 개선된 attention-map이 과연 실제로 downstream task---feature embedding for image generation---에 도움이 될지는 명확하지 않은 것.
  • Self supervised learning regime에서는 이런 global-feature가 생기는데, diffusion에서는 아닐 수 있을까?
    • 이런 dummy token 콘셉트를 사용했던 memory transformer에서는 translation task에 register-token like 한 방법론을 적용한 적이 있다 (그것이 같은 문제 때문인지는 모르겠다). 그럼 아주아주 naive 하게 생각한다면, 같은 generative task인 diffusion에서도 적용해 볼 수 있지 않을까?
    • restoring masked patch도 Diffusion처럼 “input을 복원해 간다 “라는 개념에선 비슷하기도 하다.

Example of the super-easy implementation

실제 구현은, 아래와 같이 간단하게 trainable-parameter로써 register token을 추가해 주고, 기존 transformer block에 추가해 버리는 식으로 구현할 수 있다.

아래의 예시는 stable-diffusion 3에서 MMDiT model class (official code 아님, 자체 구현한 것)에서 이를 활용하는 예시이다.

  1. __init__()에서, self.register_tokens를 정의하고
  2. forward()에서, transformer의 input-sequence에 해당하는 emb_x에 해당 register-token을 sequence-dimension에 concat 시켜주고, 맨 마지막에는 이 부분을 제외시켜 준다.
class MMDiT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.patch_size = config.patch_size
        self.h, self.w = config.height, config.width
        self.register_token_num = config.register_token_num

        # Embedding for (1) text; (2) input image; (3) time 
        self.text_cond = TextConditionModule(config.text_emb_size, config.hidden_size)
        self.patching = LatentPatchModule(config.patch_size, config.hidden_size)
        self.time_emb = TextTimeEmbedding(config.time_embed_size, config.pooled_text_size, config.cond_size)
        
        self.mmdit_blocks = nn.ModuleList(
            [MMDiTBlock(config.hidden_size, config.time_embed_size, config.attn_embed_size, config.mlp_dim, config) for layer_idx in range(config.num_layers)]
        )
        
        self.final_linear = nn.Linear(config.hidden_size, config.out_channel)
        self.modulation = nn.Linear(config.cond_size, 2)
        self.pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(config.hidden_size, (self.h//self.patch_size, self.w//self.patch_size)))
        self.register_tokens = nn.Parameter(torch.zeros(1, self.register_token_num, config.hidden_size))

    def forward(self, latent, t, text_embs: List[torch.Tensor], pooled_text_embs):
        """
            latent (torch.Tensor)
            t (torch.Tensor)
            text_embs (List[torch.Tensor])
            pooled_text_embs (torch.Tensor)
        """
        emb_c = self.text_cond(*text_embs) # (N, L, D)
        emb_t = self.time_emb(pooled_text_embs, t) # (N, D)
        emb_x = self.patching(latent) + self.pos_embed # (N, T, D), where T = H*W / (patch_size ** 2)

        # additional "register" tokens, to convey the global information
        # see https://openreview.net/forum?id=2dnO3LLiJ1
        emb_x = torch.cat((self.register_tokens.expand(emb_x.shape[0], -1, -1), emb_x), dim=1)
        

        for block in self.mmdit_blocks:
            emb_x, emb_c = block(emb_x, emb_c, emb_t)

        # remove register token for the output layer
        emb_x = emb_x[:,self.register_token_num:] 

        scale, shift = self.modulation(emb_x, emb_t)
        emb_x = self.final_linear(scale*emb_x + shift) # (N, T, patch_size**2 * out_channels)
        return self.patching.unpatchify(emb_x) # (N, out_channels, H, W)
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/09   »
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
29 30
글 보관함