시간: 14:50 - 15:30
KeyWord: JAX, TPU, XLA(compiler), MegatronLM, Maxtext

연구 소개
연사자: 류인호 | AI GDE

Jax기반 LLM 학습 프레임워크(Maxtext)
Megatron LM

TPU 특징들

  • 시스톨릭 어레이 기반 아키텍처
  • 유의한 확장성.
    • GPU끼리 하는 NVLINK와는 다르게 훨씬 더 많은 데이터 통신 가능
      • ex) 8 vs 9000 장의 GPU 카드끼리 고배속 데이터 전송 가능

JAX

  • numpy-like 하되 torch마냥 autograd를 지원하고, 대신 numpy와는 다르게 gpu지원을 잘한다는 거지.
  • parallelism을 구현할 떄 신경써야될께 줄지.
    • All-gather
    • All-reduce
  • batch처리 시 torch 보다 데이터 dim 관린 용이.

MaxText:
@deepspeed
FSDP(Fully sharded Data Parallell)