This is a zero-to-one guide on scaling modern transformers with n-dimensional parallelism in JAX. Our blog for JAXformer covers a from-scratch guide to distributed data processing, FSDP, pipeline parallelism, tensor parallelism, weight-sharding, activation-sharding, MoE scaling and much more. Our guide aims to bridge the gap between theory and end-to-end implementation by demonstrating how to scale a modern language model.
The model built throughout the blog is defined in model.py. The main training script is in main.py. utils.py and dataset.py contain the dataclasses and dataset processing implementations. debug_tpu.sh launches a TMUX with 8 panes to SSH into 8 nodes at once running the command in the command variable. launcher.sh ssh's headlessly into each node and executves run.sh creating TMUX terminals inside the ssh to allow for runs to continue even if the ssh connection is broken. setup_tpu.sh setups all the dependencies on the TPU. The data directory contains all the relevant code for tokenization.
Results for a 1B model (300M active) trained to 3.28 val loss using 3-D sharding on a cluster of 32 TPU-v4(8 FSDP, 2 Pipeline, 2 Tensor).
If you see any issues or have questions, open up an issue or send in a PR. You can also leave a comment on the website itself (powered by Giscus) or in the GitHub discussion.
This guide was written by Aditya Makkar, Divya Makkar, and Chinmay Jindal. We are all undergraduate students studying Computer Science at the University of Waterloo.
The website uses a Distill-style Jekyll theme called Al-Folio. The idea of the blog and front-end structure is inspired by Google DeepMind's How to Scale Your Model guide. Google's TRC was used to provide the compute needed. Thanks!


