Gradyan Piklerinin Kaynağı
Sakin Bir Mini-Toplu & Şok Edici Bir Mini-Toplu
Çoğu mini-toplu, makul büyüklükte gradyanlar üretir. Veriyi kabaca uyan bir model için çapraz-entroji kaybı dar bir bantta kalır; geri yayılım bu sinyali benzer büyüklükte gradyanlar olarak geri taşır.
Bazı mini-toplular yapmaz. Gradyan piklerinin üç kaynağı:
1. Aykırı örnekler. Tek bir sekans, son derece nadir bir token kombinasyonu içeren, ortalamadan uzak bir kayıp ve ortalamadan uzak bir gradyan üretir.
2. Sayısal kenar durumlar. Sıfıra yakın bir softmax paydası, NaN üreten bir katman normalizasyonu, bir FP16 taşması. Her biri tipik olanlardan büyüklük sırası daha büyük gradyanlar üretebilir.
3. Dağılım kaymaları. Tek bir eğitim çalışması sırasında veri kaynaklarını değiştirmek, modeli yeni bir dağılımla şok eder. ANDREA'nın bandit'i her 7 ila 42 adımda kaynak ağırlıklarını yeniden karıştırır. Her geçiş küçük bir dağılım kaymasıdır.
ANDREA-120M v1: Spike Zinciri
v1'de gradyan kesme yoktu. Banditten her 7 ila 42 adımda gelen kaynak geçişleri, modele kısa repo-docs (liste yapılı) patlamaları, sonra gutenberg (uzun düzyazı), sonra hermes3-general (S&A) besledi. Her geçiş gradyan sivri uçları üretti: her sivri uç, 120M ölçeğinde ağırlıkları yozlaşmış çekicilere itti.
Temel ampirik gerçek. ANDREA-12M, kırpmaya maruz kalmadan aynı haydut saldırısından sağ kurtuldu. Daha küçük ağırlık matrisleri gradyan şoklarına karşı dayanıklı kalır; tek bir kötü batch, 120M parametreyi olduğu gibi 12M parametreyi de kaçak bir çekiciye itemez. Model ölçeği büyüdükçe kırpma daha önemli hale gelir.
Global L2 Norm Kırpma
İki Seçenek: Tensor Başına veya Global
Gradyan büyüklüklerini sınırlamanın iki yolu:
Tensor başına kırpma. Her gradyan tensörünü bağımsız olarak kırp. Gömme gradyanı kendi normuna kırpılır; dikkat gradyanı kendi normuna kırpılır. Basit, ancak göreli ölçekleri bozar: bir tensördeki küçük bir sıçrama (şimdi sıfır gradyan) diğerindeki devasa gradyanla (dokunulmamış) eşleşir.
Global L2 norm kırpma. Tüm gradyanları tek bir büyük vektör olarak ele al. Her parametredeki toplam L2 normunu hesapla. Norm max_norm'u aşarsa, her gradyanı aynı faktörle ölçekle. Tensörler arasındaki göreli büyüklükleri korur.
ANDREA global kullanır. Pascanu ve diğerleri (2013), transformer eğitimi için global clipping'in tensor başına clipping'i empirik olarak geçtiğini gösterdi.
Matematik
Global L2 normunu hesapla:
norm = sqrt(sum over all params of g_i^2)
Eğer norm <= max_norm ise, gradyanlar değişmeden geçer. Eğer norm > max_norm ise, her gradyanı max_norm / norm ile ölçekle:
g_i_clipped = g_i * (max_norm / norm)
Ölçeklendirmeden sonra, yeni norm tam olarak max_norm olur. ANDREA max_norm = 1.0 kullanır.
Ölçek Faktörü Hesaplama
Gradient Norm Hesaplamasının Üç Çekirdeğe Neden İhtiyaç Duyduğu
Saf Algoritma GPU'da Çalıştırılamaz
Global L2 norm hesaplama için sözde kod:
toplam = 0
her parametre p için:
p.grad içindeki her eleman g için:
toplam += g * g
norm = sqrt(toplam)
Bir GPU'da, bu naif döngü iki nedenden dolayı başarısız olur:
1. Sıralı birikim. Tek bir total biriktirici, her iş parçacığının diğer tüm iş parçacıklarının beklemesini zorunlu kılar ve GPU paralelliğini bozar.
2. Heterojen tensörler. ANDREA-120M, büyük ölçüde farklı şekillere sahip tensörlere sahiptir: gömme (8449 x 768), dikkat QKV (768 x 768), katman normu (768). Tek bir kernel tüm şekilleri verimli bir şekilde yineleyemez.
ANDREA'nın Üç-Kernel Boru Hattı
İşi microgpt_cuda.cu içindeki üç CUDA kernel'ine bölün:
Kernel 1: k_grad_norm_partial. Her parametre tensörü için karelerin kısmi toplamını hesaplayın. Her iş parçacığı bloğu tensörün bir parçasını indirger; sonuçlar küçük bir geçici tampona yazılır. Paralellik: parça başına bir blok, tüm tensörler genelinde yüzlerce blok.
Kernel 2: k_grad_norm_final. Scratch tamponunu tek bir skaler değere indirge. Karekökünü al. Küçük bir kernel, mikrosaniyeler içinde çalışır.
Kernel 3: k_grad_scale. Eğer norm > max_norm ise, scale = max_norm / norm hesapla ve her gradient elemanını scale ile çarp. Her gradient tensörü üzerinde bir geçiş, utanç verici derecede paralel.
Sıra Önemli: Pre-Adam
Kesme boru hattı, AdamW m, v veya herhangi bir parametreyi güncellemeden ÖNCE çalışır. Neden?
Kırpılmış gradyanlar AdamW'nin üstel hareketli ortalamalarına beslenir. Eğer bir ani artış m & v içine akmasına izin verilseydi, bu çalışan ortalamaları bozar ve ani artıştan sonraki birçok adımda toparlanmayı yavaşlatırdı. Adam öncesi kırpma, ani artışın etkisini tek kötü adıma hapseder.
Neden Üç Kernel, Bir Değil?
No-Clipping'in v1'i Nasıl Öldürdüğü
Bandit Kaynak Geçişleri Her 7 ila 42 Adımda
ANDREA'nın banditi aşamalar halinde çalışır. Her aşama 7, 14, 21, 28 veya 42 adım (rastgele seçilir) sürer. Her aşama sınırında kaynak ağırlıkları değişir: belki repo-docs 0.1'den 0.6'ya sıçrar, gutenberg 0.4'ten 0.1'e düşer, hermes3-general 0.5'ten 0.7'ye yükselir.
Her geçiş, modele bir dağılım şoku olur. Kayıp kısa süreli sıçrar. Gradyanlar da onunla sıçrar: gutenberg tadında düzyazıya karşı kaybı minimize eden bir model şimdi repo-docs tadında liste yapılarını görür ve gradyanlar 10x veya 100x tipik büyüklükte düzeltici sinyal taşır.
v1 Arıza Modu
Klipping olmadan, bu 10-100x gradyan diklikleri AdamW'nin m & v ortalamalarına akıyordu. AdamW'nin yumuşatma etkisi, diklik etkisinin gerçek kötü batch'ten sonra birçok adım boyunca devam etmesine neden oluyordu. v1'de ağırlık çürümesi olmaması (vanilla Adam) ile birleşince, diklik kaynaklı ağırlık güncellemeleri aşamalar boyunca birikiyor ve ağırlıklar dejenerat bir çekiciye sürükleniyordu: bir token'ın logit'i softmax'ı domine ediyor, örneklenen çıktı o token oluyor, eğitim bağlamı o token'ı içeriyor, gradyan o token'ı pekiştiriyordu. Tekrarlama kilitlenmesi.
v2 Kararlılığı
v2, max_norm = 1.0 ile klipping ekledi, AdamW & LR ısınma ile birlikte. m & v üzerindeki diklik etkisi sınırlı; ağırlıklar tepe noktasında parametre başına adım başına lr max_norm = 0.0003 1.0 = 0.0003'ten daha hızlı sürüklenemez. Aşama geçişleri hala diklikler üretir, ancak bu diklikler optimize ediciye ulaşmadan önce sınırlanır.
Sonuç: v2 (veri filtresi v2.5 & v3 cilası sonrası) gerçekçi hatırlama, çok paragraflı tutarlılık ve biyoloji & sinyal işleme örneklerinde 9.5/10 dış notlara ulaştı.
Kapasite-Kırılganlık Eşleşmesi
Aynı haydut. Aynı veri. Aynı hiperparametreler hariç clipping. Neden 12M clipping olmadan hayatta kalırken 120M çöktü?
İki birleşen faktör:
1. Daha büyük ağırlık matrisleri daha fazla çekici depolar. 768x768 dikkat projeksiyonu 590K parametreye sahiptir; parametre başına bile küçük sapmalar dikkat davranışında anlamlı değişiklikler üretir. 384x384 dikkat projeksiyonu 147K parametreye sahiptir ve daha kısıtlı bir alt uzayda kalır.
2. Daha fazla katman daha fazla çarpımsal etkileşim anlamına gelir. v3'ün 12 transformer katmanı vardır (12M için 6'ya karşı). Ani yükselmeler 12 katmanlık birleşen doğrusal olmayanlıklardan geçer; her katman önceki katmanın sapmasını büyütebilir.
Kırılganlık kapasiteyle birleşir. Belirli bir ölçek eşiğinin üzerinde clipping zorunlu hale gelir; ANDREA bu eşiği 12M ile 120M parametre arasında bir yere koyar.
v1 Kaskadını Teşhis Etme
Klipleme Başka Nerede Uygulanır?
Komşu Etkinlikler
Clipping'e üç kardeş bağlanır:
- Etkinlik 10: AdamW. Clipping, AdamW'nin m ve v'sini spike kontaminasyonundan korur. Clipping olmadan, bir kötü parti optimizer durumunu 50+ adım için bozar.
- Etkinlik 11: LR ısınma. Isınma lr'yi sönümler; clipping g'yi sönümler. Birlikte: 1. adımda, en kötü durum parametre güncellemesi lr_after_warmup max_norm = 1.5e-7 1.0 = 1.5e-7 olur, her iki koruma olmadan ise 0.0003 * 50 = 0.015. En kötü durum erken güncelleme büyüklüğünde 100.000 kat azalma.
- Etkinlik 14: Çok kollu haydutlar. Haydut aşamasının uzunluğu (7 ila 42 adım) özellikle herhangi bir kaynağın baskın hale gelmesini önlemek için kısadır; kırpma, bu sık geçişlerin güvenli olmasını sağlayan şeydir.
Kırpma, transformer eğitimindeki en ucuz istikrar kazancıdır: 3 küçük CUDA çekirdeği, adım başına mikrosaniyeler, 120M+ modellerin yakınsama mı yoksa çökme mi yaşayacağını belirleyen kesin etki.