Nous Research presenta Token Superposition Training per accelerare la pre-addestrazione degli LLM del 2,5x
Introduzione a Token Superposition Training di Nous Research
Nous Research ha rilasciato un nuovo metodo per accelerare il pre-addestramento dei modelli linguistici di grandi dimensioni (LLM) chiamato Token Superposition Training (TST). Questo approccio riduce il tempo effettivo di addestramento mantenendo la stessa quantità di calcolo senza modificare l’architettura del modello, l'ottimizzatore, il tokenizzatore o il dataset.
L’implementazione di TST ha permesso di ottenere un vantaggio di 2,5x al 10B-A1B Mixture-of-Experts: TST richiede 4,768 ore di GPU B200 rispetto alle 12,311 necessarie dal baseline. Gli esperimenti sono stati pubblicati su arXiv.
https://arxiv.org/pdf/2605.06546
Questa tecnica si basa su due fasi distinte, una fase di sopposizione e una di recupero, permettendo al modello di gestire token in maniera più efficiente durante l’addestramento iniziale.
Il problema che TST risolve
Gli attuali modelli linguistici richiedono quantità immense di dati e risorse computazionali per pre-addestrarli. La chiave del successo di tali modelli è spesso rappresentata dal rapporto tra la quantità di dati processati e i Flops utilizzati. I tokenizer a sottoparole come i BPE aumentano notevolmente la quantità di dati gestiti grazie alla compressione delle sequenze e al conseguente aumento del throughput del testo.
Mentre il tokenizer gioca un ruolo importante, TST si propone di esaminare se ulteriori miglioramenti nel rapporto tra dati e FLOPS siano possibili durante l'addestramento, modificando solo la routine di addestramento.
Come funziona TST: due fasi
L’approccio di TST modifica il loop tradizionale di pre-addestramento in due fasi:
Fase 1 - Sopposizione
Per la prima parte del processo, che copre una frazione r delle iterazioni totali (es. r ∈ [0.2, 0.4]), il modello riceve sequenze di token diversamente da quanto avviene in un'addestramento tradizionale. Ogni gruppo di s token consecutivi è rappresentato da un singolo “s-token”, ottenuto facendo la media dei token embeddings.
Quindi, il modello elabora una sequenza di lunghezza L/s. Per preservare la parità con quanto farebbe un modello tradizionale in termini di FLOPs, la lunghezza della sequenza di input viene moltiplicata per un fattore s.
Ogni posizione latente del modello corrisponde a s token, per cui il modello vede s volte più testo, migliorando il rapporto token-FLOP.
In termini di funzione loss, invece di una cross-entropy standard, si utilizza una Multi-Hot Cross-Entropy (MCE), che assegna la stessa probabilità ad ogni token all’interno del gruppo di target. Questo può essere implementato con le librerie già esistenti, senza bisogno di modifiche.
Fase 2 - Recupero
Dopo la fase di sopposizione, il modello riprende l'addestramento come normalmente farebbe, utilizzando il checkpoint salvato dopo la fase 1. I dati generati duranti la fase 1 vengono completamente eliminati per evitare inquinamento sperimentale.
Sebbene all’inizio della fase 2 possa verificarsi uno spike della loss (tipicamente tra 1 e 2 nats), la situazione si normalizza in poche migliaia di iterazioni. Il modello finale è architettonicamente e comportamentalmente identico rispetto a uno addestrato senza TST.
I risultati sperimentali
I test sono stati condotti su modelli di varie dimensioni:
- Modelli densi 270M e 600M (SmolLM2 con archetettura adattata dal Llama3 e tokenizzazione del Llama3-8B)
- Modelli dense 3B (SmolLM3)
- Modelli MoE 10B-A1B (Qwen3 family)
Tutte le operazioni sono state eseguite utilizzando il dataset DCLM per i modelli più piccoli e una miscela 50/50 tra DCLM e FineWeb-Edu per i modelli Mixture of Experts.
I risultati mostrano che a una data dimensione e lunghezza di token (ad es. 3B/s = 6 e r ≈ 0.3), TST raggiunge un loss di 2.676 in 20k passi (rispetto a uno baseline a 36k passi). Gli score HellaSwag e ARC-Easy sono quasi identici ai valori del baseline.
Sui modelli MoE, TST è in grado di raggiungere un loss finale inferiore al baseline (2.236 vs. 2.252 su 10B-A1B) e supera il baseline in tutti i benchmark chiave (HellaSwag, ARC-Easy, ARC-Challenge e MMLU).
Due meccanismi distinti
Gli esperimenti condotti suggeriscono che TST abbia due meccanismi distinti, lato ingresso e lato uscita, che contribuiscono entrambi al miglioramento del modello.
- Lato Input: la media di sequenze scontigue crea una versione più grossolana dei dati iniziali, preparando i livelli successivi ad operare su rappresentazioni più complesse.
- Lato Output: la tecnica Multi-Hot Cross-Entropy (MCE), simile ad un MTP esteso, permette al modello di predire sequenze di token successive, migliorando l’efficienza. A differenza dell’MTP, però, non richiede aumenti di complessità architetturale.
Inoltre, un'analisi mirata ha dimostrato che la rappresentazione condivisa tra le due fasi (Phase 1 e Phase 2) è cruciale per il successo di TST. Quando questa rappresentazione è "rotta" (es. reinizializzando embedding), TST si perde totalmente.
Confronti: equal-FLOPs, equal-loss, equal-data
I confronti effettuati da Nous Research si basano su tre criteri:
- Equal-FLOPs: TST produce risultati migliori rispetto al baseline, riducendo drasticamente il costo GPU.
- Equal-Loss: Anche in condizioni di uguaglianza di perdita finale, TST vince in termini di tempo di addestramento.
- Equal-Data: Il baseline riesce a produrre risultati migliori in termini di dati totali, ma in termini di rapporto FLOP-token, TST risulta inferiore.
Questi confronti indicano i limiti di applicabilità di TST, fornendo un'analisi comprensiva delle situazioni in cui TST è vantaggioso.
<