सादा SGD ANDREA को प्रशिक्षित नहीं कर सकता
स्टोकैस्टिक ग्रेडिएंट डिसेंट, प्रारंभिक बिंदु
बैकप्रॉप हर पैरामीटर के लिए एक ग्रेडिएंट g की गणना करता है। सादा स्टोकैस्टिक ग्रेडिएंट डिसेंट (SGD) हर पैरामीटर को p -= lr * g से अपडेट करता है। एक लर्निंग रेट, प्रति स्टेप एक दिशा, पिछले ग्रेडिएंट्स की कोई स्मृति नहीं।
सादा SGD स्केल पर दो कारणों से टूट जाता है:
1. ग्रेडिएंट्स की परिमाण बहुत भिन्न होते हैं पैरामीटर्स के पार। एक दुर्लभ टोकन के लिए एम्बेडिंग को अधिकांश स्टेप्स में छोटा ग्रेडिएंट मिलता है; एक लेयरनॉर्म स्केल को बड़ा मिलता है। एक लर्निंग रेट दोनों के लिए उपयुक्त नहीं हो सकता।
2. ग्रेडिएंट्स दोलन करते हैं। 16-स्रोत कॉर्पस से एक शोरयुक्त मिनी-बैच पैरामीटर को बाएँ धकेलता है, फिर दाएँ, फिर बाएँ। सादा SGD खुद से लड़ते हुए स्टेप्स बर्बाद करता है।
Adam (Kingma & Ba, 2015) दोनों समस्याओं को ठीक करता है प्रत्येक पैरामीटर के लिए दो रनिंग एवरेज के साथ।
प्रथम मोमेंट & द्वितीय मोमेंट
m: स्मूथ्ड दिशा
प्रथम मोमेंट m हालिया ग्रेडिएंट्स को एक्सपोनेंशियली औसत बनाता है:
m = beta1 m + (1 - beta1) g
beta1 = 0.9 के साथ। कई चरणों के बाद, m एक समतलित दिशा ले लेता है; एक खराब बैच इसे मुश्किल से हिलाता है।
v: समतलित परिमाण
दूसरा क्षण v हाल के वर्गीकृत ग्रेडिएंट्स का औसत लेता है:
v = beta2 v + (1 - beta2) g^2
beta2 = 0.999 के साथ। v प्रत्येक पैरामीटर के ग्रेडिएंट की सामान्य आकार को ट्रैक करता है। बड़े-ग्रेडिएंट पैरामीटर को बड़ा v मिलता है; छोटे-ग्रेडिएंट पैरामीटर को छोटा v मिलता है।
प्रति-पैरामीटर अनुकूली लर्निंग रेट
स्मूथ्ड दिशा को स्मूथ्ड परिमाण की वर्गमूल से विभाजित करने से हर पैरामीटर को तुलनीय आधार पर रीस्केल किया जाता है:
adam_step = m / sqrt(v + eps)
छोटे-ग्रेडिएंट एम्बेडिंग्स को स्केल अप किया जाता है; बड़े-ग्रेडिएंट लेयरनॉर्म्स को स्केल डाउन किया जाता है। अब एक ग्लोबल lr हर पैरामीटर के लिए उपयुक्त है।
मोमेंट्स को पढ़ना
प्रारंभिक चरणों को क्यों चाहिए बायस सुधार
कोल्ड-स्टार्ट बायस
m & v शून्य से शुरू होते हैं। चरण 1 के बाद, m = 0.1 g_1 & v = 0.001 g_1^2। दोनों अनुमान लंबे समय के औसत को बहुत कम आंकते हैं। सुधार के बिना, ऑप्टिमाइज़र सतर्क शुरू होता है & धीरे-धीरे बढ़ता है, प्रतिनिधित्व बनने के दौरान कीमती प्रारंभिक चरणों को बर्बाद करता है।
सुधार
Adam प्रत्येक अनुमान को 1 / (1 - beta^t) से स्केल करता है जहां t चरण संख्या है:
m_hat = m / (1 - beta1^t)
v_hat = v / (1 - beta2^t)
चरण 1 में beta1 = 0.9 के साथ, विभाजक (1 - 0.9) = 0.1 होता है, इसलिए m_hat = m / 0.1 = 10 * m। पूर्वाग्रह-सुधारा गया अनुमान लंबे समय के औसत की भविष्यवाणी से मेल खाता है। जैसे ही t बढ़ता है, beta^t 0 के करीब पहुँचता है, सुधार 1 के करीब पहुँचता है, और सुधारे गए तथा असुधारे मान अभिसरण करते हैं।
डिकपल्ड वेट डिके (AdamW नवाचार)
L2 नियमितीकरण बनाम वेट डिके
क्लासिक L2 नियमितीकरण हानि में एक दंड जोड़ता है: L_total = L_data + (lambda / 2) sum(p^2)। बैकप्रॉप उस दंड को ग्रेडिएंट का हिस्सा मानता है: g_total = g_data + lambda p। L2 पद Adam के m और v अपडेट्स के माध्यम से बहता है, जो पैरामीटर के प्रति परिमाणों द्वारा चिकना और पुनःस्केल किया जाता है।
लोशचिलोव और हुटर (2019) ने सिद्ध किया कि एडम के माध्यम से रेगुलराइज़र को स्मूथ करना दोनों को भ्रष्ट करता है। एडम का अनुकूली स्केलिंग बड़े-ग्रेडिएंट पैरामीटर्स पर वेट डिके को कम करता है (जहाँ डिके को ओवरफिटिंग से सबसे कठिन लड़ना चाहिए) और छोटे-ग्रेडिएंट वाले पर इसे बढ़ा देता है।
AdamW: डिके को सीधे लागू करें
AdamW वेट डिके को ग्रेडिएंट से अलग करता है। डिके प्रत्येक पैरामीटर पर पैरामीटर अपडेट के दौरान सीधे लागू होता है, कभी m या v को छूते बिना:
p -= lr (m_hat / (sqrt(v_hat) + eps) + weight_decay p)
अब दो टर्म्स प्रत्येक स्टेप को चलाते हैं:
1. Adam टर्म: m_hat / (sqrt(v_hat) + eps) ग्रेडिएंट दिशा को प्रति-पैरामीटर परिमाण इतिहास से पुनःस्केल करता है।
2. डिके टर्म: weight_decay * p हर पैरामीटर को शून्य की ओर समान रूप से सिकोड़ता है, बिना Adam के स्मूथिंग के माध्यम से जाए।
ANDREA-120M v2 सेट करता है weight_decay = 0.01। हर स्टेप में, हर पैरामीटर शून्य की ओर 1% सिकुड़ता है, Adam टर्म के अलावा जो भी करता है।
Decoupled क्यों महत्वपूर्ण है
प्रयोगात्मक साक्ष्य
v1 Collapse (कोई weight decay नहीं)
ANDREA-120M v1 को vanilla Adam के साथ 165K steps के लिए train किया गया। Sample outputs:
- चरण 80K: region region region region region region region
- चरण 110K: ''''' ''''' '' ''' '' ''' '''?' ''' ' '' '' '
- चरण 140K: games, games, games, games, games, games, games
- चरण 165K: Budy Budy Budy Budy Budy Budy Budy Budy Budy
नुकसान के आंकड़े उचित रहे (EMA न्यूनतम 3.23 चरण 110K पर, बनाम यादृच्छिक-संभावना 9.04)। नुकसान अकेले दोहराव पतन को छिपाता है: एक मॉडल जो एक टोकन को हमेशा याद रखता है, हर चरण पर कम क्रॉस-एंट्रॉपी प्राप्त करता है जब वह टोकन प्रकट होता है।
v2 स्थिरता (weight_decay = 0.01)
v2 में AdamW जोड़ा गया (प्लस ग्रेडिएंट क्लिपिंग, LR वार्मअप, सैंपल मॉनिटरिंग)। स्टेप ~112K पर, उत्पन्न सैंपल्स:
- Carolina parakeet को 1939 में विलुप्त घोषित किया गया (तथ्यात्मक रूप से सही)
- फूरियर ट्रांसफॉर्म सिग्नलों को फ्रीक्वेंसी घटकों में विघटित करता है (पाठ्यपुस्तक परिभाषा)
- Rain's rhythmic refrain, Rivulets on the window, Respite from life's pain (हाइकू बाधा संतुष्ट)
बाहरी ग्रेडिंग ने v2 सैंपल्स को 9.5/10 रेट किया, उन्हें "इस स्केल पर प्रभावशाली सुसंगति और ज्ञान प्रतिधारण" कहते हुए।
12M ने AdamW के बिना जीवित रह लिया। क्यों?
ANDREA-12M ने vanilla Adam पर प्रशिक्षण लिया बिना collapse के। 12M पैरामीटर्स पर, वजन मैट्रिक्स इतने छोटे रहते हैं कि Adam का adaptive scaling व्यक्तिगत वजनों को repetition चलाने वाली runaway magnitudes में धकेल नहीं पाता। 120M स्केल पर, वजन magnitudes प्रति स्टेप और दूर drift करते हैं और accumulate होते हैं; uniform decay एक constant restoring force zero की ओर लागू करता है। Decoupled weight decay मॉडल के स्केलिंग के साथ अधिक महत्वपूर्ण होता है।
weight_decay = 0.01 चुनना
संबंधित गतिविधियाँ
इस कोर्स में AdamW तीन भाई-बहन गतिविधियों के साथ इंटरलॉक करता है:
- गतिविधि 11: LR वार्मअप + कोसाइन डिके। AdamW अकेला मॉडल को ताज़ा इनिशियलाइज़्ड वेट्स पर तुरंत पीक लर्निंग रेट से नहीं बचा सकता। वार्मअप 2000 स्टेप्स पर lr को रैंप करता है ताकि AdamW का बायस करेक्शन और वेट डिके प्रतिनिधित्वों को स्थिर करने का समय मिले।
- गतिविधि 12: ग्रेडिएंट क्लिपिंग। AdamW मानता है कि ग्रेडिएंट्स की मैग्नीट्यूड बाउंडेड है। ANDREA के बैंडिट में सोर्स ट्रांज़िशन हर 7 से 42 स्टेप्स में आउटगोइंग ग्रेडिएंट स्पाइक्स पैदा करते हैं; क्लिपिंग उन्हें L2 नॉर्म 1.0 पर कैप करती है इससे पहले कि AdamW m, v, या p को छुए।
- गतिविधि 13: FP32 / FP16 / FP8 प्रेसिजन। AdamW हर पैरामीटर के लिए m और v स्टोर करता है, जो वेट्स के अकेले मेमोरी फुटप्रिंट को दोगुना कर देता है। FP16 उस फुटप्रिंट को आधा कर देता है; FP8 इसे फिर से काट देता है। प्रेसिजन चॉइस ऑप्टिमाइज़र स्टेबिलिटी के साथ इंटरैक्ट करती हैं।
AdamW, warmup, clipping, और precision एक चार-पत्ती के तिपतिया घास का रूप धारण करते हैं। एक पत्ता गिराएं, ANDREA को ढहते हुए देखें।