Mesh Transformer JAX is an implementation of model & data-parallel autoregressive language models, utilizing Haiku and the
pjit operators in JAX to distribute computation on TPUs. It is the designated successor to GPT-Neo.
Mesh Transformer JAX was used to train a six billion parameter language model throughout the
For more information on Mesh Transformer JAX and GPT-J-6B, see the GPT-J-6B announcement blog post by Aran Komatsuzaki.