FlashAttention-2: 더 나은 병렬화와 작업 분할을 통한 빠른 어텐션
Digest: Transformer의 시퀀스 길이 확장을 가로막던 어텐션 연산은 메모리와 런타임이 시퀀스 길이에 대해 이차적으로 증가한다는 한계를 갖는다(Context). FlashAttention-1은 IO-aware tiling으로 메모리 병목을 해소했지만 A100에서 이론 최대 FLOPs/s의 25-40%만 달성하여 최적화된 GEMM(80-90%)에 크게 미치지 못하는 문제를 드러냈다(Insight). 저자는 (1) 비-matmul FLOPs를 줄이는 알고리즘 개선(스케일링 지연, log-sum-exp 단일 저장), (2) 단일 head에서도 시퀀스 길이 차원으로 thread block을 병렬화해 occupancy 향상, (3) warp 간 K/V 대신 Q를 분할해 공유 메모리 통신을 최소화하는 세 가지 핵심 재설계를 제안한다(Solution). 그 결과 A100에서 FlashAttention 대비 약 2배(Figure 5-6), 이론 최대의 50-73% FLOPs/s를 달성하고, GPT-3 스타일 end-to-end 학습에서 A100당 225 TFLOPs/s(MFU 72%, Table 1)를 기록했다(Evidence). 다만 Hopper(H100)의 TMA/FP8 등 신규 기능 활용이나 더 비표준적인 attention 변형에 대한 커버리지는 제한적이다(Limitations). 향후 새로운 하드웨어에 대한 자동 튜닝과 블록 희소·선형 어텐션과의 결합이 열린 과제로 남는다(Open Questions).
섹션별 요약
Introduction
Transformer의 확장은 긴 문맥(8k, 16k, 32k+)을 요구하는 LLM, 고해상도 이미지, 코드, 오디오·비디오 모델링의 핵심 요구사항이다. 어텐션은 시퀀스 길이 N에 대해 O(N^2) 런타임과 O(N^2) 메모리를 갖지만, FlashAttention-1은 softmax의 online 재계산과 SRAM tiling을 통해 HBM 접근을 O(N^2 d^2/M)으로 줄였다. 그러나 GPU의 연산 자원(Tensor Core) 활용은 여전히 낮았다. 본 논문은 FlashAttention 알고리즘의 워크플로를 재분해해 matmul 대비 비싼 비-matmul 연산을 줄이고, 병렬화 축을 추가하며, warp 내·warp 간 공유 메모리 접근을 개선한다.
Methods
세 가지 축에서 개선이 이루어진다. 첫째, forward pass에서 softmax 보정 스케일링을 마지막 블록 처리 후로 지연시키고 log-sum-exp 하나만 저장해 backward에 재사용한다. 둘째, batch·head에 더해 sequence length(N) 차원에서도 thread block을 분할해 긴 시퀀스·작은 batch에서도 SM occupancy를 유지한다. 셋째, FlashAttention-1의 split-K 방식(warp가 K/V를 나눠 가짐)을 split-Q로 바꿔 각 warp가 Q 조각을 갖고 K/V를 공유함으로써 warp 간 reduce 및 공유 메모리 왕복을 제거한다. Backward pass에도 동일한 재분할을 적용한다.
Results
A100 80GB SXM에서 FP16/BF16 기준 FlashAttention 대비 1.7-3.0배, xformers 대비 1.3-2.5배, PyTorch standard 대비 3-10배 빠르다. 헤드 크기 128, 시퀀스 8k에서 forward+backward 기준 약 230 TFLOPs/s에 도달한다. GPT-3 스타일 1.3B/2.7B 모델을 8k 문맥으로 학습할 때 A100당 225 TFLOPs/s(MFU 72%)를 기록, FlashAttention 대비 약 1.3배 end-to-end 학습 속도 향상을 보였다.
| 실험 | 세팅 | FlashAttention | FlashAttention-2 | 상대 속도 |
|---|---|---|---|---|
| Forward (head=128, N=8k) | A100 FP16 | ~124 TFLOPs/s | ~335 TFLOPs/s | 2.7배 |
| Forward+Backward (head=128, N=8k) | A100 FP16 | ~124 TFLOPs/s | ~230 TFLOPs/s | 1.9배 |
| GPT 1.3B E2E 학습 | A100, 8k 문맥 | 170 TFLOPs/s | 225 TFLOPs/s | 1.32배 |
| 이론 최대 대비 | A100 FP16 (312 TFLOPs/s) | 25-40% | 50-73% | — |
Discussion
개선의 핵심은 “올바른 병렬화 축 선택”이다. 어텐션은 matmul 내부 구조가 GEMM과 동일하지만, softmax·masking·log-sum-exp 등 비-matmul 연산이 Tensor Core idle을 유발한다. Q-분할과 sequence-length 병렬화는 공유 메모리 병목과 SM 저활용 두 문제를 동시에 해결한다. 또한 causal masking 경계 블록을 조기 skip해 불필요 연산을 제거한다.
Insights
- 하드웨어 활용도는 알고리즘 IO 복잡도뿐 아니라 warp-level 데이터 배치에 크게 의존한다.
- Online softmax는 최종 스케일링을 지연할 수 있어 boundary block에서만 비-matmul 연산을 수행하면 된다.
- log-sum-exp만 저장해도 backward의 수치적 안정성과 메모리 효율을 모두 확보할 수 있다.
Discussion Points
- H100의 TMA, asynchronous WGMMA, FP8 tensor core를 활용한 FlashAttention-3와의 관계는 무엇인가.
- 블록 희소 어텐션·ALiBi·rotary positional encoding 같은 변형과 결합 시 성능 저하는 어느 정도인가.
- 추론(Decoding) 단계의 KV cache 접근 패턴에서도 Q-split이 최적인가, 아니면 별도 커널(Paged/FlashDecoding)이 필요한가.
메타데이터
| 항목 | 내용 |
|---|---|
| 제목 | FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning |
| 저자 | Tri Dao |
| 소속 | Princeton University / Together AI |
| 연도 | 2023 |
| 발표 | arXiv preprint (2307.08691), 이후 ICLR 2024 |
| 링크 | https://arxiv.org/abs/2307.08691 |
| 키워드 | FlashAttention, Attention, GPU kernel, Work Partitioning, Transformer |
왜 이 연구를 하는가?
핵심 질문
긴 시퀀스에서도 최적화된 GEMM 수준(이론 최대 FLOPs의 80-90%)으로 어텐션 커널을 돌릴 수 있는가? FlashAttention-1이 메모리 병목은 풀었지만 왜 여전히 25-40%에 머무르는가?
기존 접근법의 한계
| 접근 | 한계 |
|---|---|
| Standard attention (PyTorch) | O(N^2) HBM 읽기/쓰기로 메모리 병목, 긴 시퀀스에서 OOM |
| xformers memory-efficient attention | 메모리는 줄였지만 커널 fusion과 warp 스케줄링이 미흡 |
| FlashAttention-1 | IO-aware tiling은 도입했으나 비-matmul 연산·warp 간 통신·병렬화 축 부족으로 Tensor Core 저활용 (25-40%) |
| 블록 희소 어텐션 | 근사 기반이라 정확도 손실 가능, 일반적 dense 대비 범용성 낮음 |
핵심 통찰
어텐션의 병목은 더 이상 HBM이 아니라 SM 내부 스케줄링이다. 따라서 커널 수준에서 (1) 비-matmul 연산 최소화, (2) sequence-length 병렬화, (3) Q-분할 warp 레이아웃을 통해 Tensor Core 가동률을 2배 끌어올릴 수 있다.
방법 (Method)
프레임워크 개요
flowchart LR A[Q, K, V in HBM] --> B[Thread block: batch x head x seq-block] B --> C[Load Q tile to SRAM] C --> D[Iterate K,V tiles] D --> E[Matmul QKT + online softmax stats] E --> F[Matmul softmax times V, accumulate] F --> G[Final rescale once at end] G --> H[Write O, log-sum-exp to HBM]
핵심 구성요소
- Non-matmul FLOPs 감소: 매 블록마다 수행하던 스케일링 1/sum(exp)를 마지막에 한 번만 적용. log-sum-exp만 저장하여 backward에서 recompute.
- Sequence-length 병렬화: grid를 (batch, head, seq_block)로 확장해 길이 8k/16k·batch=1에서도 SM occupancy 유지.
- Split-Q warp 파티셔닝: 기존 split-K는 warp 간 partial sum을 공유 메모리에 쓰고 reduce해야 했음. Split-Q에서는 각 warp가 별도 Q row를 담당해 교차 통신 제거.
- Causal masking 최적화: 대각선 경계 블록만 masking 계산을 수행하고, 완전 상삼각 블록은 아예 건너뛴다.
- Backward 재분해: forward와 동일한 Q-분할 유지, dQ/dK/dV 재계산 시 log-sum-exp 1회 재사용.
발견 (Findings)
주요 결과
| 벤치마크 | 조건 | FA-2 성능 | 비교 |
|---|---|---|---|
| Forward kernel | A100, head=128, N=2k-16k | 최대 ~335 TFLOPs/s | FA-1 대비 2.7배 |
| Forward+Backward | A100, head=128, N=8k | ~230 TFLOPs/s | FA-1 대비 1.9배 |
| 이론 최대 대비 | A100 FP16 312 TFLOPs/s | 50-73% | FA-1: 25-40% |
| End-to-end 학습 | GPT-1.3B, 8k ctx | 225 TFLOPs/s, MFU 72% | FA-1 대비 1.3배 |
| H100 포팅 | 초기 결과 | FA-1 대비 2배 이상 | 단 H100 전용 기능 미활용 |
핵심 발견
- 병목은 HBM에서 SM 내부 warp 스케줄링으로 이동했다.
- Sequence-length 병렬화 한 축 추가만으로도 긴 문맥·작은 batch에서 2배 이상의 이득.
- Causal mask skip은 인과적 LM 학습에서 forward를 약 1.7배 가속.
이론적 의의
FlashAttention-2는 exact attention도 GEMM과 같은 수준으로 최적화 가능함을 실증하여, 긴 문맥 LLM 학습의 실용적 장벽을 낮췄다. 또한 online softmax, log-sum-exp 저장, warp-level layout이 후속 연구(FlashAttention-3, FlashDecoding, PagedAttention)의 기반이 되었다. 이는 알고리즘-하드웨어 공동 설계(hardware-aware algorithms)가 단순 근사(approximate attention)보다 강력한 접근이라는 메시지를 강화한다.
재현성 및 신뢰도 평가
| 항목 | 평가 | 근거 |
|---|---|---|
| 코드 공개 | A | 공식 저장소 Dao-AILab/flash-attention에 CUDA/CUTLASS 커널 공개 |
| 데이터 공개 | A | 학습 벤치마크는 표준 GPT 학습 세팅, 공개 데이터로 재현 가능 |
| 하이퍼파라미터 | A | 블록 크기, warp 수 등 커널 설정 명시 |
| 실험 환경 | A | A100 80GB SXM, CUDA 11.8, cuDNN 등 명시 |
| 통계적 유의성 | B | 여러 시퀀스 길이·head dim 그리드 측정, 반복 실행은 명시 제한적 |
| 종합 등급 | A | 커널·학습 모두 재현 가능, 실제 오픈소스 생태계에서 광범위 검증 |
| 주장 | 신뢰도 | 근거 |
|---|---|---|
| A100에서 FA-1 대비 2배 속도 | 높음 | 커널 벤치마크 Figure 5-6, 외부 재현 다수 |
| 이론 최대의 50-73% 달성 | 높음 | A100 FP16 312 TFLOPs/s 대비 측정값 일관 |
| GPT 학습 MFU 72% | 중상 | 특정 모델·batch·context 조합에 한정, 일반화는 세팅 의존 |
| H100 이식 후에도 2배 | 중 | 초기 실험, H100 전용 기능(WGMMA/TMA) 미활용 |
읽기 난이도
GPU 아키텍처(Warp, SM, SRAM/HBM), CUTLASS, online softmax에 대한 사전 지식이 필요하다. 알고리즘 섹션은 8/10, 코드 이해까지 포함하면 9/10 수준의 전문성을 요구한다.
관련 연구
- FlashAttention - Fast and Memory-Efficient Exact Attention with IO-Awareness
- PagedAttention - Efficient Memory Management for LLM Serving with vLLM
- Mistral 7B - Sliding Window Attention
원자적 인사이트 (Zettelkasten)
Insight 1: 비-matmul 연산 지연이 Tensor Core 가동률을 좌우한다
- 출처: FlashAttention-2, Section 3.1
- 유형: 최적화 원리
- 맥락: softmax의 rescale을 블록마다 적용하면 Tensor Core가 idle 상태가 되어 이론 최대의 25-40%에 머무른다. 마지막에 한 번만 rescale하면 50-73%로 도약한다.
- 연결: FlashAttention - Fast and Memory-Efficient Exact Attention with IO-Awareness, Online softmax(Milakov & Gimelshein 2018)
Insight 2: Warp 간 공유 메모리 왕복은 split-Q로 제거한다
- 출처: FlashAttention-2, Section 3.2
- 유형: 병렬 알고리즘 설계
- 맥락: FA-1의 split-K는 warp가 K/V를 분할해 Q에 대해 partial 결과를 공유 메모리로 reduce해야 했다. Q를 분할하면 warp 간 통신이 사라지고 warp 내부 accumulate만 남아 레이턴시가 감소한다.
- 연결: CUTLASS warp-level GEMM, GQA - Training Generalized Multi-Query Transformer Models
Insight 3: Sequence-length 병렬화는 긴 문맥·작은 batch 체제를 구한다
- 출처: FlashAttention-2, Section 3.2
- 유형: 스케줄링 통찰
- 맥락: batch=1, N=16k 같은 LLM 학습/추론 케이스에서 batch·head 축만으로는 SM을 다 채울 수 없다. 시퀀스 축 분할이 이를 보완한다.
- 연결: PagedAttention - Efficient Memory Management for LLM Serving with vLLM, FlashDecoding
핵심 용어 정리
| 용어 | 정의 |
|---|---|
| FlashAttention | HBM 접근을 줄이기 위한 IO-aware tiling 기반 exact attention 커널 |
| Online softmax | 블록 단위로 max/sum을 점진적으로 업데이트하는 softmax 계산 방식 |
| Log-sum-exp (LSE) | log(sum(exp(x)))로, softmax 정규화 상수이자 backward 재계산에 사용 |
| Split-Q / Split-K | warp들이 Q 또는 K를 나누어 계산하는 병렬화 전략 |
| MFU | Model FLOPs Utilization, 실제 달성 FLOPs 대비 이론 최대 FLOPs |
| SM / Warp / SRAM | GPU Streaming Multiprocessor, 32-thread 실행 단위, 온칩 공유 메모리 |
| Tensor Core | NVIDIA GPU의 mixed-precision matmul 전용 하드웨어 |
태그
paper #2023 attention gpu-optimization flashattention transformer