SGD Simples Não Consegue Treinar ANDREA
Descida de Gradiente Estocástico, o Ponto de Partida
Backprop computa um gradiente g para cada parâmetro. Descida de gradiente estocástico (SGD) simples atualiza cada parâmetro com p -= lr * g. Uma taxa de aprendizado, uma direção por passo, sem memória de gradientes passados.
SGD simples falha em escala por dois motivos:
1. Os gradientes têm magnitudes extremamente diferentes entre parâmetros. Um embedding para um token raro recebe um gradiente minúsculo na maioria dos passos; uma escala de layernorm recebe um grande. Uma taxa de aprendizado não pode servir a ambos.
2. Os gradientes oscilam. Um mini-batch ruidoso de um corpus de 16 fontes empurra um parâmetro para a esquerda, depois para a direita, depois para a esquerda. SGD simples desperdiça passos lutando contra si mesmo.
Adam (Kingma & Ba, 2015) corrige ambos com duas médias móveis por parâmetro.
Primeiro Momento & Segundo Momento
m: Direção Suavizada
O primeiro momento m faz a média de gradientes recentes de forma exponencial:
m = beta1 m + (1 - beta1) g
com beta1 = 0.9. Após vários passos, m carrega uma direção suavizada; um lote ruim mal o desloca.
v: Magnitude Suavizada
O segundo momento v faz a média de gradientes quadrados recentes:
v = beta2 v + (1 - beta2) g^2
com beta2 = 0.999. v rastreia o quão grande o gradiente de cada parâmetro tipicamente fica. Parâmetros com gradientes grandes recebem um v grande; parâmetros com gradientes minúsculos recebem um v pequeno.
Taxa de Aprendizado Adaptativa Por Parâmetro
Dividir a direção suavizada pela raiz quadrada da magnitude suavizada reescala todos os parâmetros para uma base comparável:
adam_step = m / sqrt(v + eps)
Os embeddings de gradiente pequeno são escalados para cima; os layernorms de gradiente grande são escalados para baixo. Um único lr global agora serve para todos os parâmetros.
Lendo os Moments
Por que os Passos Iniciais Precisam de Correção de Viés
Viés de Inicialização Fria
m & v começam em zero. Após o passo 1, m = 0.1 g_1 & v = 0.001 g_1^2. Ambas as estimativas ficam dramaticamente abaixo de uma média de longo prazo. Sem correção, o otimizador começa tímido & acelera lentamente, desperdiçando passos iniciais preciosos quando as representações se formam.
A Correção
Adam escala cada estimativa por 1 / (1 - beta^t) onde t é o número do passo:
m_hat = m / (1 - beta1^t)
v_hat = v / (1 - beta2^t)
No passo 1 com beta1 = 0.9, o divisor (1 - 0.9) = 0.1, então m_hat = m / 0.1 = 10 * m. A estimativa corrigida de viés corresponde ao que a média de longo prazo preveria. À medida que t cresce, beta^t se aproxima de 0, a correção se aproxima de 1, & os valores corrigidos & não corrigidos convergem.
Decaimento de Peso Desacoplado (a Inovação do AdamW)
Regularização L2 vs Decaimento de Peso
A regularização L2 clássica adiciona uma penalidade à perda: L_total = L_data + (lambda / 2) sum(p^2). O backprop vê essa penalidade como parte do gradiente: g_total = g_data + lambda p. O termo L2 flui pelas atualizações de m & v do Adam, sendo suavizado & reescalado por magnitudes por parâmetro.
Loshchilov & Hutter (2019) provaram que suavizar um regularizador através do Adam corrompe ambos. O escalonamento adaptativo do Adam reduz o weight decay em parâmetros com gradientes grandes (onde o decay deveria combater o overfitting mais intensamente) e amplifica em parâmetros com gradientes pequenos.
AdamW: Aplicar Decay Diretamente
O AdamW desacopla o weight decay do gradiente. O decay é aplicado diretamente a cada parâmetro durante a atualização do parâmetro, nunca tocando m ou v:
p -= lr (m_hat / (sqrt(v_hat) + eps) + weight_decay p)
Dois termos agora impulsionam cada passo:
1. Termo Adam: m_hat / (sqrt(v_hat) + eps) reescala a direção do gradiente pela história de magnitude por parâmetro.
2. Termo de decaimento: weight_decay * p reduz todos os parâmetros em direção a zero, uniformemente, sem passar pelo suavização do Adam.
ANDREA-120M v2 define weight_decay = 0.01. A cada passo, cada parâmetro encolhe 1% em direção a zero, além do que o termo Adam faz.
Por Que o Decaimento Desacoplado Importa
Evidência Empírica
Colapso v1 (sem weight decay)
ANDREA-120M v1 treinado por 165K passos com Adam vanilla. Saídas de amostra:
- Passo 80K: region region region region region region region
- Passo 110K: ''''' ''''' '' ''' '' ''' '''?' ''' ' '' '' '
- Passo 140K: games, games, games, games, games, games, games
- Passo 165K: Budy Budy Budy Budy Budy Budy Budy Budy Budy
Os números de perda permaneceram razoáveis (mínimo EMA 3.23 no passo 110K, vs chance aleatória 9.04). A perda sozinha oculta o colapso de repetição: um modelo que memoriza um token para sempre alcança baixa entropia cruzada em todo passo em que esse token aparece.
Estabilidade v2 (weight_decay = 0.01)
v2 adicionou AdamW (mais clipagem de gradiente, aquecimento de LR, monitoramento de amostras). No passo ~112K, amostras produzidas:
- O periquito-de-carolina foi declarado extinto em 1939 (factualmente correto)
- A transformada de Fourier decompõe sinais em componentes de frequência (definição de livro didático)
- Refrão rítmico da chuva, Riachos na janela, Respiro da dor da vida (restrição de haiku satisfeita)
Avaliação externa classificou as amostras v2 como 9.5/10, chamando-as de "impressionante coerência & retenção de conhecimento nesta escala."
Os 12M sobreviveram sem AdamW. Por quê?
ANDREA-12M treinado com Adam vanilla sem colapso. Com 12M parâmetros, as matrizes de pesos permanecem pequenas o suficiente para que o escalonamento adaptativo do Adam não consiga empurrar pesos individuais para magnitudes descontroladas que levam à repetição. Na escala de 120M, as magnitudes dos pesos derivam mais por passo e se acumulam; o decaimento uniforme aplica uma força restauradora constante em direção a zero. O decaimento de peso desacoplado importa mais à medida que o modelo escala.
Escolhendo weight_decay = 0.01
Atividades Adjacentes
AdamW se interliga com três atividades irmãs neste curso:
- Atividade 11: Aquecimento de LR + decaimento cosseno. AdamW sozinho não pode salvar um modelo de uma taxa de aprendizado de pico instantânea em pesos recém-inicializados. O aquecimento aumenta gradualmente lr ao longo de 2000 passos para que a correção de viés do AdamW e a decaimento de peso tenham tempo de estabilizar as representações.
- Atividade 12: Clipagem de gradiente. AdamW assume que os gradientes têm magnitude limitada. Transições de fonte a cada 7 a 42 passos no bandit de ANDREA produzem picos ocasionais de gradiente; a clipagem os limita a norma L2 1.0 ANTES que AdamW toque m, v ou p.
- Atividade 13: Precisão FP32 / FP16 / FP8. AdamW armazena m e v por parâmetro, dobrando a pegada de memória dos pesos sozinhos. FP16 reduz essa pegada pela metade; FP8 reduz novamente. Escolhas de precisão interagem com a estabilidade do otimizador.
AdamW, warmup, clipping e precisão formam um trevo de quatro folhas. Deixe cair uma folha, observe ANDREA colapsar.