English· Español· Deutsch· Nederlands· Français· 日本語· ქართული· 繁體中文· 简体中文· Português· Русский· العربية· हिन्दी· Italiano· 한국어· Polski· Svenska· Türkçe· Українська· Tiếng Việt· Bahasa Indonesia

un

guest
1 / ?
back to lessons

Plain SGD Cannot Train ANDREA

Stochastic Gradient Descent, the Starting Point

Backprop computes a gradient g for every parameter. Plain stochastic gradient descent (SGD) updates each parameter with p -= lr * g. One learning rate, one direction per step, no memory of past gradients.


Plain SGD breaks at scale for two reasons:


1. Gradients have wildly different magnitudes across parameters. An embedding for a rare token receives a tiny gradient most steps; a layernorm scale receives a large one. One learning rate cannot suit both.

2. Gradients oscillate. A noisy mini-batch from a 16-source corpus pushes a parameter left, then right, then left. Plain SGD wastes steps fighting itself.


Adam (Kingma & Ba, 2015) fixes both with two running averages per parameter.

First Moment & Second Moment

m: Smoothed Direction

The first moment m averages recent gradients exponentially:


m = beta1 m + (1 - beta1) g


with beta1 = 0.9. After several steps, m carries a smoothed direction; one bad batch barely shifts it.


v: Smoothed Magnitude

The second moment v averages recent squared gradients:


v = beta2 v + (1 - beta2) g^2


with beta2 = 0.999. v tracks how big each parameter's gradient typically gets. Big-gradient parameters get a large v; tiny-gradient parameters get a small v.


Per-Parameter Adaptive Learning Rate

Dividing the smoothed direction by the square root of the smoothed magnitude rescales every parameter onto a comparable footing:


adam_step = m / sqrt(v + eps)


Tiny-gradient embeddings get scaled up; large-gradient layernorms get scaled down. One global lr now suits every parameter.

Reading the Moments

Two parameters get the same gradient `g = 0.1` on this step. Parameter A has accumulated `v = 0.0001` over many steps; parameter B has `v = 1.0`. Using `adam_step = g / sqrt(v + eps)`, which parameter receives the larger update this step, & why does Adam want it that way?

Why Early Steps Need Bias Correction

Cold-Start Bias

m & v start at zero. After step 1, m = 0.1 g_1 & v = 0.001 g_1^2. Both estimates undershoot a long-run average dramatically. Without correction, the optimizer starts timid & ramps up slowly, wasting precious early steps when representations form.


The Correction

Adam scales each estimate by 1 / (1 - beta^t) where t is the step number:


m_hat = m / (1 - beta1^t)


v_hat = v / (1 - beta2^t)


At step 1 with beta1 = 0.9, the divisor (1 - 0.9) = 0.1, so m_hat = m / 0.1 = 10 * m. The bias-corrected estimate matches what the long-run average would predict. As t grows, beta^t approaches 0, the correction approaches 1, & corrected & uncorrected values converge.

Decoupled Weight Decay (the AdamW Innovation)

L2 Regularization vs Weight Decay

Classic L2 regularization adds a penalty to a loss: L_total = L_data + (lambda / 2) sum(p^2). Backprop sees that penalty as part of the gradient: g_total = g_data + lambda p. The L2 term flows through Adam's m & v updates, getting smoothed & rescaled by per-parameter magnitudes.


Loshchilov & Hutter (2019) proved that smoothing a regularizer through Adam corrupts both. Adam's adaptive scaling shrinks weight decay on large-gradient parameters (where decay should fight overfitting hardest) & amplifies it on small-gradient ones.


AdamW: Apply Decay Directly

AdamW decouples weight decay from a gradient. Decay applies to each parameter directly during a parameter update, never touching m or v:


p -= lr (m_hat / (sqrt(v_hat) + eps) + weight_decay p)


Two terms now drive each step:


1. Adam term: m_hat / (sqrt(v_hat) + eps) rescales gradient direction by per-parameter magnitude history.

2. Decay term: weight_decay * p shrinks every parameter toward zero, uniformly, without going through Adam's smoothing.


ANDREA-120M v2 sets weight_decay = 0.01. Every step, every parameter shrinks 1% toward zero, in addition to whatever the Adam term does.


AdamW Optimizer Step

Why Decoupled Matters

ANDREA-120M v1 used vanilla Adam (no weight decay) & collapsed into repetition loops by step 110K. v2 used AdamW with `weight_decay = 0.01` & produced coherent multi-paragraph text. Explain (a) what specific term gets added in AdamW that vanilla Adam lacks, & (b) why putting that term INSIDE the gradient (classic L2) would be worse than AdamW's decoupled placement OUTSIDE.

Empirical Evidence

v1 Collapse (no weight decay)

ANDREA-120M v1 trained for 165K steps with vanilla Adam. Sample outputs:


- Step 80K: region region region region region region region

- Step 110K: ''''' ''''' '' ''' '' ''' '''?' ''' ' '' '' '

- Step 140K: games, games, games, games, games, games, games

- Step 165K: Budy Budy Budy Budy Budy Budy Budy Budy Budy


Loss numbers stayed reasonable (EMA minimum 3.23 at step 110K, vs random-chance 9.04). Loss alone hides repetition collapse: a model that memorizes one token forever achieves low cross-entropy on every step that token appears.


v2 Stability (weight_decay = 0.01)

v2 added AdamW (plus gradient clipping, LR warmup, sample monitoring). At step ~112K, samples produced:


- Carolina parakeet was declared extinct in 1939 (factually correct)

- The Fourier transform decomposes signals into frequency components (textbook definition)

- Rain's rhythmic refrain, Rivulets on the window, Respite from life's pain (haiku constraint satisfied)


External grading rated v2 samples 9.5/10, calling them "impressive coherence & knowledge retention at this scale."


The 12M Survived without AdamW. Why?

ANDREA-12M trained on vanilla Adam without collapse. At 12M parameters, weight matrices stay small enough that Adam's adaptive scaling cannot push individual weights into the runaway magnitudes that drive repetition. At 120M scale, weight magnitudes drift further per step & accumulate; uniform decay applies a constant restoring force toward zero. Decoupled weight decay matters more as a model scales.

Choosing weight_decay = 0.01

Why might `weight_decay = 0.01` work for a 120M-parameter model but a value 100x larger (`weight_decay = 1.0`) destroy training? Reason about the update rule `p -= lr * (adam_term + weight_decay * p)`. Pick a representative `p`, plug in `weight_decay = 1.0` for a single step, & describe what happens to `p` after a few steps.

Adjacent Activities

AdamW interlocks with three sibling activities in this course:


- Activity 11: LR warmup + cosine decay. AdamW alone cannot save a model from instant peak learning rate on freshly initialized weights. Warmup ramps lr over 2000 steps so AdamW's bias correction & weight decay get time to stabilize representations.

- Activity 12: Gradient clipping. AdamW assumes gradients have bounded magnitude. Source transitions every 7 to 42 steps in ANDREA's bandit produce occasional gradient spikes; clipping caps them at L2 norm 1.0 BEFORE AdamW touches m, v, or p.

- Activity 13: FP32 / FP16 / FP8 precision. AdamW stores m & v per parameter, doubling the memory footprint of weights alone. FP16 cuts that footprint in half; FP8 cuts it again. Precision choices interact with optimizer stability.


AdamW, warmup, clipping, & precision form a four-leaf clover. Drop one leaf, watch ANDREA collapse.

Optimizer Reflection

Pick one ANDREA-120M v1 sample (e.g., `region region region region`) & describe how AdamW's `weight_decay * p` term, applied every step from step 0, would have prevented that specific failure mode. One paragraph.