0% found this document useful (0 votes)
243 views29 pages

Multi-Token Prediction Paper9036

This document discusses a novel approach to training large language models (LLMs) using multi-token prediction instead of traditional next-token prediction, which enhances sample efficiency and improves performance on generative benchmarks. The proposed method allows models to predict multiple future tokens simultaneously, resulting in faster inference and better problem-solving capabilities, particularly in coding tasks. Experimental results show that models utilizing this technique outperform standard next-token models, especially as model size increases, without incurring additional training time or memory overhead.

Uploaded by

abhi.sk1004
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
243 views29 pages

Multi-Token Prediction Paper9036

This document discusses a novel approach to training large language models (LLMs) using multi-token prediction instead of traditional next-token prediction, which enhances sample efficiency and improves performance on generative benchmarks. The proposed method allows models to predict multiple future tokens simultaneously, resulting in faster inference and better problem-solving capabilities, particularly in coding tasks. Experimental results show that models utilizing this technique outperform standard next-token models, especially as model size increases, without incurring additional training time or memory overhead.

Uploaded by

abhi.sk1004
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd

Better & Faster Large Language Models via Multi-token Prediction

Fabian Gloeckle * 1 2 Badr Youbi Idrissi * 1 3 Baptiste Rozière 1 David Lopez-Paz + 1 Gabriel Synnaeve + 1

Abstract task: next-token prediction. Despite the recent wave of


Large language models such as GPT and Llama impressive achievements (OpenAI, 2023), next-token pre-
are trained with a next-token prediction loss. In diction remains an inefficient way of acquiring language,
this work, we suggest that training language mod- world knowledge and reasoning capabilities. More precisely,
els to predict multiple future tokens at once results teacher forcing with next-token prediction latches on local
in higher sample efficiency. More specifically, at patterns and overlooks “hard” decisions. Consequently, it
each position in the training corpus, we ask the remains a fact that state-of-the-art next-token predictors call
model to predict the following n tokens using n for orders of magnitude more data than human children to
independent output heads, operating on top of a arrive at the same level of fluency (Frank, 2023).
shared model trunk. Considering multi-token pre- In this study, we argue that training LLMs to predict multiple
diction as an auxiliary training task, we measure tokens at once will drive these models toward better sample
improved downstream capabilities with no over- efficiency. As anticipated in Figure 1, multi-token prediction
head in training time for both code and natural instructs the LLM to predict the n future tokens from each
language models. The method is increasingly use- position in the training corpora, all at once and in parallel (Qi
ful for larger model sizes, and keeps its appeal et al., 2020).
when training for multiple epochs. Gains are es-
pecially pronounced on generative benchmarks
Contributions While multi-token prediction has been
like coding, where our models consistently out-
studied in previous literature (Qi et al., 2020), the present
perform strong baselines by several percentage
work offers the following contributions:
points. Our 13B parameter models solves 12 %
more problems on HumanEval and 17 % more on
MBPP than comparable next-token models. Ex- 1. We propose a simple multi-token prediction architec-
periments on small algorithmic tasks demonstrate ture with no train time or memory overhead (Section 2).
that multi-token prediction is favorable for the
development of induction heads and algorithmic 2. We provide experimental evidence that this training
reasoning capabilities. As an additional benefit, paradigm is beneficial at scale, with models up to 13B
models trained with 4-token prediction are up to parameters solving around 15% more code problems
3× faster at inference, even with large batch sizes. on average (Section 3).

3. Multi-token prediction enables self-speculative decod-


1. Introduction ing, making models up to 3 times faster at inference
time across a wide range of batch-sizes (Section 3.2).
Humanity has condensed its most ingenious undertakings,
surprising findings and beautiful productions into text. While cost-free and simple, multi-token prediction is an ef-
Large Language Models (LLMs) trained on all of these fective modification to train stronger and faster transformer
corpora are able to extract impressive amounts of world models. We hope that our work spurs interest in novel aux-
knowledge, as well as basic reasoning capabilities by im- iliary losses for LLMs well beyond next-token prediction,
plementing a simple—yet powerful—unsupervised learning as to improve the performance, coherence, and reasoning
*
Equal contribution 1 FAIR at Meta 2 CERMICS Ecole des Ponts abilities of these fascinating models.
ParisTech 3 LISN Université Paris-Saclay. Correspondence to:
Fabian Gloeckle <fgloeckle@[Link]>, Badr Youbi Idrissi <by-
oubi@[Link]>. 2. Method
Proceedings of the 41 st International Conference on Machine Standard language modeling learns about a large text corpus
Learning, Vienna, Austria. PMLR 235, 2024. Copyright 2024 by x1 , . . . xT by implementing a next-token prediction task.
the author(s). Formally, the learning objective is to minimize the cross-

1
Better & Faster Large Language Models via Multi-token Prediction

representation zt:1 of the observed context xt:1 , then fed


into n independent heads to predict in parallel each of the
n future tokens (see Figure 1). This leads to the follow-
ing factorization of the multi-token prediction cross-entropy
loss:
X
Ln = − log Pθ (xt+n:t+1 | zt:1 ) · Pθ (zt:1 | xt:1 )
t
n
XX
=− log Pθ (xt+i | zt:1 ) · Pθ (zt:1 | xt:1 ).
t i=1

In practice, our architecture consists of a shared transformer


trunk fs producing the hidden representation zt:1 from the
observed context xt:1 , n independent output heads imple-
mented in terms of transformer layers fhi , and a shared
unembedding matrix fu . Therefore, to predict n future
tokens, we compute:

Pθ (xt+i | xt:1 ) = softmax(fu (fhi (fs (xt:1 )))),

for i = 1, . . . n, where, in particular, Pθ (xt+1 | xt:1 ) is


our next-token prediction head. See Appendix B for other
variations of multi-token prediction architectures.

Memory-efficient implementation One big challenge in


training multi-token predictors is reducing their GPU mem-
Figure 1: Overview of multi-token prediction. (Top) Dur- ory utilization. To see why this is the case, recall that in
ing training, the model predicts 4 future tokens at once, by current LLMs the vocabulary size V is much larger than the
means of a shared trunk and 4 dedicated output heads. Dur- dimension d of the latent representation—therefore, logit
ing inference, we employ only the next-token output head. vectors become the GPU memory usage bottleneck. Naive
Optionally, the other three heads may be used to speed-up implementations of multi-token predictors that materialize
inference time. (Bottom) Multi-token prediction improves all logits and their gradients, both of shape (n, V ), severely
pass@1 on the MBPP code task, significantly so as model limit the allowable batch-size and average GPU memory
size increases. Error bars are confidence intervals of 90% utilization. Because of these reasons, in our architecture
computed with bootstrapping over dataset samples. we propose to carefully adapt the sequence of forward and
backward operations, as illustrated in Figure 2. In particular,
after the forward pass through the shared trunk fs , we se-
entropy loss quentially compute the forward and backward pass of each
X independent output head fi , accumulating gradients at the
L1 = − log Pθ (xt+1 | xt:1 ), (1) trunk. While this creates logits (and their gradients) for the
t
output head fi , these are freed before continuing to the next
where Pθ is our large language model under training, as to output head fi+1 , requiring the long-term storage only of the
maximize the probability of xt+1 as the next future token, d-dimensional trunk gradient ∂Ln /∂fs . In sum, we have
given the history of past tokens xt:1 = xt , . . . , x1 . reduced the peak GPU memory utilization from O(nV + d)
to O(V + d), at no expense in runtime (Table S5).
In this work, we generalize the above by implementing a
multi-token prediction task, where at each position of the Inference During inference time, the most basic use of the
training corpus, the model is instructed to predict n future proposed architecture is vanilla next-token autoregressive
tokens at once. This translates into the cross-entropy loss prediction using the next-token prediction head Pθ (xt+1 |
X xt:1 ), while discarding all others. However, the additional
Ln = − log Pθ (xt+n:t+1 | xt:1 ). (2)
output heads can be leveraged to speed up decoding from the
t
next-token prediction head with self-speculative decoding
To make matters tractable, we assume that our large lan- methods such as blockwise parallel decoding (Stern et al.,
guage model Pθ employs a shared trunk to produce a latent 2018)—a variant of speculative decoding (Leviathan et al.,

2
Better & Faster Large Language Models via Multi-token Prediction

MBPP Human Eval


+4.5
+1.7

Pass@1
2 5 2 3 7
5 13 14
-1.7
7 11 24 26 -0.6
+3.9
+5.0

Pass@10
10 21
27 36 54 57
5 9 13
Figure 2: Order of the forward/backward in an n-token
prediction model with n = 2 heads. By performing the -5.4 -1.0 17 29 34
forward/backward on the heads in sequential order, we avoid
materializing all unembedding layer gradients in memory +2.2 30 45 51
+7.5

Pass@100
simultaneously and reduce peak GPU memory usage. 60 75 77
11 17 24
-9.8 -2.3 30 52 56
2023) without the need for an additional draft model—and

0.3B
0.6B
1.3B
3B
6.7B
13B

0.3B
0.6B
1.3B
3B
6.7B
13B
speculative decoding with Medusa-like tree attention (Cai
et al., 2024).
Figure 3: Results of n-token prediction models on MBPP
3. Experiments on real data by model size. We train models of six sizes in the range
or 300M to 13B total parameters on code, and evaluate
We demonstrate the efficacy of multi-token prediction losses
pass@1,10,100 on the MBPP (Austin et al., 2021) and Hu-
by seven large-scale experiments. Section 3.1 shows how
manEval (Chen et al., 2021) benchmark with 1000 samples.
multi-token prediction is increasingly useful when grow-
Multi-token prediction models are worse than the baseline
ing the model size. Section 3.2 shows how the additional
for small model sizes, but outperform the baseline at scale.
prediction heads can speed up inference by a factor of 3×
Error bars are confidence intervals of 90% computed with
using speculative decoding. Section 3.3 demonstrates how
bootstrapping over dataset samples.
multi-token prediction promotes learning longer-term pat-
terns, a fact most apparent in the extreme case of byte-level
tokenization. Section 3.4 shows that 4-token predictor leads
to strong gains with a tokenizer of size 32k. Section il- 3.1. Benefits scale with model size
lustrates that the benefits of multi-token prediction remain To study this phenomenon, we train models of six sizes
for training runs with multiple epochs. Section 3.6 show- in the range 300M to 13B parameters from scratch on at
cases the rich representations promoted by pretraining with least 91B tokens of code. The evaluation results in Fig-
multi-token prediction losses by finetuning on the Code- ure 3 for MBPP (Austin et al., 2021) and HumanEval (Chen
Contests dataset (Li et al., 2022). Section 3.7 shows that et al., 2021) show that it is possible, with the exact same
the benefits of multi-token prediction carry to natural lan- computational budget, to squeeze much more performance
guage models, improving generative evaluations such as out of large language models given a fixed dataset using
summarization, while not regressing significantly on stan- multi-token prediction.
dard benchmarks based on multiple choice questions and
negative log-likelihoods. We believe this usefulness only at scale to be a likely reason
why multi-token prediction has so far been largely over-
To allow fair comparisons between next-token predictors looked as a promising training loss for large language model
and n-token predictors, the experiments that follow always training.
compare models with an equal amount of parameters. That
is, when we add n − 1 layers in future prediction heads, we
3.2. Faster inference
remove n − 1 layers from the shared model trunk. Please
refer to Table S14 for the model architectures and to Ta- We implement greedy self-speculative decoding Stern et al.
ble S13 for an overview of the hyperparameters we use in (2018) with heterogeneous batch sizes using xFormers
our experiments. (Lefaudeux et al., 2022) and measure decoding speeds of our

3
Better & Faster Large Language Models via Multi-token Prediction

Table 1: Multi-token prediction improves performance and unlocks efficient byte level training. We compare models
with 7B parameters trained from scratch on 200B and on 314B bytes of code on the MBPP (Austin et al., 2021), HumanEval
(Chen et al., 2021) and APPS (Hendrycks et al., 2021) benchmarks. Multi-token prediction largely outperforms next token
prediction on these settings. All numbers were calculated using the estimator from Chen et al. (2021) based on 200 samples
per problem. The temperatures were chosen optimally (based on test scores; i.e. these are oracle temperatures) for each
model, dataset and pass@k and are reported in Table S12.

MBPP HumanEval APPS/Intro


Training data Vocabulary n
@1 @10 @100 @1 @10 @100 @1 @10 @100
1 19.3 42.4 64.7 18.1 28.2 47.8 0.1 0.5 2.4
313B bytes 8 32.3 50.0 69.6 21.8 34.1 57.9 1.2 5.7 14.0
bytes
(0.5 epochs) 16 28.6 47.1 68.0 20.4 32.7 54.3 1.0 5.0 12.9
32 23.0 40.7 60.3 17.2 30.2 49.7 0.6 2.8 8.8
1 30.0 53.8 73.7 22.8 36.4 62.0 2.8 7.8 17.4
2 30.3 55.1 76.2 22.2 38.5 62.6 2.1 9.0 21.7
200B tokens
(0.8 epochs) 32k tokens 4 33.8 55.9 76.9 24.0 40.1 66.1 1.6 7.1 19.9
6 31.9 53.9 73.1 20.6 38.4 63.9 3.5 10.8 22.7
8 30.7 52.2 73.4 20.0 36.6 59.6 3.5 10.4 22.1
1T tokens 1 40.7 65.4 83.4 31.7 57.6 83.0 5.4 17.8 34.1
(4 epochs) 32k tokens
4 43.1 65.9 83.7 31.6 57.3 86.2 4.3 15.6 33.7

best 4-token prediction model with 7B parameters on com- model by nearly two times. The 8-byte prediction model
pleting prompts taken from a test dataset of code and natural is a strong byte-based model, approaching the performance
language (Table S2) not seen during training. We observe a of token-based models despite having been trained on 1.7×
speedup of 3.0× on code with an average of 2.5 accepted less data.
tokens out of 3 suggestions on code, and of 2.7× on text.
On an 8-byte prediction model, the inference speedup is 3.4. Searching for the optimal n
6.4× (Table S3). Pretraining with multi-token prediction
allows the additional heads to be much more accurate than To better understand the effect of the number of predicted
a simple finetuning of a next-token prediction model, thus tokens, we did comprehensive ablations on models of scale
allowing our models to unlock self-speculative decoding’s 7B trained on 200B tokens of code. We try n = 1, 2, 4, 6
full potential. and 8 in this setting. Results in table 1 show that training
with 4-future tokens outperforms all the other models con-
sistently throughout HumanEval and MBPP for pass at 1,
3.3. Learning global patterns with multi-byte prediction
10 and 100 metrics: +3.8%, +2.1% and +3.2% for MBPP
To show that the next-token prediction task latches to local and +1.2%, +3.7% and +4.1% for HumanEval. Interestingly,
patterns, we went to the extreme case of byte-level tokeniza- for APPS/Intro, n = 6 takes the lead with +0.7%, +3.0%
tion by training a 7B parameter byte-level transformer on and +5.3%. It is very likely that the optimal window size
314B bytes, which is equivalent to around 116B tokens. depends on input data distribution. As for the byte level
The 8-byte prediction model achieves astounding improve- models the optimal window size is more consistent (8 bytes)
ments compared to next-byte prediction, solving 67% more across these benchmarks.
problems on MBPP pass@1 and 20% more problems on
HumanEval pass@1. Al-Rfou et al. (2019) also show that 3.5. Training for multiple epochs
muti target prediction has a positive effect on character level
language modeling. Multi-token training still maintains an edge on next-token
prediction when trained on multiple epochs of the same
Multi-byte prediction is therefore a very promising avenue data. The improvements diminish but we still have a
to unlock efficient training of byte-level models. Self- +2.4% increase on pass@1 on MBPP and +3.2% increase
speculative decoding can achieve speedups of 6 times for on pass@100 on HumanEval, while having similar perfor-
the 8-byte prediction model, which would allow to fully mance for the rest. As for APPS/Intro, a window size of 4
compensate the cost of longer byte-level sequences at infer- was already not optimal with 200B tokens of training.
ence time and even be faster than a next-token prediction

4
Better & Faster Large Language Models via Multi-token Prediction

3.6. Finetuning multi-token predictors 52.5 n


Pretrained models with multi-token prediction loss also out-
50.0 1
2

Average accuracy
perform next-token models for use in finetunings. We evalu- 47.5
4
ate this by finetuning 7B parameter models from Section 3.3 45.0
on the CodeContests dataset (Li et al., 2022). We compare
42.5
the 4-token prediction model with the next-token prediction
baseline, and include a setting where the 4-token prediction 40.0
model is stripped off its additional prediction heads and 37.5
finetuned using the classical next-token prediction target.
35.0
According to the results in Figure 4, both ways of finetuning
5000 10000 15000 20000 25000
the 4-token prediction model outperform the next-token pre- Training step
diction model on pass@k across k. This means the models
are both better at understanding and solving the task and Figure 5: Multi-token training with 7B models doesn’t
at generating diverse answers. Note that CodeContests is improve performance on choice tasks. This figure shows
the most challenging coding benchmark we evaluate in this the evolution of average accuracy of 6 standard NLP bench-
study. Next-token prediction finetuning on top of 4-token marks. Detailed results in Appendix G for 7B models
prediction pretraining appears to be the best method overall, trained on 200B tokens of language data. The 2 future
in line with the classical paradigm of pretraining with auxil- token model has the same performance as the baseline and
iary tasks followed by task-specific finetuning. Please refer the 4 future token model regresses a bit. Larger model sizes
to Appendix F for details. might be necessary to see improvements on these tasks.

10.0
token prediction loss, respectively. In Figure S12, we evalu-
5.0 ate the resulting checkpoints on 6 standard NLP benchmarks.
pass@k (%)

On these benchmarks, the 2-future token prediction model


2.0 performs on par with the next-token prediction baseline
throughout training. The 4-future token prediction model
1.0
n=1, n'=1 suffers a performance degradation. Detailed numbers are
0.5 n=4, n'=1 reported in Appendix G.
n=4, n'=4 However, we do not believe that multiple-choice and
0.2 likelihood-based benchmarks are suited to effectively dis-
1 10 100 1000 cern generative capabilities of language models. In order
k
to avoid the need for human annotations of generation qual-
Figure 4: Comparison of finetuning performance on ity or language model judges—which comes with its own
CodeContests. We finetune a 4-token prediction model pitfalls, as pointed out by Koo et al. (2023)—we conduct
on CodeContests (Li et al., 2022) (train split) using n′ - evaluations on summarization and natural language math-
token prediction as training loss with n′ = 4 or n′ = 1, ematics benchmarks and compare pretrained models with
and compare to a finetuning of the next-token prediction training sets sizes of 200B and 500B tokens and with next-
baseline model (n = n′ = 1). For evaluation, we gen- token and multi-token prediction losses, respectively.
erate 1000 samples per test problem for each temperature For summarization, we use eight benchmarks where
T ∈ {0.5, 0.6, 0.7, 0.8, 0.9}, and compute pass@k for each ROUGE metrics (Lin, 2004) with respect to a ground-truth
value of k and T . Shown is k 7→ maxT pass_at(k, T ), i.e. summary allow automatic evaluation of generated texts. We
we grant access to a temperature oracle. We observe that finetune each pretrained model on each benchmark’s train-
both ways of finetuning the 4-token prediction model out- ing dataset for three epochs and select the checkpoint with
perform the next-token prediction baseline. Intriguingly, the highest ROUGE-L F1 score on the validation dataset.
using next-token prediction finetuning on top of the 4-token Figure 6 shows that multi-token prediction models with both
prediction model appears to be the best method overall. n = 2 and n = 4 improve over the next-token baseline in
ROUGE-L F1 scores for both training dataset sizes, with
3.7. Multi-token prediction on natural language the performance gap shrinking with larger dataset size. All
metrics can be found in Appendix H.
To evaluate multi-token prediction training on natural lan-
guage, we train models of size 7B parameters on 200B For natural language mathematics, we evaluate the pre-
tokens of natural language with a 4-token, 2-token and next- trained models in 8-shot mode on the GSM8K benchmark

5
Better & Faster Large Language Models via Multi-token Prediction

27.5 n=1 0.5


n=2
27.0 n=4 0.4
Avg. ROUGE-L F1

Induction success
26.5 0.3
26.0 0.2
25.5 0.1 n=1 (baseline)
n=2 (ours)
25.0 0.0
200 500 1 3 10 30 100 300 1000
Training tokens (B) Parameters (M)
Figure 6: Performance on abstractive text summariza- Figure 7: Induction capability of n-token prediction mod-
tion. Average ROUGE-L (longest common subsequence els. Shown is accuracy on the second token of two token
overlap) F1 score for 7B models trained on 200B and 500B names that have already been mentioned previously. Shown
tokens of natural language on eight summarization bench- are numbers for models trained with a next-token and a
marks. We finetune the respective models on each task’s 2-token prediction loss, respectively, with two independent
training data separately for three epochs and select the check- runs each. The lines denote per-loss averages. For small
points with highest ROUGE-L F1 validation score. Both model sizes, next-token prediction models learn practically
n = 2 and n = 4 multi-token prediction models have an no or significantly worse induction capability than 2-token
advantage over next-token prediction models. Individual prediction models, with their disadvantage disappearing at
scores per dataset and more details can be found in Ap- the size of 100M nonembedding parameters.
pendix H.

4.1. Induction capability


Induction describes a simple pattern of reasoning that com-
pletes partial patterns by their most recent continuation (Ols-
son et al., 2022). In other words, if a sentence contains “AB”
and later mentions “A”, induction is the prediction that the
(Cobbe et al., 2021) and measure accuracy of the final an- continuation is “B”. We design a setup to measure induction
swer produced after a chain-of-thought elicited by the few- capability in a controlled way. Training small models of
shot examples. We evaluate pass@k metrics to quantify sizes 1M to 1B nonembedding parameters on a dataset of
diversity and correctness of answers like in code evaluations children stories, we measure induction capability by means
and use sampling temperatures between 0.2 and 1.4. The of an adapted test set: in 100 stories from the original test
results are depicted in Figure S13 in Appendix I. For 200B split, we replace the character names by randomly generated
training tokens, the n = 2 model clearly outperforms the names that consist of two tokens with the tokenizer we em-
next-token prediction baseline, while the pattern reverses ploy. Predicting the first of these two tokens is linked to the
after 500B tokens and n = 4 is worse throughout. semantics of the preceding text, while predicting the second
token of each name’s occurrence after it has been mentioned
4. Ablations on synthetic data at least once can be seen as a pure induction task. In our
experiments, we train for up to 90 epochs and perform early
What drives the improvements in downstream performance stopping with respect to the test metric (i.e. we allow an
of multi-token prediction models on all of the tasks we have epoch oracle). Figure 7 reports induction capability as mea-
considered? By conducting toy experiments on controlled sured by accuracy on the names’ second tokens in relation
training datasets and evaluation tasks, we demonstrate that to model size for two runs with different seeds.
multi-token prediction leads to qualitative changes in model
We find that 2-token prediction loss leads to a vastly im-
capabilities and generalization behaviors. In particular,
proved formation of induction capability for models of size
Section 4.1 shows that for small model sizes, induction
30M nonembedding parameters and below, with their advan-
capability—as discussed by Olsson et al. (2022)—either
tage disappearing for sizes of 100M nonembedding parame-
only forms when using multi-token prediction as training
ters and above.1 We interpret this finding as follows: multi-
loss, or it is vastly improved by it. Moreover, Section 4.2
shows that multi-token prediction improves generalization 1
Note that a perfect score is not reachable in this benchmark
on an arithmetic task, even more so than tripling model size. as some of the tokens in the names in the evaluation dataset never

6
Better & Faster Large Language Models via Multi-token Prediction

100 n=1 and evaluate models on a task on polynomial arithmetic in


n=2 the ring F7 [X]/(X 5 ) with unary negation, addition, mul-
80 n=4 tiplication and composition of polynomials as operations.
Accuracy (%)

The coefficients of the operands and the operators are sam-


60 pled uniformly. The task is to return the coefficients of the
polynomials corresponding to the resulting expressions. The
40 number m of operations contained in the expressions is se-
lected uniformly from the range from 1 to 5 at training time,
20 and can be used to adjust the difficulty of both in-domain
(m ≤ 5) and out-of-domain (m > 5) generalization evalua-
0 tions. The evaluations are conducted with greedy sampling
1 2 3 4 5 6 7 8 9 10 on a fixed test set of 2000 samples per number of operations.
# operations We train models of two small sizes with 30M and 100M
nonembedding parameters, respectively. This simulates the
in-domain out-of-domain conditions of large language models trained on massive text
corpora which are likewise under-parameterized and unable
Figure 8: Accuracy on a polynomial arithmetic task with to memorize their entire training datasets.
varying number of operations per expression. Training
with multi-token prediction losses increases accuracy across Multi-token prediction improves algorithmic reasoning ca-
task difficulties. In particular, it also significantly improves pabilities as measured by this task across task difficulties
out-of-domain generalization performance, albeit at a low (Figure 8). In particular, it leads to impressive gains in
absolute level. Tripling the model size, on the other hand, out-of-distribution generalization, despite the low absolute
has a considerably smaller effect than replacing next-token numbers. Increasing the model size from 30M to 100M
prediction with multi-token prediction loss (Figure S16). parameters, on the other hand, does not improve evalua-
Shown are two independent runs per configuration with tion accuracy as much as replacing next-token prediction by
100M parameter models. multi-token prediction does (Figure S16). In Appendix K,
we furthermore show that multi-token prediction models
retain their advantage over next-token prediction models
on this task when trained and evaluated with pause tokens
token prediction losses help models to learn transferring (Goyal et al., 2023).
information across sequence positions, which lends itself
to the formation of induction heads and other in-context
learning mechanisms. However, once induction capability
5. Why does it work? Some speculation
has been formed, these learned features transform induction Why does multi-token prediction afford superior perfor-
into a task that can be solved locally at the current token and mance on coding evaluation benchmarks, and on small al-
learned with next-token prediction alone. From this point gorithmic reasoning tasks? Our intuition, developed in this
on, multi-token prediction actually hurts on this restricted section, is that multi-token prediction mitigates the distri-
benchmark—but we surmise that there are higher forms butional discrepancy between training-time teacher forc-
of in-context reasoning to which it further contributes, as ing and inference-time autoregressive generation. We sup-
evidenced by the results in Section 3.1. In Figure S14, we port this view with an illustrative argument on the implicit
provide evidence for this explanation: replacing the chil- weights multi-token prediction assigns to tokens depending
dren stories dataset by a higher-quality 9:1 mix of a books on their relevance for the continuation of the text, as well as
dataset with the children stories, we enforce the formation with an information-theoretic decomposition of multi-token
of induction capability early in training by means of the prediction loss.
dataset alone. By consequence, except for the two smallest
model sizes, the advantage of multi-token prediction on the 5.1. Lookahead reinforces choice points
task disappears: feature learning of induction features has
converted the task into a pure next-token prediction task. Not all token decisions are equally important for gener-
ating useful texts from language models (Bachmann and
4.2. Algorithmic reasoning Nagarajan, 2024; Lin et al., 2024). While some tokens allow
stylistic variations that do not constrain the remainder of
Algorithmic reasoning tasks allow to measure more involved the text, others represent choice points that are linked with
forms of in-context reasoning than induction alone. We train higher-level semantic properties of the text and may decide
appear in the training data, and in our architecture, embedding and whether an answer is perceived as useful or derailing.
unembedding parameters are not linked.

7
Better & Faster Large Language Models via Multi-token Prediction

H(Y ). We decompose these two quantities as:


H(X) = H(X | Y ) + I(X; Y ),
H(X) + H(Y ) = H(X | Y ) + 2I(X; Y ) + H(Y | X).
By discarding the term H(Y | X)—which appears again
when predicting at the following position—we observe that
2-token prediction increases the importance of I(X; Y ) by
a factor of 2. So, multi-token predictors are more accurate at
predicting tokens X that are of relevance for the remainder
Figure 9: Multi-token prediction loss assigns higher im- of the text to come. In Appendix L.2, we give a relative ver-
plicit weights to consequential tokens. Shown is a se- sion of the above equations that shows the increased weight
quence in which all transitions except “5 → A” are easy to of relative mutual information in a loss decomposition of
predict, alongside the corresponding prediction targets in 2-token prediction loss.
3-token prediction. Since the consequences of the difficult
transition “5 → A” are likewise hard to predict, this transi-
tion receives a higher implicit weight in the overall loss via 6. Related work
its correlates “3 → A”, ..., “5 → C”. Language modeling losses Dong et al. (2019) and Tay
et al. (2022) train on a mixture of denoising tasks with dif-
ferent attention masks (full, causal and prefix attention) to
bridge the performance gap with next token pretraining on
generative tasks. Tay et al. (2022) uses the span corrup-
Multi-token prediction implicitly assigns weights to training tion objective, which replaces spans of tokens with special
tokens depending on how closely they are correlated with tokens for the encoder and the decoder then predicts the con-
their successors. As an illustrative example, consider the tents of those spans. Unlike UniLM, this allows full causal
sequence depicted in Figure 9 where one transition is a training with teacher forcing. Similarly, Yang et al. (2019)
hard-to-predict choice point while the other transitions are train on permuted sequences, while conserving the original
considered “inconsequential”. Inconsequential transitions positional embeddings, effectively training the model to pre-
following a choice point are likewise hard to predict in dict various parts of the sequence given a mix of past and
advance. By marking and counting loss terms, we find that future information. This permuted language modeling is
n-token prediction associates a weight of n(n+1)
2 to choice the closest task to ours since it allows predicting beyond the
points via their correlates, and a smaller weight of n to next token. However all of these language modeling tasks
inconsequential points. Please refer to Appendix L.3 for train on a small percentage of the input text: on average
more details. Generally, we believe that the quality of text only 15% of the tokens are backwarded through. For Dong
generations depends on picking the right decisions at choice et al. (2019), where the masking is done in BERT style, it
points, and that n-token prediction losses promote those. is hard to mask more than 15% since it destroys too much
information. For Tay et al. (2022), it is technically possible
5.2. Information-theoretic argument to have a larger proportion but in practice, the settings used
have between 15% and 25% of masked tokens. (Yang et al.,
Language models are typically trained by teacher-forcing,
2019) also makes it possible to train on the whole sequence
where the model receives the ground truth for each future
since it is only permuted, and no information is lost. Yet,
token during training. However, during test time generation
in practice, since the completely random permutation is
is unguided and autoregressive, whereby errors accumulate.
very hard to reconstruct, only 15% are predicted for training
Teacher-forcing, we argue, encourages models to focus on
stability reasons.
predicting well in the very short term, at the potential ex-
pense of ignoring longer-term dependencies in the overall
Multi-token prediction in language modelling Qi et al.
structure of the generated sequence.
(2020) argue that multi-token prediction encourages plan-
To illustrate the impact of multi-token prediction, consider ning, improves representations and prevents the overfitting
the following information-theoretic argument. Here, X on local patterns that can result from teacher-forced training.
denotes the next future token, and Y the second-next future However, their technical approach replicates the residual
token. The production of both of these tokens is conditioned stream n-fold while ours allows for compute-matched com-
on some observed, input context C, that we omit from our parisons and makes the residual representations participate
equations for simplicity. When placed before token X, more directly in the auxiliary loss terms. Stern et al. (2018)
vanilla next-token prediction concerns the quantity H(X), and Cai et al. (2024) propose model finetunings with multi-
while multi-token prediction with n = 2 aims at H(X) + token prediction for faster inference but do not study the

8
Better & Faster Large Language Models via Multi-token Prediction

effects of such a loss during pretraining. Pal et al. (2023) use In future work we would like to better understand how to au-
probing methods to show that next-token prediction models tomatically choose n in multi-token prediction losses. One
are able to predict additional consecutive tokens to a certain possibility to do so is to use loss scales and loss balanc-
extent, but less so than our models which are specifically ing (Défossez et al., 2022). Also, optimal vocabulary sizes
trained for this task. Jianyu Zhang (2024) observe improve- for multi-token prediction are likely different from those for
ments in language modelling tasks with multi-label binary next-token prediction, and tuning them could lead to better
classification over the occurrence of vocabulary words in results, as well as improved trade-offs between compressed
the future as an auxiliary learning task. sequence length and compute-per-byte expenses. Finally,
we would like to develop improved auxiliary prediction
Self-speculative decoding Stern et al. (2018) are, to the losses that operate in embedding spaces (LeCun, 2022).
best of our knowledge, the first to suggest a speculative
decoding scheme for faster inference. Our architecture re- Impact statement
places their linear prediction heads by transformer layers,
but is otherwise similar. By reorganizing the order of the for- The goal of this paper is to make language models more
ward/backward, we can use all loss terms instead of stochas- compute and data efficient. While this may in principle
tically picking one head for loss computation. Cai et al. reduce the ecological impact of training LLMs, we shall be
(2024) present a more elaborate self-speculative decoding careful about rebound effects. All societal advantages, as
scheme that uses the top-k predictions of each head instead well as risks, of LLMs should be considered while using
of the best one only. It can be used with the multi-token this work.
prediction models we train. Santilli et al. (2023) Propose an
alternative parallel decoding algorithm for encoder/decoder Environmental impact
architectures where the decoded block is refined iteratively.
In aggregate, training all models reported in the paper re-
Multi-target prediction Multi-task learning is the quired around 500K GPU hours of computation on hardware
paradigm of training neural networks jointly on several tasks of type A100-80GB and H100. Estimated total emissions
to improve performance on the tasks of interest (Caruana, were around 50 tCO2eq, 100% of which were offset by
1997). Learning with such auxiliary tasks allows models to Meta’s sustainability program.
exploit dependencies between target variables and can even
be preferable in the case of independent targets (Waegeman Acknowledgements
et al., 2019). While more specifically tailored architectures
for multi-target prediction are conceivable (Spyromitros- We thank Jianyu Zhang, Léon Bottou, Emmanuel Dupoux,
Xioufis et al., 2016; Read et al., 2021), modern deep learn- Pierre-Emmanuel Mazaré, Yann LeCun, Quentin Garrido,
ing approaches usually rely on large shared model trunks Megi Dervishi, Mathurin Videau and Timothée Darcet and
with separate prediction heads for the respective tasks (Caru- other FAIR PhD students and CodeGen team members for
ana, 1997; Silver et al., 2016; Lample et al., 2022) like we helpful discussions. We thank Jonas Gehring for his tech-
do. Multi-target prediction has been shown to be a suc- nical expertise and the original Llama team and xFormers
cessful strategy in various domains, e.g. for learning time team for enabling this kind of research.
series prediction with more distant time steps in the future
as auxiliary targets (Vapnik and Vashist, 2009) or for learn-
ing from videos with several future frames (Mathieu et al.,
2016; Srivastava et al., 2016) or representations of future
frames (Vondrick et al., 2016) as auxiliary targets.

7. Conclusion
We have proposed multi-token prediction as an improvement
over next-token prediction in training language models for
generative or reasoning tasks. Our experiments (up to 7B pa-
rameters and 1T tokens) show that this is increasingly useful
for larger models and in particular show strong improve-
ments for code tasks. We posit that our method reduces
distribution mismatch between teacher-forced training and
autoregressive generation. When used with speculative de-
coding, exact inference gets 3 times faster.

9
Better & Faster Large Language Models via Multi-token Prediction

References Alexandre Défossez, Jade Copet, Gabriel Synnaeve, and


Yossi Adi. High fidelity neural audio compression. arXiv
Rami Al-Rfou, Dokook Choe, Noah Constant, Mandy Guo,
preprint arXiv:2210.13438, 2022.
and Llion Jones. Character-level language modeling with
deeper self-attention. In Proceedings of the AAAI con- Moussa Kamal Eddine, Antoine J. P. Tixier, and Michalis
ference on artificial intelligence, volume 33, pages 3159– Vazirgiannis. Barthez: a skilled pretrained french
3166, 2019. sequence-to-sequence model, 2021.

Jacob Austin, Augustus Odena, Maxwell Nye, Maarten Alexander R. Fabbri, Irene Li, Tianwei She, Suyi Li, and
Bosma, Henryk Michalewski, David Dohan, Ellen Jiang, Dragomir R. Radev. Multi-news: a large-scale multi-
Carrie Cai, Michael Terry, Quoc Le, et al. Program document summarization dataset and abstractive hierar-
synthesis with large language models. arXiv preprint chical model, 2019.
arXiv:2108.07732, 2021.
Mehrdad Farahani. Summarization using bert2bert model on
Gregor Bachmann and Vaishnavh Nagarajan. The pitfalls wikisummary dataset. [Link]
of next-token prediction, 2024. summary, 2020.

Mehrdad Farahani, Mohammad Gharachorloo, and Moham-


Samy Bengio, Oriol Vinyals, Navdeep Jaitly, and Noam
mad Manthouri. Leveraging parsbert and pretrained mt5
Shazeer. Scheduled sampling for sequence prediction
for persian abstractive text summarization. In 2021 26th
with recurrent neural networks, 2015.
International Computer Conference, Computer Society
Yonatan Bisk, Rowan Zellers, Ronan Le Bras, Jianfeng of Iran (CSICC). IEEE, March 2021. doi: 10.1109/
Gao, and Yejin Choi. Piqa: Reasoning about physical csicc52343.2021.9420563. URL [Link]
commonsense in natural language, 2019. org/10.1109/CSICC52343.2021.9420563.

Michael C Frank. Bridging the data gap between children


Tianle Cai, Yuhong Li, Zhengyang Geng, Hongwu Peng, and large language models. Trends in Cognitive Sciences,
Jason D. Lee, Deming Chen, and Tri Dao. Medusa: Sim- 2023.
ple llm inference acceleration framework with multiple
decoding heads, 2024. Bogdan Gliwa, Iwona Mochol, Maciej Biesek, and Alek-
sander Wawer. Samsum corpus: A human-annotated
Rich Caruana. Multitask learning. Machine learning, 28: dialogue dataset for abstractive summarization. In Pro-
41–75, 1997. ceedings of the 2nd Workshop on New Frontiers in
Summarization. Association for Computational Linguis-
Mark Chen, Jerry Tworek, Heewoo Jun, Qiming Yuan, Hen- tics, 2019. doi: 10.18653/v1/d19-5409. URL http:
rique Ponde, Jared Kaplan, Harri Edwards, Yura Burda, //[Link]/10.18653/v1/D19-5409.
Nicholas Joseph, Greg Brockman, et al. Evaluating
large language models trained on code. arXiv preprint Sachin Goyal, Ziwei Ji, Ankit Singh Rawat, Aditya Krishna
arXiv:2107.03374, 2021. Menon, Sanjiv Kumar, and Vaishnavh Nagarajan. Think
before you speak: Training language models with pause
Nakhun Chumpolsathien. Using knowledge distillation from tokens, 2023.
keyword extraction to improve the informativeness of neu-
ral cross-lingual summarization. Master’s thesis, Beijing Dan Hendrycks, Steven Basart, Saurav Kadavath, Mantas
Institute of Technology, 2020. Mazeika, Akul Arora, Ethan Guo, Collin Burns, Samir
Puranik, Horace He, Dawn Song, et al. Measuring cod-
Karl Cobbe, Vineet Kosaraju, Mohammad Bavarian, Mark ing challenge competence with apps. arXiv preprint
Chen, Heewoo Jun, Lukasz Kaiser, Matthias Plappert, arXiv:2105.09938, 2021.
Jerry Tworek, Jacob Hilton, Reiichiro Nakano, et al.
Ari Holtzman, Jan Buys, Li Du, Maxwell Forbes, and Yejin
Training verifiers to solve math word problems. arXiv
Choi. The curious case of neural text degeneration, 2020.
preprint arXiv:2110.14168, 2021.
Leon Bottou Jianyu Zhang. Multi-label classification as an
Li Dong, Nan Yang, Wenhui Wang, Furu Wei, Xiaodong Liu, auxiliary loss for language modelling. personal commu-
Yu Wang, Jianfeng Gao, Ming Zhou, and Hsiao-Wuen nication, 2024.
Hon. Unified language model pre-training for natural lan-
guage understanding and generation. In Proceedings of Mandar Joshi, Eunsol Choi, Daniel S. Weld, and Luke Zettle-
the 33rd International Conference on Neural Information moyer. Triviaqa: A large scale distantly supervised chal-
Processing Systems, pages 13063–13075, 2019. lenge dataset for reading comprehension, 2017.

10
Better & Faster Large Language Models via Multi-token Prediction

Diederik Kingma and Jimmy Ba. Adam: A method for Ilya Loshchilov and Frank Hutter. Decoupled weight decay
stochastic optimization. ICLR, 2015. regularization, 2019.

Ryan Koo, Minhwa Lee, Vipul Raheja, Jong Inn Park, Michael Mathieu, Camille Couprie, and Yann LeCun. Deep
Zae Myung Kim, and Dongyeop Kang. Benchmarking multi-scale video prediction beyond mean square error,
cognitive biases in large language models as evaluators. 2016.
arXiv preprint arXiv:2309.17012, 2023.
Ramesh Nallapati, Bowen Zhou, Cicero Nogueira dos san-
Tom Kwiatkowski, Jennimaria Palomaki, Olivia Redfield, tos, Caglar Gulcehre, and Bing Xiang. Abstractive text
Michael Collins, Ankur Parikh, Chris Alberti, Danielle summarization using sequence-to-sequence rnns and be-
Epstein, Illia Polosukhin, Matthew Kelcey, Jacob Devlin, yond, 2016.
Kenton Lee, Kristina N. Toutanova, Llion Jones, Ming-
Shashi Narayan, Shay B. Cohen, and Mirella Lapata. Don’t
Wei Chang, Andrew Dai, Jakob Uszkoreit, Quoc Le, and
give me the details, just the summary! topic-aware con-
Slav Petrov. Natural questions: a benchmark for question
volutional neural networks for extreme summarization,
answering research. Transactions of the Association of
2018.
Computational Linguistics, 2019.
Catherine Olsson, Nelson Elhage, Neel Nanda, Nicholas
Guillaume Lample, Marie-Anne Lachaux, Thibaut Lavril,
Joseph, Nova DasSarma, Tom Henighan, Ben Mann,
Xavier Martinet, Amaury Hayat, Gabriel Ebner, Au-
Amanda Askell, Yuntao Bai, Anna Chen, Tom Con-
rélien Rodriguez, and Timothée Lacroix. Hypertree proof
erly, Dawn Drain, Deep Ganguli, Zac Hatfield-Dodds,
search for neural theorem proving, 2022.
Danny Hernandez, Scott Johnston, Andy Jones, Jack-
Yann LeCun. A path towards autonomous machine intelli- son Kernion, Liane Lovitt, Kamal Ndousse, Dario
gence version 0.9. 2, 2022-06-27. Open Review, 62(1), Amodei, Tom Brown, Jack Clark, Jared Kaplan, Sam
2022. McCandlish, and Chris Olah. In-context learning
and induction heads. Transformer Circuits Thread,
Benjamin Lefaudeux, Francisco Massa, Diana Liskovich, 2022. [Link]
Wenhan Xiong, Vittorio Caggiano, Sean Naren, Min Xu, learning-and-induction-heads/[Link].
Jieru Hu, Marta Tintore, Susan Zhang, Patrick Labatut,
OpenAI. Gpt-4 technical report, 2023.
and Daniel Haziza. xformers: A modular and hack-
able transformer modelling library. [Link] Long Ouyang, Jeff Wu, Xu Jiang, Diogo Almeida, Carroll L.
com/facebookresearch/xformers, 2022. Wainwright, Pamela Mishkin, Chong Zhang, Sandhini
Agarwal, Katarina Slama, Alex Ray, John Schulman, Ja-
Yaniv Leviathan, Matan Kalman, and Yossi Matias. Fast cob Hilton, Fraser Kelton, Luke Miller, Maddie Simens,
inference from transformers via speculative decoding, Amanda Askell, Peter Welinder, Paul Christiano, Jan
2023. Leike, and Ryan Lowe. Training language models to
follow instructions with human feedback, 2022.
Yujia Li, David Choi, Junyoung Chung, Nate Kush-
man, Julian Schrittwieser, Rémi Leblond, Tom Eccles, Koyena Pal, Jiuding Sun, Andrew Yuan, Byron C. Wallace,
James Keeling, Felix Gimeno, Agustin Dal Lago, et al. and David Bau. Future lens: Anticipating subsequent
Competition-level code generation with alphacode. Sci- tokens from a single hidden state, 2023.
ence, 378(6624):1092–1097, 2022.
Weizhen Qi, Yu Yan, Yeyun Gong, Dayiheng Liu, Nan
Chin-Yew Lin. ROUGE: A package for automatic evalu- Duan, Jiusheng Chen, Ruofei Zhang, and Ming Zhou.
ation of summaries. In Text Summarization Branches Prophetnet: Predicting future n-gram for sequence-to-
Out, pages 74–81, Barcelona, Spain, July 2004. Asso- sequence pre-training, 2020.
ciation for Computational Linguistics. URL https:
//[Link]/W04-1013. Jesse Read, Bernhard Pfahringer, Geoffrey Holmes, and
Eibe Frank. Classifier chains: A review and perspectives.
Zhenghao Lin, Zhibin Gou, Yeyun Gong, Xiao Liu, Yelong Journal of Artificial Intelligence Research, 70:683–718,
Shen, Ruochen Xu, Chen Lin, Yujiu Yang, Jian Jiao, Nan 2021.
Duan, and Weizhu Chen. Rho-1: Not all tokens are what
you need, 2024. Melissa Roemmele, Cosmin Adrian Bejan, and Andrew S
Gordon. Choice of plausible alternatives: An evaluation
Ilya Loshchilov and Frank Hutter. Sgdr: Stochastic gradient of commonsense causal reasoning. In 2011 AAAI Spring
descent with warm restarts, 2017. Symposium Series, 2011.

11
Better & Faster Large Language Models via Multi-token Prediction

Andrea Santilli, Silvio Severino, Emilian Postolache, Zhilin Yang, Zihang Dai, Yiming Yang, Jaime Carbonell,
Valentino Maiorca, Michele Mancusi, Riccardo Marin, Russ R Salakhutdinov, and Quoc V Le. Xlnet: Gen-
and Emanuele Rodolà. Accelerating transformer infer- eralized autoregressive pretraining for language under-
ence for translation via parallel decoding. arXiv preprint standing. In Advances in neural information processing
arXiv:2305.10427, 2023. systems, pages 5753–5763, 2019.

Maarten Sap, Hannah Rashkin, Derek Chen, Ronan LeBras, Rowan Zellers, Ari Holtzman, Yonatan Bisk, Ali Farhadi,
and Yejin Choi. Socialiqa: Commonsense reasoning and Yejin Choi. Hellaswag: Can a machine really finish
about social interactions, 2019. your sentence?, 2019.

David Silver, Aja Huang, Chris J Maddison, Arthur Guez,


Laurent Sifre, George Van Den Driessche, Julian Schrit-
twieser, Ioannis Antonoglou, Veda Panneershelvam,
Marc Lanctot, et al. Mastering the game of go with
deep neural networks and tree search. nature, 529(7587):
484–489, 2016.

Aaditya K Singh, Stephanie CY Chan, Ted Moskovitz, Erin


Grant, Andrew M Saxe, and Felix Hill. The transient
nature of emergent in-context learning in transformers.
arXiv preprint arXiv:2311.08360, 2023.

Eleftherios Spyromitros-Xioufis, Grigorios Tsoumakas,


William Groves, and Ioannis Vlahavas. Multi-target re-
gression via input space expansion: treating targets as
inputs. Machine Learning, 104:55–98, 2016.

Nitish Srivastava, Elman Mansimov, and Ruslan Salakhut-


dinov. Unsupervised learning of video representations
using lstms, 2016.

Mitchell Stern, Noam Shazeer, and Jakob Uszkoreit. Block-


wise parallel decoding for deep autoregressive models,
2018.

Yi Tay, Mostafa Dehghani, Vinh Q Tran, Xavier Gar-


cia, Jason Wei, Xuezhi Wang, Hyung Won Chung, Sia-
mak Shakeri, Dara Bahri, Tal Schuster, et al. Ul2:
Unifying language learning paradigms. arXiv preprint
arXiv:2205.05131, 2022.

Vladimir Vapnik and Akshay Vashist. A new learning


paradigm: Learning using privileged information. Neural
networks, 22(5-6):544–557, 2009.

Carl Vondrick, Hamed Pirsiavash, and Antonio Torralba.


Anticipating visual representations from unlabeled video,
2016.

Willem Waegeman, Krzysztof Dembczyński, and Eyke


Hüllermeier. Multi-target prediction: a unifying view
on problems and methods. Data Mining and Knowledge
Discovery, 33:293–324, 2019.

Vikas Yadav, Steven Bethard, and Mihai Surdeanu. Quick


and (not so) dirty: Unsupervised selection of justifica-
tion sentences for multi-hop question answering. arXiv
preprint arXiv:1911.07176, 2019.

12
Better & Faster Large Language Models via Multi-token Prediction

A. Additional results on self-speculative decoding

3.0 1.0
Throughput (relative)

2.5 0.8

Latency (relative)
2.0 0.6
1.5
k=1 0.4 k=1
1.0 k=2 k=2
0.5 k=3 0.2 k=3
k=4 k=4
0.0 0.0
1 8 16 24 32 40 1 8 16 24 32 40
Batch size Batch size
Figure S10: Decoding speeds and latencies with self-speculative decoding relative to standard autoregressive decoding.
We use k heads of a 4-token prediction model and evaluate decoding speeds of a code model as explained in Table S2. All
numbers are relative to the autoregressive (k = 1) baseline with the same batch size.

Table S2: Relative speedups with self-speculative decoding. For wikipedia and books we prompt a 7B parameter model
trained on 500B tokens, and for code we prompt a 7B parameter model trained on 1T tokens of code on 4200 sequences of
512 tokens from a test dataset not seen during training, and generate completions consisting of 512 tokens using greedy
self-speculative decoding (Stern et al., 2018) using the indicated number of heads from a 4-token prediction model. Note
that the maximal speedup that can be obtained with self-speculative decoding using k heads is k. The last column shows the
average number of tokens retrieved from a forward containing this sequence (both verification and prediction). The speedup
was evaluated at the maximal batch size of 42, but is constant across batch sizes (Figure S10).

Wikipedia Books Code


# Heads used Rel. speedup Tokens / forward Rel. speedup Tokens / forward Rel. speedup Tokens / forward
1 1.00 1.00 1.00 1.00 1.00 1.00
2 1.79 1.88 1.77 1.87 1.85 1.94
3 2.35 2.57 2.32 2.56 2.54 2.78
4 2.74 3.12 2.67 3.09 3.05 3.50

Table S3: Relative speedups with self-speculative decoding with byte-level models on code. We prompt the 7B parameter
models from Section 3.3 on 4096 sequences of 1024 bytes of code not seen during training, and generate completions
consisting of 1024 bytes using greedy self-speculative decoding (Stern et al., 2018) as in Table S2. The speedup was
evaluated at a batch size of 16.

n=8 n = 16 n = 32
# Heads used Rel. speedup Tokens / forward Rel. speedup Tokens / forward Rel. speedup tokens / forward
1 1.00 1.00 1.00 1.00 1.00 1.00
2 1.94 1.98 1.94 1.98 1.93 1.97
4 3.67 3.84 3.63 3.81 3.62 3.80
8 6.39 7.04 6.25 6.92 6.22 6.89
12 − − 8.07 9.36 8.01 9.30
16 − − 9.24 11.20 9.15 11.15
20 − − − − 9.83 12.61
24 − − − − 10.34 13.67
28 − − − − 10.55 14.58
32 − − − − 10.84 15.35

13
Better & Faster Large Language Models via Multi-token Prediction

B. Alternative architectures
Table S4: Alternative architectures improve on baseline but not as consistently. Alternative architectures for multi-token
prediction are worth exploring to improve efficiency. Here we tried Anticausal, causal and linear and showed no significant
improvement with respect to Parallel architecture.

MBPP HumanEval APPS/Intro


n Head type Architecture +Layers @1 @10 @100 @1 @10 @100 @1 @10 @100
1 transformer parallel 0 30.0 53.8 73.7 22.8 36.4 62.0 2.8 7.8 17.4
linear parallel 0 33.6 55.0 76.2 21.9 38.5 63.7 3.1 10.1 23.0
anticausal 0 30.8 54.8 75.3 20.9 38.4 64.5 2.0 8.7 21.6
4
transformer causal 0 31.9 54.9 74.9 20.9 38.1 67.3 4.0 11.6 22.8
0 33.8 55.9 76.9 24.0 40.1 66.1 1.6 7.1 19.9
parallel
3 33.3 55.7 77.3 22.4 39.4 66.7 2.6 9.5 22.1

The architecture described in Section 2 is not the only sensible option, but proved technically viable and well-performing in
our experiments. We describe and compare alternative architectures in this section.

Replicated unembeddings Replicating the unembedding matrix n times is a simple method for implementing multi-token
prediction architectures. However, it requires matrices with shapes (d, nV ) in the notation of Section 2, which is prohibitive
for large-scale trainings.

Linear heads Apart from using a single transformer layer for the heads Hi , other architectures are conceivable. We
experimented with a single linear layer without any nonlinearity as heads, amounting to linear probing of the model’s
residual representation z. Architectures with more than one layer per head are also possible, but we did not pursue this
direction further.

Causal and anticausal variant Instead of making the prediction heads Pi (xt+i | zt:1 ) architecturally independent of each
other, we can also allow them to rely on other heads’ (pre-unembedding) outputs. In a causal variant, later prediction heads
are applied on top of the previous ones, i.e. the i-th prediction head Pi is given by

Pθ (xt+i |·) = softmax ◦ fu ◦ fhi ◦ fhi−1 · · · ◦ fh1 ◦ fs .

In another anticausal variant, the network starts by predicting the most distant tokens before gradually refining up to the
following token:
Pθ (xt+i |·) = softmax ◦ fu ◦ fhi ◦ fhi+1 · · · ◦ fhn ◦ fs .
These architectures likewise allow a sequential forward/backward order as the parallel architecture from Section 2. This is
described in Figure S11.

14
Better & Faster Large Language Models via Multi-token Prediction

4
Head 2 Loss 2
5
3 6
7
Head 1 Loss 1
8
2 9

Trunk

1 10

Input

Figure S11: Order of the forward/backward in a causal n-token prediction model with n = 2 heads. Like in the
forward/backward depicted for parallel prediction heads in Figure 2, we avoid materializing all unembedding layer gradients
in memory simultaneously and reduce peak GPU memory usage significantly. The iteration over the heads starts with the
one furthest to the trunk. At each head, a gradient from the succeeding prediction heads and from the head’s own loss are
accumulated for both the head’s output and its weights.

C. Training speeds

Table S5: Training time relative to next-token prediction training. The slight overhead when using multi-token prediction
here is explained by a suboptimal use of Fully Sharded Data Parallel. In our implementation, when doing separate backward
passes for each head, we lose the overlap of layer weight communication and computation, therefore it incurs a very slight
overhead that can be removed if reimplemented correctly.

Model n=1 n=2 n=4


0.3B 1.00 1.07 1.22
0.6B 1.00 1.05 1.13
1.3B 1.00 1.04 1.12
3B 1.00 1.02 1.07
6.7B 1.00 1.02 1.07
13B 1.00 1.04 1.09

D. Finetuning

Table S6: Finetuning LLama 2 with multi-token prediction does not significantly improve performance. We tried to
finetune LLama 2 with 4-token prediction but this did not yield significant improvements compared to the baseline. We
suppose that this new loss changes the initialization too brutally and never really recovers. We still some improvements for
example on MBPP Pass@1. All runs use 200B tokens of code.

MBPP HumanEval APPS/Intro


n Head type +Layers @1 @10 @100 @1 @10 @100 @1 @10 @100
1 transformer 0 39.6 65.1 82.4 31.4 57.7 84.7 10.0 21.6 36.7
linear 0 39.3 63.7 81.3 29.0 53.4 82.2 6.9 20.0 34.0
4 0 38.3 62.2 80.1 27.9 53.6 82.4 5.8 18.2 34.3
transformer
3 42.5 64.4 81.3 28.7 56.9 82.4 7.8 21.2 37.3

15
Better & Faster Large Language Models via Multi-token Prediction

E. Additional results on model scaling behavior

Table S7: Scaling model size Full results of scaling model size with n=1,2 and 4.

MBPP HumanEval
Model Size Fut @1 @10 @100 @1 @10 @100
1 1.8 10.4 29.9 1.9 5.0 10.9
0.3B 2 1.7 10.1 27.2 1.5 4.4 10.3
4 1.0 6.3 20.1 1.2 4.0 8.6
1 4.7 21.0 45.2 2.9 8.5 16.7
0.6B 2 4.6 21.0 44.7 3.2 8.9 16.2
4 3.0 15.6 38.0 2.7 7.7 15.5
1 6.8 27.0 51.0 4.6 13.1 24.3
1.3B 2 7.3 27.5 51.7 5.4 13.6 23.3
4 7.4 27.6 50.1 4.8 12.3 22.5
1 11.1 36.4 60.4 7.2 17.2 29.8
3B 2 11.8 37.2 60.5 8.0 18.2 31.2
4 12.7 37.6 61.1 7.2 18.5 33.3
1 23.9 54.2 74.7 12.8 29.3 51.7
6.7B 2 24.7 54.8 76.4 13.2 32.2 53.9
4 26.0 55.8 76.0 13.8 33.2 58.5
1 26.0 57.1 77.0 14.1 33.6 56.0
13B 2 30.5 60.5 79.4 15.2 36.9 60.0
4 30.5 61.0 79.2 15.8 38.6 63.5

F. Details on CodeContests finetuning


We use the Python subset of the CodeContests (Li et al., 2022) train split with reward annotations (“correct” / “incorrect”)
and condition on correct solutions at evaluation time. For evaluation, we generate 1000 samples per problem from the test
split for each temperature T ∈ {0.5, 0.6, 0.7, 0.8, 0.9}, and compute the unbiased estimator for pass@k from Chen et al.
(2021) for each value of k and T . It is possible that models that were pretrained with different losses have different respective
optimal temperatures for pass@k, so we compute and show k 7→ maxT pass_at(k, T ) in Figure 4. In other words, we grant
pass@k access to a temperature oracle. For small values of k, pass@k measures the capability of understanding and solving
tasks while for large k, it additionally favors diversity in outputs. According to the results in Figure 4, multi-token prediction
pretraining leads to finetuned models that are better on both axes.

16
Better & Faster Large Language Models via Multi-token Prediction

G. Additional results on natural language benchmarks


We evaluate the models from Section 3.7 on standard natural language processing benchmarks: ARC Challenge (Yadav
et al., 2019), COPA (Roemmele et al., 2011), Hellaswag (Zellers et al., 2019), Natural Questions (Kwiatkowski et al., 2019),
PIQA (Bisk et al., 2019), SIQA (Sap et al., 2019) and TriviaQA (Joshi et al., 2017).

arc_challenge copa hellaswag


80
35 60
value

30 70 50

25 40

nq piqa siqa
15
75 46
10 n
value

70 44 1
2
5 4
65 42
10000 20000 10000 20000
tqa global_step global_step
40
30
value

20
10
10000 20000
global_step
Figure S12: Multiple token training with 7B models doesn’t improve performance on choice tasks. This figure shows
the evolution of average accuracy of some standard NLP benchmarks (ARC Challenge COPA Hellaswag MMLU Natural
Questions PIQA SIQA and TriviaQA. For the 7B models trained on 200B tokens of language data, the 2 future token
model has the same performance as the baseline and the 4 future token model regresses a bit. Larger model sizes might be
necessary to see improvements on these tasks.

17
Better & Faster Large Language Models via Multi-token Prediction

H. Additional results on abstractive text summarization


In this section, we report comprehensive evaluation results on summarization tasks for the 7B parameter models trained on
200B and 500B tokens of natural language from Section 3.7.

Table S8: Comprehensive evaluation on abstractive text summarization. ROUGE-n (n-gram overlap) and ROUGE-L
(longest common subsequence overlap) F1 scores for 7B models trained on 200B and 500B tokens of natural language,
respectively. The last three columns correspond to models trained on 500B tokens, the previous three to models trained on
200B tokens. Shown are numbers of the n = 1 baseline and the absolute difference of n = 2 and n = 4 models trained
on the same number of tokens. Summary-level ROUGE-L (“ROUGE-Lsum ”) is reported where it differs from ROUGE-L.
Model checkpoints with maximal validation ROUGE-L F1 are selected separately for each model dataset and model type
and reported in the first row corresponding to each dataset. Boldface for numbers within 0.05 difference to the best one for
each dataset size separately.

Task Metric Baseline 200B ∆n=2 ∆n=4 Baseline 500B ∆n=2 ∆n=4
evaluation epoch 2 2 2 2 2 2
ROUGE-1 42.88 +0.74 +0.74 43.77 +0.55 +0.50
ROUGE-2 19.56 +0.52 +0.53 20.34 +0.52 +0.34
CNN/Dailymail (Nallapati et al., 2016)
ROUGE-3 11.11 +0.39 +0.35 11.69 +0.36 +0.19
ROUGE-L 29.72 +0.66 +0.49 30.51 +0.48 +0.37
ROUGE-Lsum 40.18 +0.72 +0.68 41.02 +0.56 +0.52
evaluation epoch 1 3 3 2 3 2
ROUGE-1 44.48 +1.70 +1.72 45.87 +1.05 +0.69
Multi-News (Fabbri et al., 2019) ROUGE-2 16.88 +0.44 +0.70 17.56 +0.42 +0.40
ROUGE-3 9.63 -0.06 +0.17 9.91 +0.22 +0.18
ROUGE-L 23.82 +0.17 +0.40 24.22 +0.20 +0.26
evaluation epoch 2 2 3 2 1 3
ROUGE-1 32.95 +0.41 +0.35 33.37 +0.32 +0.78
OrangeSum (Eddine et al., 2021) ROUGE-2 13.90 +0.31 +0.36 14.22 +0.25 +0.53
ROUGE-3 8.01 +0.19 +0.21 8.12 +0.22 +0.48
ROUGE-L 23.62 +0.36 +0.51 23.91 +0.23 +0.66
evaluation epoch 1 1 1 1 2 3
ROUGE-1 1.03 +0.02 0.00 0.92 +0.09 +0.05
pn-summary (Farahani et al., 2021) ROUGE-2 0.13 +0.02 +0.03 0.15 0.00 0.00
ROUGE-3 0.02 0.00 +0.02 0.02 0.00 +0.02
ROUGE-L 1.02 +0.03 +0.01 0.91 +0.09 +0.05
evaluation epoch 3 3 3 3 3 3
ROUGE-1 51.39 +0.70 +0.63 52.54 -0.24 +0.69
SAMSum (Gliwa et al., 2019) ROUGE-2 26.46 +0.76 +0.30 27.74 -0.20 +0.82
ROUGE-3 16.40 +0.91 +0.28 17.56 -0.30 +0.71
ROUGE-L 42.59 +0.90 +0.51 43.92 -0.10 +0.63
evaluation epoch 2 3 3 3 3 3
ROUGE-1 45.08 +0.63 +1.12 45.48 +0.77 +0.91
ThaiSum (Chumpolsathien, 2020) ROUGE-2 27.85 +0.30 +0.73 28.07 +0.74 +0.64
ROUGE-3 15.73 +0.04 +0.43 15.82 +0.50 +0.30
ROUGE-L 44.92 +0.64 +1.12 45.31 +0.76 +0.89
evaluation epoch 3 3 3 3 3 3
ROUGE-1 10.16 +0.67 -0.23 12.80 -0.17 -0.99
WikiSummary (Farahani, 2020) ROUGE-2 4.46 -0.03 -0.09 6.17 -0.11 -0.69
ROUGE-3 1.31 +0.21 +0.13 1.98 -0.08 -0.33
ROUGE-L 10.11 +0.65 -0.28 12.69 -0.17 -0.99
evaluation epoch 2 2 3 2 2 3
ROUGE-1 42.16 +0.71 +1.07 43.42 +0.78 +0.67
XSum (Narayan et al., 2018) ROUGE-2 19.19 +0.54 +0.55 20.32 +0.68 +0.34
ROUGE-3 10.43 +0.38 +0.28 11.23 +0.48 +0.20
ROUGE-L 34.03 +0.67 +0.92 35.18 +0.79 +0.63

18
Better & Faster Large Language Models via Multi-token Prediction

Table S9: Performance on abstractive text summarization. ROUGE-L (longest common subsequence overlap) F1 score
for 7B models trained on 200B and 500B tokens of natural language. We finetune the respective models on each task’s
training data separately for a given number of epochs and select the checkpoints with maximal ROUGE-L F1 on the
validation dataset. The second and fifth column report the numbers for a next-token prediction model, while the third, fourth,
sixth and seventh one report the absolute improvements for 2-token and 4-token prediction models trained on the same
amount of data, respectively. Boldface for numbers within 0.05 difference to the best one for each dataset size separately.

Dataset Baseline 200B ∆n=2 ∆n=4 Baseline 500B ∆n=2 ∆n=4


CNN/Dailymail 29.72 +0.66 +0.49 30.51 +0.48 +0.37
Multi-News 23.82 +0.17 +0.40 24.22 +0.20 +0.26
OrangeSum 23.62 +0.36 +0.51 23.91 +0.23 +0.66
pn-summary 1.02 +0.03 +0.01 0.91 +0.09 +0.05
SAMSum 42.59 +0.90 +0.51 43.92 -0.10 +0.63
ThaiSum 44.92 +0.64 +1.12 45.31 +0.76 +0.89
WikiSummary 10.11 +0.65 -0.28 12.69 -0.17 -0.99
XSum 34.03 +0.67 +0.92 35.18 +0.79 +0.63
Average 26.23 +0.51 +0.46 27.08 +0.28 +0.31

Table S10: Summary statistics for abstractive text summarization evaluations. Reported are averages for ROUGE-n and
ROUGE-L metrics across all datasets from Table S8, separately for precision, recall and F1 score. Both 2-token and 4-token
prediction models outperform the next-token prediction baseline. Trained on 500B tokens, 4-token prediction models appear
better at recall metrics while 2-token prediction models appear better at precision metrics. Model checkpoints are selected
as described in Table S8. Boldface for numbers within 0.05 difference to the best one for each dataset size separately.

Metric Aspect Baseline 200B ∆n=2 ∆n=4 Baseline 500B ∆n=2 ∆n=4
F1 33.77 +0.70 +0.68 34.77 +0.39 +0.41
ROUGE-1 precision 35.76 +0.88 +0.83 37.03 +0.42 -0.04
recall 34.37 +0.45 +0.45 35.14 +0.35 +0.68
F1 16.06 +0.36 +0.39 16.82 +0.29 +0.30
ROUGE-2 precision 16.97 +0.40 +0.43 17.91 +0.29 +0.03
recall 16.34 +0.28 +0.35 16.99 +0.32 +0.48
F1 9.08 +0.26 +0.23 9.54 +0.18 +0.22
ROUGE-3 precision 9.59 +0.29 +0.28 10.17 +0.18 +0.05
recall 9.26 +0.21 +0.20 9.65 +0.21 +0.35
F1 26.23 +0.51 +0.46 27.08 +0.28 +0.31
ROUGE-L precision 27.79 +0.62 +0.55 28.85 +0.28 -0.09
recall 26.71 +0.37 +0.32 27.40 +0.28 +0.57
F1 27.53 +0.52 +0.48 28.40 +0.29 +0.33
ROUGE-Lsum precision 29.07 +0.64 +0.58 30.15 +0.29 -0.08
recall 28.13 +0.35 +0.33 28.81 +0.29 +0.60

19
Better & Faster Large Language Models via Multi-token Prediction

I. Additional results on mathematical reasoning in natural language


200B tokens 500B tokens
8 n=1
n=2
pass@1 (%)
3 6 n=4
4
2
2

20 30
pass@10 (%)

15 20

60
pass@100 (%)

50 60

40
40
0.2 0.4 0.6 0.8 1.0 1.2 1.4 0.2 0.4 0.6 0.8 1.0 1.2 1.4
Temperature Temperature
Figure S13: Performance on the mathematical reasoning benchmark GSM8K (Cobbe et al., 2021). We evaluate
pretrained next-token and multi-token prediction models trained on 200B and 500B tokens of natural language in 8-shot
mode using nucleus sampling (Holtzman et al., 2020) with probability mass 0.95 and various sampling temperatures.
Reported are the frequencies of the correct final answer to appear among k samples, for k = 1, 10, 100, estimated from
200 samples like in code generation benchmarks (Chen et al., 2021). After 200B tokens, the 2-token prediction model
has a clear advantage over the next-token baseline but the order reverses after 500B tokens. The 4-token prediction model
is worse throughout. We interpret this similarly to the findings in Section 4.1: the follow-your-nose chains-of-thought
required for GSM8K may be difficult to learn from a limited amount of data, attesting to the data efficiency of multi-token
prediction training. Once the correct circuits for correct autoregressive chains-of-thought in this domain have formed,
however, multi-token prediction comes at a cost.

20
Better & Faster Large Language Models via Multi-token Prediction

J. Additional results on induction learning

1.000
0.975

Induction success
0.950
0.925
0.900
0.875 n=1 (baseline)
n=2 (ours)
0.850
1 3 10 30 100 300 1000
Parameters (M)
Figure S14: Induction capability of n-token prediction models trained on higher-quality data. Shown is accuracy on the
second token of two token names that have already been mentioned previously. Training on a 9:1 mix of a books dataset and
the children storiy dataset, we observe that induction capability forms significantly earlier in training (not shown here) and to
a higher degree. We believe that this is explained both because our evaluation dataset no longer contains out-of-distribution
tokens (Section 4.1) and because the higher-quality data contained in the books dataset makes induction necessary earlier on
(especially for small models, cf. Singh et al. (2023)). In particular, by enforcing the formation of induction capability in the
model by means of the dataset – instead of the loss – the advantage of 2-token prediction models on this task disappears
except for the smallest models: feature learning converts the task into a pure next-token prediction task.

21
Better & Faster Large Language Models via Multi-token Prediction

K. Additional results on algorithmic reasoning


We investigate the following computation-sharing hypothesis for explaining the efficacy of multi-token prediction as training
loss.

The prediction difficulty of different tokens in natural text varies greatly. Some tokens may be the continuations
of partial words that are uniquely determined from their preceding context without any effort, while others may
require to predict theorem names in difficult mathematical proofs or the correct answer to an exam question.
Language models with residual connections have been shown to refine their output token distribution with each
successive layer, and can be trained with early exit strategies that spend variable amounts of computational
resources per token position. Multi-token prediction losses explicitly encourage information-sharing between
adjacent token positions and can thus be viewed as a method to learn allocating computational resources in
language models more efficiently to the tokens that benefit most of it.

To check the truth of this hypothesis, we augment the polynomial arithmetic task from Section 4.2 with a varying number of
pause tokens (Goyal et al., 2023) inserted between the question and a token that denotes the beginning of the answer. Pause
tokens introduce additional computational resources that can be expended for computations that are expected to be useful
later on in the sequence, in other words: to start thinking about the answer. According to the computation-sharing hypothesis,
multi-token prediction models learn information-sharing and thus computation-sharing between token positions more easily,
and may be better at making use of these additional computational resources than next-token prediction models are. In
Figure S15, we show the evaluation results on the polynomial arithmetic task with a fixed number of pause tokens inserted
both at training and evaluation time. Multi-token prediction models likewise outperform next-token prediction models
on these task variants across task difficulties and model sizes. However, we do not see strong evidence of a widening or
shrinking of this gap i.e. we cannot conclude from these experiments on the veracity of the computation-sharing hypothesis.
In Table S11, we report results from another experiment in the same spirit: by adding spaces and newlines to HumanEval
and MBPP prompts, we add “pause tokens” in a somewhat natural way. According to these results, multi-token prediction
models have a slight advantage at using this additionally provided compute, but the effect is marginal.

100 n=1 100 n=1


n=2 n=2
80 n=4 80 n=4
Accuracy (%)

Accuracy (%)

60 60
40 40
20 20
0 0
1 2 3 4 5 6 7 8 9 10 1 2 3 4 5 6 7 8 9 10
# operations # operations
in-domain out-of-domain in-domain out-of-domain
(a) 5 pause tokens (b) 10 pause tokens

Figure S15: Accuracy on a polynomial arithmetic task with varying number of operations per expression and pause
tokens. We train and evaluate models on the polynomial arithmetic task described in Section 4.2, modified by the addition
of pause tokens (Goyal et al., 2023): between the question and the equality sign that indicates the beginning of the answer,
we add a constant number of pause tokens both in training and evaluation. For both a variant with five and with ten pause
tokens, respectively, we observe comparable improvements from using multi-token prediction to the ones obtained in the
case without pause tokens (Figure 8).

22
Better & Faster Large Language Models via Multi-token Prediction

Table S11: Utilization of additional whitespace tokens in code benchmarks.

Task Whitespace n=1 n=4


APPS/Intro spaces + newline +0.21 +0.34
APPS/Intro newline +0.79 +0.69
HumanEval spaces + newline -0.72 -0.16
HumanEval newline -0.26 +0.10
MBPP spaces + newline -0.10 -0.06
MBPP newline +0.03 -0.08
Average -0.01 +0.14

30M, n=1
100 30M, n=2
30M, n=4
100M, n=1
100M, n=2
100M, n=4
80

60
Accuracy (%)

40

20

0
1 2 3 4 5 6 7 8 9 10
# operations

in-domain out-of-domain
Figure S16: Accuracy on a polynomial arithmetic task for two model sizes. We train and evaluate models with 30M and
100M parameters on the polynomial arithmetic task described in Section 4.2. Tripling the model size has a smaller effect
on performance than replacing next-token prediction loss by multi-token prediction. Shown are two independent runs per
configuration and their means, the 100M parameter models being identical to the ones in Figure 8.

23
Better & Faster Large Language Models via Multi-token Prediction

Table S12: Optimal temperatures for all numbers in table 1

MBPP HumanEval APPS/Intro


Training data Vocabulary n
@1 @10 @100 @1 @10 @100 @1 @10 @100
1 0.2 0.8 0.8 0.1 0.8 0.8 0.8 0.8 0.8
313B bytes 8 0.1 0.8 0.8 0.1 0.8 0.8 0.4 0.4 0.4
bytes
(0.5 epochs) 16 0.1 0.8 0.8 0.1 0.8 0.8 0.4 0.4 0.4
32 0.1 0.4 0.8 0.1 0.4 0.8 0.1 0.4 0.4
1 0.1 0.8 0.8 0.1 0.8 0.8 0.1 0.4 0.8
2 0.1 0.8 0.8 0.2 0.8 0.8 0.4 0.4 0.8
200B tokens
(0.8 epochs) 32k tokens 4 0.1 0.8 0.8 0.1 0.8 0.8 0.2 0.8 0.8
6 0.1 0.8 0.8 0.2 0.8 0.8 0.4 0.4 0.8
8 0.1 0.8 0.8 0.1 0.8 0.8 0.2 0.4 0.8
1T tokens 1 0.1 0.8 0.8 0.1 0.8 0.8 0.1 0.4 0.8
(4 epochs) 32k tokens
4 0.1 0.8 0.8 0.2 0.8 0.8 0.4 0.8 0.8

L. Additional intuitions on multi-token prediction


L.1. Comparison to scheduled sampling
In Section 5.2, we argued that multi-token prediction reduces the distribution mismatch between teacher-forced training and
autoregressive evaluation of language models. Scheduled sampling (Bengio et al., 2015) is a curriculum learning method
that likewise aims to bridge this gap in sequence prediction tasks by gradually replacing more and more input tokens with
model-generated ones.
While effective in areas such as time series forecasting, scheduled sampling is, in our opinion, inapplicable to language
modelling due to the discrete nature of text. Replacing ground truth input sequences by interleavings of ground truth and
model-generated tokens frequently results in ungrammatical, factually wrong or otherwise incoherent text, which should
be avoided at all cost. Moreover, unlike multi-token prediction, the technique originally developed for recurrent neural
networks cannot easily be adapted for parallel training setups like the ones of transformer models.

L.2. Information-theoretic argument


We give details on the information-theoretic terms appearing in the decomposition in Section 5.2 and derive a relative
version that similarly allows to decompose multi-token prediction losses. As in Section 5.2, denote by X the next token
and by Y the second-next one, and omit conditioning on the preceding context C for ease of notation. In Section 5.2, we
decomposed H(X) + H(Y )—the quantity of interest for 2-token prediction models—as follows:

H(X) + H(Y ) = H(X | Y ) + 2I(X; Y ) + H(Y | X). (3)

Let us explain each of the terms. The entropy terms denote the uncertainty contained in the ground-truth random variables
X and Y . 2 The term H(Y | X) is a classical next-token entropy for the prefix (C, X). The conditional entropy H(X | Y )
is a more theoretical entity not modelled by causal models. It describes the uncertainty about X given the prefix C and suffix
Y , and therefore captures the local variations of X that do not affect the continuation of the text Y . The mutual information
I(X; Y ) on the other hand describes the information about Y contained in X (and vice versa) and therefore captures the
variations of X which constrain the continuation of the text.
However, the argument given in Section 5.2 relies on the assumption that multi-token prediction losses obey a similar
decomposition as the sum of the ground-truth entropies themselves. Let us make this rigorous. Denote by p(x, y) the
joint distribution of X and Y , by p(x) (short for pX (x)) the marginal distribution of X and by p(y) the one of Y . Denote
the densities of the model’s predictions by q(x, y), q(x) and q(y), respectively, conditional distributions by p(x | y) and
Kullback-Leibler divergence from q to p by D(p ∥ q) and cross-entropy from q to p by H(p, q).
Definition L.1. The conditional cross-entropy H(pX|Y , qX|Y ) of X conditioned on Y from q to p is defined as the
2
In particular, they do not refer to model predictions.

24
Better & Faster Large Language Models via Multi-token Prediction

expectation under y of the cross-entropy between the distributions pX and qX conditioned on y, in formulas:

H(pX|Y , qX|Y ) = E H(pX|Y =y , qX|Y =y ) = E H(p(· | y), q(· | y)).


y∼pY y∼pY

Definition L.2. The relative mutual information Ip∥q (X; Y ) of X and Y from q relative to p is defined by

Ip∥q (X; Y ) = D(p ∥ qX ⊗ qY ) − D(p ∥ q).

We have Ip∥q (X; Y ) = H(pX , qX ) + H(pY , qY ) − H(p, q), Ip∥p (X; Y ) = Ip (X; Y ) reduces to standard mutual informa-
tion under the distribution p and Ip∥q (X; Y ) is symmetric in X and Y but can be negative.
We have the following relative version of the decomposition H(X) = H(X | Y ) + I(X; Y ).
Lemma L.3. H(pX , qX ) = H(pX|Y , qX|Y ) + Ip∥q (X; Y ).

Proof. We calculate
X
H(pX , qX ) = − p(x) log q(x)
x
X
=− p(x, y) log q(x)
x,y
X q(x)q(y) p(x, y) q(x, y)
=− p(x, y) log
x,y
p(x, y) q(x, y) q(y)
X
= D(p ∥ qX ⊗ qY ) − D(p ∥ q) − p(y)p(x | y) log q(x | y)
x,y
X
= Ip∥q (X; Y ) + p(y)H(pX|y , qY |y )
y

= Ip∥q (X; Y ) + H(pX|Y , qX|Y ).

Symmetrizing, we get the desired relative version of H(X) + H(Y ) = H(X | Y ) + 2I(X; Y ) + H(Y | X):

H(pX , qX ) + H(pY , qY ) = H(pX|Y , qX|Y ) + 2Ip∥q (X; Y ) + H(pY |X , qY |X ).

Setting p to be the empirical distribution of the training data, the left-hand side describes the cross-entropy loss used to
train 2-token prediction models. The right-hand side gives the decomposition into a local cross-entropy term, a mutual
information term with weight two and a shifted next-token cross-entropy term. We interpret this as follows: by adding the
term H(pY , qY ) to the loss, 2-token prediction incentivizes models to precompute features which will become useful for
predicting Y in the next step and increases the weight of the relative mutual information term in the loss. What does relative
mutual information actually mean? By interpreting Kullback-Leibler divergence D(p ∥ q) as the average number of bits
needed in addition to send data from p with a code optimized for q instead of p, we see that minimizing

Ip∥q (X; Y ) = D(p ∥ qX ⊗ qY ) − D(p ∥ q)

means minimizing the average number of additional bits needed to send data from p with a code optimized for q that treats
X and Y as independent compared to one that does not. If this number is small, q managed to exploit the mutual information
of X and Y under p.

L.3. Lookahead reinforces choice points


Training with multi-head prediction increases the importance of choice points in the loss in comparison to inconsequential
decisions. To make this argument, we present a simplified model of language modelling. Consider a sequential decision task
and a model M that is trained in a teacher-forced way on optimal trajectories. We distinguish choice points –transitions that
lead to different outcomes – and inconsequential decisions which do not (Figure S17 (a) and (b)).

25
Better & Faster Large Language Models via Multi-token Prediction

(a)
(c)
(b)

Figure S17: Example of a sequential prediction task with derailing. The goal is to go from the arrow to the trophy.
Turning around is not allowed. Most transitions are unique, but there are two turns to be taken correctly, the consequential
decisions (a) and (c). Turn (b) is an inconsequential decision: the paths join right after it. Next to transitions (a) and (b),
we sketch how a 4-step prediction loss can place more emphasis on consequential transitions than inconsequential ones
during teacher-forced training. Next to transition (c), we sketch how a 4-step lookahead can prevent models from taking
irreversible suboptimal decisions during autoregressive decoding.

More formally, assume that the language model is deployed in a reinforcement learning setting like in reinforcement learning
from human feedback (Ouyang et al., 2022) (states are prompts followed by the partial sequence of tokens xt:1 generated so
far, actions are single tokens xt+1 to generate, rewards are external R(xt:1 )). The quantity
 
X
Vπ (xt:1 ) = Ext+i ∼π(xt+i−1:1 ),i≥1  R(xt+i:1 )
i≥0

is the value of the state xt:1 following the policy π, while


r
σπ (xt:1 ) = Var [Vπ (xt+1:1 )]
xt+1 ∼π(xt:1 )

quantifies the importance of the decision xt+1 on the value thereafter. Choice points can formally be viewed as steps t for
which σπ (xt:1 ) is large, while inconsequential points are steps where it is low. Note that for completion models, there is no
explicit reward, and our argument is merely meant to illustrate what we mean by choice points.
Derailing denotes a situation where autoregressive generation of trajectories from M at inference time results in bad
outcomes after M made a mistake on a choice point. Even if subsequently, M acts optimally given this choice, the final
outcome can be significantly worse than the outcome of the optimal trajectory.
Staying in the teacher-forced setting, we ask: What is the impact of training M with n-step prediction instead of next-
step prediction on this task? Say xt → xt+1 is a choice point in an optimal trajectory with the suboptimal choice
being xt → x̃t+1 (Figure S17 (a)). Assume that the trajectories preceding xt and succeeding xt+1 and x̃t+1 consist of
inconsequential transitions, the latter denoted by x̃t+j → x̃t+j+1 . We will compare the losses of a teacher-forced next-step
prediction model and a teacher-forced n-step prediction model on the partial trajectory (xt−n+1 , . . . xt ). For the next-step
prediction model, the predictions are (xt−n+2 , . . . , xt , x̃t+1 ) with a single wrong prediction. The predictions of an n-step
prediction model at time t − n + i, i = 1, . . . , n are (xt−n+i+1 , . . . , xt , x̃t+1 , . . . , x̃t+i ) with i wrong predictions. In other
words, an n-step prediction model receives 1 + . . . + n = n(n+1) 2 loss terms pertaining to such a choice point and its
consequences, while each inconsequential transition (Figure S17 (b)) is only reinforced n times as often as in a next-step
prediction model. In other words, choice points receive on average n+1 2 times more importance in the loss of n-step
prediction models than in next-step prediction models.

26
Better & Faster Large Language Models via Multi-token Prediction

As argued in Section 5.1, we believe that this model captures important features of training and inference with language
models: choice points are semantically important turning points in the generated texts, such as the final answer to a question
or a specific line of code, while inconsequential decisions can be a choice among synonyms or of variable names in code.
Apart from this training dynamics point of view, we hypothesize that n-step prediction also allows the formation of circuits
that specifically spot inconsistencies between predictions for earlier and later steps. For instance, if in an early layer of
the model, it can be predicted that a decision xt → x̃t+1 leads to suboptimal outcomes x̃t+n (Figure S17 (c)), subsequent
layers can reduce the probability of xt → x̃t+1 in the model’s next-step prediction. Such behaviors also happen in next-step
prediction models given enough capacity, but our experiments in Section 4.2 point to the fact that circuits of this kind are
formed more easily in multi-step architectures that enforce the required information x̃t+n to be available to the model when
predicting x̃t+1 . We believe that this situation appears frequently in natural language and code modelling, for instance where
an initial answer to a question contradicts the results of the chain of thought brought forward with the intention to justify it.
In more general terms, this situation arises whenever predicting first x̃n+i for some 1 < i ≤ n and then x̃n+1 based on x̃n+i
is easier than predicting x̃n+1 directly. We discuss this phenomenon of factorization orders in the next section and present a
specific instance of it that frequently appears in modelling natural language.

L.4. Factorization orders


Causal language modelling factorizes probabilities over text sequences xt · · · x1 classically as
t
Y
P (xt · · · x1 ) = P (xi | xi−1 · · · x1 ).
i=1

While moving forward in time is certainly the most natural choice of factorization order, there exist cases where it is
suboptimal. In inflectional languages, for instance, agreement between related sentence parts is a frequent pattern with one
word directing the grammatical forms of others. Consider the German sentence

Wie konnten auch Worte meiner durstenden Seele genügen?3


Friedrich Hölderlin, Fragment von Hyperion (1793)

where "genügen" requires a dative case object and then "Seele" requires the possessive pronoun "mein" to be in female
singular dative form "meiner" and the participle "durstend" to be in female singular dative form in weak declination
"durstenden" because it follows "meiner". In other words, the factorization order

Wie konnten auch Worte → genügen → Seele → meiner → durstenden?

is arguably an easier one for constructing the above sentence. Humans as well as language models therefore have to perform
this factorization (which deviates from the causal order in which predictions take place!) within their latent activations, and
a 4-token prediction loss makes this easier as it explicitly encourages models to have all information about the successive 4
tokens in its latent representations.

3
roughly: How could words be enough for my thirsty soul?

27
Better & Faster Large Language Models via Multi-token Prediction

M. Training hyperparameters

Table S13: Overview of all training hyperparameters used. We schedule all learning rates with a linear warmup and
cosine decay (Loshchilov and Hutter, 2017) to a fraction of the peak learning rate which is depicted in the last column
(“decay ratio”). All experiments use the Adam (Kingma and Ba, 2015) optimizer with β1 = 0.9, β2 = 0.95 and decoupled
L2 weight decay (Loshchilov and Hutter, 2019) coefficient 0.1. We clip gradients to a maximal Euclidean norm of 1.0 in all
experiments except CodeContests finetunings, where we use 0.1 instead. Summarization finetunings correspond to three
epochs on all datasets except BigPatent (1 epoch). Byte-level models use the architecture with replicated unembeddings
from Appendix B.

Model Batch size (220 ) Steps Tokens (B) Warmup steps Peak LR Context length Decay ratio
Model scaling (Section 3.1)
0.3B 8 10,850 91.0 1000 3 × 10−4 4096 0.03
0.6B 8 10,850 91.0 1000 3 × 10−4 4096 0.03
1.3B 8 10,850 91.0 1000 3 × 10−4 4096 0.03
3B 8 10,850 91.0 1000 3 × 10−4 4096 0.03
7B 8 25,000 209.7 2000 3 × 10−4 4096 0.03
13B 8 25,000 209.7 1000 3 × 10−4 4096 0.03
Code models (Section 3)
7B 200B 8 25,000 209.7 2000 3 × 10−4 4096 0.03
7B 500B 7 68,570 503.3 2000 3 × 10−4 4096 0.03
7B 1T 7 136,240 1000.0 2000 3 × 10−4 4096 0.03
Byte-level models (Section 3.3)
7B 314GB 12 25,000 314.6 2000 3 × 10−4 8192 0.03
Language models (Section 3.7)
7B 200B 8 25,000 209.7 2000 3 × 10−4 4096 0.10
7B 500B 8 60,000 503.3 2000 3 × 10−4 4096 0.10
Induction task (Section 4.1)
1M – 1B 0.25 100,000 26.2 2000 10−4 2048 0.03
1M – 1B (Appendix J) 0.5 50000 26.2 2000 10−4 2048 0.03
Arithmetic task (Section 4.2)
30M 0.25 100,000 26.2 2000 10−4 1024 0.03
100M 0.25 100,000 26.2 2000 10−4 2048 0.03
Summarization (Section 3.7)
BigPatent 0.125 76,680 10.1 100 3 × 10−5 4096 0.03
CNN/Dailymail 0.125 7,140 0.9 100 3 × 10−5 4096 0.03
Multi-News 0.125 3,330 0.4 100 3 × 10−5 4096 0.03
OrangeSum 0.125 360 0.0 100 3 × 10−5 4096 0.03
pn-summary 0.125 3,450 0.5 100 3 × 10−5 4096 0.03
SAMSum 0.125 60 0.0 100 3 × 10−5 4096 0.03
ThaiSum 0.125 23,640 3.1 100 3 × 10−5 4096 0.03
WikiSummary 0.125 2,550 0.3 100 3 × 10−5 4096 0.03
XSum 0.125 2,760 0.4 100 3 × 10−5 4096 0.03
CodeContests (Section 3.6)
7B 0.25 13,000 3.6 400 5 × 10−5 4096 0.004

28
Better & Faster Large Language Models via Multi-token Prediction

Table S14: Overview of model architectures used for scaling analyses.

Name Dimension Layers Heads


1M 128 5 4
3M 256 4 8
10M 384 6 8
30M 512 10 8
100M 768 14 12
300M 1024 25 16
1B 1536 36 24
0.3B 1024 18 16
0.6B 1280 27 20
1.3B 2048 24 16
3B 2560 36 20
6.7B (“7B”) 4096 32 32
13B 5120 40 40

29

You might also like