Звичайний SGD не може навчати ANDREA
Стохастичний градієнтний спуск, точка відліку
Backprop обчислює градієнт g для кожного параметра. Звичайний стохастичний градієнтний спуск (SGD) оновлює кожен параметр за формулою p -= lr * g. Одна швидкість навчання, один напрямок на крок, без пам’яті про минулі градієнти.
Звичайний SGD ламається на великих масштабах з двох причин:
1. Градиєнти мають радикально різні величини серед параметрів. Ембединг для рідкісного токена отримує крихітний градієнт більшість кроків; масштаб layernorm отримує великий. Один learning rate не може підійти обом.
2. Градиєнти коливаються. Шумний міні-батч з корпусу 16 джерел штовхає параметр ліворуч, потім праворуч, потім ліворуч. Звичайний SGD витрачає кроки, борючись сам із собою.
Adam (Kingma & Ba, 2015) виправляє обидві проблеми за допомогою двох ковзних середніх на параметр.
Перший момент & Другий момент
m: Згладжений напрямок
Перший момент m експоненційно усереднює недавні градієнти:
m = beta1 m + (1 - beta1) g
з beta1 = 0.9. Після кількох кроків m несе згладжений напрямок; один поганий батч ледь його змінює.
v: Згладжена величина
Другий момент v усереднює недавні квадрати градієнтів:
v = beta2 v + (1 - beta2) g^2
з beta2 = 0.999. v відстежує, наскільки великим зазвичай є градієнт кожного параметра. Параметри з великими градієнтами отримують великий v; параметри з крихітними градієнтами отримують малий v.
Адаптивна швидкість навчання для кожного параметра
Ділення згладженого напрямку на квадратний корінь згладженої величини переналаштовує кожен параметр на порівнянну основу:
adam_step = m / sqrt(v + eps)
Тонкі градієнтні ембедінги масштабуються вгору; великі градієнтні layernorm масштабуються вниз. Один глобальний lr тепер підходить для кожного параметра.
Читання моментів
Чому ранні кроки потребують корекції зміщення
Зміщення холодного старту
m та v починають з нуля. Після кроку 1, m = 0.1 g_1 та v = 0.001 g_1^2. Обидві оцінки значно недооцінюють довгострокову середню. Без корекції оптимізатор починає обережно та повільно нарощує, витрачаючи дорогоцінні ранні кроки, коли формуються представлення.
Корекція
Adam масштабує кожну оцінку на 1 / (1 - beta^t), де t — номер кроку:
m_hat = m / (1 - beta1^t)
v_hat = v / (1 - beta2^t)
На кроці 1 з beta1 = 0.9 дільник (1 - 0.9) = 0.1, тому m_hat = m / 0.1 = 10 * m. Виправлена оцінка збігається з тим, що передбачає довгострокова середня. Зі зростанням t значення beta^t наближається до 0, виправлення наближається до 1, та виправлені та невиправлені значення збігаються.
Розділена вага занепаду (інновація AdamW)
L2 регуляризація проти ваги занепаду
Класична L2 регуляризація додає штраф до втрат: L_total = L_data + (lambda / 2) sum(p^2). Зворотне поширення бачить цей штраф як частину градієнта: g_total = g_data + lambda p. Член L2 проходить через оновлення m та v Adam, згладжуючись та перемасштабовуючись залежно від величин параметрів.
Loshchilov & Hutter (2019) довели, що згладжування регуляризатора через Adam псує обидва. Адаптивне масштабування Adam зменшує weight decay на параметрах з великими градієнтами (де decay повинен найсильніше боротися з перенавчанням) та підсилює його на параметрах з малими градієнтами.
AdamW: Застосовувати Decay безпосередньо
AdamW відокремлює weight decay від градієнта. Decay застосовується безпосередньо до кожного параметра під час оновлення параметра, ніколи не торкаючись m чи v:
p -= lr (m_hat / (sqrt(v_hat) + eps) + weight_decay p)
Тепер два члени керують кожним кроком:
1. Термін Adam: m_hat / (sqrt(v_hat) + eps) перемасштабує напрямок градієнта з урахуванням історії магнітуди для кожного параметра.
2. Термін затухання: weight_decay * p зменшує кожен параметр до нуля рівномірно, без проходження через згладжування Adam.
ANDREA-120M v2 встановлює weight_decay = 0.01. На кожному кроці кожен параметр зменшується на 1% до нуля, окрім того, що робить термін Adam.
Чому роз'єднане затухання важливе
Емпіричні докази
Колапс v1 (без weight decay)
ANDREA-120M v1 тренувалася 165K кроків з vanilla Adam. Зразки виводів:
- Крок 80K: region region region region region region region
- Крок 110K: ''''' ''''' '' ''' '' ''' '''?' ''' ' '' '' '
- Крок 140K: games, games, games, games, games, games, games
- Крок 165K: Budy Budy Budy Budy Budy Budy Budy Budy Budy
Значення втрат залишалися розумними (EMA мінімум 3.23 на кроці 110K, проти випадкового шансу 9.04). Втрати самі по собі приховують колапс повторення: модель, яка запам'ятовує один токен назавжди, досягає низької крос-ентропії на кожному кроці, де з'являється цей токен.
Стабільність v2 (weight_decay = 0.01)
v2 додав AdamW (плюс обрізання градієнтів, розігрів LR, моніторинг зразків). На кроці ~112K згенеровано зразки:
- Караолінський папужка був оголошений вимерлим у 1939 році (фактично правильно)
- Дискретне перетворення Фур'є розкладає сигнали на частотні компоненти (підручникове визначення)
- Ритмічний рефрен дощу, Струмочки на вікні, Перепочинок від болю життя (вимоги хайку виконані)
Зовнішня оцінка оцінила зразки v2 у 9.5/10, назвавши їх "вражаючою когерентністю та збереженням знань на цьому масштабі."
12M вижив без AdamW. Чому?
ANDREA-12M тренувалася на vanilla Adam без колапсу. При 12M параметрів матриці ваг залишаються достатньо малими, щоб адаптивне масштабування Adam не могло штовхати окремі ваги до неконтрольованих величин, що призводять до повторення. На масштабі 120M величини ваг дрейфують далі на крок та накопичуються; уніформний затухання застосовує постійну відновлювальну силу до нуля. Відокремлене затухання ваг має більше значення при масштабуванні моделі.
Вибір weight_decay = 0.01
Суміжні активності
AdamW переплітається з трьома суміжними активностями в цьому курсі:
- Активність 11: LR warmup + cosine decay. AdamW сам по собі не може врятувати модель від миттєвої пікової швидкості навчання на свіжо ініціалізованих вагах. Warmup поступово підвищує lr протягом 2000 кроків, щоб корекція зміщення AdamW та вага занепаду встигли стабілізувати представлення.
- Активність 12: Gradient clipping. AdamW припускає, що градієнти мають обмежену величину. Джерело переходів кожні 7–42 кроки в бандиті ANDREA створює періодичні сплески градієнтів; кліпінг обмежує їх нормою L2 1.0 ДО того, як AdamW торкнеться m, v чи p.
- Активність 13: Точність FP32 / FP16 / FP8. AdamW зберігає m та v для кожного параметра, подвоюючи обсяг пам'яті тільки для ваг. FP16 скорочує цей обсяг удвічі; FP8 скорочує ще раз. Вибір точності взаємодіє зі стабільністю оптимізатора.
AdamW, warmup, clipping та precision утворюють чотирилисткову конюшину. Випустіть один листочок — і спостерігайте, як ANDREA руйнується.