FlashAttention — IO-aware 정확 어텐션 가속 알고리즘

Digest: Transformer의 self-attention은 시퀀스 길이 에 대해 시간·메모리 모두 으로 커지지만, 기존 “효율적 attention” 연구들은 FLOPs 감소에만 집중해 실제 wall-clock 속도를 크게 못 올렸다(sparse/low-rank 근사 대부분이 실측에서 표준 attention보다 느린 역설). 저자들의 핵심 통찰은 GPU에서 병목은 연산량이 아니라 HBM(High-Bandwidth Memory, 느리고 큰 글로벌 메모리)과 SRAM(작지만 빠른 on-chip 캐시) 사이의 IO라는 점이며, 따라서 알고리즘을 IO-aware하게 재설계해야 한다는 것이다. 해결책은 (1) Q,K,V를 블록 단위로 SRAM에 올려 타일링(tiling) 하고 softmax를 블록별로 누적하는 online softmax(블록마다 running max 과 running sum 을 유지하여 결과를 rescale), (2) 역전파 시 attention matrix를 저장하지 않고 만 저장해 재계산(recomputation) 으로 복원하는 두 트릭이다. 결과적으로 HBM 접근이 표준 attention의 에서 (은 SRAM 크기)로 줄어들고, BERT-large에서 MLPerf 1.1 대비 15% 가속, GPT-2에서 3× 가속, long-range arena에서 2.4× 가속, 그리고 Path-X(61.4%)/Path-256(63.1%) 에서 Transformer 최초로 chance 이상을 달성했다. 한계: 알고리즘이 CUDA로 수작업 커널 구현되어 있어 다른 가속기/컴파일러로 이식하려면 상당한 재작성이 필요하며, head별 병렬화 효율은 head 수가 적을수록 떨어진다. 열린 질문: IO-aware 원칙을 attention 외 다른 primitive(MLP, layernorm chain)까지 확장하고, 컴파일러가 자동으로 타일링을 합성할 수 있는가.


섹션별 요약

Introduction

  • Transformer의 확장 병목은 self-attention의 시간·메모리. 기존 efficient attention(Linformer, Performer, Reformer, Longformer 등)은 FLOPs는 줄였지만 실측 속도·메모리에서 표준 attention을 꾸준히 이기지 못함.
  • 문제 재정의: FLOPs가 아니라 메모리 계층 간 IO가 실제 병목이다. GPU의 HBM(수십 GB, 대역폭 ~1.5–2 TB/s)과 SRAM(수백 KB/SM, 수십 TB/s)의 속도 차가 거대하여, attention은 memory-bound.
  • 기여: (i) GPU IO를 명시적으로 고려한 exact attention 알고리즘 FlashAttention, (ii) HBM 접근에 대한 하한 증명 및 알고리즘의 근사 최적성, (iii) 블록 희소 확장(block-sparse FlashAttention)으로 근사 방법과 결합.

Background

  • GPU 메모리 계층: Register → SRAM(shared memory, on-chip) → HBM(off-chip DRAM). Attention에서 , softmax, dropout, 같은 연산 사이마다 중간 행렬을 HBM에 쓰고 읽는 다중 왕복이 발생.

📚 잠깐: SRAM vs HBM — 무엇이 다른가?

FlashAttention의 모든 설계는 이 두 메모리의 비대칭에서 출발한다. 핵심만 짚고 넘어가자.

SRAM (Static RAM, on-chip)

  • GPU die 위에 직접 박혀 있는 shared memory / L1 cache. SM(Streaming Multiprocessor)마다 별도 보유.
  • 비트당 트랜지스터 6개(6T cell) — 면적·전력 효율이 나빠 용량은 작지만, 캐패시터 refresh가 필요 없어 접근 latency가 1–2 cycle로 극히 짧다.
  • 연산 유닛 바로 옆이라 데이터 전송 거리가 짧고 대역폭이 압도적.

HBM (High-Bandwidth Memory, off-chip)

  • GPU die 옆에 3D-stacked DRAM을 놓고 silicon interposer로 광폭 버스(1024-bit+) 연결.
  • DRAM 셀(1T1C)이라 면적당 밀도가 높아 수십 GB 단위 용량이 가능. 대신 refresh·row activation 비용 때문에 latency가 수백 cycle.
  • off-chip이므로 데이터를 읽으려면 칩 경계를 넘어야 함 → 에너지 비용도 크다(대략 SRAM 접근의 수십~100배).
항목SRAM (on-chip)HBM (off-chip)비율 (HBM ÷ SRAM)
위치GPU die 내부 (per-SM)die 옆 stacked DRAM
용량 (A100 기준)약 192 KB / SM, 총 약 20 MB40–80 GB약 4,000× 큼
대역폭약 19 TB/s (aggregate on-chip)약 1.5–2.0 TB/s약 10× 느림
Latency약 1–2 cycle약 200–400 cycle약 200× 느림
에너지/byte약 수 pJ약 수백 pJ (약 50–100×)약 50–100× 비쌈
비트당 트랜지스터6T (SRAM cell)1T1C (DRAM cell)
단위 용량당 die 면적 비용매우 높음낮음
휘발성 / refresh휘발성, refresh 불필요휘발성, 주기적 refresh 필요

읽는 법: HBM은 SRAM보다 4,000× 크지만 10× 느리고 100× 비싸다. 그래서 큰 행렬은 HBM에 둘 수밖에 없지만, 연산은 SRAM에 올린 블록 안에서 끝내야 한다. FlashAttention의 tiling/recomputation은 정확히 이 제약을 푸는 알고리즘이다.

※ 수치는 NVIDIA A100 기준 대략값. H100·B100은 SRAM·HBM 모두 더 큼·빠름이지만 상대적 비대칭은 동일.

  • 표준 attention: , , . 중간 가 HBM에 머무르므로 HBM IO는 . 이 커질수록 항이 지배.
  • Kernel fusion의 한계: 프레임워크 레벨 fusion은 softmax의 전역적 normalization 때문에 불가능했음. 핵심 도전은 “softmax를 블록 단위로 쪼개면서도 수학적으로 정확한 결과를 얻는 것”.

Method

(1) Tiling: 개의 블록으로, 개의 블록으로 분할. 블록 크기는 SRAM 크기 에 맞춰 , .

(2) Online softmax (running max/sum): 전역 softmax를 계산하려면 전체 행의 max와 sum이 필요한데, 이를 블록을 누적 갱신으로 대체. 블록 의 부분 점수 에 대해 블록별 를 구하고, 직전까지의 를 다음 식으로 갱신:

이 갱신은 수학적으로 표준 softmax와 동일(exact)하며, 중간 행렬을 HBM에 쓰지 않는다.

(3) Recomputation (backward): forward에서 를 저장하는 대신 만 HBM에 저장( 메모리). backward에서 을 다시 SRAM에 불러 를 블록 단위로 재계산. FLOPs는 늘지만 HBM IO가 줄어 wall-clock 시간은 오히려 감소.

(4) IO 분석: 블록마다 블록을 SRAM에 적재하면서 바깥 루프, 안쪽 루프(또는 그 반대)로 순회. 전체 HBM 접근은

표준 attention의 보다 대개 배 작다 (예: , 100 KB이면 크게 유리). 저자들은 어떤 알고리즘도 을 깰 수 없음을 SRAM 모델에서 증명.

(5) Block-sparse FlashAttention: 미리 정한 block mask 에 대해 0인 블록을 skip. IO가 sparsity 에 비례해 줄어 . 기존 sparse attention(Longformer, BigBird 등)과 결합 가능.

Experiments

  • BERT-large (seq 512): end-to-end training에서 MLPerf 1.1 record 대비 15% 가속.
  • GPT-2 (seq 1K, small/medium): HuggingFace 대비 3×, Megatron 대비 1.7× 가속. perplexity 0.7 향상(더 긴 컨텍스트 훈련 가능해서).
  • Long-Range Arena (seq 1K–4K): 2.4× 가속, 모든 baseline 효율적 attention보다 빠름.
  • Path-X(16K)/Path-256(64K): Transformer 최초로 chance 이상(각각 61.4%, 63.1%)을 달성 — 지금까지 메모리 문제로 학습 자체가 불가능했던 길이.
  • 메모리: 시퀀스 길이에 대해 선형(그림 기준 표준 attention 대비 최대 20×).
  • 정확성 검증: exact algorithm이므로 numerical 오차 제외 표준 attention과 동일 출력.

Limitations

  • 수작업 CUDA 커널: 새로운 하드웨어(TPU, 다른 GPU 세대), 새로운 variant(예: ALiBi, relative bias) 마다 재작성 필요. 컴파일러 자동화가 미해결.
  • 병렬성: 현재 구현은 head/batch 단위 병렬에 주로 의존 → head 수가 적고 batch가 작은 inference 시 SM 활용률 저하(후속작 FlashAttention-2에서 개선).
  • 근사 비교 불가 영역: kernel-based linear attention처럼 FLOPs를 가진 방법 대비, 아주 긴 시퀀스()에서는 점근적으로 불리할 수 있음.
  • IO 하한 증명은 SRAM 모델 가정: multi-level 계층(L2 cache 등)을 완전히 반영하지 않음.

메타데이터

항목내용
제목FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
저자Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
소속Stanford (Hazy Research), University at Buffalo
연도2022 (NeurIPS)
발표arXiv:2205.14135
링크arXiv, GitHub
키워드attention, IO-aware, tiling, online softmax, recomputation, GPU kernel

왜 이 연구를 하는가?

핵심 질문

“효율적 attention들이 FLOPs를 줄였는데도 왜 실제 GPU에서 표준 attention보다 느린가 — 그리고 exact attention을 근본부터 빠르게 만들 수 있는가?”

기존 접근법의 한계

한계설명
FLOPs 중심 최적화Linformer/Performer/Reformer 등은 이론 연산량을 또는 으로 낮췄지만 memory-bound GPU에서 실측 속도는 표준보다 느린 경우가 대부분.
중간 행렬 HBM 왕복표준 구현은 , , 를 단계별로 HBM에 썼다 읽으며 메모리 트래픽 발생.
Softmax가 fusion을 막음softmax의 행 단위 전역 정규화 때문에 단순 kernel fusion으로는 를 SRAM에 가둘 수 없다고 여겨졌음.
근사 품질 저하sparse/low-rank 근사는 정확도 손실을 동반하여 long-range 태스크(Path-X 등)에서 수렴 자체가 실패.

핵심 통찰

  • Attention은 compute-bound가 아니라 memory-bound → 최적화 목표는 FLOPs가 아니라 HBM 접근 횟수.
  • Softmax의 normalization은 running max/sum만 추적하면 블록 단위로 쪼갤 수 있다(online/streaming softmax). 이는 수학적으로 exact.
  • Backward의 활성화 저장 메모리는 재계산으로 대체할 수 있고, HBM IO가 줄어드는 덕분에 재계산의 추가 FLOPs가 실측 속도에 오히려 유리.

방법 (Method)

프레임워크 개요

flowchart TB
    subgraph HBM["HBM (slow, large)"]
        Q[("Q (N×d)")]
        K[("K (N×d)")]
        V[("V (N×d)")]
        O[("O (N×d)")]
        ML[("m, ℓ (N)")]
    end
    subgraph SRAM["SRAM (fast, small M)"]
        Qi["Q_i block (Br×d)"]
        Kj["K_j block (Bc×d)"]
        Vj["V_j block (Bc×d)"]
        Sij["S_ij = Q_i K_jᵀ"]
        Pij["P̃_ij = exp(S_ij - m̃)"]
        Oi["O_i (Br×d) + running m_i, ℓ_i"]
    end

    Q -- load block --> Qi
    K -- load block --> Kj
    V -- load block --> Vj
    Qi --> Sij --> Pij
    Pij --> Oi
    Vj --> Oi
    Oi -- write once --> O
    Oi -. running stats .-> ML

    subgraph BW["Backward (recomputation)"]
        direction LR
        RQ["reload Q,K,V"] --> RS["recompute S,P in SRAM<br/>using stored m,ℓ"]
        RS --> dQKV["grad Q, K, V"]
    end

    O --> BW
    ML --> BW

핵심 구성요소

  • Outer loop over K,V blocks, inner loop over Q blocks: 각 에서 를 SRAM에 적재 → 계산 → block-local 계산 → 누적 update.
  • Running statistics: 출력 는 새 블록이 올 때마다 이전 누적분을 로 rescale하여 수치적으로 안정.
  • Causal masking: 대각 위쪽 블록을 skip하여 학습 시 약 2× 추가 속도.
  • Backward kernel: 로부터 를 블록 단위 재계산, 는 atomic add(혹은 행 순회 재구성)로 누적.

발견 (Findings)

주요 결과

Model/MethodDatasetMetricScorevs. Baseline
FlashAttention (BERT-large)WikipediaTraining time15% fasterMLPerf 1.1 record
FlashAttention (GPT-2 small/medium)OpenWebTextTraining time3× fasterHuggingFace
FlashAttention (GPT-2)OpenWebTextPerplexity−0.7 PPL동일 compute 대비
FlashAttention (Transformer)Long-Range ArenaWall-clock2.4× fasterStandard attention
FlashAttention (Transformer)Path-X (N=16K)Accuracy61.4%< chance for prior Transformers
Block-sparse FlashAttentionPath-256 (N=64K)Accuracy63.1%First Transformer > chance
Memory usageseq 1K–16KPeak memlinear in NStandard = quadratic (최대 20× 감소)

핵심 발견

  • HBM IO 가 표준의 보다 작아지는 임계에서 wall-clock 가속이 FLOPs 비율을 초과(재계산으로 FLOPs는 늘지만 시간은 준다).
  • 선형 메모리 덕에 학습 가능한 시퀀스 길이 자체가 확장 → Path-X/256 같은 초장 시퀀스에서 chance 이상이 처음 가능.
  • 희소성과 결합 시 효과가 곱연산적으로 누적 (block-sparse + tiling).

이론적 의의

IO 복잡도를 1차 시민으로

딥러닝 알고리즘 분석의 기본 단위를 FLOPs에서 memory transfers로 옮긴 선례. Tiling이라는 HPC의 오래된 기법이 Transformer 현대 워크로드에 정확히 들어맞음을 보여, 이후 FlashAttention-2/3, PagedAttention(vLLM), RingAttention 등 IO-aware kernel 연구 물결의 출발점이 됨.

Exact ≠ slow

“근사해야 빠르다”는 암묵적 가정을 반박. Exact한 결과를 유지하면서도 메모리 계층 재설계만으로 2–3× 속도와 10–20× 메모리 감소가 가능. 근사 방법의 정확도 페널티 없이 장점을 흡수.

Recomputation을 공격적으로 쓰는 설계 원칙

활성화 체크포인트가 단순히 메모리 절약 수단이 아니라 속도 최적화의 일부가 될 수 있음을 입증. 메모리-compute tradeoff가 단조롭지 않다는 관찰.


재현성 및 신뢰도 평가

항목등급비고
코드 공개flash-attention GitHub, CUDA/Triton 구현
데이터 공개OpenWebText, LRA, Path-X 등 공개 벤치마크
하이퍼파라미터블록 크기 , SRAM 가정 , 학습 세팅 전부 보고
실험 환경A100 40GB, CUDA version, mixed precision 기록
통계적 신뢰도⚠️multi-seed stddev는 일부 태스크에만 보고
종합 등급A업계 표준으로 자리잡을 만큼 재현이 충분히 이루어짐

주장별 신뢰도

#주장근거신뢰도
1HBM 접근이 로 감소알고리즘 분석 + 벤치 측정 HBM traffic 일치🟢
2BERT/GPT-2 end-to-end 2–3× 가속실측 table + 공식 MLPerf 기록 대비🟢
3Path-X/256 chance 이상 최초 달성공개된 LRA 리더보드 추종🟢
4IO 하한 최적성SRAM 모델 가정 하 증명 (모델 한정적)🟡
5block-sparse로 sparsity 비례 가속실험 데이터 기반, 일부 패턴만 검증🟡

읽기 난이도: ⭐⭐⭐

필요 배경지식: GPU 아키텍처(HBM/SRAM, warp, shared memory), softmax 수치 안정성(log-sum-exp), backprop 체크포인팅, 기본 attention. 알고리즘 자체는 단순하지만 IO 분석과 CUDA-level 관점 때문에 난이도가 높음.


관련 연구 비교 매트릭스

FlashAttentionStandard AttentionReformer (LSH)Performer (FAVOR+)Longformer (sliding+global)
핵심 접근IO-aware tiling + online softmax + recompute3단계 naive kernelLSH 버킷 내 sparserandom feature kernel 근사window + global token sparse
문제 정의exact, memory-bound 해소baselineapproximate, FLOPs 감소approximate, linear FLOPsapproximate, linear FLOPs
정확성exactexact근사 (버킷 에러)근사 (random feature 분산)근사 (mask 제약)
메모리
HBM IO
실측 속도(1K–4K)2.4× faster종종 더 느림종종 더 느림비슷하거나 느림
품질표준과 동일 (exact)baseline하락 가능하락 가능하락 가능
확장성16K–64K까지 검증2K 근방 한계이론상 길지만 실측 한계긴 시퀀스 가능4K–16K
한계CUDA 종속, 커널 재작성 비용메모리/시간 LSH 해시 품질 의존random feature 분산마스크 편향
코드 공개

관련 연구


원자적 인사이트 (Zettelkasten)

💡 딥러닝 최적화의 1차 지표는 FLOPs가 아니라 memory traffic이다

출처: FlashAttention - Fast and Memory-Efficient Exact Attention with IO-Awareness (Dao et al., 2022)
유형: 이론적

현대 가속기(A100, H100)는 compute throughput 대비 메모리 대역폭이 훨씬 느리게 성장해 왔다. attention처럼 low-arithmetic-intensity 연산은 memory-bound이고, 따라서 FLOPs를 줄여도 HBM 왕복이 남아 있으면 실측 속도가 개선되지 않는다. 최적화 지표를 FLOPs → HBM bytes로 바꾸면 전혀 다른 알고리즘 설계가 나온다.

핵심 조건/맥락: arithmetic intensity가 peak FLOP/BW 비율보다 낮은 연산에 성립. GEMM 같은 compute-bound 연산에는 덜 적용.
연결: Mamba - Linear-Time Sequence Modeling with Selective State Spaces, roofline model
활용 가능성: layernorm chain, MoE routing, KV cache paging 등 새로운 primitive에 IO-aware 재설계 적용.

💡 Softmax의 전역 normalization은 running max/sum으로 스트리밍 분해된다

출처: FlashAttention - Fast and Memory-Efficient Exact Attention with IO-Awareness (Dao et al., 2022)
유형: 방법론적

는 행 전체를 봐야 정의되지만, rescaling으로 블록 단위 누적이 수학적으로 정확히 전역 softmax와 같아진다. 이는 “global reduction”을 “streaming update”로 바꾸는 일반적 트릭이며, kernel fusion이 불가능했던 reduction 연산에 적용 가능.

핵심 조건/맥락: log-sum-exp 안정성 유지를 위해 max tracking 필수. 수치 정밀도는 FP16/BF16에서도 충분히 안정.
연결: online normalization, parallel prefix sum, block-cyclic reduction
활용 가능성: normalization layer, 대형 vocab에서의 top-k softmax, distributed softmax across GPUs.

💡 Recomputation이 항상 “느려지지만 메모리를 아끼는” 교환인 것은 아니다

출처: FlashAttention - Fast and Memory-Efficient Exact Attention with IO-Awareness (Dao et al., 2022)
유형: 실험적

Gradient checkpointing의 고전적 해석은 “활성화 저장을 줄이는 대신 재계산 FLOPs를 추가”이지만, FlashAttention에서는 중간 를 저장하지 않아서 HBM 쓰기·읽기가 사라지고, 재계산은 SRAM에서 빠르게 끝난다. 결과: 메모리 and wall-clock . 메모리-compute tradeoff 곡선이 파레토적이 아닐 수 있다.

핵심 조건/맥락: 재계산이 on-chip에서 끝나고, 저장을 없앰으로써 제거되는 IO가 재계산 FLOPs의 시간보다 클 때 성립.
연결: activation checkpointing, recompute-vs-bandwidth tradeoff
활용 가능성: 역전파가 병목인 다른 레이어(LayerNorm 통계, rotary embedding cache) 재설계.

💡 “Exact”는 “approximate”에게 내준 자리를 되찾을 수 있다

출처: FlashAttention - Fast and Memory-Efficient Exact Attention with IO-Awareness (Dao et al., 2022)
유형: 연결

효율성 연구는 오래도록 “더 빠르려면 근사해야 한다”는 전제를 따랐지만, 병목이 알고리즘 복잡도가 아니라 시스템 구현에 있다면 정확한 알고리즘의 시스템 재설계만으로 근사를 넘어설 수 있다. FlashAttention은 sparse/low-rank attention 대부분을 exact로 이긴 첫 사례.

핵심 조건/맥락: 병목이 구현 레벨 병목(IO, 커널 호출 오버헤드)일 때. 점근 복잡도 자체가 문제이면 여전히 근사 필요.
연결: “systems research vs algorithmic research” 보완 관계
활용 가능성: 다른 “효율적 X” 문제(효율적 convolution, 효율적 MoE) 재검토.

💡 IO 하한은 SRAM 모델 파라미터 에 명시적으로 의존한다

출처: FlashAttention - Fast and Memory-Efficient Exact Attention with IO-Awareness (Dao et al., 2022)
유형: 이론적

하한은 SRAM 크기 이 커질수록 IO가 선형적으로 줄어든다는 것을 의미. 새 GPU 세대(H100의 더 큰 shared memory)는 자동으로 IO가 줄어든다는 예측과 실측이 일치. 이는 알고리즘-하드웨어 co-design의 정량적 가이드가 된다.

핵심 조건/맥락: 모델은 two-level SRAM↔HBM만 가정. L2 cache 등 중간 계층은 미반영.
연결: red-blue pebble game, external memory model
활용 가능성: 하드웨어 세대별 성능 예측, 타일 크기 자동 튜닝.


핵심 용어 정리

용어정의
HBMHigh-Bandwidth Memory. GPU의 off-chip DRAM(A100: 40–80 GB, ~1.5–2 TB/s). 크지만 상대적으로 느림.
SRAMOn-chip shared memory/L1 cache. SM당 ~100–200 KB, 수십 TB/s 수준. 작지만 매우 빠름.
IO-aware알고리즘 설계 시 FLOPs 대신 메모리 계층 간 전송량을 최소화 목표로 삼는 관점.
Tiling큰 행렬을 SRAM에 맞는 블록으로 쪼개 loop을 재구성하여 재사용성을 극대화하는 HPC 기법.
Online softmax스트리밍으로 들어오는 부분합에 대해 running max 과 running sum 을 유지해 전역 softmax를 수치안정적으로 계산.
Recomputation (checkpointing)순전파에서 중간 활성화를 저장하지 않고 역전파 시 재계산해 메모리를 절약하는 기법.
Arithmetic intensityFLOPs/byte 비율. 낮으면 memory-bound, 높으면 compute-bound.
Memory-bound / Compute-bound연산의 wall-clock이 각각 메모리 대역폭 / FLOP throughput에 의해 제한되는 상태.
Block-sparse attentionattention 행렬을 블록 단위로 마스킹하여 특정 블록만 계산. Longformer/BigBird 등에서 활용.
FlashAttention의 row/column 블록 크기. , .
각 query 행의 running max와 running exp-sum (softmax denominator). 역전파용 저장.

태그

paper #2022 attention flash-attention io-awareness gpu-kernel efficient-transformer