1952 words 10 minutes
LLM with the JAX ecosystem from scratch - Part 1

This is the first one of a series of posts documenting my journey of building LLMs using the JAX ecosystem from scratch.

Why would I want to do such a thing? You might ask. After all, the vast majority of open source model releases and libraries are based on PyTorch. Even the transformer library from Hugging Face dropped support for the JAX ecosystem last year (see this and this). Well, here are the main reasons:

  • The ease of setting up parallel training with JAX is amazing. Instead of writing explicit collective communications like this, I only need to specify the sharding of the arrays/tensors, and the compilers will figure out how to do the communication. Of course, JAX/XLA also provides the flexibility of manually coding those communications.
  • The performance increase and cost reduction as reported on various sources, because of the jit compilation, though there is also torch.compile in PyTorch. If I had an extra hobby budget to spare, it would be interesting to do some benchmarking on this. For now, I’ll take the side that I could squeeze more compute out of my hobby budget with JAX.
  • Just because I can out of curiosity.

This is also partly based on the Stanford CS 336 class I took last year (unofficially ofc). What I’m basically trying to do is to reproduce most of the interesting things there, plus a few other things such as muP and scaling the model to 8 GPUs with some experiments such as FSDP + TP training.

My code is in this repo:

djwenren
/
llm-with-jax-practice
Waiting for api.github.com...
00K
0K
0K
Waiting...

Most of the implementation is done as of March 2026. I’m yet to do the proper tuning (such as architecture and training hyperparameters) and implement the RLVR training pipelines similar to the one explained in my other post.

The structure of this series is roughly as follows.

  • The basic components, including layers, data loader, checkpoint management, optimizers, and the basic pre-training pipeline.
  • Sharding and parallel training.
  • Scaling up. This part includes memory and flops estimation and maximal update parameterization (muP).
  • RLVR training. This part is yet to be done as of mid March 2026. I might have to use some open source models and potentially port them into JAX since the low-budget LLM I pretrained from scratch might not have the enough intelligence to demonstrate any interesting post-training behavior.

Layers#

For the layers and the framework in general, I used Flax NNX. I very briefly tried an earlier version of Flax a few years back. At the time, it felt very non-Pythonic and from that perspective, PyTorch is clearly a better framework. But now with the latest Flax NXX, the framework has matured significantly, and it feels pretty natural to implement neural net layers.

The states (parameters, generic variables, Python native numbers etc.) are now contained in the modules, instead of users having to explicitly manage the states. From the eyes of JAX, the modules are basically computations with the class members (states) captured. The implementation is pretty similar to PyTorch and descriptive. See here for more details.

One interesting thing to note about JAX is the use of vmap and jax.lax.scan for defining and calling the layers of transformer blocks. One naive implementation of it is to wrap the blocks in an nnx.list, and call them in a for loop, similar to how one would do in PyTorch:

class TransformerLm(nnx.Module):
"""Transformer language model."""
def __init__(
self,
config: TransformerConfig,
rngs: nnx.Rngs,
dtype: jnp.dtype = jnap.float32,
):
self.transformer_blocks = nnx.List(
[
L.TransformerBlock(
d_model=config.d_model,
num_heads=config.num_heads,
d_ff=self._get_d_ff(
d_model=config.d_model,
d_ff_to_d_model=config.d_ff_to_d_model,
d_ff=config.d_ff,
),
rngs=rngs,
rope=self.rope,
dtype=dtype,
)
for _ in range(config.num_layers)
]
)
def __call__(
self, input_tokens: Int[jnp.ndarray, "... seq_len"]
) -> Float[jnp.ndarray, "... seq_len vocab_size"]:
for transformer_block in self.transformer_blocks:
activation = transformer_block(
in_features=activation,
token_positions=token_positions,
)

The sharp bit (“sharp bit” as in the same spirit as it’s used in this guide) about this is that the JAX/XLA compiler might unroll the for loop as an effort to optimize the code. This would often lead to very large jit compilation output, with very little actual efficiency improvement. One way to fix this is to define the series of blocks using nnx.vmap and call them using jax.lax.scan:

class TransformerLm(nnx.Module):
"""Transformer language model."""
def __init__(
self,
config: TransformerConfig,
rngs: nnx.Rngs,
dtype: jnp.dtype = jnp.float32,
):
@nnx.vmap(transform_metadata={nnx.PARTITION_NAME: None}, in_axes=(0,))
def _create_transformer_block(rngs: nnx.Rngs) -> L.TransformerBlock:
return L.TransformerBlock(
d_model=config.d_model,
num_heads=config.num_heads,
d_ff=_get_d_ff(
d_model=config.d_model,
d_ff_to_d_model=config.d_ff_to_d_model,
d_ff=config.d_ff,
),
rngs=rngs,
dtype=dtype,
sharding=sharding.transformer_blocks,
use_mu_p=False,
attn_std=None,
ffn_std=None,
)
self.transformer_blocks = _create_transformer_block(
rngs.fork(split=config.num_layers)
def __call__(
self, input_tokens: Int[jnp.ndarray, "... seq_len"]
) -> Float[jnp.ndarray, "... seq_len vocab_size"]:
def scan_over_transformer_blocks(activation, transformer_block):
return (
transformer_block(
in_features=activation,
token_positions=token_positions,
rope=self.rope,
),
None,
)
activation, _ = jax.lax.scan(
scan_over_transformer_blocks, activation, self.transformer_blocks
)

This can control the unrolling of the for loops. If one wishes to control the runtime efficiency, they can even use the unroll parameter of jax.lax.scan.

Data loader#

For loading the training and validation datasets, I used the Grain library. The nice things I like about it are

  • Natural integration with the rest of the JAX ecosystem.
  • Support for checkpointing, so the training can restart at the same point of the training data flow.
  • Support for multithread prefetching, making it more efficient.

The data sets are still the TinyStory and OpenWebText datasets used in CS 336. I didn’t re-implement the BPE tokenizer and instead used the tokens directly from there in the format of Numpy array dumps. I just need to write a custom Grain data source and configure a dataset batch iterator.

One difference between this and the data loading when working with the PyTorch ecosystem is that there doesn’t seem to be a need for calling .pin_memory().to(device, non_blocking=True), at least that’s according to my conversation with Gemini. This is because the JAX/XLA compiler already handles memory management. When a NumPy array is passed to JAX (either via jax.device_input or into a jit-compiled function, the runtime handles the transfer efficiently, and may have automatically pinned memory internally.

Checkpoint management#

For checkpoint management, I used the Orbax library for similar reasons as above. For my use case, I need to save the following in the checkpoints:

  • Model state.
  • Optimizer state.
  • Metadata, such as train config and model config, so that the same configs can be loaded from checkpoint too.

Later when using muP, there are two optimizers used (one for embedding, one for the other model parts), but the overall structure of the checkpoint manager remains the same.

The critical requirement when training large models is that when restoring from checkpoints, we shouldn’t first materialize a placeholder model and optimizer, and then immediately update their parameters from the checkpoint. Instead, we can pass in abstract models returned from nnx.eval_shape, such as the following

abstract_model = nnx.eval_shape(
lambda: transformer.TransformerLm(
config=model_config, rngs=nnx.Rngs(jax.random.key(42)), sharding=sharding
)
)

In this way, the parameters in the model will be replaced by jax.ShapeDtypeStruct and since I used explicit sharding, the sharding information is also retained.

With the abstract model, we can create an abstract optimizer, and then load real parameters from the checkpoint to replace their abstract parameters (of type jax.ShapeDtypeStruct). So throughout the process, we only materialize the parameters once. See the following snippet

class CheckpointManager(BaseCheckpointManager):
def restore(
self,
step: int,
abstract_model: nnx.Module,
**kwargs,
) -> tuple[nnx.Module, PyTree[Any], ...]:
"""Restores the checkpoint."""
assert "tx" in kwargs, "tx must be provided"
tx = kwargs["tx"]
assert isinstance(
tx, optax.GradientTransformation
), "tx must be an instance of optax.GradientTransformation"
# 1. Create abstract optimizer on top of abstract model.
# Since abstract_model contains ShapeDtypeStructs, no real arrays are allocated here.
abstract_optimizer = nnx.Optimizer(abstract_model, tx, wrt=nnx.Param)
# 2. Split both together to get a unified GraphDef and combined abstract state.
# This allows us to restore both in one merge call, ensuring correct linking.
# Path 0: optimizer state, Path 1: model state.
opt_model_graph_def, abstract_combined_state = nnx.split(
(abstract_optimizer, abstract_model)
)
abstract_combined_state = _canonicalize_sharding(abstract_combined_state)
# 3. Restore using fixed shardings from their respective checkpoint slots.
restored_args = self._ocp_checkpoint_manager.restore(
step=step,
args=ocp.args.Composite(
model_state=ocp.args.StandardRestore(abstract_combined_state[1]),
optimizer_state=ocp.args.StandardRestore(abstract_combined_state[0]),
metadata=ocp.args.JsonRestore(),
),
)
# 4. Merge everything back into real objects in one go.
# This bypasses optax.init() and prevents materializing zero-filled states.
full_restored_state = nnx.State(
{0: restored_args.optimizer_state, 1: restored_args.model_state}
)
restored_optimizer, restored_model = nnx.merge(
opt_model_graph_def, full_restored_state
)
return restored_model, restored_args.metadata, restored_optimizer
Side note about sharding

It may be possible that nnx.eval_shape will also abstract the mesh specified in the shardings of the parameters. In other words, the shardings of the parameters in the input abstract model may be of type jax.sharding.AbstractMesh. Orbax’s checkpoint restoration can only take shardings with physical meshes (as of mid-March 2026). So I needed to canonicalize (replace abstract mesh with physical mesh) their shardings.

Optimizers#

For optimizers, I use the Optax library for the same reasons explained above. The unique thing about this library is that it treats every optimizer as a gradient transformation, or a chain of gradient transformations. That’s what a user will need to implement, and the library will handle the application of the final output of that gradient transformation to model parameters.

I re-implemented Adam, weight decay, and the learning weight cosine schedule similar to CS 336. One sharp bit is to be careful with what we specify as nnx.data state parameter and nnx.static state parameter. The difference is basically that nnx.static state parameters are treated as static parameters when the update method is jitted. In other words, in later calls to the jitted update method, static state parameters always take the value of the very first call. This could be a trap for the step state parameter. I naively thought in the beginning that it could simply be a Python native int. But that would make it stay at 0 throughout the training. A solution is to wrap it in a jax.array and mark it with nnx.data.

Pre-training pipeline#

The pre-training pipeline is fairly similar to that in PyTorch, such as this one from CS 336. Here are two interesting unique things to JAX.

JIT fresh model and optimizer creation#

Jitting the creation of a fresh model and optimizer can help with doing it more efficiently. For example, when we specify shardings of the model across multiple devices, without sharding the model parameters will be first created on CPU/RAM and then sent to their corresponding device. But with sharding, the compiler will generate code that directly creates those parameters on their right devices. One way of doing this is as simple as the following

def _get_sp_model_and_optimizer(
train_config: _train_config.TrainConfig,
model_config: transformer.TransformerConfig,
sharding: _sharding.TransformerLmSharding,
ckpt_manager: checkpoint.CheckpointManager,
) -> tuple[nnx.Module, nnx.Optimizer]:
"""Gets the model and optimizers."""
@nnx.jit
def _get_fresh_model_and_optimizer():
model = transformer.TransformerLm(
config=model_config, rngs=nnx.Rngs(jax.random.key(42)), sharding=sharding
)
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
return model, optimizer
if latest_step is None:
return _get_fresh_model_and_optimizer()

This falls under the broader topic of buffer donation. From the guide:

When JAX executes a computation it uses buffers on the device for all inputs and outputs. If you know that one of the inputs is not needed after the computation, and if it matches the shape and element type of one of the outputs, you can specify that you want the corresponding input buffer to be donated to hold an output.

In the case of train step, such as the following

@nnx.jit(donate_argnames=("local_model", "local_optimizer"))
def _train_step(
local_model: nnx.Module,
local_optimizer: nnx.Optimizer,
input_seq: Int[jnp.ndarray, "batch_size context_length"],
target_seq: Int[jnp.ndarray, "batch_size context_length"],
) -> tuple[Float[jnp.ndarray, ""], Float[jnp.ndarray, ""]]:
"""Trains the model for one step."""
loss, grads = nnx.value_and_grad(loss_fn)(local_model, input_seq, target_seq)
local_optimizer.update(local_model, grads)
return (
loss,
# Compute the total L2 norm of the gradients.
jnp.sqrt(
jax.tree.reduce(
lambda acc, x: acc + jnp.sum(jnp.square(x)),
grads,
0,
)
),
)

local_model and local_optimizer are the input and the output. Therefore, by donating these two arguments, we effectively save half of the HBM cost by having them updated in place. This issue and optimization is studied and discussed in detail in this GitHub issue.

Closing#

With the above, I was able to start running pre-training on one device. Here is one example training curve

Example pre-training learning curve

In part 2, I’ll share how I implemented sharding for various parallelisms, such as Fully Sharded Data Parallel (FSDP) and/or Tensor Parallelism (TP) ✌️

LLM with the JAX ecosystem from scratch - Part 1
https://www.djwenren.com/posts/llm-with-the-jax-ecosystem-from-scratch-part-1/
Author
Danjie Wenren
Published at
2026-03-18
License
CC BY-NC-SA 4.0