Nous Research Releases Token Superposition Training to Speed Up LLM Pre-Training by Up to 2.5x Across 270M to 10B Parameter Models
Editors Pick
Agentic AI
Artificial Intelligence
AI Infrastructure
Tech News
AI Paper Summary
Technology
AI Shorts
Applications
Language Model
Large Language Model
Machine Learning
New Releases
Staff
Pre-training large language models is expensive enough that even modest efficiency improvements can translate into meaningful cost and time savings.
Nous Research is releasing Token Superposition Training (TST)
, a method that substantially reduces pre-training wall-clock time at fixed compute without touching the model architecture, optimizer, tokenizer, parallelism strategy, or training data.
At the 10B-A1B mixture-of-experts scale, TST reaches a lower final training loss than a matched-FLOPs baseline while consuming 4,768 B200-GPU-hours versus the baseline’s 12,311 — roughly a 2.5x reduction in total pre-training time.
https://arxiv.org/pdf/2605.06546
The Problem TST is Solving
Modern LLM pre-training is heavily data-driven. Recent training regimes routinely overtrain well beyond compute-optimal estimates, and raw text throughput. How much data a model can process per FLOP has become a key lever. Subword tokenizers like BPE already improve throughput by compressing sequences; and the research suggests much of the BPE advantage over byte-level models comes simply from shorter sequences, which means the model sees more text per unit of compute.
TST asks whether that throughput lever can be pulled further during training, independently of the tokenizer and without permanently changing the model.
How TST Works: Two Phases
TST modifies the standard pre-training loop in two sequential phases:
Phase 1 — Superposition:
For the first
r
fraction of total training steps (the paper finds
r ∈ [0.2, 0.4]
to be close to optimal across tested scales), the model does not receive individual tokens. Instead, the input sequence of length
L
is segmented into non-overlapping bags of
s
contiguous tokens. In the embedding layer, each bag is collapsed into a single latent “s-token” by averaging the
s
token embeddings. The transformer then processes a sequence of length
L/s
.
Crucially, each TST step is kept equal-FLOPs to a standard training step by
increasing the data sequence length by s times
during the superposition phase. Because each latent position corresponds to
s
source tokens, the model ingests
s
times as much text per unit of compute — this is what drives the throughput gain.
On the output side, each latent position predicts the next bag of
s
tokens rather than a single next token. The standard cross-entropy loss is replaced with a multi-hot cross-entropy (MCE) loss, which assigns equal probability mass
1/s
to each token in the target bag. The MCE loss reduces to a simple mean of standard cross-entropy terms over the
s
targets — it can be implemented using the existing fused CE kernels already present in any major pre-training library, without writing a new kernel or adding an auxiliary head.
Phase 2 — Recovery:
After the superposition phase, training resumes from the saved checkpoint with standard next-token prediction for the remaining
1 - r
steps. The TST code is fully removed at this boundary to avoid any experimental contamination. A transient loss spike occurs at the transition, typically between 1 and 2 nats, which resolves within a few thousand steps. After that, the recovered model crosses below the equal-FLOPs baseline and remains there.
The model produced at the end of Phase 2 is architecturally identical to one produced by conventional pre-training, with the same next-token prediction inference behavior.
What the Experiments Show
TST was validated at four scales:
270M and 600M dense (SmolLM2 shapes adapted to the Llama3 modeling code, with the Llama3-8B tokenizer and untied input/output embeddings — which makes the 270M model equivalent in size to SmolLM2-135M and the 600M to SmolLM2-360M), 3B dense (SmolLM3 shape), and a 10B-A1B MoE in the Qwen3 family. Training used the DCLM dataset for the smaller runs and a 50/50 mix of DCLM and FineWeb-Edu for the MoE run. All runs used AdamW with the Warmup-Stable-Decay learning rate schedule and were run in TorchTitan under FSDP parallelism, on 64 NVIDIA B200 GPUs for the larger models and 8 B200 GPUs for the smaller ones.
At the 3B scale with bag size
s = 6
and step ratio
r = 0.3
, TST at 20,000 steps reaches a final loss of 2.676 — nearly matching a 36,000-step baseline at 2.677 — while using 247 B200-GPU-hours versus 443. The 20k-step TST run scores 62.4 on HellaSwag and 66.3 on ARC-Easy, versus 62.3 and 65.9 for the 36k baseline.
At the 10B-A1B MoE scale with
s = 16
and
r ≈ 0.25
, the TST run processes 2T data tokens and achieves a final loss of 2.236, below the baseline’s 2.252 after 1.05T tokens, while beating it on all four reported benchmarks: HellaSwag (71.2 vs. 70.1), ARC-Easy (74.2 vs. 73.8), ARC-Challenge (47.3 vs. 46.3), and MMLU (39.0 vs. 37.4).
The research team presents three comparison views against the baseline — equal-FLOPs, equal-loss, and equal-data. Under equal-FLOPs and equal-loss conditions, TST consistently wins. Under equal total token consumption, the baseline wins, because TST’s effective compute budget per data token is smaller. This is an important boundary condition that determines where TST applies.
Two Distinct Mechanisms
An ablation study isolates the input-side and output-side components. Both independently outperform the baseline; combining them produces further improvement without signs of interference. The authors interpret this as evidence that TST is two orthogonal mechanisms rather than a single trick.
The output-side mechanism — next-bag-of-tokens prediction — is conceptually related to multi-token prediction (MTP). Unlike MTP, which adds
k
independent prediction heads and extra parameters, TST keeps a single output head and replaces only the target. This makes it the least expensive member of a growing class of future-signal auxiliary objectives. Unlike MTP, it shows consistent gains across all tested scales including small models where MTP has been shown to degrade performance.
The input-side mechanism has no direct analog in the recent pre-training literature. The research team offers two plausible explanations: it may implicitly regularize the embedding geometry (since many random s-grams of tokens must remain linearly separable once averaged), or it may act as a form of pre-pre-training, exposing the model to a coarser version of the real data before fine-resolution language modeling begins.
A targeted ablation directly tests what happens when representation continuity is broken. The research team runs a 3B TST experiment where the input embedding and output LM head are randomly re-initialized at the start of Phase 2. The result: final loss jumps to 2.938 — worse than both the TST run (2.676) and the standard baseline (2.808). The Phase 1 TST steps contributed nothing to the final model. This confirms that shared representations across both phases are not incidental to TST’s success — they are what makes it work.
Marktechpost’s Visual Explainer
Token Superposition Training — Practical Guide
arXiv 2605.06546
01 / Overview
What Is Token Superposition Training?
Token Superposition Training (TST) is a two-phase pre-training method from Nous Research that increases token throughput per FLOP without changing the model architecture, optimizer, tokenizer, parallelism, or training data.
The core idea:
Instead of feeding one token at a time, average
s
contiguous token embeddings into one “s-token,” train on that for the first
r
fraction of steps, then switch back to standard next-token prediction. The final model is architecturally identical to one trained normally.
Phase 1 (Superposition)
— model reads bags of s tokens, predicts the next bag
Phase 2 (Recovery)
— standard next-token prediction resumes from the checkpoint
Inference
— completely unchanged; no new heads, no new parameters
Validated at
270M, 600M, 3B dense and 10B–A1B MoE
TST trades compute efficiency for higher data consumption. Best suited for compute-bound pre-training, not data-bound.
02 / Phase 1
Phase 1 — The Superposition Phase
For the first
r
fraction of total training steps, the input sequence of length
L
is split into non-overlapping bags of
s
contiguous tokens. Their embeddings are averaged into a single latent s-token. The transformer processes a sequence of length
L/s
— but each position corresponds to
s
real tokens, so throughput is
s×
higher at the same FLOPs.
Equal-FLOPs trick:
To keep each step equal-FLOPs to baseline, the data sequence length is increased by
s×
— not the batch size. Every TST step costs the same compute as a standard step.
On the output side, the loss target shifts from a single next token to the next
bag of s tokens
. The multi-hot cross-entropy (MCE) loss assigns equal probability mass
1/s
to each token in the target bag:
# L_MCE = mean of s standard CE terms
for
i
in
range(superposition_bag_size):
target = labels[..., i].flatten(0, 1)
loss += torch.nn.functional.cross_entropy(pred, target)
loss = loss / superposition_bag_size
No new kernel needed — reuses the existing fused CE kernel in your pre-training library.
03 / Phase 2
Phase 2 — The Recovery Phase
After
r × total_steps
of superposition training, resume from the checkpoint with the TST code
fully removed
. Standard next-token prediction runs for the remaining
(1 — r) × total_steps
.
What happens at the switch:
A loss spike of 1–2 nats occurs at the phase boundary. It resolves within a few thousand steps. After that, the model crosses below the equal-FLOPs baseline and stays there.
Remove TST code fully — do not keep it as an auxiliary loss during Phase 2
Do
not
re-initialize the input embedding or LM head at the boundary
Shared representations across both phases are what make TST work
Re-initializing the embedding or LM head at the phase boundary completely breaks TST. In a 3B ablation, this raised final loss from 2.676 to 2.938 — worse than the 2.808 baseline. The Phase 1
← Torna alle news