NVIDIA AI Lancia Gated DeltaNet-2, Un Layer Lineare Attenzione Con Erase E Write Scissi
NVIDIA ha recentemente lanciato il Gated DeltaNet-2, un innovativo layer di attenzione lineare che affronta il problema fondamentale della manipolazione della memoria compressa senza alterare le relazioni esistenti. Questo layer introduce due gate separati, un erase gate e un write gate, per gestire in modo indipendente la rimozione e l’aggiunta di informazione. L’algoritmo è stato addestrato su un insieme di dati di 100 miliardi di token FineWeb-Edu con fino a 1,3 miliardi di parametri e si dimostra superiore ad altri modelli come Mamba-2, KDA e Mamba-3 rispetto a suite di benchmark estese.
Il problema del gate scalare in modelli a regola di differenza
Un layer ricorrente di attenzione lineare memorizza uno stato matriciale St e lo legge mediante il query. DeltaNet aggiunge modifiche attive sottraendo il valore associato attualmente alla chiave corrente. Usa un coefficiente di passo scalare βt per controllore quanto sovrascrivere. Mamba-2 aggiunge un decadimento scalare αt per dimenticati globali.
Gated DeltaNet unifica entrambe le operazioni ma mantiene entrambi i gate scalari, uno per testa. Kimi Delta Attention (KDA) perfeziona il lato decadimento sostituendo αt con un vettore channel-wise, mantenendo comunque un solo βt per la modifica attiva. Tuttavia, questo scalare controlla due aspetti distinti insieme: la quantità di contenuti esistenti da eliminare e la quantità di nuovi contenuti da aggiungere, creando una limitazione modellistica.
Regola a due gate
Gated DeltaNet-2 divide le due decisioni con la Regola Delta a Gate. Introduce un gate di cancellazione bt ∈ [0,1]dk sull’asse chiave e un gate di scrittura wt ∈ [0,1]dv sull’asse valore. Entrambi i gate sono prodotti da proiezioni sigmoidali della rappresentazione del token. L’aggiornamento applica decadimento prima dell’edit attivo.
Compattamente espressa, la ricorsione è: St = (I − kt (bt ⊙ kt)⊤) Dt St−1 + kt (wt ⊙ vt)⊤. Qui Dt = Diag(αt) rappresenta il decadimento su canale ereditato da KDA. La parte sinistra della matrice di rimozione è kt, preservando la direzione di scrittura della normale regola delta. La parte destra diventa bt ⊙ kt, rendono la direzione di lettura selezionabile su canale. Il termine di scrittura ktzt⊤ utilizza zt = wt ⊙ vt, rendendo l’aggiornamento dei valori su canale selezionabile.
Training a blocchi e backward conosciuto dei gate
La ricorrenza ammette una forma a chunk WY che emula la struttura utilizzata da KDA. Il decadimento cumulativo per canale è assorbito nei due fattori di ogni cancella. L’aggiornamento per chunk diviene un prodotto di matrici asimmetriche della forma I − k̄rēr⊤. L’implementazione utilizza un chunk size C = 64 con nuclei Triton fusi.
Per quanto riguarda la backward pass, il trucco scalare utilizzato da KDA non è più applicabile. Il lato di scrittura contiene un diverso gate diagonale per canali valore, il lato di cancella contiene un diverso gate diagonale per chiavi. Quindi, i fattori dei gate devono apparire dentro i prodotti scalari che accumulano gradient. L’articolo derivi esplicitamente questo prodotto vettoriale Jacobiano vettore. Su GPU Hopper, il nucleo backward fuso WY è limitato a due e quattro warp per evitare un’asserzione di layout WGMMA di Triton.
Design a blocchi e modello ibrido
Gated DeltaNet-2 è utilizzato come mixer ricorrente in un blocco standard a stile Transformer. Le strade query e chiave utilizzano proiezione lineare, breve convoluzione causale, SiLU e normalizzazione L2. La strada valore utilizza proiezione lineare, breve convoluzione e SiLU. Il decadimento αt, il gate di cancella bt, e il gate di scrittura wt provengono da rami lineari separati.
Il modello ricorrente produce un output RMS-normalizzato, moltiplicato per una porta SiLU e proiettato indietro. Una variante ibrida inserisce Sliding-Window Attention (SWA) dopo il mixer ricorrente. Un cellule ripetuta contiene Gated DeltaNet-2, un MLP, SWA e un’altra MLP. SWA gestisce interazioni locali esatte, mentre il mixer ricorrente comprime lunghe storie cronologiche. Il modello ibridizza mantiene la scalabilità lineare con una cache di attenzione limitata.
Confronto e Risultati
Tutti i modelli sono 1.3B parametri addestrati su 100B di token FineWeb-Edu. Il conto del parametri e la dimensione di stato ricorrente sono uguali per tutti i modelli. Lo stato ricorrente mantiene 262.144 float per layer per batch element. La lunghezza di addestramento è 4K token, e i modelli ibridi usano una finestra SWA di 2K. La base Mamba-3 MIMO usa rango R = 4.
In termini di modellazione linguistica e ragionamento comune, Gated DeltaNet-2 ha il migliore risultato medio in entrambi gli scenari. Il modello ricorrente media 53.11 lungo LAMBADA e suite di ragionamento. Si posiziona sopra Mamba-3 MIMO a 52.39 e KDA a 52.28. Nel setting ibrido, Gated DeltaNet-2 media 53.97 contro Mamba-3 MIMO a 52.72. Visto che la dimensione dello stato ricorrente è uguale, guadagno suggerisce l’aggiornamento di regola, non di più memoria.
I miglioramenti più evidenti appaiono su RULER long-context retrieval. Nel setting ricorrente, S-NIAH-2 su 4K va da 89.0 (KDA) a 93.0. S-NIAH-3 su 2K salta da 63.2 (KDA) a 89.8. MK-NIAH-1 su