SGD Simple No Puede Entrenar ANDREA
Descenso de Gradiente Estocástico, el Punto de Partida
Backprop calcula un gradiente g para cada parámetro. El descenso de gradiente estocástico (SGD) simple actualiza cada parámetro con p -= lr * g. Una tasa de aprendizaje, una dirección por paso, sin memoria de gradientes pasados.
SGD simple falla a escala por dos razones:
1. Los gradientes tienen magnitudes muy diferentes entre parámetros. Un embedding para un token raro recibe un gradiente diminuto en la mayoría de los pasos; una escala de layernorm recibe uno grande. Una sola tasa de aprendizaje no puede adaptarse a ambos.
2. Los gradientes oscilan. Un mini-batch ruidoso de un corpus de 16 fuentes empuja un parámetro hacia la izquierda, luego hacia la derecha, luego hacia la izquierda. SGD simple desperdicia pasos luchando contra sí mismo.
Adam (Kingma & Ba, 2015) corrige ambos con dos promedios móviles por parámetro.
Primer Momento & Segundo Momento
m: Dirección Suavizada
El primer momento m promedia gradientes recientes de forma exponencial:
m = beta1 m + (1 - beta1) g
con beta1 = 0.9. Después de varios pasos, m lleva una dirección suavizada; un lote malo apenas lo desplaza.
v: Magnitud Suavizada
El segundo momento v promedia gradientes al cuadrado recientes:
v = beta2 v + (1 - beta2) g^2
con beta2 = 0.999. v rastrea qué tan grande se vuelve típicamente el gradiente de cada parámetro. Los parámetros con gradientes grandes obtienen un v grande; los parámetros con gradientes diminutos obtienen un v pequeño.
Tasa de Aprendizaje Adaptativa por Parámetro
Dividir la dirección suavizada por la raíz cuadrada de la magnitud suavizada reescala cada parámetro a un pie de igualdad comparable:
adam_step = m / sqrt(v + eps)
Los embeddings de gradiente pequeño se escalan hacia arriba; las layernorms de gradiente grande se escalan hacia abajo. Ahora un lr global es adecuado para cada parámetro.
Leyendo los Moments
Por qué los pasos iniciales necesitan corrección de sesgo
Sesgo de inicio en frío
m & v comienzan en cero. Después del paso 1, m = 0.1 g_1 & v = 0.001 g_1^2. Ambas estimaciones subestiman drásticamente un promedio a largo plazo. Sin corrección, el optimizador comienza tímido y aumenta lentamente, desperdiciando preciados pasos iniciales cuando se forman las representaciones.
La Corrección
Adam escala cada estimación por 1 / (1 - beta^t) donde t es el número de paso:
m_hat = m / (1 - beta1^t)
v_hat = v / (1 - beta2^t)
En el paso 1 con beta1 = 0.9, el divisor (1 - 0.9) = 0.1, por lo que m_hat = m / 0.1 = 10 * m. La estimación corregida de sesgo coincide con lo que predeciría el promedio a largo plazo. A medida que t crece, beta^t se acerca a 0, la corrección se acerca a 1, y los valores corregidos y no corregidos convergen.
Decaimiento de Peso Desacoplado (la Innovación de AdamW)
Regularización L2 vs Decaimiento de Peso
La regularización L2 clásica agrega una penalización a la pérdida: L_total = L_data + (lambda / 2) sum(p^2). La retropropagación ve esa penalización como parte del gradiente: g_total = g_data + lambda p. El término L2 fluye a través de las actualizaciones m y v de Adam, siendo suavizado y reescalado por magnitudes por parámetro.
Loshchilov & Hutter (2019) demostraron que suavizar un regularizador a través de Adam corrompe ambos. La escala adaptativa de Adam reduce la weight decay en parámetros con gradientes grandes (donde la decay debería combatir el sobreajuste con más fuerza) y la amplifica en los de gradientes pequeños.
AdamW: Aplicar Decay Directamente
AdamW desacopla la weight decay del gradiente. La decay se aplica directamente a cada parámetro durante la actualización del parámetro, sin tocar nunca m ni v:
p -= lr (m_hat / (sqrt(v_hat) + eps) + weight_decay p)
Dos términos ahora impulsan cada paso:
1. Término Adam: m_hat / (sqrt(v_hat) + eps) reescala la dirección del gradiente por el historial de magnitud por parámetro.
2. Término de decaimiento: weight_decay * p reduce cada parámetro hacia cero de manera uniforme, sin pasar por el suavizado de Adam.
ANDREA-120M v2 establece weight_decay = 0.01. En cada paso, cada parámetro se reduce un 1% hacia cero, además de lo que haga el término Adam.
Por qué el Decaimiento Desacoplado Importa
Evidencia Empírica
Colapso v1 (sin weight decay)
ANDREA-120M v1 entrenado por 165K pasos con Adam vanilla. Salidas de muestra:
- Paso 80K: region region region region region region region
- Paso 110K: ''''' ''''' '' ''' '' ''' '''?' ''' ' '' '' '
- Paso 140K: games, games, games, games, games, games, games
- Paso 165K: Budy Budy Budy Budy Budy Budy Budy Budy Budy
Los números de pérdida se mantuvieron razonables (mínimo EMA 3.23 en el paso 110K, vs 9.04 por azar aleatorio). La pérdida por sí sola oculta el colapso por repetición: un modelo que memoriza un token para siempre logra una baja entropía cruzada en cada paso en que aparece ese token.
Estabilidad v2 (weight_decay = 0.01)
v2 añadió AdamW (más recorte de gradiente, calentamiento de LR, monitoreo de muestras). En el paso ~112K, las muestras producidas:
- El periquito de Carolina fue declarado extinto en 1939 (factualmente correcto)
- La transformada de Fourier descompone señales en componentes de frecuencia (definición de libro de texto)
- Lluvia rítmica repite, Regueros en la ventana, Respiro del dolor (restricción de haiku satisfecha)
La calificación externa valoró las muestras de v2 en 9.5/10, llamándolas "impresionante coherencia y retención de conocimiento a esta escala."
Los 12M sobrevivieron sin AdamW. ¿Por qué?
ANDREA-12M entrenado con Adam vanilla sin colapso. Con 12M parámetros, las matrices de pesos permanecen lo suficientemente pequeñas para que la escala adaptativa de Adam no pueda empujar pesos individuales a magnitudes descontroladas que provocan repetición. A escala de 120M, las magnitudes de los pesos se desvían más por paso y se acumulan; la decadencia uniforme aplica una fuerza restauradora constante hacia cero. La decadencia de pesos desacoplada importa más a medida que el modelo escala.
Elegir weight_decay = 0.01
Actividades adyacentes
AdamW se entrelaza con tres actividades hermanas en este curso:
- Actividad 11: Calentamiento de LR + decaimiento coseno. AdamW solo no puede salvar un modelo de una tasa de aprendizaje máxima instantánea en pesos recién inicializados. El calentamiento incrementa lr durante 2000 pasos para que la corrección de sesgo de AdamW y la decaimiento de pesos tengan tiempo de estabilizar las representaciones.
- Actividad 12: Recorte de gradientes. AdamW asume que los gradientes tienen magnitud acotada. Las transiciones de fuente cada 7 a 42 pasos en el bandido de ANDREA producen picos ocasionales de gradientes; el recorte los limita a norma L2 1.0 ANTES de que AdamW toque m, v o p.
- Actividad 13: Precisión FP32 / FP16 / FP8. AdamW almacena m y v por parámetro, duplicando la huella de memoria de los pesos solos. FP16 reduce esa huella a la mitad; FP8 la reduce nuevamente. Las elecciones de precisión interactúan con la estabilidad del optimizador.
AdamW, warmup, clipping y precision forman un trébol de cuatro hojas. Quita una hoja, observa cómo ANDREA colapsa.