빠른 Transformer 디코딩 — Multi-Query Attention
Digest: Transformer의 autoregressive 디코딩(한 토큰씩 순차 생성)은 매 스텝마다 과거 토큰들의 K/V 캐시(이전 key/value 텐서 저장소)를 전부 다시 메모리에서 읽어와야 해서 연산량이 아니라 memory bandwidth(메모리 대역폭)가 병목이 된다. Shazeer는 이 병목의 범인이 “헤드별로 별도의 K/V를 유지하는” 설계라는 통찰에서 출발해, Q는 h개 헤드로 유지하되 K/V는 모든 헤드가 공유하는 단 1쌍만 두는 Multi-Query Attention(MQA) 을 제안한다. WMT14 En-De 번역에서 MHA 대비 BLEU 저하는 28.7 → 28.5 수준(Table 2, ~0.2 차이)으로 미미하지만, 디코더 한 토큰당 시간은 46μs → 3.8μs로 약 10배 이상 단축(Table 2)되어 “속도-품질 trade-off에서 MQA가 거의 공짜 점심”임을 보였다. 한계로 저자는 단일 태스크(번역)만 검증했다는 점과 매우 작은 헤드 수에서 과도 공유 시 품질이 떨어질 가능성을 언급하고, 남은 열린 질문으로 대규모 LLM, 긴 문맥, 다양한 modality에서의 범용성이 제기되어 2023년 GQA(Grouped-Query Attention)와 LLaMA-2 이후 거의 모든 주요 LLM의 사실상 표준 설계로 흡수되었다.
섹션별 요약
Introduction
- 배경: Vaswani et al. 2017의 Transformer는 학습 시 시퀀스 길이에 대해 병렬화되지만, inference 시 autoregressive 디코딩은 토큰 단위 순차 계산이 불가피하다.
- 문제: 각 디코딩 스텝에서 누적된 K/V 캐시를 GPU/TPU 메모리에서 SRAM으로 반복 로드해야 하며, arithmetic intensity(연산/메모리 접근 비율)가 낮아 memory bandwidth-bound가 된다.
- 기여:
- MHA의 memory bandwidth 분석을 통해 K/V 로드가 지배적 비용임을 식별.
- 모든 헤드가 단일 K/V 쌍을 공유하는 MQA 제안.
- WMT14 En-De에서 품질 손실 최소화 + 디코딩 속도 대폭 향상을 실증.
Methods
Multi-Head Attention (MHA) 요약:
- h개 헤드 각각이 독립된 projection 를 가진다.
- 각 헤드는 dim 의 Q/K/V를 생성.
- 디코딩 스텝마다 모든 h개 K, V 텐서를 메모리에서 읽음.
Multi-Query Attention (MQA):
- Q만 h개 헤드로 유지 (, i=1..h).
- K, V는 헤드 간 공유되는 단일 projection (scalar in head dimension).
- 각 헤드의 Q는 같은 K, V에 질의를 던진다 (이름이 “one write-head”인 이유).
- 학습 시 병렬 계산 구조는 거의 동일, 차이는 K/V 텐서가 헤드 축을 잃어 차원을 유지.
Results
| 모델 | BLEU (WMT14 En-De) | Dec time / token (μs) | Enc time / token (μs) | Training step (ms) |
|---|---|---|---|---|
| MHA baseline (h=8) | 28.7 | 46 | 1.7 | 433 |
| MQA (공유 K/V) | 28.5 | 3.8 (~12배↓) | 1.5 | 397 |
| MHA (h=1, 작은 모델) | 27.2 | 31 | 2.3 | — |
| Local + MQA 등 변형 | 28.3~28.5 | 3~5 | — | — |
값은 논문 Table 2 기반 (정확한 열 명칭과 일부 수치는 본문 참조). 핵심은 BLEU 0.2 손실로 디코더 토큰당 시간 ~12배 단축.
Discussion
- 저자 인정 한계: 단일 벤치마크(WMT14 En-De)만 검증. 인코더 디코딩은 이미 빠르므로 개선폭이 작고, 주요 이득은 디코더 단계에 집중.
- trade-off 분석: 헤드 수 h가 클수록 MQA의 상대적 효율 이득이 커지지만, K/V 표현력 축소로 인한 품질 저하 위험도 함께 증가.
- 향후 방향: 더 큰 모델, 긴 시퀀스, 다양한 태스크에서 MQA의 품질–속도 곡선 탐색 (→ 이후 GQA로 연결).
Insights
- 주목할 점: 병목은 FLOPs가 아니라 memory bandwidth. 설계 개선의 출발점을 “무엇을 로드하는가”에서 찾는다.
- 연결 고리: KV 캐시 관리(PagedAttention, vLLM), GQA - Training Generalized Multi-Query Transformer Models, FlashAttention의 IO-aware 설계 사상과 동일 계보.
- 시사점: K와 V는 헤드 간 중복이 커서 공유해도 큰 손실이 없다 — 이후 GQA가 “부분 공유(그룹)“로 일반화할 수 있었던 경험적 근거.
- 비판적 코멘트: 논문은 2019년 수준의 작은 Transformer-base 크기에서만 검증되어, 대규모 LLM에서 MQA가 종종 품질 저하가 체감되는 이슈는 후속 GQA에서야 체계적으로 다뤄졌다.
Discussion Points
- 논쟁점: 헤드별 K/V의 표현력 손실이 실제로 downstream 태스크(추론, 긴 문맥)에서 누적되는가?
- 검증 필요 가정: “모든 헤드가 같은 K/V를 공유해도 충분하다”는 가정이 대규모/고난도 태스크에서 성립하는가.
- 후속 연구: GQA (2305.13245), MLA (Multi-head Latent Attention, DeepSeek), MQA를 LoRA로 upcycle하는 방법 등.
왜 이 연구를 하는가?
핵심 질문
Autoregressive 디코딩의 memory bandwidth 병목을 아키텍처 수준에서 어떻게 근본적으로 줄일 수 있는가?
기존 접근법의 한계
| 한계 | 설명 |
|---|---|
| MHA의 KV 캐시 크기 | 시퀀스 길이 , 배치 , 헤드 , 헤드당 차원 에서 로 증가 |
| Arithmetic intensity 저하 | 디코딩 스텝당 연산량 대비 메모리 접근량이 커서 GPU FLOPs를 활용 못함 |
| 학습-추론 비대칭 | 학습은 parallel하지만 추론은 sequential — 학습 최적화로 해결 불가 |
핵심 통찰
- 병목의 원인은 헤드별로 K/V가 중복 저장되어 있다는 것 → 공유로 h배 감소 가능.
- 쿼리(Q)는 다양성이 중요하지만 K/V는 “무엇을 참조할지의 좌표계” 역할이라 헤드 간 공유 여지가 크다.
방법 (Method)
프레임워크 개요
graph TB subgraph MHA["Multi-Head Attention (기존)"] X1[입력 X] --> Q1[Q heads: h개] X1 --> K1[K heads: h개] X1 --> V1[V heads: h개] Q1 --> A1[h개 독립 어텐션] K1 --> A1 V1 --> A1 A1 --> O1[Concat + Out proj] end subgraph MQA["Multi-Query Attention (제안)"] X2[입력 X] --> Q2[Q heads: h개] X2 --> KV2[K/V: 공유 1개] Q2 --> A2[h개 헤드 all share K,V] KV2 --> A2 A2 --> O2[Concat + Out proj] end MHA -.KV 캐시 b·m·h·d_k.-> Cache1[캐시 크기 X] MQA -.KV 캐시 b·m·d_k.-> Cache2[캐시 크기 X / h]
핵심 구성요소
1. Query projection (헤드별 유지)
2. Key/Value projection (공유)
3. 각 헤드의 attention (공유 K/V에 질의)
4. Output projection
KV 캐시 감소비
| 항목 | MHA | MQA | 비율 |
|---|---|---|---|
| K/V 캐시 요소 수 | 1 / h | ||
| 디코딩 스텝당 메모리 로드 | 1 / h |
예: 이면 KV 캐시 8배, (LLaMA 규모)이면 32배 감소.
발견 (Findings)
주요 결과 (WMT14 En-De, Table 2 기반)
| 모델 | BLEU | Dec μs/token | Enc μs/token |
|---|---|---|---|
| MHA (baseline) | 28.7 | 46 | 1.7 |
| MQA | 28.5 | 3.8 | 1.5 |
| MHA h=1 (small) | 27.2 | 31 | 2.3 |
핵심 발견
- 디코딩 속도: MQA가 MHA 대비 토큰당 약 12배 빠름 (46 → 3.8 μs, Table 2).
- 품질: BLEU 0.2 포인트 손실 — 통계적으로 미미.
- 인코딩: 본래 인코딩은 병렬이라 이득이 작지만 1.7 → 1.5 μs로 소폭 개선.
- 학습 속도: 파라미터 수 감소로 학습 스텝도 살짝 빨라짐 (433 → 397 ms).
이론적 의의
Memory-bandwidth-aware 아키텍처 설계의 효시
FLOPs 최적화(예: sparse attention)에 치우쳐 있던 흐름에서, “무엇을 메모리에서 읽는가”가 1차 비용이라는 관점을 명시화했다. 이 사고방식은 FlashAttention, PagedAttention(vLLM), Ring Attention 등 IO-aware 시스템 연구 전반에 영향을 주었다.
헤드 공유의 설계 공간 개방
MHA(h개 K/V)와 MQA(1개 K/V) 사이에 연속적인 설계 공간이 존재함을 암시했고, 이는 4년 뒤 GQA - Training Generalized Multi-Query Transformer Models (Ainslie et al. 2023)가 G개 그룹 공유로 일반화하면서 명시화되었다. LLaMA-2 70B, LLaMA-3, Mistral, Qwen, Gemma 등 거의 모든 주요 LLM이 GQA/MQA 계열을 채택.
재현성 및 신뢰도 평가
| 항목 | 등급 | 비고 |
|---|---|---|
| 코드 공개 | ⚠️ | 공식 구현 없음. Mesh-TensorFlow 기반 내부 코드 언급만 존재 |
| 데이터 공개 | ✅ | WMT14 En-De 공개 벤치마크 |
| 하이퍼파라미터 | ✅ | Transformer-base 표준 설정 + 부록에 상세 기재 |
| 실험 환경 | ⚠️ | TPU 기반, 정확한 하드웨어 세대/토폴로지 모호 |
| 통계적 신뢰도 | ❌ | 단일 실행, 표준편차/신뢰구간 없음 |
| 종합 등급 | B | 설계와 수치는 명확하나 다중 실행·외부 검증 부재 |
주장별 신뢰도
| # | 주장 | 근거 | 신뢰도 |
|---|---|---|---|
| 1 | MQA가 MHA 대비 디코딩을 10배 이상 가속 | Table 2 (46→3.8μs) | 🟢 |
| 2 | BLEU 손실은 미미 (< 0.3) | WMT14 En-De 단일 실험 | 🟡 (다중 seed/태스크 미검증) |
| 3 | 인코더보다 디코더에서 이득이 큼 | 메모리 대역폭 분석 + Table 2 | 🟢 |
| 4 | 대규모 모델/긴 문맥에서도 품질 유지 | 본 논문에는 직접 증거 없음 | 🔴 (후속 GQA에서 부분적 반박) |
읽기 난이도: ⭐⭐
Transformer 기본 구조와 GPU 메모리 계층(HBM vs SRAM), arithmetic intensity 개념을 알아야 함. 수식은 많지 않지만 memory bandwidth 분석 부분은 시스템 배경이 필요.
관련 연구 비교 매트릭스
| 축 | 본 논문 (MQA, 2019) | MHA (Vaswani 2017) | GQA (Ainslie 2023) | FlashAttention (Dao 2022) |
|---|---|---|---|---|
| 핵심 접근 | K/V 헤드 간 완전 공유 (h→1) | 헤드별 독립 K/V | 그룹 단위 K/V 공유 (h→G) | IO-aware tiling |
| 문제 정의 | 디코딩 memory bandwidth | 병렬 어텐션 표현력 | MQA 품질 저하 보완 | 전체 어텐션의 HBM I/O |
| 데이터 | WMT14 En-De | WMT14 En-De/Fr | T5 계열, 다국어/QA | MLM, LM 벤치마크 |
| 핵심 메트릭 | BLEU 28.5, 3.8μs/token | BLEU 28.7, 46μs/token | MHA 근접 + MQA 근접 속도 | 2-4배 wall-clock 가속 |
| 확장성 | 매우 높음 (모든 LLM) | 표준 | LLaMA-2/3 등 사실상 표준 | 매우 높음 |
| 한계 | 품질 저하 가능성 | 디코딩 느림 | 최적 G 탐색 필요 | 아키텍처 변경 없음 |
| 코드 공개 | ❌ | ✅ | ✅ | ✅ |
관련 연구
- GQA - Training Generalized Multi-Query Transformer Models — MQA(1 head)와 MHA(h head) 사이의 그룹화 일반화. MQA의 품질 저하 이슈를 보완한 후속.
- FlashAttention-2 - Faster Attention with Better Parallelism and Work Partitioning — 같은 “memory-bandwidth 관점”이지만 커널 레벨 해법.
- Attention Methods — 어텐션 변형 전체 지도에서 MQA의 위치.
원자적 인사이트 (Zettelkasten)
💡 디코딩 병목은 FLOPs가 아니라 KV 캐시 bandwidth다
출처: MQA - Fast Transformer Decoding with Multi-Query Attention (Shazeer, 2019)
유형: 이론적 / 시스템 관찰
Autoregressive 디코딩에서 한 토큰을 생성할 때의 arithmetic intensity(연산/바이트 비율)는 매우 낮아서, GPU의 FLOPs는 놀면서 HBM↔SRAM 대역폭이 포화된다. 따라서 아키텍처 최적화의 1차 목표는 **“스텝당 읽어야 하는 텐서의 크기”**를 줄이는 것이다.
핵심 조건/맥락: 배치가 작고 시퀀스가 길어 KV 캐시가 지배적인 경우에 특히 강하게 성립.
연결: FlashAttention-2 - Faster Attention with Better Parallelism and Work Partitioning, vLLM의 PagedAttention.
활용 가능성: 새 어텐션 변형 제안 시 FLOPs보다 메모리 이동량을 1차 지표로 삼아야 함.
💡 K와 V는 헤드 간 중복이 크다
출처: MQA - Fast Transformer Decoding with Multi-Query Attention (Shazeer, 2019)
유형: 실험적
h개 헤드가 각자 별도 K/V를 가지는 대신 모든 헤드가 하나의 K/V를 공유해도 번역 품질이 BLEU 0.2 이내로 유지된다(Table 2). 이는 K/V 공간이 Q 공간보다 표현 redundancy가 크다는 구조적 증거다.
핵심 조건/맥락: Transformer-base 크기, 번역 태스크. 대규모 LLM에서는 부분적으로만 성립 (→ GQA가 필요).
연결: GQA - Training Generalized Multi-Query Transformer Models — 공유 단위를 연속적으로 조절.
활용 가능성: 어떤 축(Q/K/V/FFN/layer)이 가장 많은 redundancy를 가지는지 체계적으로 탐색 → low-rank/shared design의 출발점.
💡 학습-추론 비대칭이 아키텍처 선택의 결정 요인이다
출처: MQA - Fast Transformer Decoding with Multi-Query Attention (Shazeer, 2019)
유형: 방법론적
학습은 시퀀스 방향으로 parallel하지만 inference는 sequential하므로, 동일 아키텍처라도 학습/추론 비용 프로파일이 근본적으로 다르다. MQA는 학습 속도에는 미미한 변화만 주면서 추론만 선택적으로 가속한다.
핵심 조건/맥락: Causal 디코더. 양방향 인코더에는 이득이 제한적.
연결: Speculative decoding, MoE, distillation — 모두 “학습-추론 비대칭”을 설계 변수로 삼음.
활용 가능성: 새 모듈 제안 시 학습/추론 각각에서 지배적 비용이 다름을 명시적으로 분해할 것.
💡 “단 1개로의 압축”이 극단값으로서 후속 일반화를 유발한다
출처: MQA - Fast Transformer Decoding with Multi-Query Attention (Shazeer, 2019)
유형: 연결 / 메타
MQA는 h→1이라는 극단적 압축을 먼저 제시함으로써, h와 1 사이의 연속 공간(=GQA의 G)이 열린다. 연구 전략적으로 극단값 먼저 제시 → 중간값 일반화라는 패턴은 sparse↔dense, on-policy↔off-policy, frozen↔full-finetune 등에서도 반복되는 유효한 접근이다.
핵심 조건/맥락: 극단값이 유의미한 이득을 증명할 때만 중간 공간 탐색에 동기가 생김.
연결: GQA - Training Generalized Multi-Query Transformer Models, LoRA ↔ full finetune, MoE top-1 ↔ top-k.
활용 가능성: 새 아이디어 제안 시 가장 단순/극단 버전을 먼저 검증하는 전략.
핵심 용어 정리
| 용어 | 정의 |
|---|---|
| MHA (Multi-Head Attention) | h개 헤드 각각이 독립된 Q/K/V projection을 갖는 표준 Transformer 어텐션 |
| MQA (Multi-Query Attention) | Q는 h개 헤드 유지, K/V는 모든 헤드가 공유하는 단일 쌍을 쓰는 변형 |
| KV 캐시 | 디코딩 시 과거 토큰들의 key/value 텐서를 누적 저장하는 메모리 버퍼 |
| Memory bandwidth | 연산 장치와 메모리(HBM) 간 데이터 전송 속도. arithmetic intensity 낮을 때 병목 |
| Arithmetic intensity | FLOPs / bytes loaded. 낮으면 memory-bound, 높으면 compute-bound |
| Autoregressive decoding | 한 토큰 생성 → 다음 토큰의 입력으로 사용하며 순차 진행하는 생성 방식 |
| Head | 어텐션을 서로 다른 subspace에서 병렬 수행하는 단위 |
| WMT14 En-De | 표준 영어→독일어 번역 벤치마크. Transformer 논문 이래 공인 baseline |
| BLEU | n-gram precision 기반 기계번역 품질 지표 |
태그
paper #2019 attention mqa kv-cache decoding multi-head-variants