시간: 14:50 - 15:30
KeyWord: JAX, TPU, XLA(compiler), MegatronLM, Maxtext
연구 소개
연사자: 류인호 | AI GDE
Jax기반 LLM 학습 프레임워크(Maxtext)
Megatron LM
TPU 특징들
- 시스톨릭 어레이 기반 아키텍처
- 유의한 확장성.
- GPU끼리 하는 NVLINK와는 다르게 훨씬 더 많은 데이터 통신 가능
- ex) 8 vs 9000 장의 GPU 카드끼리 고배속 데이터 전송 가능
- GPU끼리 하는 NVLINK와는 다르게 훨씬 더 많은 데이터 통신 가능
JAX
- numpy-like 하되 torch마냥 autograd를 지원하고, 대신 numpy와는 다르게 gpu지원을 잘한다는 거지.
- parallelism을 구현할 떄 신경써야될께 줄지.
- All-gather
- All-reduce
- batch처리 시 torch 보다 데이터 dim 관린 용이.
MaxText:
@deepspeed
FSDP(Fully sharded Data Parallell)