2158 words 11 minutes
LLM with the JAX ecosystem from scratch - Part 3

In this last part of the LLM With The JAX Ecosystem From Scratch series, I share my experience of using Maximal Update Parameterization (μ\muP) and the scaling laws to train LLMs. I’ll first go over a brief overview about μ\muP mostly from a practitioner’s perspective and how I implemented it. Then I’ll explain how I used the scaling laws to figure out the largest model that I can scale up to.

At every step during this process of building the training pipeline and figuring out the model architecture and hyperparameters, there are so many things to explore and optimize. However, the purpose of this series is to get hands-on practice of using the JAX ecosystem to train modern LLMs, and get a good sense of the high level landscape of the academia and industry research and engineering directions. So I wouldn’t be able to go any deeper in any direction in this series than it already is. Instead, this series serves as a good starting point for those explorations.

Maximal update parameterization#

One of the most compute-intense problems in training LLMs is finding the optimal learning rate. One typically has to sweep across a large range of values. To make it worse, the optimal learning rate is not transferable when moving to larger models. So as we scale up the model trying to fit a scaling law, we would have to repeatedly do this sweep at each scale.

μ\muP solves this problem by finding a way to scale the model size, parameter initialization standard deviations, and the learning rate using a single scale factor mpm_p, which is defined as the ratio of model width over the base model width:

mp=dmodeldbasem_p = \frac{d_{\text model}}{d_{\text base}},

where dd_{\cdot} means the model width, namely the hidden state size of the LLM.

The scaling dependencies on mpm_p of the various hyperparameters make sure that throughout the training, the norms (more precisely, the spectral norm) of the activations, weights, and gradients are independent of mpm_p, hence staying at Θ(1)\Theta(1). Since the scaling is only a function of mpm_p, one no longer needs to tune the hyperparameters at every scale. Instead, the base model hyperparameters can be reused in, or transferred to, larger scale models. The following is a comparison of the training loss vs learning rate plots in standard parameterization and μ\muP from the Tensor Programs V paper.

muP vs standard parameterization training loss vs learning rate plots from the Tensor Programs V paper

As we can see, with standard parameterization, the optimal learning rate shifts as the model size changes, while it stays stable with μ\muP.

The exact derivation of the scaling functions is rather mathy and complicated. It was first derived in the Tensor Program series of papers (starting with Tensor Programs I). A more accessible introduction is here. For obvious reasons, I’ll skip the derivation in this series, and just cite the scaling functions that are relevant for my application. Another helpful source of learning mumuP is this blog post from CerebrasAI and EleutherAI.

ParameterStandard Parameterizationμ\muP
Embedding initialization varianceσbase2\sigma^2_{\text base}σbase2\sigma^2_{\text base}
Embedding learning rate (Adam)ηbase\eta_{\text base}ηbase\eta_{\text base}
Embedding activationxWembxW_{\text emb}αinputxWemb\alpha_{\text input}\cdot xW_{\text emb}
Hidden layer initialization varianceσbase2\sigma^2_{\text base}σbase2/mp\sigma^2_{\text base} / m_p
Hidden layer learning rate (Adam)ηbase\eta_{\text base}ηbase/mp\eta_{\text base} / m_p
Output/logits activationxWoutxW_{\text out}αoutputxWout/mp\alpha_{\text output}\cdot xW_{\text out} / m_p
Attention logitsQTK/dmodelQ^T K / \sqrt{d_{\text model}}QTK/dmodelQ^T K / d_{\text model}

Attention logits normalization#

One interesting thing to note is the difference in the normalization factor in attention logits. In SP, it is 1/dhead1/\sqrt{d_{\text head}}, but in μ\muP, it becomes 1/dhead1/d_{\text head}. The short answer is: dhead\sqrt{d_{\text{head}}} is derived under the assumption that the elements of the query (QQ) and key (KK) vectors are independent random variables, which is only true exactly at initialization. However, during training, weight updates cause the vectors to become correlated. When vectors are correlated, their dot product scales linearly with the dimension (dheadd_{\text{head}}), not the square root

In the case of SP, at initialization, the projection matrices WqW_q and WkW_k are typically initialized from a Gaussian distribution, meaning the resulting query and key vectors qq and kk have entries that are independent, identically distributed (i.i.d.) random variables with a mean of zero.

Let qmq_m and kmk_m be the mm-th coordinates of qq and kk, both with Θ(1)\Theta(1) variance. The dot product is:

qk=m=1dqmkmq \cdot k = \sum_{m=1}^{d} q_m k_m

Because qmq_m and kmk_m are independent and zero-mean:

E[qmkm]=E[qm]E[km]=0\mathbb{E}[q_m k_m] = \mathbb{E}[q_m]\mathbb{E}[k_m] = 0

According to the Central Limit Theorem (CLT), the sum of dd independent zero-mean variables acts as a random walk. While the expectation remains 0, the variance grows linearly with dd:

Var(qk)=m=1dVar(qmkm)=Θ(d)\text{Var}(q \cdot k) = \sum_{m=1}^{d} \text{Var}(q_m k_m) = \Theta(d)

To keep the attention logits (pre-softmax values) at a stable Θ(1)\Theta(1) variance so the softmax doesn’t immediately saturate, Standard Parameterization divides by the standard deviation:

LogitsSP=qkd\text{Logits}_{\text{SP}} = \frac{q \cdot k}{\sqrt{d}}

However, as training progresses, WqW_q and WkW_k becomes more and more correlated, so are qq and kk. They eventually align along meaningful feature directions. Because of this correlation, the expected value of their coordinate-wise product is no longer zero:

E[qmkm]=c0\mathbb{E}[q_m k_m] = c \neq 0

Now, we must apply the Law of Large Numbers (LLN) rather than the CLT. The sum of dd variables with a non-zero mean cc scales linearly with dd:

qk=m=1dqmkm=d(1dm=1dqmkm)q \cdot k = \sum_{m=1}^{d} q_m k_m = d \left( \frac{1}{d} \sum_{m=1}^{d} q_m k_m \right)

As dd \to \infty, the sample mean converges to the expected value:

qkdc=Θ(d)q \cdot k \approx d \cdot c = \Theta(d)

In μ\muP, we still want to keep the logits and hence the softmax well-behaved, so we divide by dd instead of d\sqrt{d}:

LogitsμP=qkd=Θ(d)d=Θ(1)\text{Logits}_{\mu\text{P}} = \frac{q \cdot k}{d} = \frac{\Theta(d)}{d} = \Theta(1)

Learning rate for the output/logits layer#

You might also wonder why the learning rate for the output/logits layer is still scaled by 1/mp1/m_p when the activation is already scaled by 1/mp1/m_p. The answer lies in the fact that we use Adam as the optimizer, whose update rule is:

Δw=αtv^t+ϵm^t\Delta w = -\frac{\alpha_t}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t,

where αt\alpha_t is the learning rate adjusted by optimization steps taken, m^t\hat{m}_t and v^t\hat{v}_t are the first and second moment estimates of the gradients, and ϵ\epsilon is a small constant to prevent division by zero. Since v^t\hat{v}_t scales with mpm_p as v^tv^t/mp2\hat{v}_t \to \hat{v}_t / m_p^2, and m^t\hat{m}_t scales as m^tm^t/mp\hat{m}_t \to \hat{m}_t / m_p, the update step Δw\Delta w scales as ΔwΔw/mpmp2=Δw\Delta w \to \Delta w / m_p \cdot \sqrt{m_p^2} = \Delta w, which has no dependency on mpm_p. Therefore, we need to an extra 1/mp1/m_p factor from the learning rate itself even though the output activation is already scaled by 1/mp1/m_p.

Scaling law#

There is a lot to talk about scaling laws, such as pretraining scaling, RL scaling, inference time scaling, and agent scaling. But for the purpose of this series, I’ll simply mention one thing that can guide how I pick the largest model size given the compute budget I have. For more details of pretraining scaling laws, check out these resources:

The part I need for this exercise is this rule of thumb connecting the amount of data needed and the amount of model parameters. Various careful studies of scaling laws of modern LLMs found that in order to train the models to a good quality, one would typically need as many tokens as XX times of model parameters. XX was about 20 until recently, where it has gone up significantly, such as 150, 200 etc.

The reason for those higher ratios is that inference is much more expensive than training, so it would be more preferable to train a smaller model much longer (to reach similar performances as larger models trained for less data) so that the inference time cost (which is accrued repeatedly throughout the model’s deployment life span) could be reduced, even if the training budget is not spent optimally. But for my little experiment, 20 is a good rule of thumb.

Putting it together#

So now we have two nice guidelines at play to guide us to find the optimal training recipe.

  • Scaling law tells us given the compute budget, what’s the size of the model I should train.
  • μ\muP parameterization tells us for that model size, what the training hyperparameters should be, such as learning rate, parameter initialization standard deviation etc.

For my little exercise, the recipe then could be divided into two stages: finding the optimal hyperparameters for the base model, and scaling it up to the largest model I can train given my compute budget, which is 2 hours on 8x H100 SXM.

Base model architecture#

The first step is to find a reasonable base model size so that it can be run cheaply, such as on my laptop with 5080. The reason for this is that we will need to scan the μ\muP hyperparameters, which would require tens, if not hundreds, training runs of the base model.

The model size is relatively easy to compute given the model architecture. I tried to build this util to compute the various numbers associated with a model and its inference and training cost, such as number of parameters, memory cost of model state and optimizer state, memory cost of activation, flops cost of one inference or one training pass.

Anything about the model is easy to estimate (number of parameters, model state and optimizer state memory cost etc.), but activation cost and flops are harder because at the runtime, there could be various implementation details and optimizations that could change the cost footprint. I initially tried to do a grid search over the number of layers and batch size to find a good base model set up. If the estimates are accurate, I can easily find an appropriate batch size and layers number combo from a plot like this:

muP base model batch size num of layers grid search

But my util wildly underestimated the memory cost, so I ended up reducing the number of layers slightly and using microbatching to fit the base model training on my laptop 5080. But the grid search at least provides a somewhat good starting point. I eventually settled with the following parameters:

  • dbase=256d_{\text base} = 256. This can’t be too small, in order for the effects of the law of large numbers to kick in (for μ\muP).
  • Number of layers = 36.
  • Batch size 16.
  • Max context window size 1024.

The other parameters can be found in any of the sweep runs, such as this one.

Hyperparameter sweep#

As seen from the μ\muP section, the hyperparameters we need to tune are

  • Base learning rate ηbase\eta_{\text base}.
  • Base standard deviation used for model parameter initialization σbase\sigma_{\text base}.
  • Multiplicative factor αinput\alpha_{\text input} for the embedding layer.
  • Multiplicative factor αoutput\alpha_{\text output} for the output/logit layer.

I did three rounds of sweeps (sweep 1, 2, 3) for a base model above, and found the following hyperparameters:

  • ηbase=0.008\eta_{\text base} = 0.008.
  • σbase=0.25\sigma_{\text base} = 0.25.
  • αinput=1.6\alpha_{\text input} = 1.6.
  • αoutput=3.2\alpha_{\text output} = 3.2.

Scaling up#

Given all the preparation in this post and the previous ones (in particular, parallel training and all the plumbing work), all that is left is to figure out what mpm_p should we scale up to. This is determined by the compute budget we have, which is 2 hours on 8x H100 SXM. Based on the specs of H100 SXM, the total FLOPS I have is 1979×1012×8×2×3600×η=1.14×1020λ1979 \times 10^{12} \times 8 \times 2 \times 3600 \times \eta = 1.14\times 10^{20} \lambda, where λ\lambda is the Model FLOPS Utilization.

The FLOPS of a training step can be approximated by matrix multiplications. This gives a total training FLOPS estimate of 6×NP×NT6\times N_P \times N_T, where NPN_P is the number of model parameters and NTN_T is the number of training tokens. So roughly the FLOPS as a function of NPN_P is 120NT2120 N_T^2. Note how FLOPS grow quadratically with the number of parameters.

Setting the two equal to each other 120NT2=1.14×1020λ120 N_T^2 = 1.14\times 10^{20} \lambda gives a rough estimate of the largest model I can train. λ\lambda can be estimated by actually running the training pipeline. From the theoretical FLOP/s on the spec, we can get a theoretical value of steps/sec. λ\lambda is then basically the actual steps/sec (as measured by test runs) divided by that theoretical value. In my case, it turns out to be about 1/71/7. This puts my max mpm_p at 5–this just shows how expensive this way of training LLMs is.

The following are the loss curves of my large run with this setup on 8x H100 SXM.

training loss of the largest training run

validation loss of the largest training run

Later I had some further optimizations, such as using gradient checkpointing and moving to a fp8 quantization. I’m sure they can improve the training pipeline efficiency and hence the final model quality, but I’ll leave it to future experiments when I have more time and hobby funds.

Closing#

Wheeew, that’s a wrap.

whew meme

I hope you enjoyed reading this series as much as I enjoyed writing it, and can take something useful away with you.

And please feel free to leave a comment below and let me know if you have any questions or suggestions! If you’re interested, you can find the full code for this series in my repo

djwenren
/
llm-with-jax-practice
Waiting for api.github.com...
00K
0K
0K
Waiting...

Modern AI is fun. Let’s keep exploring!

LLM with the JAX ecosystem from scratch - Part 3
https://www.djwenren.com/posts/llm-with-the-jax-ecosystem-from-scratch-part-3/
Author
Danjie Wenren
Published at
2026-04-09
License
CC BY-NC-SA 4.0