English· Español· Deutsch· Nederlands· Français· 日本語· ქართული· 繁體中文· 简体中文· Português· Русский· العربية· हिन्दी· Italiano· 한국어· Polski· Svenska· Türkçe· Українська· Tiếng Việt· Bahasa Indonesia

un

guest
1 / ?
back to lessons

Where Gradient Spikes Come From

A Calm Mini-Batch & a Shocking One

Most mini-batches produce gradients with reasonable magnitudes. Cross-entropy loss for a model that already roughly fits the data stays in a narrow band; backprop carries that signal back as gradients of similar size.


Some mini-batches do not. Three sources of gradient spikes:


1. Outlier examples. A single sequence with an extremely rare token combination produces a far-from-mean loss & a far-from-mean gradient.

2. Numerical edge cases. A near-zero softmax denominator, a NaN-producing layernorm, an FP16 overflow. Each can produce gradients orders of magnitude larger than typical.

3. Distribution shifts. Switching data sources during a single training run shocks the model with a new distribution. ANDREA's bandit reshuffles source weights every 7 to 42 steps. Each switch is a small distribution shift.


ANDREA-120M v1: Spike Cascade

v1 had no gradient clipping. Source transitions every 7 to 42 steps from the bandit fed the model brief bursts of repo-docs (list-structured), then gutenberg (long prose), then hermes3-general (Q&A). Each transition produced gradient spikes: each spike pushed weights into degenerate attractors at 120M scale.


Key empirical fact. ANDREA-12M survived the same bandit without clipping. Smaller weight matrices stay robust to gradient shocks; a single bad batch cannot push 12M parameters into a runaway attractor the way it can push 120M. Clipping matters more as a model scales.

Global L2 Norm Clipping

Two Choices: Per-Tensor or Global

Two ways to bound gradient magnitudes:


Per-tensor clipping. Clip each gradient tensor independently. Embedding gradient gets clipped to its own norm; attention gradient gets clipped to its own norm. Simple, but distorts relative scales: a small spike in one tensor (now zero gradient) pairs with a huge gradient in another (untouched).


Global L2 norm clipping. Treat all gradients as one big vector. Compute the total L2 norm across every parameter. If the norm exceeds max_norm, scale every gradient by the same factor. Preserves relative magnitudes across tensors.


ANDREA uses global. Pascanu et al. (2013) demonstrated empirically that global clipping outperforms per-tensor for transformer training.


The Math

Compute the global L2 norm:


norm = sqrt(sum over all params of g_i^2)


If norm <= max_norm, gradients pass through unchanged. If norm > max_norm, scale every gradient by max_norm / norm:


g_i_clipped = g_i * (max_norm / norm)


After scaling, the new norm equals exactly max_norm. ANDREA uses max_norm = 1.0.

Computing a Scale Factor

Suppose during one training step, the global L2 norm of all gradients computes to `3.5`. ANDREA's `max_norm = 1.0`. Compute (a) the scale factor that gets applied, (b) what the new global L2 norm equals after scaling, & (c) what would happen if the unclipped norm were `0.4` instead of `3.5`. Show your arithmetic.

Why Gradient Norm Computation Needs Three Kernels

The Naive Algorithm Cannot Run on a GPU

Pseudocode for global L2 norm computation:


total = 0
for each param p:
  for each element g in p.grad:
    total += g * g
norm = sqrt(total)

On a GPU, this naive loop fails for two reasons:


1. Sequential accumulation. A single total accumulator forces every thread to wait for every other thread, defeating GPU parallelism.

2. Heterogeneous tensors. ANDREA-120M has tensors of vastly different shapes: embedding (8449 x 768), attention QKV (768 x 768), layernorm (768). One kernel cannot efficiently iterate all shapes.


ANDREA's Three-Kernel Pipeline

Split the work into three CUDA kernels in microgpt_cuda.cu:


Kernel 1: k_grad_norm_partial. For each parameter tensor, compute a partial sum of squares. Each thread block reduces a chunk of the tensor; results write to a small scratch buffer. Parallelism: one block per chunk, hundreds of blocks across all tensors.


Kernel 2: k_grad_norm_final. Reduce the scratch buffer to a single scalar. Take its square root. One small kernel, runs in microseconds.


Kernel 3: k_grad_scale. If norm > max_norm, compute scale = max_norm / norm & multiply every gradient element by scale. One pass over every gradient tensor, embarrassingly parallel.


Order Matters: Pre-Adam

The clipping pipeline runs BEFORE AdamW updates m, v, or any parameter. Why?


Clipped gradients feed AdamW's exponential moving averages. If a spike were allowed to flow into m & v, it would corrupt those running averages & slow recovery for many steps after the spike. Clipping pre-Adam keeps the spike's effect confined to the single bad step.


Gradient Clipping with 3 CUDA Kernels

Why Three Kernels, Not One?

Suppose someone proposed merging `k_grad_norm_partial` & `k_grad_norm_final` into a single kernel that computes the entire global norm in one pass. Give one specific reason why this merger would either fail or perform worse on a GPU. Reference how GPU thread blocks share memory & synchronize.

How No-Clipping Killed v1

Bandit Source Transitions Every 7 to 42 Steps

ANDREA's bandit operates in phases. Each phase lasts 7, 14, 21, 28, or 42 steps (chosen randomly). At each phase boundary, source weights shift: maybe repo-docs jumps from 0.1 to 0.6, gutenberg drops from 0.4 to 0.1, hermes3-general rises from 0.5 to 0.7.


Each transition is a distribution shock to the model. Loss spikes briefly. Gradients spike with it: a model that was minimizing loss against gutenberg-flavored prose now sees repo-docs-flavored list structures, & gradients carry corrective signal that can be 10x or 100x typical magnitude.


v1 Failure Mode

Without clipping, those 10-100x gradient spikes flowed into AdamW's m & v averages. AdamW's smoothing meant the spike effect persisted for many steps after the actual bad batch. Combined with no weight decay (vanilla Adam in v1), spike-driven weight updates compounded over phases until weights drifted into a degenerate attractor: one token's logit dominated softmax, sampled output was that token, training context contained that token, gradient reinforced that token. Repetition lock-in.


v2 Stability

v2 added clipping with max_norm = 1.0, alongside AdamW & LR warmup. Spike effect on m & v is bounded; weights cannot drift faster than lr max_norm = 0.0003 1.0 = 0.0003 per parameter per step at peak. Phase transitions still produce spikes, but those spikes are capped before they reach the optimizer.


Result: v2 (after data filter v2.5 & v3 polish) reached factual recall, multi-paragraph coherence, & 9.5/10 external grades on biology & signal-processing samples.


The Capacity-Brittleness Coupling

Same bandit. Same data. Same hyperparameters except clipping. Why did 12M survive without clipping while 120M collapsed?


Two compounding factors:


1. Larger weight matrices store more attractors. A 768x768 attention projection has 590K parameters; even small per-parameter drift produces meaningful changes in attention behavior. A 384x384 attention projection has 147K parameters & stays in a more constrained subspace.

2. More layers means more multiplicative interactions. v3 has 12 transformer layers (vs 6 for 12M). Spikes propagate through 12 layers of compounding nonlinearities; each layer can amplify the prior layer's drift.


Brittleness compounds with capacity. Clipping becomes mandatory above some scale threshold; ANDREA puts that threshold somewhere between 12M & 120M parameters.

Diagnosing the v1 Cascade

Suppose at step 50,000 of v1 training, a single mini-batch produces a gradient with global L2 norm of 50.0 (typical batches produce ~0.5). Trace what happens to AdamW's first moment `m` over the next 10 steps if subsequent batches return to typical gradient magnitudes. Consider how `m = beta1 * m + (1 - beta1) * g` with beta1=0.9 propagates the spike.

Where Else Does Clipping Apply?

Beyond ANDREA's bandit-driven curriculum, name one OTHER training scenario where global L2 gradient clipping would be similarly important, & give one mechanism that makes it so.

Adjacent Activities

Three siblings link to clipping:


- Activity 10: AdamW. Clipping protects AdamW's m & v from spike contamination. Without clipping, one bad batch corrupts optimizer state for 50+ steps.

- Activity 11: LR warmup. Warmup damps lr; clipping damps g. Together: at step 1, the worst-case parameter update is lr_after_warmup max_norm = 1.5e-7 1.0 = 1.5e-7, vs 0.0003 * 50 = 0.015 without either guard. A 100,000x reduction in worst-case early update magnitude.

- Activity 14: Multi-armed bandits. The bandit phase length (7 to 42 steps) is short specifically to prevent any one source from dominating; clipping is what makes those frequent transitions safe.


Clipping is the cheapest stability win in transformer training: 3 small CUDA kernels, microseconds per step, decisive impact on whether 120M+ models converge or collapse.