PaLM: Scaling Language Modeling with Pathways

Digest: PaLM(Pathways Language Model)은 Google이 Pathways 시스템을 활용하여 6144 TPU v4 칩에서 학습한 540B 파라미터 Dense Transformer이다. 780B 토큰으로 학습되었으며, 29개 NLP 벤치마크 중 28개에서 SOTA를 달성했다. 핵심 기여는 (1) BIG-Bench에서 인간 평균 성능 초과, (2) 5-shot MMLU 69.3으로 당시 SOTA, (3) 스케일링에 따른 **chain-of-thought 추론 능력의 불연속적 출현(emergent abilities)**을 보인 점이다. 아키텍처적으로 SwiGLU, Multi-Query Attention, RoPE(부분), 병렬 Attention-FFN 구조를 채택했다.


아키텍처 상세

모델 스펙

모델ParamsLayersHeadsKV Headsd_modelFFN DimContext
PaLM-8B8B321616 (MHA)4096163842048
PaLM-62B62B643232 (MHA)8192327682048
PaLM-540B540B1184848 (MHA)18432737282048

아키텍처 핵심 구성요소

구성요소설명
Parallel Attention + FFNAttention과 FFN을 병렬로 실행: → 학습 속도 ~15% 향상
Multi-Query AttentionKV 헤드 1개를 모든 Q 헤드가 공유 (추론 시 메모리 절약) — 8B/62B에서 사용
SwiGLUFFN 활성 함수
RoPE회전 위치 임베딩
No BiasAttention, FFN에 bias 제거
SentencePiece256,000 vocab
NormalizationPre-LayerNorm
graph TD
    A["입력 x"] --> B["LayerNorm"]
    B --> C["Multi-Head Attention<br/>(MQA for 8B/62B)<br/>+ RoPE"]
    B --> D["SwiGLU FFN"]
    C --> E["x + Attn(Norm(x)) + FFN(Norm(x))"]
    D --> E

    style E fill:#e8f5e9

병렬 Attention+FFN: 순차적(Attn→FFN) 대신 병렬 실행으로 TPU 활용률 극대화


사전 학습

학습 데이터

데이터셋비율설명
Social media conversations50%필터링된 대화 데이터
Filtered webpages27%웹 크롤링 + 품질 필터
Books13%영어 도서
Wikipedia4%다국어
Code5%GitHub 코드
News1%뉴스 기사
합계100%780B 토큰 (영어 중심)

학습 하이퍼파라미터

항목PaLM-540B
OptimizerAdafactor (β₂ decay)
Learning RatePeak 1×10⁻² (540B), inverse sqrt schedule
Warmup10,000 steps
Weight Decay— (Adafactor 내장)
Batch Size2048 sequences × 2048 tokens = 4M tokens → 점진적 증가
Dropout0 (없음)
HardwareTPU v4 6144 chips (2 pods)
Pathways다중 TPU pod 간 분산 학습 시스템
학습 시간~1200 TPU v4 core-days

벤치마크 비교

주요 벤치마크

벤치마크PaLM-540BPaLM-62BGPT-3 (175B)Chinchilla (70B)Gopher (280B)
MMLU (5-shot)69.353.743.967.660.0
HellaSwag83.479.778.980.879.2
WinoGrande77.072.470.273.770.1
TriviaQA81.472.472.3
NaturalQuestions29.321.516.6
BIG-Bench Avg대부분 SOTA

BIG-Bench 주요 결과

  • BIG-Bench 전체: 인간 평균 성능 초과 (58개 태스크 중)
  • Emergent abilities: 8B→62B에서는 나타나지 않던 능력이 540B에서 불연속적으로 출현
    • 예: 3-digit addition, logical deduction, word unscrambling

Chain-of-Thought 추론

벤치마크PaLM-540B (standard)PaLM-540B (CoT)GPT-3 (CoT)
GSM8K17.9 (8-shot)56.946.9
MATH8.8
MGSM (다국어 수학)78.5

동시대 비교 매트릭스

특성PaLM (2022.04)Chinchilla (2022.04)GPT-3 (2020.05)Gopher (2021.12)
파라미터540B70B175B280B
학습 토큰780B1.4T300B300B
학습 시스템Pathways (6144 TPU)
MMLU69.367.643.960.0
AttentionMHA/MQAMHAMHA
FFNSwiGLUGELUGELU
병렬 Attn+FFN
핵심 기여스케일+Emergent abilities스케일링 법칙퓨샷안전성/분석

한계

  • Chinchilla-suboptimal: 540B에 780B 토큰 → 스케일링 법칙상 under-trained
  • Context 2048: 짧은 컨텍스트
  • 비공개: 가중치 미공개
  • 탄소 배출: 6144 TPU 학습의 환경 비용