Zwykły SGD Nie Może Trenować ANDREA
Stochastic Gradient Descent, Punkt Startowy
Backprop oblicza gradient g dla każdego parametru. Zwykły stochastic gradient descent (SGD) aktualizuje każdy parametr za pomocą p -= lr * g. Jeden learning rate, jeden kierunek na krok, brak pamięci poprzednich gradientów.
Zwykły SGD psuje się na dużą skalę z dwóch powodów:
1. Gradienty mają bardzo różne wartości wśród parametrów. Embedding dla rzadkiego tokena otrzymuje malutki gradient przez większość kroków; skala layernorm otrzymuje duży. Jedna szybkość uczenia nie pasuje do obu.
2. Gradienty oscylują. Hałaśliwa mini-batch z korpusu 16 źródeł pcha parametr w lewo, potem w prawo, potem w lewo. Zwykły SGD marnuje kroki walcząc sam ze sobą.
Adam (Kingma & Ba, 2015) naprawia oba problemy za pomocą dwóch średnich kroczących na parametr.
Pierwsza Chwila & Druga Chwila
m: Wygładzony Kierunek
Pierwsza chwila m uśrednia ostatnie gradienty wykładniczo:
m = beta1 m + (1 - beta1) g
z beta1 = 0.9. Po kilku krokach m zawiera wygładzony kierunek; jeden zły batch ledwo go przesuwa.
v: Wygładzona wielkość
Drugie momento v uśrednia ostatnie gradienty do kwadratu:
v = beta2 v + (1 - beta2) g^2
z beta2 = 0.999. v śledzi, jak duże są typowo gradienty każdego parametru. Parametry z dużymi gradientami otrzymują dużą wartość v; parametry z małymi gradientami otrzymują małą wartość v.
Adaptacyjna szybkość uczenia się per parametr
Podział wygładzonego kierunku przez pierwiastek kwadratowy z wygładzonej wielkości przeskalowuje każdy parametr na porównywalną skalę:
adam_step = m / sqrt(v + eps)
Małe gradienty embeddingów są skalowane w górę; duże gradienty layernormów są skalowane w dół. Jeden globalny lr teraz pasuje do każdego parametru.
Czytanie Momentów
Dlaczego wczesne kroki potrzebują korekty biasu
Cold-Start Bias
m & v start at zero. After step 1, m = 0.1 g_1 & v = 0.001 g_1^2. Both estimates undershoot a long-run average dramatically. Without correction, the optimizer starts timid & ramps up slowly, wasting precious early steps when representations form.
Korekta
Adam skaluje każdą estymatę przez 1 / (1 - beta^t) gdzie t to numer kroku:
m_hat = m / (1 - beta1^t)
v_hat = v / (1 - beta2^t)
W kroku 1 przy beta1 = 0.9 dzielnik (1 - 0.9) = 0.1, więc m_hat = m / 0.1 = 10 * m. Skorygowana estymacja błędu odpowiada temu, co przewiduje długoterminowa średnia. W miarę jak t rośnie, beta^t zbliża się do 0, korekta zbliża się do 1, a skorygowane i nieskorygowane wartości zbiegają się.
Rozdzielone Wygašanie Wagi (innowacja AdamW)
Regularizacja L2 vs Wygašanie Wagi
Klasyczna regularizacja L2 dodaje karę do straty: L_total = L_data + (lambda / 2) sum(p^2). Backprop widzi tę karę jako część gradientu: g_total = g_data + lambda p. Termin L2 przepływa przez aktualizacje m i v w Adamie, ulegając wygładzeniu i przeskalowaniu przez magnitudy per-parametru.
Loshchilov & Hutter (2019) udowodnili, że wygładzanie regularizatora za pomocą Adam korumpuje oba. Adaptacyjne skalowanie Adam zmniejsza weight decay na parametrach o dużych gradientach (gdzie decay powinien najmocniej zwalczać przetrenowanie) i wzmacnia go na tych o małych gradientach.
AdamW: Zastosuj Decay Bezpośrednio
AdamW oddziela weight decay od gradientu. Decay jest stosowany bezpośrednio do każdego parametru podczas aktualizacji parametru, nigdy nie dotykając m ani v:
p -= lr (m_hat / (sqrt(v_hat) + eps) + weight_decay p)
Dwa składniki teraz napędzają każdy krok:
1. Termin Adama: m_hat / (sqrt(v_hat) + eps) skaluje kierunek gradientu na podstawie historii magnitudy per-parametru.
2. Termin zaniku: weight_decay * p zmniejsza każdy parametr w kierunku zera, równomiernie, bez przechodzenia przez wygładzanie Adama.
ANDREA-120M v2 ustawia weight_decay = 0.01. Na każdym kroku każdy parametr zmniejsza się o 1% w kierunku zera, oprócz tego, co robi termin Adama.
Dlaczego dekoplowanie ma znaczenie
Dowody empiryczne
Kolaps v1 (brak weight decay)
ANDREA-120M v1 trenowane przez 165K kroków z vanilla Adam. Przykładowe wyjścia:
- Krok 80K: region region region region region region region
- Krok 110K: ''''' ''''' '' ''' '' ''' '''?' ''' ' '' '' '
- Krok 140K: games, games, games, games, games, games, games
- Krok 165K: Budy Budy Budy Budy Budy Budy Budy Budy Budy
Wartości straty pozostały rozsądne (minimum EMA 3,23 w kroku 110K, w porównaniu do losowego 9,04). Sama strata ukrywa załamanie powtarzania: model, który zapamiętuje jeden token na zawsze, osiąga niską entropię krzyżową w każdym kroku, w którym ten token się pojawia.
Stabilność v2 (weight_decay = 0.01)
v2 dodało AdamW (plus obcinanie gradientu, rozgrzewka LR, monitorowanie próbek). W kroku ~112K, wyprodukowane próbki:
- Karoliński papugaryk został uznany za wymarły w 1939 roku (faktycznie poprawne)
- Transformata Fouriera rozkłada sygnały na składowe częstotliwościowe (podręcznikowa definicja)
- Rytmiczny refren deszczu, Strumyczki na oknie, Ukojenie bólu życia (haiku spełnia ograniczenia)
Zewnętrzna ocena oceniła próbki v2 na 9.5/10, nazywając je „imponującą spójnością i retencją wiedzy na tę skalę.”
12M przetrwało bez AdamW. Dlaczego?
ANDREA-12M trenowana na vanilla Adam bez załamania. Przy 12M parametrów, macierze wag pozostają wystarczająco małe, że adaptacyjne skalowanie Adama nie może wypchnąć pojedynczych wag do niekontrolowanych wielkości powodujących powtarzanie. Przy skali 120M, wielkości wag dryfują dalej na krok i się akumulują; jednolity zanik stosuje stałą przywracającą siłę w kierunku zera. Odłączony zanik wag ma większe znaczenie w miarę skalowania modelu.
Wybór weight_decay = 0.01
Pokrewne aktywności
AdamW współdziała z trzema pokrewnymi aktywnościami w tym kursie:
- Aktywność 11: Rozgrzewka LR + cosine decay. Sam AdamW nie uratuje modelu przed natychmiastową maksymalną szybkością uczenia na świeżo zainicjalizowanych wagach. Rozgrzewka stopniowo zwiększa lr przez 2000 kroków, aby korekcja biasu AdamW oraz weight decay miały czas na ustabilizowanie reprezentacji.
- Aktywność 12: Gradient clipping. AdamW zakłada, że gradienty mają ograniczoną wielkość. Źródło przechodzi co 7 do 42 kroków w bandycie ANDREA, co powoduje okazjonalne skoki gradientów; clipping ogranicza je do normy L2 1.0 PRZED dotknięciem m, v lub p przez AdamW.
- Aktywność 13: Precyzja FP32 / FP16 / FP8. AdamW przechowuje m i v na parametr, podwajając zużycie pamięci wag. FP16 redukuje to zużycie o połowę; FP8 redukuje je ponownie. Wybory precyzji oddziałują na stabilność optymalizatora.
AdamW, warmup, clipping i precyzja tworzą czterolistną koniczynę. Zgubisz jeden listek, patrz, jak ANDREA się załamuje.