Plain SGD Cannot Train ANDREA
Stochastic Gradient Descent, the Starting Point
Backprop computes a gradient g for every parameter. Plain stochastic gradient descent (SGD) updates each parameter with p -= lr * g. One learning rate, one direction per step, no memory of past gradients.
Plain SGD breaks at scale for two reasons:
1. Gradients have wildly different magnitudes across parameters. An embedding for a rare token receives a tiny gradient most steps; a layernorm scale receives a large one. One learning rate cannot suit both.
2. Gradients oscillate. A noisy mini-batch from a 16-source corpus pushes a parameter left, then right, then left. Plain SGD wastes steps fighting itself.
Adam (Kingma & Ba, 2015) fixes both with two running averages per parameter.
First Moment & Second Moment
m: Smoothed Direction
The first moment m averages recent gradients exponentially:
m = beta1 m + (1 - beta1) g
with beta1 = 0.9. After several steps, m carries a smoothed direction; one bad batch barely shifts it.
v: Smoothed Magnitude
The second moment v averages recent squared gradients:
v = beta2 v + (1 - beta2) g^2
with beta2 = 0.999. v tracks how big each parameter's gradient typically gets. Big-gradient parameters get a large v; tiny-gradient parameters get a small v.
Per-Parameter Adaptive Learning Rate
Dividing the smoothed direction by the square root of the smoothed magnitude rescales every parameter onto a comparable footing:
adam_step = m / sqrt(v + eps)
Tiny-gradient embeddings get scaled up; large-gradient layernorms get scaled down. One global lr now suits every parameter.
Reading the Moments
Why Early Steps Need Bias Correction
Cold-Start Bias
m & v start at zero. After step 1, m = 0.1 g_1 & v = 0.001 g_1^2. Both estimates undershoot a long-run average dramatically. Without correction, the optimizer starts timid & ramps up slowly, wasting precious early steps when representations form.
The Correction
Adam scales each estimate by 1 / (1 - beta^t) where t is the step number:
m_hat = m / (1 - beta1^t)
v_hat = v / (1 - beta2^t)
At step 1 with beta1 = 0.9, the divisor (1 - 0.9) = 0.1, so m_hat = m / 0.1 = 10 * m. The bias-corrected estimate matches what the long-run average would predict. As t grows, beta^t approaches 0, the correction approaches 1, & corrected & uncorrected values converge.
Decoupled Weight Decay (the AdamW Innovation)
L2 Regularization vs Weight Decay
Classic L2 regularization adds a penalty to a loss: L_total = L_data + (lambda / 2) sum(p^2). Backprop sees that penalty as part of the gradient: g_total = g_data + lambda p. The L2 term flows through Adam's m & v updates, getting smoothed & rescaled by per-parameter magnitudes.
Loshchilov & Hutter (2019) proved that smoothing a regularizer through Adam corrupts both. Adam's adaptive scaling shrinks weight decay on large-gradient parameters (where decay should fight overfitting hardest) & amplifies it on small-gradient ones.
AdamW: Apply Decay Directly
AdamW decouples weight decay from a gradient. Decay applies to each parameter directly during a parameter update, never touching m or v:
p -= lr (m_hat / (sqrt(v_hat) + eps) + weight_decay p)
Two terms now drive each step:
1. Adam term: m_hat / (sqrt(v_hat) + eps) rescales gradient direction by per-parameter magnitude history.
2. Decay term: weight_decay * p shrinks every parameter toward zero, uniformly, without going through Adam's smoothing.
ANDREA-120M v2 sets weight_decay = 0.01. Every step, every parameter shrinks 1% toward zero, in addition to whatever the Adam term does.
Why Decoupled Matters
Empirical Evidence
v1 Collapse (no weight decay)
ANDREA-120M v1 trained for 165K steps with vanilla Adam. Sample outputs:
- Step 80K: region region region region region region region
- Step 110K: ''''' ''''' '' ''' '' ''' '''?' ''' ' '' '' '
- Step 140K: games, games, games, games, games, games, games
- Step 165K: Budy Budy Budy Budy Budy Budy Budy Budy Budy
Loss numbers stayed reasonable (EMA minimum 3.23 at step 110K, vs random-chance 9.04). Loss alone hides repetition collapse: a model that memorizes one token forever achieves low cross-entropy on every step that token appears.
v2 Stability (weight_decay = 0.01)
v2 added AdamW (plus gradient clipping, LR warmup, sample monitoring). At step ~112K, samples produced:
- Carolina parakeet was declared extinct in 1939 (factually correct)
- The Fourier transform decomposes signals into frequency components (textbook definition)
- Rain's rhythmic refrain, Rivulets on the window, Respite from life's pain (haiku constraint satisfied)
External grading rated v2 samples 9.5/10, calling them "impressive coherence & knowledge retention at this scale."
The 12M Survived without AdamW. Why?
ANDREA-12M trained on vanilla Adam without collapse. At 12M parameters, weight matrices stay small enough that Adam's adaptive scaling cannot push individual weights into the runaway magnitudes that drive repetition. At 120M scale, weight magnitudes drift further per step & accumulate; uniform decay applies a constant restoring force toward zero. Decoupled weight decay matters more as a model scales.
Choosing weight_decay = 0.01
Adjacent Activities
AdamW interlocks with three sibling activities in this course:
- Activity 11: LR warmup + cosine decay. AdamW alone cannot save a model from instant peak learning rate on freshly initialized weights. Warmup ramps lr over 2000 steps so AdamW's bias correction & weight decay get time to stabilize representations.
- Activity 12: Gradient clipping. AdamW assumes gradients have bounded magnitude. Source transitions every 7 to 42 steps in ANDREA's bandit produce occasional gradient spikes; clipping caps them at L2 norm 1.0 BEFORE AdamW touches m, v, or p.
- Activity 13: FP32 / FP16 / FP8 precision. AdamW stores m & v per parameter, doubling the memory footprint of weights alone. FP16 cuts that footprint in half; FP8 cuts it again. Precision choices interact with optimizer stability.
AdamW, warmup, clipping, & precision form a four-leaf clover. Drop one leaf, watch ANDREA collapse.