In this last part of the LLM With The JAX Ecosystem From Scratch series, I share my experience of using Maximal Update Parameterization (P) and the scaling laws to train LLMs. I’ll first go over a brief overview about P mostly from a practitioner’s perspective and how I implemented it. Then I’ll explain how I used the scaling laws to figure out the largest model that I can scale up to.
At every step during this process of building the training pipeline and figuring out the model architecture and hyperparameters, there are so many things to explore and optimize. However, the purpose of this series is to get hands-on practice of using the JAX ecosystem to train modern LLMs, and get a good sense of the high level landscape of the academia and industry research and engineering directions. So I wouldn’t be able to go any deeper in any direction in this series than it already is. Instead, this series serves as a good starting point for those explorations.
Maximal update parameterization
One of the most compute-intense problems in training LLMs is finding the optimal learning rate. One typically has to sweep across a large range of values. To make it worse, the optimal learning rate is not transferable when moving to larger models. So as we scale up the model trying to fit a scaling law, we would have to repeatedly do this sweep at each scale.
P solves this problem by finding a way to scale the model size, parameter initialization standard deviations, and the learning rate using a single scale factor , which is defined as the ratio of model width over the base model width:
,
where means the model width, namely the hidden state size of the LLM.
The scaling dependencies on of the various hyperparameters make sure that throughout the training, the norms (more precisely, the spectral norm) of the activations, weights, and gradients are independent of , hence staying at . Since the scaling is only a function of , one no longer needs to tune the hyperparameters at every scale. Instead, the base model hyperparameters can be reused in, or transferred to, larger scale models. The following is a comparison of the training loss vs learning rate plots in standard parameterization and P from the Tensor Programs V paper.

As we can see, with standard parameterization, the optimal learning rate shifts as the model size changes, while it stays stable with P.
The exact derivation of the scaling functions is rather mathy and complicated. It was first derived in the Tensor Program series of papers (starting with Tensor Programs I). A more accessible introduction is here. For obvious reasons, I’ll skip the derivation in this series, and just cite the scaling functions that are relevant for my application. Another helpful source of learning P is this blog post from CerebrasAI and EleutherAI.
| Parameter | Standard Parameterization | P |
|---|---|---|
| Embedding initialization variance | ||
| Embedding learning rate (Adam) | ||
| Embedding activation | ||
| Hidden layer initialization variance | ||
| Hidden layer learning rate (Adam) | ||
| Output/logits activation | ||
| Attention logits |
Attention logits normalization
One interesting thing to note is the difference in the normalization factor in attention logits. In SP, it is , but in P, it becomes . The short answer is: is derived under the assumption that the elements of the query () and key () vectors are independent random variables, which is only true exactly at initialization. However, during training, weight updates cause the vectors to become correlated. When vectors are correlated, their dot product scales linearly with the dimension (), not the square root
In the case of SP, at initialization, the projection matrices and are typically initialized from a Gaussian distribution, meaning the resulting query and key vectors and have entries that are independent, identically distributed (i.i.d.) random variables with a mean of zero.
Let and be the -th coordinates of and , both with variance. The dot product is:
Because and are independent and zero-mean:
According to the Central Limit Theorem (CLT), the sum of independent zero-mean variables acts as a random walk. While the expectation remains 0, the variance grows linearly with :
To keep the attention logits (pre-softmax values) at a stable variance so the softmax doesn’t immediately saturate, Standard Parameterization divides by the standard deviation:
However, as training progresses, and becomes more and more correlated, so are and . They eventually align along meaningful feature directions. Because of this correlation, the expected value of their coordinate-wise product is no longer zero:
Now, we must apply the Law of Large Numbers (LLN) rather than the CLT. The sum of variables with a non-zero mean scales linearly with :
As , the sample mean converges to the expected value:
In P, we still want to keep the logits and hence the softmax well-behaved, so we divide by instead of :
Learning rate for the output/logits layer
You might also wonder why the learning rate for the output/logits layer is still scaled by when the activation is already scaled by . The answer lies in the fact that we use Adam as the optimizer, whose update rule is:
,
where is the learning rate adjusted by optimization steps taken, and are the first and second moment estimates of the gradients, and is a small constant to prevent division by zero. Since scales with as , and scales as , the update step scales as , which has no dependency on . Therefore, we need to an extra factor from the learning rate itself even though the output activation is already scaled by .
Scaling law
There is a lot to talk about scaling laws, such as pretraining scaling, RL scaling, inference time scaling, and agent scaling. But for the purpose of this series, I’ll simply mention one thing that can guide how I pick the largest model size given the compute budget I have. For more details of pretraining scaling laws, check out these resources:
- The original scaling law paper.
- The Chinchilla paper.
- Stanford CS 336 lecture on scaling laws, part one and two.
The part I need for this exercise is this rule of thumb connecting the amount of data needed and the amount of model parameters. Various careful studies of scaling laws of modern LLMs found that in order to train the models to a good quality, one would typically need as many tokens as times of model parameters. was about 20 until recently, where it has gone up significantly, such as 150, 200 etc.
The reason for those higher ratios is that inference is much more expensive than training, so it would be more preferable to train a smaller model much longer (to reach similar performances as larger models trained for less data) so that the inference time cost (which is accrued repeatedly throughout the model’s deployment life span) could be reduced, even if the training budget is not spent optimally. But for my little experiment, 20 is a good rule of thumb.
Putting it together
So now we have two nice guidelines at play to guide us to find the optimal training recipe.
- Scaling law tells us given the compute budget, what’s the size of the model I should train.
- P parameterization tells us for that model size, what the training hyperparameters should be, such as learning rate, parameter initialization standard deviation etc.
For my little exercise, the recipe then could be divided into two stages: finding the optimal hyperparameters for the base model, and scaling it up to the largest model I can train given my compute budget, which is 2 hours on 8x H100 SXM.
Base model architecture
The first step is to find a reasonable base model size so that it can be run cheaply, such as on my laptop with 5080. The reason for this is that we will need to scan the P hyperparameters, which would require tens, if not hundreds, training runs of the base model.
The model size is relatively easy to compute given the model architecture. I tried to build this util to compute the various numbers associated with a model and its inference and training cost, such as number of parameters, memory cost of model state and optimizer state, memory cost of activation, flops cost of one inference or one training pass.
Anything about the model is easy to estimate (number of parameters, model state and optimizer state memory cost etc.), but activation cost and flops are harder because at the runtime, there could be various implementation details and optimizations that could change the cost footprint. I initially tried to do a grid search over the number of layers and batch size to find a good base model set up. If the estimates are accurate, I can easily find an appropriate batch size and layers number combo from a plot like this:

But my util wildly underestimated the memory cost, so I ended up reducing the number of layers slightly and using microbatching to fit the base model training on my laptop 5080. But the grid search at least provides a somewhat good starting point. I eventually settled with the following parameters:
- . This can’t be too small, in order for the effects of the law of large numbers to kick in (for P).
- Number of layers = 36.
- Batch size 16.
- Max context window size 1024.
The other parameters can be found in any of the sweep runs, such as this one.
Hyperparameter sweep
As seen from the P section, the hyperparameters we need to tune are
- Base learning rate .
- Base standard deviation used for model parameter initialization .
- Multiplicative factor for the embedding layer.
- Multiplicative factor for the output/logit layer.
I did three rounds of sweeps (sweep 1, 2, 3) for a base model above, and found the following hyperparameters:
- .
- .
- .
- .
Scaling up
Given all the preparation in this post and the previous ones (in particular, parallel training and all the plumbing work), all that is left is to figure out what should we scale up to. This is determined by the compute budget we have, which is 2 hours on 8x H100 SXM. Based on the specs of H100 SXM, the total FLOPS I have is , where is the Model FLOPS Utilization.
The FLOPS of a training step can be approximated by matrix multiplications. This gives a total training FLOPS estimate of , where is the number of model parameters and is the number of training tokens. So roughly the FLOPS as a function of is . Note how FLOPS grow quadratically with the number of parameters.
Setting the two equal to each other gives a rough estimate of the largest model I can train. can be estimated by actually running the training pipeline. From the theoretical FLOP/s on the spec, we can get a theoretical value of steps/sec. is then basically the actual steps/sec (as measured by test runs) divided by that theoretical value. In my case, it turns out to be about . This puts my max at 5–this just shows how expensive this way of training LLMs is.
The following are the loss curves of my large run with this setup on 8x H100 SXM.


Later I had some further optimizations, such as using gradient checkpointing and moving to a fp8 quantization. I’m sure they can improve the training pipeline efficiency and hence the final model quality, but I’ll leave it to future experiments when I have more time and hobby funds.
Closing
Wheeew, that’s a wrap.

I hope you enjoyed reading this series as much as I enjoyed writing it, and can take something useful away with you.
And please feel free to leave a comment below and let me know if you have any questions or suggestions! If you’re interested, you can find the full code for this series in my repo
Modern AI is fun. Let’s keep exploring!