Where Gradient Spikes Come From
A Calm Mini-Batch & a Shocking One
Most mini-batches produce gradients with reasonable magnitudes. Cross-entropy loss for a model that already roughly fits the data stays in a narrow band; backprop carries that signal back as gradients of similar size.
Some mini-batches do not. Three sources of gradient spikes:
1. Outlier examples. A single sequence with an extremely rare token combination produces a far-from-mean loss & a far-from-mean gradient.
2. Numerical edge cases. A near-zero softmax denominator, a NaN-producing layernorm, an FP16 overflow. Each can produce gradients orders of magnitude larger than typical.
3. Distribution shifts. Switching data sources during a single training run shocks the model with a new distribution. ANDREA's bandit reshuffles source weights every 7 to 42 steps. Each switch is a small distribution shift.
ANDREA-120M v1: Spike Cascade
v1 had no gradient clipping. Source transitions every 7 to 42 steps from the bandit fed the model brief bursts of repo-docs (list-structured), then gutenberg (long prose), then hermes3-general (Q&A). Each transition produced gradient spikes: each spike pushed weights into degenerate attractors at 120M scale.
Key empirical fact. ANDREA-12M survived the same bandit without clipping. Smaller weight matrices stay robust to gradient shocks; a single bad batch cannot push 12M parameters into a runaway attractor the way it can push 120M. Clipping matters more as a model scales.
Global L2 Norm Clipping
Two Choices: Per-Tensor or Global
Two ways to bound gradient magnitudes:
Per-tensor clipping. Clip each gradient tensor independently. Embedding gradient gets clipped to its own norm; attention gradient gets clipped to its own norm. Simple, but distorts relative scales: a small spike in one tensor (now zero gradient) pairs with a huge gradient in another (untouched).
Global L2 norm clipping. Treat all gradients as one big vector. Compute the total L2 norm across every parameter. If the norm exceeds max_norm, scale every gradient by the same factor. Preserves relative magnitudes across tensors.
ANDREA uses global. Pascanu et al. (2013) demonstrated empirically that global clipping outperforms per-tensor for transformer training.
The Math
Compute the global L2 norm:
norm = sqrt(sum over all params of g_i^2)
If norm <= max_norm, gradients pass through unchanged. If norm > max_norm, scale every gradient by max_norm / norm:
g_i_clipped = g_i * (max_norm / norm)
After scaling, the new norm equals exactly max_norm. ANDREA uses max_norm = 1.0.
Computing a Scale Factor
Why Gradient Norm Computation Needs Three Kernels
The Naive Algorithm Cannot Run on a GPU
Pseudocode for global L2 norm computation:
total = 0
for each param p:
for each element g in p.grad:
total += g * g
norm = sqrt(total)
On a GPU, this naive loop fails for two reasons:
1. Sequential accumulation. A single total accumulator forces every thread to wait for every other thread, defeating GPU parallelism.
2. Heterogeneous tensors. ANDREA-120M has tensors of vastly different shapes: embedding (8449 x 768), attention QKV (768 x 768), layernorm (768). One kernel cannot efficiently iterate all shapes.
ANDREA's Three-Kernel Pipeline
Split the work into three CUDA kernels in microgpt_cuda.cu:
Kernel 1: k_grad_norm_partial. For each parameter tensor, compute a partial sum of squares. Each thread block reduces a chunk of the tensor; results write to a small scratch buffer. Parallelism: one block per chunk, hundreds of blocks across all tensors.
Kernel 2: k_grad_norm_final. Reduce the scratch buffer to a single scalar. Take its square root. One small kernel, runs in microseconds.
Kernel 3: k_grad_scale. If norm > max_norm, compute scale = max_norm / norm & multiply every gradient element by scale. One pass over every gradient tensor, embarrassingly parallel.
Order Matters: Pre-Adam
The clipping pipeline runs BEFORE AdamW updates m, v, or any parameter. Why?
Clipped gradients feed AdamW's exponential moving averages. If a spike were allowed to flow into m & v, it would corrupt those running averages & slow recovery for many steps after the spike. Clipping pre-Adam keeps the spike's effect confined to the single bad step.
Why Three Kernels, Not One?
How No-Clipping Killed v1
Bandit Source Transitions Every 7 to 42 Steps
ANDREA's bandit operates in phases. Each phase lasts 7, 14, 21, 28, or 42 steps (chosen randomly). At each phase boundary, source weights shift: maybe repo-docs jumps from 0.1 to 0.6, gutenberg drops from 0.4 to 0.1, hermes3-general rises from 0.5 to 0.7.
Each transition is a distribution shock to the model. Loss spikes briefly. Gradients spike with it: a model that was minimizing loss against gutenberg-flavored prose now sees repo-docs-flavored list structures, & gradients carry corrective signal that can be 10x or 100x typical magnitude.
v1 Failure Mode
Without clipping, those 10-100x gradient spikes flowed into AdamW's m & v averages. AdamW's smoothing meant the spike effect persisted for many steps after the actual bad batch. Combined with no weight decay (vanilla Adam in v1), spike-driven weight updates compounded over phases until weights drifted into a degenerate attractor: one token's logit dominated softmax, sampled output was that token, training context contained that token, gradient reinforced that token. Repetition lock-in.
v2 Stability
v2 added clipping with max_norm = 1.0, alongside AdamW & LR warmup. Spike effect on m & v is bounded; weights cannot drift faster than lr max_norm = 0.0003 1.0 = 0.0003 per parameter per step at peak. Phase transitions still produce spikes, but those spikes are capped before they reach the optimizer.
Result: v2 (after data filter v2.5 & v3 polish) reached factual recall, multi-paragraph coherence, & 9.5/10 external grades on biology & signal-processing samples.
The Capacity-Brittleness Coupling
Same bandit. Same data. Same hyperparameters except clipping. Why did 12M survive without clipping while 120M collapsed?
Two compounding factors:
1. Larger weight matrices store more attractors. A 768x768 attention projection has 590K parameters; even small per-parameter drift produces meaningful changes in attention behavior. A 384x384 attention projection has 147K parameters & stays in a more constrained subspace.
2. More layers means more multiplicative interactions. v3 has 12 transformer layers (vs 6 for 12M). Spikes propagate through 12 layers of compounding nonlinearities; each layer can amplify the prior layer's drift.
Brittleness compounds with capacity. Clipping becomes mandatory above some scale threshold; ANDREA puts that threshold somewhere between 12M & 120M parameters.
Diagnosing the v1 Cascade
Where Else Does Clipping Apply?
Adjacent Activities
Three siblings link to clipping:
- Activity 10: AdamW. Clipping protects AdamW's m & v from spike contamination. Without clipping, one bad batch corrupts optimizer state for 50+ steps.
- Activity 11: LR warmup. Warmup damps lr; clipping damps g. Together: at step 1, the worst-case parameter update is lr_after_warmup max_norm = 1.5e-7 1.0 = 1.5e-7, vs 0.0003 * 50 = 0.015 without either guard. A 100,000x reduction in worst-case early update magnitude.
- Activity 14: Multi-armed bandits. The bandit phase length (7 to 42 steps) is short specifically to prevent any one source from dominating; clipping is what makes those frequent transitions safe.
Clipping is the cheapest stability win in transformer training: 3 small CUDA kernels, microseconds per step, decisive impact on whether 120M+ models converge or collapse.