Transformer Attention Variants Survey

Transformer의 self-attention은 O(n²) 시간/공간 복잡도를 가지며, 긴 시퀀스 처리에 병목이 된다. 이를 해결하기 위한 다양한 attention 변형들을 정리한다.


1. Sparse Attention 계열

Generating Long Sequences with Sparse Transformers

핵심 내용

  • Attention matrix를 sparse factorization하여 O(n²) → O(n√n) 복잡도로 감소
  • Strided pattern과 fixed pattern을 조합한 sparse attention 패턴 제안
  • 수만 timestep 길이의 시퀀스를 수백 layer로 모델링 가능

방법론

  • Strided Attention: 고정 간격으로 attention 수행
  • Fixed Attention: 특정 위치에만 attention 수행
  • Recomputation으로 메모리 절약, fast attention kernel로 학습 가속

특징점

  • 최초의 sparse attention 제안, 이후 연구의 기반
  • 이미지, 오디오, 텍스트 모두에 적용 가능
  • Enwik8, CIFAR-10, ImageNet-64에서 SOTA 달성

Longformer: The Long-Document Transformer

핵심 내용

  • O(n²) → O(n) 선형 복잡도의 attention 메커니즘
  • Local windowed attention + task-specific global attention 조합
  • 수천 토큰 이상의 긴 문서 처리 가능

방법론

  • Sliding Window Attention: 각 토큰이 양쪽 w개 토큰에만 attention
  • Dilated Sliding Window: 더 넓은 receptive field를 위한 dilated 버전
  • Global Attention: [CLS] 등 특정 토큰은 전체 시퀀스에 attention

특징점

  • Standard self-attention의 drop-in replacement
  • RoBERTa 대비 long document task에서 일관된 성능 향상
  • WikiHop, TriviaQA에서 SOTA

Big Bird: Transformers for Longer Sequences

핵심 내용

  • O(n²) → O(n) 선형 복잡도의 sparse attention
  • Universal approximator이며 Turing complete함을 이론적으로 증명
  • 기존 대비 8배 긴 시퀀스 처리 가능

방법론

  • Random Attention: 무작위로 선택된 토큰들에 attention
  • Window Attention: 인접 토큰들에 대한 local attention
  • Global Attention: O(1)개의 global token이 전체에 attention

특징점

  • Sparse attention의 이론적 기반 제시
  • QA, 요약 등 NLP task에서 큰 성능 향상
  • Genomics 데이터에 새로운 적용 제안

2. Linear Attention 계열

Linformer: Self-Attention with Linear Complexity

핵심 내용

  • Self-attention matrix가 low-rank로 근사 가능함을 발견
  • O(n²) → O(n) 시간/공간 복잡도 달성
  • Standard Transformer와 동등한 성능 유지

방법론

  • Key와 Value를 저차원으로 projection (n → k)
  • Attention: Q(E·K)ᵀ(F·V) 형태로 계산
  • Projection matrix E, F를 학습 또는 고정

특징점

  • Low-rank 근사라는 간단한 아이디어로 효율성 달성
  • 긴 시퀀스에서 메모리/시간 효율적
  • 구현이 상대적으로 단순

Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention

핵심 내용

  • Self-attention을 kernel feature map의 linear dot-product로 표현
  • Matrix product의 associativity를 활용해 O(n²) → O(n) 달성
  • Transformer와 RNN의 관계 규명

방법론

  • Softmax attention을 φ(Q)φ(K)ᵀV로 근사
  • Kernel trick: φ(q)ᵀφ(k) ≈ softmax(qᵀk)
  • Iterative하게 계산 가능 → RNN처럼 동작

특징점

  • 이론적으로 Transformer와 RNN의 연결고리 제시
  • Autoregressive 생성에서 4000배 속도 향상
  • Vanilla Transformer와 유사한 성능

Rethinking Attention with Performers

핵심 내용

  • Softmax attention을 provable accuracy로 선형 근사
  • FAVOR+ (Fast Attention Via positive Orthogonal Random features) 제안
  • Sparsity나 low-rankness 가정 없이 선형 복잡도 달성

방법론

  • Softmax kernel을 random feature로 근사
  • Positive orthogonal random features로 분산 감소
  • Unbiased estimator와 uniform convergence 보장

특징점

  • 이론적 보장이 강함 (unbiased, low variance)
  • Regular Transformer와 완전히 호환
  • Protein sequence modeling 등 대규모 task에 적용

3. Multi-Query / Grouped-Query Attention

Fast Transformer Decoding: One Write-Head is All You Need (MQA)

핵심 내용

  • Key와 Value를 모든 attention head가 공유
  • KV cache 크기 대폭 감소 → 메모리 bandwidth 절약
  • Incremental decoding 속도 크게 향상

방법론

  • 기존 MHA: 각 head마다 별도의 K, V
  • MQA: 단일 K, V를 모든 head가 공유
  • Query만 head별로 다르게 유지

특징점

  • 추론 속도에 초점 (학습이 아닌)
  • 품질 저하가 minor함
  • 이후 GQA, MLA 등의 기반이 됨

GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

핵심 내용

  • MQA(1개 KV head)와 MHA(n개 KV head)의 중간 형태
  • 기존 MHA 모델을 5% pre-training compute로 GQA로 변환 가능
  • MHA 수준 품질 + MQA 수준 속도

방법론

  • Grouped-Query Attention: g개의 KV head 사용 (1 < g < h)
  • Query head들을 그룹으로 나누어 각 그룹이 KV head 공유
  • Uptraining: MHA checkpoint에서 GQA로 fine-tuning

특징점

  • LLaMA 2, Mistral 등 최신 LLM에서 표준으로 채택
  • 품질-속도 trade-off의 최적점
  • 기존 모델 재활용 가능 (uptraining)

4. IO-Aware Attention (Flash Attention)

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

핵심 내용

  • Approximate가 아닌 exact attention이면서 빠름
  • GPU memory hierarchy (HBM ↔ SRAM) IO를 최적화
  • 메모리 O(n) + 속도 2-4배 향상

방법론

  • Tiling: Attention을 block 단위로 나누어 SRAM에서 계산
  • Softmax를 online으로 계산 (전체 attention matrix 저장 불필요)
  • Recomputation으로 backward pass 메모리 절약

특징점

  • Approximation 없이 정확한 attention 계산
  • BERT 15%, GPT-2 3배, Long-range arena 2.4배 속도 향상
  • Path-X (16K), Path-256 (64K) 등 초장문 처리 가능
  • 현대 LLM 학습의 de facto standard

FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

핵심 내용

  • FlashAttention 대비 약 2배 추가 속도 향상
  • A100에서 이론적 최대 FLOPs의 50-73% 달성
  • GEMM 연산 효율에 근접

방법론

  • Non-matmul FLOPs 감소
  • Single head에서도 thread block 간 병렬화
  • Warp 간 shared memory 통신 최소화

특징점

  • FlashAttention의 최적화된 후속 버전
  • GPT 학습에서 A100당 225 TFLOPs/s 달성 (72% utilization)
  • 현재 가장 널리 사용되는 attention 구현

Summary Table

방법연도복잡도핵심 아이디어특징
Sparse Transformer2019O(n√n)Sparse factorization최초 sparse attention
Longformer2020O(n)Local + Global attentionDrop-in replacement
BigBird2020O(n)Random + Window + Global이론적 보장
Linformer2020O(n)Low-rank projection간단한 구현
Linear Attention2020O(n)Kernel feature mapRNN과의 연결
Performer2020O(n)FAVOR+ random features강한 이론적 보장
MQA2019O(n²)단일 KV head추론 속도 최적화
GQA2023O(n²)그룹 KV head품질-속도 균형
FlashAttention2022O(n²)*IO-aware tilingExact + Fast
FlashAttention-22023O(n²)*병렬화 최적화현재 표준

*FlashAttention은 계산 복잡도는 O(n²)이나, 메모리는 O(n)이며 wall-clock 시간이 크게 단축됨.


분류 기준

  1. Sparse Attention: 일부 토큰에만 attention (Sparse Transformer, Longformer, BigBird)
  2. Linear Attention: 수학적 근사로 선형 복잡도 (Linformer, Linear Attention, Performer)
  3. KV Head 공유: 추론 효율화 (MQA, GQA)
  4. IO 최적화: 하드웨어 수준 최적화 (FlashAttention)