How the NanoGPT Speedrun WR dropped by 20% in 3 months

In early 2024 Andrej Karpathy stood up an llm.c repo to train GPT-2 (124M), which took an equivalent of 45 minutes on 8xH100 GPUs to reach 3.28 cross entropy loss. By Jan 2025, collaborators of modded-nanogpt brought that time down to 3 minutes. It sat near 3 minutes until July 2025, having a large swath of optimization already applied: RoPE, value embeddings, reduce scatter grad updates, Muon, QK Norm, Relu^2, a custom FP8 head, skip connections, flex attention, short-long windows, attention window warmup, linear lr cooldown, and more. Yet, in the last 3 months the record has fallen by another 20% to 2 minutes and 20 seconds.

Many of the improvements in the last 20% have not yet been published outside of the modded-nanogpt repo. This post summarizes those improvements. Not everything will generalize to larger scales, but there are some core concepts that I believe are promising. Improvements are sorted into ML and Engineering, grouped by concept, and subjectively ranked by their general applicability. Each change includes an estimated runtime impact and links to the associated PRs. The post concludes with general thoughts on the process of finding improvements in transformer architectures and training recipes.

CatRankDescriptionRough Est Impact PR
ML1Document Alignment3s108, 118
ML2Dynamic Attention Window4s118, 122, 127, 131
ML3Heterogenous Batch Sizes4s136
ML4Backout2s140
ML5Polar Express1s134
ML6Smear Module1.5s130
ML7Sparse Attention Gate0.5s117
ML8More Bfloat160.5s125
ML9Softmax Skip Gate1s125
ML10Drop initial MLP Layer1.5s120
ENG1Flash Attention 33s118
ENG2Parameter reshaping for shared reduce scatter1.5s109, 132
ENG3Async Data Fetch and Index1.5s127
ENG4Vectorized Optimizer Step0.5s125
ENG5Triton Kernel for Symmetric Matmul1s109
ENG6Resize Lambda Params1.5s140

Latest version with all implemented code: https://​​github.com/​​KellerJordan/​​modded-nanogpt/​​blob/​​ba3e54f378b11af1ee33c2d518820e4532020190/​​train_gpt.py (Updates must be found through open PR list due to inactive repo owner)

ML Improvements

#1: Document Alignment

Intra-document masking[1] is a common technique used in models such as Llama 3 to prevent attention queries from attending to positions in other documents. However, masking is only half the picture. NanoGPT applies a data processing step during training such that each GPU receives the first 2048 tokens of at least 16 unique documents per step. The 2048 limit was optimized by Varun Srivastava. This approach has several benefits:

  1. Lower variance gradients. FineWeb documents can have up to 70,000 tokens. Since each training step contains 262,144 tokens, a naïve data sampling strategy may have 14 of its gradient estimates for a step coming from a single highly correlated document. This sampling approach ensures that each gradient is informed by at least 128 documents.

  2. Beginning of Sentence token is kept in context window. Prior research[2] demonstrated that having the bos token in the context window can improve performance.

  3. No mid-context learning. The model does not need to waste effort trying to learn from samples that start in the middle of a document.

#2: Dynamic Attention Window Management by Layer

NanoGPT applies a window sizing scheme across its 10 attention layers of [short, short, short, long, short, short, short, short, short, long]. The short window is initialized to 128 tokens and the long window to 384 tokens. 3 transformations occur during training:

  • At 13 of training: Increase from 128384 to 384896. Apply YaRN[3].

  • At 23 of training: Increase from 384896 to 640/​1408. Apply YaRN.

  • At 33 of training: Increase from 640/​1408 to 768/​2560. Apply YaRN.

Partial RoPE is applied to 50% of the head dimensions. It was observed that the long windows primarily attend to the stationary dimensions, and are responsible for model tasks such ‘find activations that look very similar to me, regardless of their position’. These long windows showed much more flexibility with window extensions, in particular the jump from 1408 to 2560 after training is complete.

#3 Heterogenous Batch Sizes

Critical batch size theory focuses on finding the single optimal batch size for a given model and dataset. However, parameters within a model have distinct training characteristics that lead to different optimal batch sizes. NanoGPT uses gradient accumulation to only update the embedding and lm_head weights every other step, creating heterogenous batch sizes within the same model. This means the gradients for these parameters across all 50,000 tokens in the vocabulary only need to be synced across GPUs half as often, leading to faster time per step.

# on even steps, only step Muon params
# on odd steps, step all params
if step%2==0:
    optimizer2.step()
    optimizer2.zero_grad(set_to_none=True)
else:
    for opt in optimizers:
        opt.step()
    model.zero_grad(set_to_none=True)

#4 Backout: Enabling a model to back out context for predictions

In the standard transformer architecture contributions to the residual stream have to serve two purposes at once: provide context to downstream layers, and add to the final prediction. However, information may be valuable for downstream context but not directly map to the lm_head vector of the token needed to be predicted. To enable the model to modulate these two functions independently, its given the ability to back out prior contributions just before making a prediction. The model learns to back out 50% of the contributions from the first 23 layers. The core of this idea is from Sebastian Müller.

x -= backout_lambda*residual_stream_after_layer8
x = norm(x)
logits = self.lm_head(x)

#5 Polar Express

This is a more accurate method for computing the orthogonalization step in Muon compared to Newton Schulz. See the original paper[4] for details. Incorporation into the repo was performed by Varun Srivastava.

#6 Smear Module

It is common for several heads in a transformer to devolve into a “Previous token head” that always attends to the previous position. However, attention is a computationally inefficient way to attend to the previous position. NanoGPT introduces the Smear Module, which enables tokens to directly peer backwards 1 position and smear the prior token forward. The contribution is gated on a sigmoid gate that is fed by the first 12 dimensions of the token embedding space. On average, the model learns that (token + 0.07prior_token) is a more useful representation than (token).

x = self.embed(input_seq)
smear_lambda = self.scalars[5 * len(self.blocks)]
smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)]))
x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]])
x = x0 = norm(x[None])

#7 Sparse Attention Gate

Attention does not have a built in way to perform a no-op. Many mechanisms to alleviate this have been proposed, but they are often either not directly compatible with Flash Attention 3, or incur high runtime overhead (in the context of this speedrun). NanoGPT uses a sigmoid gate for each attention head to modulate the attention output. The gate is fed by only the first 12 dimensions of the residual stream, enabling fast updates while significantly reducing the bos token attention sink behavior.

# init
self.attn_gate = CastedLinear(12, num_heads)

# perform attn out projection
y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1)

#8 More Bfloat16

The cross entropy loss calculation is left in bfloat16 instead of casting up to float32. The language model head and rotary cos and sin terms are stored in bfloat16 instead of bfloat32. This gives a faster runtime on the forward pass with minimal increase in loss. Adam gradient calculations are left in bfloat16. The parameter storage for MLP and attention matrices are left in float32 due to higher sensitivity to loss. This change was implemented by the Hive AI.

#9 Softmax Skip Gate

Skip connections had previously been setup and initialized with weights of 1:1 with the main pathway. This is replaced with a sigmoid gate that is initialized to produce 0.18. The smaller initialization for skip connections gives the model worse initial training, but better final performance due to encouraging the formation of deeper pathways. This change was implemented by the Hive AI.

#10 Drop MLP Layer

EmelyanenkoK dropped the initial MLP layer and increased the step count to partially compensate, after running an ablation that showed it had the least impact of all MLP layers in the model.

Engineering Improvements

#1 Flash Attention 3

In order to make Flash Attention 3 compatible with torch.compile, an unmerged version is used. Varun Srivastava streamlined this process with the huggingface kernels library, and implemented flash_attn_varlen_func() to maintain the intra-document masking that was previously applied via flex attention.

y = flash_attn_interface.flash_attn_varlen_func(
	q[0], k[0], v[0], 
	cu_seqlens_q=seqlens, 
	cu_seqlens_k=seqlens, 
	max_seqlen_q=max_len, 
	max_seqlen_k=max_len,
	causal=True, 
	softmax_scale=attn_scale, 
	window_size=(bm_size, 0)
)

#2 Parameter reshaping for shared reduce scatter

The optimizer implementation for MLP and Attention parameters uses Muon, which requires that the entire gradient for a matrix be collected onto a single GPU to perform an accurate orthogonalization update. After each training step each GPU has its own gradients, which need to get collected in one place. Torch has a distributed API call to take 8 parameters, pick a GPU to own each, and have the other 7 GPUs send their copy to the designated owner. However, the API call only works if each GPU has a parameter of equal size. This means that if there aren’t exactly 8 parameters, extra padding variables get created.

To minimize padding variables, all attention and MLP weights are reshaped to the same dimensions of [d_model, 4*d_model]. Bryon Xu implemented this for MLP. This means that on the forward pass the MLP out projection gets transposed back to shape (4*d_model, d_model) and the attention matrix gets reshaped to (4, d_model, d_model), for the 4 projections for Q,K,V, Out. The attention parameters also get reshaped prior to the orthogonalization update, shown below.

# Compute zeropower for the entire chunk in a single, batched call.
original_shape = batched_update_grads.shape
# Reshape attn params from [hdim, dim*4] to [4,hdim,dim] to apply NS indepedently to Q,K,V,O
module_idx = start_idx if start_idx<len(params) else 0
if getattr(params[module_idx],'module','none')=='attn':
	batch = 4 * original_shape[0]
	d1 = original_shape[1] 
	d2 = original_shape[2] // 4
	batched = batched_update_grads.view(batch, d1, d2)
	v_chunk = polar_express(batched)
	v_chunk = v_chunk.view(original_shape)
else:
	v_chunk = polar_express(batched_update_grads)

#3 Async Data Fetch and Index

Start prefetching and indexing the next shard immediately. Since this occurs on the CPU, there is ample time to perform this during the GPU heavy workload, and we shouldn’t be bottlenecking GPU activities on CPU data indexing. Only partially index the first shard before starting to train on it. Kickoff a parallel thread to finish indexing it, which gets picked up on the 5th step.

#4 Vectorized Optimizer Step

Torch reduce scatter, Muon orthogonalization, and torch all gather can be executed across multiple parameters at once, as long as the total parameter count is divisible by 8. This change was implemented by the Hive AI.

#5 Triton Kernel for Symmetric Matmul

Multiplying two matrices with shape (m, k) and (k, n) requires 2*m*k*n FLOPS. However, multiplying a matrix (m, k) with its own transpose (k, m) can be done with only m*k*m FLOPS. The result is symmetric, so we only need to compute half of it and copy the result across the diagonal. This is used in the first step of Newton Schulz for the Muon update. This update is from Bryon Xu.

#6 Resize Lambda Parameters

The model has a host of scalars to manage the weighting of various connections. Originally it was assumed that the exact update process of these scalars was less relevant, since the count (<100) is dwarfed by the core model parameters. However, it was later observed when the count was set to 56 or 72 scalars the runtime dropped meaningfully compared to 64 scalars. While the exact cause is not fully understood, it is weakly hypothesized that coalesced memory access patterns are playing a role here, where each GPU can access 4 data values simultaneously. After the Adam optimizer splits the parameters 8 ways across the GPUs, 56 scalars was leading to 7 parameters per GPU, and 72 scalars was leading to 9 parameters per GPU.

Takeaways from the Journey

All changes above without a listed author were from myself. I have learned a lot in the last 3 months about the process of discovering model improvements (and my bank account has also lost some weight). I hope that I can keep learning, to the point where I’ll look back and consider current me a bit clueless. Here are my takeaways from where I stand today.

Optimize for many ideas over good ideas. The best way to have a good idea is to first have 20 bad ideas and learn from them. When I was first starting on the speedrun, I spent a lot of effort trying to mentally anticipate how an idea might pan out. I have found it more advantageous to limit my initial mental effort to ‘Is this idea plausible’. If so, immediately shift into ‘How can I test it’. The most fruitful spot for learning is after I have test results and have to think through why an idea failed or worked. I want to go from ideation to that spot as quickly as possible.

Work backwards from the data. Moving the needle on a pre-optimized task means you have to find ideas that no one has thought of yet. The approach of ‘read a paper’, ‘apply the paper’, ‘repeat’, is a good way to keep your inspiration in the same spot as the people who have already been testing ideas. If you instead work backwards from the data, it will give you a completely unique perspective- just from the fact that there are millions of different ways to look at the data. Here is an example where I explore how the phrase http://​​stickygooeycreamychewy.com gets perfectly predicted by the model on its second time ever seeing it. This gives me a unique perspective on how the last layer is functioning, leading to the post training attention window improvements.

Let gradient magic work for you. I’m using the phrase ‘gradient magic’ to refer to how backpropagation can almost instantly find the local gradient across the entire parameter space. This is something I’ve heard for years but didn’t understand until recently, because it is so remarkably different from how humans approach problems. If a human was in a footrace and they had 100 million doors in front of them and needed to pick a route, it would be tremendously helpful if someone could remove the worse half of the doors. Choice parallelizes humans. Backpropagation cuts through it. Instead of trying to help the model by eliminating choices, give it more context and more choices.

Find environments with feedback. I don’t work in AI, I don’t have professors or peers in AI, and none of my friends work on anything related to AI. As a result, I am rarely ever getting feedback. Most of my knowledge consumption in this space is unidirectional, which I’m realizing is horribly inefficient. I got lunch one time with a super cool guy at AI2, and had a video call a year ago with a very friendly research engineer. Those two experiences were very helpful for me, though sparse. The consistent feedback from a speedrun timing and some of the discussions coming off of it has been incredibly productive for my rate of learning. In a sense, it helps level the playing field between myself and those already in feedback rich environments. If I was better about networking I probably could have leveled that playing field a long time ago. Now that I’ve had this experience, “What is the level of feedback” is a question I’m asking about every new learning environment.

  1. ^

    Analysing The Impact of Sequence Composition on Language Model Pre-Training https://​​arxiv.org/​​pdf/​​2402.13991

  2. ^

    Efficient Streaming Language Models with Attention Sinks https://​​arxiv.org/​​abs/​​2309.17453

  3. ^

    YaRN: Efficient Context Window Extension of Large Language Models. https://​​arxiv.org/​​pdf/​​2309.00071

  4. ^

    The Polar Express: Optimal Matrix Sign Methods and Their Application to the Muon Algorithm. https://​​arxiv.org/​​pdf/​​2505.16932