Mesh Transformer JAX is an implementation of model & data-parallel autoregressive language models, utilizing Haiku and the xmap/pjit operators in JAX to distribute computation on TPUs. It is the designated successor to GPT-Neo.

Mesh Transformer JAX was used to train a six billion parameter language model throughout the month of May and first week of June 2021. Upon release on June 8, 2021, GPT-J-6B became the highest-performing autoregressive language model freely available to the public.

For more information on Mesh Transformer JAX and GPT-J-6B, see the GPT-J-6B announcement blog post by Aran Komatsuzaki.