FlashAttention-2: 더 나은 병렬화와 작업 분할을 통한 빠른 어텐션

Digest: Transformer의 시퀀스 길이 확장을 가로막던 어텐션 연산은 메모리와 런타임이 시퀀스 길이에 대해 이차적으로 증가한다는 한계를 갖는다(Context). FlashAttention-1은 IO-aware tiling으로 메모리 병목을 해소했지만 A100에서 이론 최대 FLOPs/s의 25-40%만 달성하여 최적화된 GEMM(80-90%)에 크게 미치지 못하는 문제를 드러냈다(Insight). 저자는 (1) 비-matmul FLOPs를 줄이는 알고리즘 개선(스케일링 지연, log-sum-exp 단일 저장), (2) 단일 head에서도 시퀀스 길이 차원으로 thread block을 병렬화해 occupancy 향상, (3) warp 간 K/V 대신 Q를 분할해 공유 메모리 통신을 최소화하는 세 가지 핵심 재설계를 제안한다(Solution). 그 결과 A100에서 FlashAttention 대비 약 2배(Figure 5-6), 이론 최대의 50-73% FLOPs/s를 달성하고, GPT-3 스타일 end-to-end 학습에서 A100당 225 TFLOPs/s(MFU 72%, Table 1)를 기록했다(Evidence). 다만 Hopper(H100)의 TMA/FP8 등 신규 기능 활용이나 더 비표준적인 attention 변형에 대한 커버리지는 제한적이다(Limitations). 향후 새로운 하드웨어에 대한 자동 튜닝과 블록 희소·선형 어텐션과의 결합이 열린 과제로 남는다(Open Questions).

섹션별 요약

Introduction

Transformer의 확장은 긴 문맥(8k, 16k, 32k+)을 요구하는 LLM, 고해상도 이미지, 코드, 오디오·비디오 모델링의 핵심 요구사항이다. 어텐션은 시퀀스 길이 N에 대해 O(N^2) 런타임과 O(N^2) 메모리를 갖지만, FlashAttention-1은 softmax의 online 재계산과 SRAM tiling을 통해 HBM 접근을 O(N^2 d^2/M)으로 줄였다. 그러나 GPU의 연산 자원(Tensor Core) 활용은 여전히 낮았다. 본 논문은 FlashAttention 알고리즘의 워크플로를 재분해해 matmul 대비 비싼 비-matmul 연산을 줄이고, 병렬화 축을 추가하며, warp 내·warp 간 공유 메모리 접근을 개선한다.

Methods

세 가지 축에서 개선이 이루어진다. 첫째, forward pass에서 softmax 보정 스케일링을 마지막 블록 처리 후로 지연시키고 log-sum-exp 하나만 저장해 backward에 재사용한다. 둘째, batch·head에 더해 sequence length(N) 차원에서도 thread block을 분할해 긴 시퀀스·작은 batch에서도 SM occupancy를 유지한다. 셋째, FlashAttention-1의 split-K 방식(warp가 K/V를 나눠 가짐)을 split-Q로 바꿔 각 warp가 Q 조각을 갖고 K/V를 공유함으로써 warp 간 reduce 및 공유 메모리 왕복을 제거한다. Backward pass에도 동일한 재분할을 적용한다.

Results

A100 80GB SXM에서 FP16/BF16 기준 FlashAttention 대비 1.7-3.0배, xformers 대비 1.3-2.5배, PyTorch standard 대비 3-10배 빠르다. 헤드 크기 128, 시퀀스 8k에서 forward+backward 기준 약 230 TFLOPs/s에 도달한다. GPT-3 스타일 1.3B/2.7B 모델을 8k 문맥으로 학습할 때 A100당 225 TFLOPs/s(MFU 72%)를 기록, FlashAttention 대비 약 1.3배 end-to-end 학습 속도 향상을 보였다.

실험세팅FlashAttentionFlashAttention-2상대 속도
Forward (head=128, N=8k)A100 FP16~124 TFLOPs/s~335 TFLOPs/s2.7배
Forward+Backward (head=128, N=8k)A100 FP16~124 TFLOPs/s~230 TFLOPs/s1.9배
GPT 1.3B E2E 학습A100, 8k 문맥170 TFLOPs/s225 TFLOPs/s1.32배
이론 최대 대비A100 FP16 (312 TFLOPs/s)25-40%50-73%

Discussion

개선의 핵심은 “올바른 병렬화 축 선택”이다. 어텐션은 matmul 내부 구조가 GEMM과 동일하지만, softmax·masking·log-sum-exp 등 비-matmul 연산이 Tensor Core idle을 유발한다. Q-분할과 sequence-length 병렬화는 공유 메모리 병목과 SM 저활용 두 문제를 동시에 해결한다. 또한 causal masking 경계 블록을 조기 skip해 불필요 연산을 제거한다.

Insights

  • 하드웨어 활용도는 알고리즘 IO 복잡도뿐 아니라 warp-level 데이터 배치에 크게 의존한다.
  • Online softmax는 최종 스케일링을 지연할 수 있어 boundary block에서만 비-matmul 연산을 수행하면 된다.
  • log-sum-exp만 저장해도 backward의 수치적 안정성과 메모리 효율을 모두 확보할 수 있다.

Discussion Points

  • H100의 TMA, asynchronous WGMMA, FP8 tensor core를 활용한 FlashAttention-3와의 관계는 무엇인가.
  • 블록 희소 어텐션·ALiBi·rotary positional encoding 같은 변형과 결합 시 성능 저하는 어느 정도인가.
  • 추론(Decoding) 단계의 KV cache 접근 패턴에서도 Q-split이 최적인가, 아니면 별도 커널(Paged/FlashDecoding)이 필요한가.

메타데이터

항목내용
제목FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
저자Tri Dao
소속Princeton University / Together AI
연도2023
발표arXiv preprint (2307.08691), 이후 ICLR 2024
링크https://arxiv.org/abs/2307.08691
키워드FlashAttention, Attention, GPU kernel, Work Partitioning, Transformer

왜 이 연구를 하는가?

핵심 질문

긴 시퀀스에서도 최적화된 GEMM 수준(이론 최대 FLOPs의 80-90%)으로 어텐션 커널을 돌릴 수 있는가? FlashAttention-1이 메모리 병목은 풀었지만 왜 여전히 25-40%에 머무르는가?

기존 접근법의 한계

접근한계
Standard attention (PyTorch)O(N^2) HBM 읽기/쓰기로 메모리 병목, 긴 시퀀스에서 OOM
xformers memory-efficient attention메모리는 줄였지만 커널 fusion과 warp 스케줄링이 미흡
FlashAttention-1IO-aware tiling은 도입했으나 비-matmul 연산·warp 간 통신·병렬화 축 부족으로 Tensor Core 저활용 (25-40%)
블록 희소 어텐션근사 기반이라 정확도 손실 가능, 일반적 dense 대비 범용성 낮음

핵심 통찰

어텐션의 병목은 더 이상 HBM이 아니라 SM 내부 스케줄링이다. 따라서 커널 수준에서 (1) 비-matmul 연산 최소화, (2) sequence-length 병렬화, (3) Q-분할 warp 레이아웃을 통해 Tensor Core 가동률을 2배 끌어올릴 수 있다.

방법 (Method)

프레임워크 개요

flowchart LR
    A[Q, K, V in HBM] --> B[Thread block: batch x head x seq-block]
    B --> C[Load Q tile to SRAM]
    C --> D[Iterate K,V tiles]
    D --> E[Matmul QKT + online softmax stats]
    E --> F[Matmul softmax times V, accumulate]
    F --> G[Final rescale once at end]
    G --> H[Write O, log-sum-exp to HBM]

핵심 구성요소

  1. Non-matmul FLOPs 감소: 매 블록마다 수행하던 스케일링 1/sum(exp)를 마지막에 한 번만 적용. log-sum-exp만 저장하여 backward에서 recompute.
  2. Sequence-length 병렬화: grid를 (batch, head, seq_block)로 확장해 길이 8k/16k·batch=1에서도 SM occupancy 유지.
  3. Split-Q warp 파티셔닝: 기존 split-K는 warp 간 partial sum을 공유 메모리에 쓰고 reduce해야 했음. Split-Q에서는 각 warp가 별도 Q row를 담당해 교차 통신 제거.
  4. Causal masking 최적화: 대각선 경계 블록만 masking 계산을 수행하고, 완전 상삼각 블록은 아예 건너뛴다.
  5. Backward 재분해: forward와 동일한 Q-분할 유지, dQ/dK/dV 재계산 시 log-sum-exp 1회 재사용.

발견 (Findings)

주요 결과

벤치마크조건FA-2 성능비교
Forward kernelA100, head=128, N=2k-16k최대 ~335 TFLOPs/sFA-1 대비 2.7배
Forward+BackwardA100, head=128, N=8k~230 TFLOPs/sFA-1 대비 1.9배
이론 최대 대비A100 FP16 312 TFLOPs/s50-73%FA-1: 25-40%
End-to-end 학습GPT-1.3B, 8k ctx225 TFLOPs/s, MFU 72%FA-1 대비 1.3배
H100 포팅초기 결과FA-1 대비 2배 이상단 H100 전용 기능 미활용

핵심 발견

  • 병목은 HBM에서 SM 내부 warp 스케줄링으로 이동했다.
  • Sequence-length 병렬화 한 축 추가만으로도 긴 문맥·작은 batch에서 2배 이상의 이득.
  • Causal mask skip은 인과적 LM 학습에서 forward를 약 1.7배 가속.

이론적 의의

FlashAttention-2는 exact attention도 GEMM과 같은 수준으로 최적화 가능함을 실증하여, 긴 문맥 LLM 학습의 실용적 장벽을 낮췄다. 또한 online softmax, log-sum-exp 저장, warp-level layout이 후속 연구(FlashAttention-3, FlashDecoding, PagedAttention)의 기반이 되었다. 이는 알고리즘-하드웨어 공동 설계(hardware-aware algorithms)가 단순 근사(approximate attention)보다 강력한 접근이라는 메시지를 강화한다.

재현성 및 신뢰도 평가

항목평가근거
코드 공개A공식 저장소 Dao-AILab/flash-attention에 CUDA/CUTLASS 커널 공개
데이터 공개A학습 벤치마크는 표준 GPT 학습 세팅, 공개 데이터로 재현 가능
하이퍼파라미터A블록 크기, warp 수 등 커널 설정 명시
실험 환경AA100 80GB SXM, CUDA 11.8, cuDNN 등 명시
통계적 유의성B여러 시퀀스 길이·head dim 그리드 측정, 반복 실행은 명시 제한적
종합 등급A커널·학습 모두 재현 가능, 실제 오픈소스 생태계에서 광범위 검증
주장신뢰도근거
A100에서 FA-1 대비 2배 속도높음커널 벤치마크 Figure 5-6, 외부 재현 다수
이론 최대의 50-73% 달성높음A100 FP16 312 TFLOPs/s 대비 측정값 일관
GPT 학습 MFU 72%중상특정 모델·batch·context 조합에 한정, 일반화는 세팅 의존
H100 이식 후에도 2배초기 실험, H100 전용 기능(WGMMA/TMA) 미활용

읽기 난이도

GPU 아키텍처(Warp, SM, SRAM/HBM), CUTLASS, online softmax에 대한 사전 지식이 필요하다. 알고리즘 섹션은 8/10, 코드 이해까지 포함하면 9/10 수준의 전문성을 요구한다.

관련 연구

원자적 인사이트 (Zettelkasten)

Insight 1: 비-matmul 연산 지연이 Tensor Core 가동률을 좌우한다

  • 출처: FlashAttention-2, Section 3.1
  • 유형: 최적화 원리
  • 맥락: softmax의 rescale을 블록마다 적용하면 Tensor Core가 idle 상태가 되어 이론 최대의 25-40%에 머무른다. 마지막에 한 번만 rescale하면 50-73%로 도약한다.
  • 연결: FlashAttention - Fast and Memory-Efficient Exact Attention with IO-Awareness, Online softmax(Milakov & Gimelshein 2018)

Insight 2: Warp 간 공유 메모리 왕복은 split-Q로 제거한다

  • 출처: FlashAttention-2, Section 3.2
  • 유형: 병렬 알고리즘 설계
  • 맥락: FA-1의 split-K는 warp가 K/V를 분할해 Q에 대해 partial 결과를 공유 메모리로 reduce해야 했다. Q를 분할하면 warp 간 통신이 사라지고 warp 내부 accumulate만 남아 레이턴시가 감소한다.
  • 연결: CUTLASS warp-level GEMM, GQA - Training Generalized Multi-Query Transformer Models

Insight 3: Sequence-length 병렬화는 긴 문맥·작은 batch 체제를 구한다

핵심 용어 정리

용어정의
FlashAttentionHBM 접근을 줄이기 위한 IO-aware tiling 기반 exact attention 커널
Online softmax블록 단위로 max/sum을 점진적으로 업데이트하는 softmax 계산 방식
Log-sum-exp (LSE)log(sum(exp(x)))로, softmax 정규화 상수이자 backward 재계산에 사용
Split-Q / Split-Kwarp들이 Q 또는 K를 나누어 계산하는 병렬화 전략
MFUModel FLOPs Utilization, 실제 달성 FLOPs 대비 이론 최대 FLOPs
SM / Warp / SRAMGPU Streaming Multiprocessor, 32-thread 실행 단위, 온칩 공유 메모리
Tensor CoreNVIDIA GPU의 mixed-precision matmul 전용 하드웨어

태그

paper #2023 attention gpu-optimization flashattention transformer