0% found this document useful (0 votes)
54 views28 pages

What Learning Algorithm Is In-Context Learning Investigations With Linear Models

This paper investigates whether in-context learning in neural networks can be understood as implicitly implementing known learning algorithms. Through both theoretical analysis and empirical experiments with linear regression problems, the paper provides evidence that transformer-based in-context learners may rediscover and implement standard estimation algorithms like gradient descent and ridge regression. Specifically, it shows that transformers can accurately match the predictions of these algorithms and transition between algorithms depending on factors like model depth and noise in the training data. Preliminary analysis also indicates that important quantities computed by linear learning algorithms can be decoded from the activations of in-context learners.

Uploaded by

Qami
Copyright
© © All Rights Reserved
Available Formats
Download as PDF, TXT or read online on Scribd
Download as pdf or txt
0% found this document useful (0 votes)
54 views28 pages

What Learning Algorithm Is In-Context Learning Investigations With Linear Models

This paper investigates whether in-context learning in neural networks can be understood as implicitly implementing known learning algorithms. Through both theoretical analysis and empirical experiments with linear regression problems, the paper provides evidence that transformer-based in-context learners may rediscover and implement standard estimation algorithms like gradient descent and ridge regression. Specifically, it shows that transformers can accurately match the predictions of these algorithms and transition between algorithms depending on factors like model depth and noise in the training data. Preliminary analysis also indicates that important quantities computed by linear learning algorithms can be decoded from the activations of in-context learners.

Uploaded by

Qami
Copyright
© © All Rights Reserved
Available Formats
Download as PDF, TXT or read online on Scribd
Download as pdf or txt
Download as pdf or txt
You are on page 1/ 28

Google Research

W HAT LEARNING ALGORITHM IS IN - CONTEXT LEARN -


ING ? I NVESTIGATIONS WITH LINEAR MODELS

Ekin Akyürek1,2,a Dale Schuurmans1 Jacob Andreas∗2 Tengyu Ma∗1,3,b Denny Zhou∗1

1 2 3 ∗
Google Research MIT CSAIL Stanford University collaborative advising

A BSTRACT
arXiv:2211.15661v2 [cs.LG] 29 Nov 2022

Neural sequence models, especially transformers, exhibit a remarkable capacity


for in-context learning. They can construct new predictors from sequences of
labeled examples (x, f (x)) presented in the input without further parameter up-
dates. We investigate the hypothesis that transformer-based in-context learners
implement standard learning algorithms implicitly, by encoding smaller models
in their activations, and updating these implicit models as new examples appear
in the context. Using linear regression as a prototypical problem, we offer three
sources of evidence for this hypothesis. First, we prove by construction that trans-
formers can implement learning algorithms for linear models based on gradient
descent and closed-form ridge regression. Second, we show that trained in-context
learners closely match the predictors computed by gradient descent, ridge regres-
sion, and exact least-squares regression, transitioning between different predictors
as transformer depth and dataset noise vary, and converging to Bayesian estima-
tors for large widths and depths. Third, we present preliminary evidence that
in-context learners share algorithmic features with these predictors: learners’ late
layers non-linearly encode weight vectors and moment matrices. These results
suggest that in-context learning is understandable in algorithmic terms, and that
(at least in the linear case) learners may rediscover standard estimation algorithms.
Code and reference implementations are released at this https link.

1 I NTRODUCTION
One of the most surprising behaviors observed in large neural sequence models is in-context learn-
ing (ICL; Brown et al., 2020). When trained appropriately, models can map from sequences of
(x, f (x)) pairs to accurate predictions f (x0 ) on novel inputs x0 . This behavior occurs both in mod-
els trained on collections of few-shot learning problems (Chen et al., 2021; Min et al., 2021) and
surprisingly in large language models trained on open-domain text (Brown et al., 2020; Zhang et al.,
2022; Chowdhery et al., 2022). ICL requires the neural network to implicitly construct a map from
in-context examples to a predictor without any updates to the model’s parameters themselves. How
can a neural network with fixed parameters to learn a new function from a new dataset on the fly?
This paper investigates the hypothesis that some instances of ICL can be understood as implicit
implementation of known learning algorithms: in-context learners encode an implicit, context-
dependent model in their hidden activations, and train this model on in-context examples in the
course of computing these internal activations. As in recent investigations of empirical properties
of ICL (Garg et al., 2022; Xie et al., 2022), we study the behavior of transformer-based predictors
(Vaswani et al., 2017) on a restricted class of learning problems, here linear regression. Unlike
in past work, our goal is not to understand what functions ICL can learn, but how it learns these
functions: the specific inductive biases and algorithmic properties of transformer-based ICL.
In Section 3, we investigate theoretically what learning algorithms transformer decoders can imple-
ment. We prove by construction that they require only a modest number of layers and hidden units
a
Correspondences to akyurek@mit.edu. Ekin is a student at MIT, but primarily did this work while he was
intern at Google Research.
b
The work is done when Tengyu Ma works as a visiting researcher at Google Research.

1
What learning algorithm is in-context learning? Investigations with linear models

to train linear models: for d-dimensional regression problems, with O(d) hidden size and constant
depth, a transformer can implement a single step of gradient descent; and with O(d2 ) hidden size
and constant depth, a transformer can update a ridge regression solution to include a single new
observation. Intuitively, n steps of these algorithms can be implemented with n times more layers.
In Section 4, we investigate empirical properties of trained in-context learners. We begin by con-
structing linear regression problems in which learner behavior is under-determined by training data
(so different valid learning rules will give different predictions on held-out data). We show that
model predictions are closely matched by existing predictors (including those studied in Section 3),
and that they transition between different predictors as model depth and training set noise vary, be-
having like Bayesian predictors at large hidden sizes and depths. Finally, in Section 5, we present
preliminary experiments showing how model predictions are computed algorithmically. We show
that important intermediate quantities computed by learning algorithms for linear models, including
parameter vectors and moment matrices, can be decoded from in-context learners’ hidden activa-
tions.
A complete characterization of which learning algorithms are (or could be) implemented by deep
networks has the potential to improve both our theoretical understanding of their capabilities and
limitations, and our empirical understanding of how best to train them. This paper offers first steps
toward such a characterization: some in-context learning appears to involve familiar algorithms,
discovered and implemented by transformers from sequence modeling tasks alone.

2 P RELIMINARIES

Training a machine learning model involves many decisions, including the choice of model archi-
tecture, loss function and learning rule. Since the earliest days of the field, research has sought to
understand whether these modeling decisions can be automated using the tools of machine learning
itself. Such “meta-learning” approaches typically treat learning as a bi-level optimization problem
(Schmidhuber et al., 1996; Andrychowicz et al., 2016; Finn et al., 2017): they define “inner” and
“outer” models and learning procedures, then train an outer model to set parameters for an inner
procedure (e.g. initializer or step size) to maximize inner model performance across tasks.
Recently, a more flexible family of approaches has gained popularity. In in-context learning (ICL),
meta-learning is reduced to ordinary supervised learning: a large sequence model (typically imple-
mented as a transformer network) is trained to map from sequences [x1 , f (x1 ), x2 , f (x2 ), ..., xn ]
to predictions f (xn ) (Brown et al., 2020; Olsson et al., 2022; Laskin et al., 2022). ICL does not
specify an explicit inner learning procedure; instead, this procedure exists only implicitly through
the parameters of the sequence model. ICL has shown impressive results on synthetic tasks and
naturalistic language, vision problems (Garg et al., 2022; Min et al., 2021; Zhou et al., 2022).
While past work has characterized what kinds of functions ICL can learn (Garg et al., 2022; Laskin
et al., 2022) and the distributional properties of pretraining that can elicit in-context learning (Xie
et al., 2021; Chan et al., 2022), but how ICL learns these functions has remained unclear. What
learning algorithms (if any) are implementable by deep network models? Which algorithms are
actually discovered in the course of training? This paper takes first steps toward answering these
questions, focusing on a widely used model architecture (the transformer) and an extremely well-
understood class of learning problems (linear regression).

2.1 T HE T RANSFORMER A RCHITECTURE

Transformers (Vaswani et al., 2017) are neural network models that map a sequence of input vectors
x = [x1 , . . . , xn ] to a sequence of output vectors y = [y1 , . . . , yn ]. Each layer in a transformer maps
a matrix H (l) (interpreted as a sequence of vectors) to a sequence H (l+1) . To do so, a transformer
(l)
layer processes each column hi of H (l) in parallel. Here, we are interested in autoregressive (or
“decoder-only”) transformer models in which each layer first computes a self-attention:
(l)
ai = Attention(hi ; W F , W Q , W K , W V ) (1)
F
= W [b1 , . . . , bm ] (2)

2
What learning algorithm is in-context learning? Investigations with linear models

where each b is the response of an “attention head” defined by:


 
bj = softmax (WjQ hi )> (WjK H:i ) (WjV H:i ) . (3)

then applies a feed-forward transformation:


(l+1)
hi = FF(ai ; W1 , W2 ) (4)
(l) (l)
= W1 σ(W2 λ(ai + hi )) + ai + hi . (5)

Here σ denotes a nonlinearity, e.g. a Gaussian error linear unit (GeLU; Hendrycks & Gimpel, 2016):
x  x 
σ(x) = 1 + erf √ , (6)
2 2
and λ denotes layer normalization (Ba et al., 2016):

x − E[x]
λ(x) = p , (7)
Var[x]

where the expectation and variance are computed across the entries of x. To map from x to y, a
transformer applies a sequence of such layers, each with its own parameters. We use θ to denote a
model’s full set of parameters (the complete collection of W matrices across layers). The three main
factors governing the computational capacity of a transformer are its depth (the number of layers),
its hidden size (the dimension of the vectors h), and the number of heads (denoted m above).

2.2 T RAINING FOR I N -C ONTEXT L EARNING

We study transformer models directly trained on an ICL objective (Some past work has found that
ICL also “emerges” in models trained on general text datasets; Brown et al., 2020.). To train a
transformer T with parameters θ to perform ICL, we first define a class of functions F, a distribution
p(f ) supported on F, a distribution p(x) over the domain of functions in F, and a loss function L.
We then choose θ to optimize:
n
 
X
arg min E  L (f (xn ), Tθ ([x1 , f (x1 ) . . . , xn ])) (8)
θ x1 ,...,xn ∼p(x)
f ∼p(f ) i=1

We refer to the resulting Tθ as an in-context learner.

2.3 L INEAR R EGRESSION

Our experiments focus on linear regression problems. In these problems, F is the space of linear
functions f (x) = w> x where w, x ∈ Rd , and the loss function is the squared error L(y, y 0 ) =
(y − y 0 )2 . Linear regression is a model problem in machine learning and statistical estimation, with
diverse algorithmic solutions. It thus offers an ideal test-bed for understanding ICL. Given a dataset
with inputs X = [x1 , . . . , xn ] and y = [y1 , . . . , yn ], the (regularized) linear regression objective:
X
L(w> xi , yi ) + λkwk22 (9)
i
minimized by: w∗ = (X > X + λI)−1 X > y (10)

With λ = 0, this objective is known as ordinary least squares regression (OLS); with λ > 0, it
is known as ridge regression (Hoerl & Kennard, 1970). (As discussed further in Section 4, ridge
regression can also be assigned a Bayesian interpretation.) To present a linear regression problem to
a transformer, we encode both x and f (x) as d + 1-dimensional vectors: x̃i = [0, xi ], ỹi = [yi , 0d ],
where 0d denotes the d-dimensional zero vector.

3
What learning algorithm is in-context learning? Investigations with linear models

3 W HAT LEARNING ALGORITHMS CAN A TRANSFORMER IMPLEMENT ?


For a transformer-based model to solve Eq. (9) by implementing an explicit learning algorithm,
that learning algorithm must be implementable via Eq. (1) and Eq. (4) with some fixed choice of
transformer parameters θ. In this section, we prove constructively that such parameterizations exist,
giving concrete implementations of two standard learning algorithms. These proofs yield upper
bounds on how many layers and hidden units suffice to implement (though not necessarily learn)
each algorithm. Proofs are given in Appendices A and B.

3.1 P RELIMINARIES

It will be useful to first establish a few computational primitives with simple transformer implemen-
tations. Consider the following four functions from RH×T → RH×T :
mov(H; s, t, i, j, i0 , j 0 ): selects the entries of the sth column of H between rows i and j, and copies
them into the tth column (t ≥ s) of H between rows i0 and j 0 , yielding the matrix:
| H:i−1,t |
" #
H:,:t Hi0 :j 0 ,s H:,t+1: .
| Hj,t |

mul(H; a, b, c, (i, j), (i0 , j 0 ), (i00 , j 00 )): in each column h of H, interprets the entries between i and
j as an a × b matrix A1 , and the entries between i0 and j 0 as a b × c matrix A2 , multiplies these
matrices together, and stores the result between rows i00 and j 00 , yielding a matrix in which each
column has the form [h:i00 −1 , A1 A2 , hj 00 : ]> .
div(H; (i, j), i0 , (i00 , j 00 )): in each column h of H, divides the entries between i and j by the abso-
lute value of the entry at i0 , and stores the result between rows i00 and j 00 , yielding a matrix in which
every column has the form [h:i00 −1 , hi:j /|hi0 |, hj 00 : ]> .
aff(H; (i, j), (i0 , j 0 ), (i00 , j 00 ), W1 , W2 , b): in each column h of H, applies an affine transformation
to the entries between i and j and i0 and j 0 , then stores the result between rows i00 and j 00 , yielding
a matrix in which every column has the form [h:i00 −1 , W1 hi:j + W2 hi0 :j 0 + b, hj 00 : ]> .
Lemma 1. Each of mov, mul, div and aff can be implemented by a single transformer decoder
layer: in Eq. (1) and Eq. (4), there exist matrices W Q , W K , W V , W F , W1 and W2 such that, given
a matrix H as input, the layer’s output has the form of the corresponding function output above. 1

With these operations, we can implement building blocks of two important learning algorithms.

3.2 G RADIENT DESCENT

Rather than directly solving linear regression problems by evaluating Eq. (10), a standard approach
to learning exploits a generic loss minimization framework, and optimizes the ridge-regression ob-
jective in Eq. (9) via gradient descent on parameters w. This involves repeatedly computing updates:
∂  
w0 = w − α L(w> xi , yi ) + λkwk22 = w − 2α(xw> x − yx + λw) (11)
∂w

for different examples (xi , yi ), and finally predicting w0> xn on a new input xn . A step of this
gradient descent procedure can be implemented by a transformer:
Theorem 1. A transformer can compute Eq. (11) (i.e. the prediction resulting from single step of
gradient descent on an in-context example) with constant number of layers and O(d) hidden space,
where d is the problem dimension of the input x. Specifically, there exist transformer parameters θ
such that, given an input matrix of the form:
 
0 yi 0
H (0) = · · · ··· , (12)
xi 0 xn
the transformer’s output matrix H (L) contains an entry equal to w0> xn (Eq. (11)) at the column
index where xn is input.
1
We omit the trivial size preconditions, e.g. mul: (i − j = a ∗ b, i0 − j 0 = b ∗ c, i00 − j 00 = c ∗ d).

4
What learning algorithm is in-context learning? Investigations with linear models

3.3 C LOSED - FORM REGRESSION

Another way to solve the linear regression problem is to directly compute the closed-form solution
Eq. (10). This is somewhat challenging computationally, as it requires inverting the regularized
covariance matrix X > X + λI. However, one can exploit the Sherman–Morrison formula (Sherman
& Morrison, 1950) to reduce the inverse to a sequence of rank-one updates performed example-by-
example. For any invertible square A,
−1 A−1 uv > A−1
A + uv > = A−1 − . (13)
1 + v > A−1 u
Because the covariance matrix X > X in Eq. (10) can be expressed as a sum of rank-one terms
each involving a single training example xi , this can be used to construct an iterative algorithm for
computing the closed-form ridge-regression solution.
Theorem 2. A transformer can predict according to a single Sherman–Morrison update:
!
I >I
0 > −1 I λ xi xi λ

w = λI + xi xi xi yi = − I
xi yi (14)
λ 1 + x> i λ xi

with constant layers and O(d2 ) hidden space. More precisely, there exists a set of transformer
parameters θ such that, given an input matrix of the form in Eq. (12), the transformer’s output
matrix H (L) contains an entry equal to w0> xn (Eq. (14)) at the column index where xn is input.

Discussion. There are various existing universality results for transformers (Yun et al., 2019; Wei
et al., 2021), and for neural networks more generally (Hornik et al., 1989). These generally require
very high precision, very deep models, or the use of an external “tape”, none of which appear to be
important for in-context learning in the real world. Results in this section establish sharper upper
bounds on the necessary capacity required to implement learning algorithms specifically, bringing
theory closer to the range where it can explain existing empirical findings. We emphasize that
Theorem 1 and Theorem 2 each show the implementation of a single step of an iterative algorithm;
these results can be straightforwardly generalized to the multi-step case by “stacking” groups of
transformer layers. As described next, it is these iterative algorithms that capture the behavior of
real learners.

4 W HAT COMPUTATION DOES AN IN - CONTEXT LEARNER PERFORM ?


The previous section showed that the building blocks for two specific procedures—gradient descent
on the least-squares objective and closed-form computation of its minimizer—are implementable by
transformer networks. These constructions show that, in principle, fixed transformer parameteriza-
tions are expressive enough to simulate these learning algorithms. When trained on real datasets,
however, in-context learners might implement other learning algorithms. In this section, we in-
vestigate the empirical properties of trained in-context learners in terms of their behavior. In the
framework of Marr’s (2010) “levels of analysis”, we aim to explain ICL at the computational level
by identifying the kind of algorithms to regression problems that transformer-based ICL implements.

4.1 B EHAVIORAL M ETRICS

Determining which learning algorithms best characterize ICL predictions requires first quantifying
the degree to which two predictors agree. We use two metrics to do so:

Squared prediction difference. Given any learning algorithm A that maps from a set of input–
output pairs D = [x1 , y1 , . . . , xn , yn ] to a predictor f (x) = A(D)(x), we define the squared
prediction difference (SPD):
SPD(A1 , A2 ) = E (A1 (D)(x0 ) − A2 (D)(x0 ))2 , (15)
D=[x1 ,...]∼p(D)
x0 ∼p(x)

where D is sampled as in Eq. (8). SPD measures agreement at the output level, regardless of the
algorithm used to compute this output.

5
What learning algorithm is in-context learning? Investigations with linear models

0.8 0.35
(OLS, ICL)
(Ridge(0.1), ICL)
0.30 (GD(0.01), ICL)
(SGD(0.01), ICL)
2) 0.6 0.25 (GD(0.02), ICL)
(SGD(0.03), ICL)

2)
(OLS, Y)
1/d SPD( 1 ,

0.20 (KNN(3, weighted), ICL)

ILWD( 1 ,
(KNN(3, uniform), ICL)
0.4 (OLS, Y)
0.15 (Ridge(0.1), Y)
(ICL, Y)
0.10
0.2
0.05
0.0 0.00
1 2 4 6 8 10 12 14 3 5 7 9 11 13
#exemplars #exemplars

(a) Predictor–ICL fit w.r.t. prediction differences. (b) Predictor–ICL fit w.r.t implicit weights.
Figure 1: Fit between ICL and standard learning algorithms: We plot (dimension normalized)
SPD and ILWD values between textbook algorithms and ICL on noiseless linear regression with
d = 8. GD(α) denotes one step of batch gradient descent and SGD(α) denotes one pass of stochas-
tic gradient descent with learning rate α. Ridge(λ) denotes Ridge regression with regularization
parameter λ. Under both evaluations, in-context learners agree closely with ordinary least squares,
and are significantly less well approximated by other solutions to the linear regression problem.

Implicit linear weight difference. When ground-truth predictors all belong to a known, paramet-
ric function class (as with the linear functions here), we may also investigate the extent to which
different learners agree on the parameters themselves. Given an algorithm A, we sample a con-
text dataset D as above, and an additional collection of unlabeled test inputs DX 0 = {x0i }. We
0 0
 compute A’s
then
0
 prediction on each xi , yielding a predictor-specific dataset DA = {(xi , ŷi )} =
xi , A(D)(xi ) encapsulating the function learned by A. Next we compute the implied param-
eters: X
ŵA = arg min (ŷi − w> x0i )2 . (16)
w
i
We can then quantify agreement between two predictors A1 and A2 by computing the distance
between their implied weights in expectation over datasets:
ILWD(A1 , A2 ) = ED EDX 0 kŵA1 − ŵA2 k22 . (17)
When the predictors are not linear, ILWD measures the difference between closest linear predictors
(in Eq. (16) sense) to each algorithm. For algorithms that have linear hypothesis space (e.g. Ridge
regression), we will use the actual value of ŵA instead of the estimated value.

4.2 E XPERIMENTAL S ETUP

We train a Transformer decoder autoregresively on the objective in Eq. (8). For all experi-
ments, we perform a hyperparameter search over depth L ∈ {1, 2, 4, 8, 12, 16}, hidden size
W ∈ {16, 32, 64, 256, 512, 1024} and heads M ∈ {1, 2, 4, 8}. Other hyper-parameters are noted
in Appendix D. For our main experiments, we found that L = 16, H = 512, M = 4 minimized
loss on a validation set. We follow the training guidelines in Garg et al. (2022), and trained mod-
els for 500, 000 iterations, with each in-context dataset consisting of 40 (x, y) pairs. For the main
experiments we generate data according to p(w) = N (0, I) and p(x) = N (0, I).

4.3 R ESULTS

ICL matches ordinary least squares predictions on noiseless datasets. We begin by comparing
a (L = 16, H = 512, M = 4) transformer against a variety of reference predictors:

• k-nearest neighbors: In the uniform variant, models predict ŷi = 31 j yj , where j is the
P

P data point to xi where j < i. In the weighted variant, a weighted average


top-3 closest
ŷi ∝ 31 j |xi − xj |−2 yj is calculated, normalized by the total weights of the yj s.
• One-pass stochastic gradient descent: ŷi = wi> xi where wi is obtained by stochastic
gradient descent on the previous examples with batch-size equals to 1: wi = wi−1 −
2α(x> > >
i−1 wi−1 xi−1 − xi−1 yi−1 + λwi−1 ).

6
What learning algorithm is in-context learning? Investigations with linear models

(Lstsq, ICL) 1.25e-05 1.34e-04 3.96e-04 1.51e-03 4.13e-03

(Ridge(1/16), ICL) 1.10e-04 3.29e-05 1.12e-04 8.24e-04 2.92e-03

(A1 , A2 )
(Ridge(1/9), ICL) 3.49e-04 9.65e-05 3.86e-05 4.50e-04 2.15e-03

(Ridge(1/4), ICL) 1.69e-03 8.64e-04 4.39e-04 3.30e-05 6.81e-04

(Ridge(4/9), ICL) 4.83e-03 3.09e-03 2.21e-03 7.52e-04 6.10e-05

(0.0/1.0)2 = 0 (0.25/1.0)2 = 1/16 (0.5/1.5)2 = 1/9 (0.5/1.0)2 = 1/4 (0.5/0.75)2 = 4/9


2/ 2

Figure 2: ICL under uncertainty: With problem dimension d = 8, and for different values of prior
variance τ 2 and data noise σ 2 , we display (dimension-normalized) MSPD values for each predictor
pair, where MSPD is the average SPD value over underdetermined region of the linear problem.
1
Brightness is proportional with MSPD . ICL most closely follows the minimum-Bayes-risk Ridge
σ2
regression output for all τ 2 values.

• One-step batch gradient descent: ŷi = wi> xi where wi is obtained by one of step gradi-
ent descent on the batch of previous examples: wi = w0 − 2α(X > w> X − X > Y + λw0 ).
• Ridge regression: We compute ŷi = w0> xi where w0> = (X > X + λI)−1 X > Y . We
denote the case of λ = 0 as OLS.

The agreement between the transformer-based ICL and these predictors is shown in Fig. 1. As can be
seen, there are clear differences in fit to predictors: for almost any number of examples, normalized
SPD and ILWD are small between the transformer and OLS predictor (with squared error less than
0.01), while other predictors (especially nearest neighbors) agree considerably less well.
When the number of examples is less than the input dimension d = 8, the linear regression problem
is under-determined, in the sense that multiple linear models can exactly fit the in-context training
dataset. In these cases, OLS regression selects the minimum-norm weight vector, and (as shown in
Fig. 1), the in-context learner’s predictions are reliably consistent with this minimum-norm predictor.
Why, when presented with an ambiguous dataset, should ICL behave like this particular predictor?
One possibility is that, because the weights used to generate the training data are sampled from a
Gaussian centered at zero, ICL learns to output the minimum Bayes risk solution when predicting
under uncertainty. Building on these initial findings, our next set of experiments investigates whether
ICL is behaviorally equivalent to Bayesian inference more generally.

ICL matches the minimum Bayes risk predictor on noisy datasets. To more closely exam-
ine the behavior of ICL algorithms under uncertainty, we add noise to the training data: now we
present the in-context dataset as a sequence: [x1 , f (x1 ) + 1 , . . . , xn , f (xn ) + n ] where each
i ∼ N (0, σ 2 ). Recall that ground-truth weight vectors are themselves sampled from a Gaussian
distribution; together, this choice of prior and noise mean that the learner cannot be certain about
the target function with any number of examples.
Standard Bayesian statistics gives that the optimal predictor for minimizing the loss in Eq. (8) is:
ŷ = E[y|x, D]. (18)
This is because, conditioned on x and D, the scalar ŷ(x, D) := E[y|x, D] is the minimizer of the loss
E[(y − ŷ)2 |x, D], and thus the estimator ŷ is the minimzier of E[(y − ŷ)2 ] = Ex,D [E[(y − ŷ)2 |x, D]].
For linear regression with Gaussian priors and Gaussian noise, the Bayesian estimator in Eq. (18)
has a closed-form expression:
 σ 2 −1 >
ŵ = X > X + 2 I X Y ; ŷ = ŵ> x . (19)
τ

Note that this predictor has the same form as the ridge predictor from Section 2.3, with the reg-
2
ularization parameter set to στ 2 . In the presence of noisy labels, does ICL match this Bayesian
predictor? We explore this by varying both the dataset noise σ 2 and the prior variance τ 2 (sampling
w ∼ N (0, τ 2 )). For these experiments, the SPD values between the in-context learner and various
regularized linear models is shown in Fig. 2. As predicted, as variance increases, the value of the
ridge parameter that best explains ICL behavior also increases. For all values of σ 2 , τ 2 , the ridge

7
What learning algorithm is in-context learning? Investigations with linear models

0.4
(OLS, ICL)
(Ridge(0.1), ICL)
0.3 (Ridge(0.5), ICL)

2)
(KNN(3, weighted), ICL)
(KNN(3, uniform), ICL)
MSPD( 1 ,
0.2 (SGD(0.03), ICL)
(GD(0.02), ICL)
(SGD(0.01), ICL)
0.1 (GD(0.01), ICL)

0.0
20 21 22 23 24 25 27 29
L (depth) H (hidden size)
(a) Linear regression problem with d = 8
0.4
(OLS, ICL)
(Ridge(0.1), ICL)
0.3 (Ridge(0.5), ICL)
2)

(KNN(3, weighted), ICL)


(KNN(3, uniform), ICL)
MSPD( 1 ,

0.2 (SGD(0.03), ICL)


(GD(0.02), ICL)
(SGD(0.01), ICL)
0.1 (GD(0.01), ICL)

0.0
20 21 22 23 24 25 27 29
L (depth) H (hidden size)
(b) Linear regression problem with d = 16
Figure 3: Computational constraints on ICL: We show SPD averaged over underdetermined re-
gion of the linear regression problem. In-context learners behaviorally match ordinary least squares
predictors if there is enough number of layers and hidden sizes. When varying model depth (left
background), algorithmic “phases” emerge: models transition between being closer to gradient de-
scent, (red background), ridge regression (green background), and OLS regression (blue).

parameter that gives the best fit to the transformer behavior is also the one that minimizes Bayes risk.
These experiments clarify the finding above, showing that ICL in this setting behaviorally matches
minimum-Bayes-risk predictor.
We also note that when the noise level σ → 0+ , the Bayes predictor converges to the ordinary least
square predictors. Therefore, the results on noiseless datasets studied in the beginning paragraph of
this subsection can be viewed as corroborating the finding here in the setting with σ → 0+ .

ICL exhibits algorithmic phase transitions as model depth increases. The two experiments
above evaluated extremely high-capacity models in which (given findings in Section 3) computa-
tional constraints are not likely to play a role in the choice of algorithm implemented by ICL. But
what about smaller models—does the size of an in-context learner play a role in determining the
learning algorithm it implements? To answer this question, we run two final behavioral experi-
ments: one in which we vary the hidden size (while optimizing the depth and number of heads as in
Section 4.2), then vary the depth of the transformer (while optimizing the hidden size and number
of heads). These experiments are conducted without dataset noise.
Results are shown in Fig. 3. When we vary the depth, learners occupy three distinct regimes: very
shallow models (1L) are best approximated by a single step of gradient descent (though not well-
approximated in an absolute sense). Slightly deeper models (2L-4L) are best approximated by ridge
regression, while the deepest (+8L) models match OLS as observed in Fig. 3. Similar phase shifts
occurs when we vary hidden size in 16D problem. Interestingly, we can read hidden size require-
ments to be close to ridge-regression-like solutions as H ≥ 16 and H ≥ 32 for 8D and 16D
problems respectively, suggesting that ICL discovers more efficient ways to use available hidden
state than our theoretical constructions requiring O(d2 ). Together, these results show that ICL does
not necessarily involve minimum-risk prediction. However, even in models too computationally
constrained to perform Bayesian inference, alternative interpretable computations can emerge.

8
What learning algorithm is in-context learning? Investigations with linear models

X T Y probe error vs layer Layer-8's attention heatmap over time


0.06 Linear Probe
0.75

2 1 0
MLP Probe

Target Index
0.04
0.50

MSE
0.02 0.25

3
0.00
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 0 5 10 15 20 25 30 35 40 45 50 55 60 65 70 75
layer position

wOLS probe errors vs layer Layer-12's attention heatmap over time


0.04 Linear Probe
0.75

2 1 0
MLP Probe

Target Index
0.50
MSE

0.02
0.25

3
0.00
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 0 5 10 15 20 25 30 35 40 45 50 55 60 65 70 75
layer position
Figure 4: Probing results on d = 4 problem: Both moments X > Y (top) and least-square solu-
tion wOLS (middle) are recoverable from learner representations. Plots in the left column show the
accuracy of the probe for each target in different model layers. Dashed lines show the best probe
accuracies obtained on a control task featuring a fixed weight vector w = 1. Plots in the right col-
umn show the attention heatmap for the best layer’s probe, with the number of input examples on
the x-axis. The value of the target after n exemplars is decoded primarily from the representation of
yn , or, after n = d examplars, uniformly from yn≥4 .

5 D OES ICL ENCODE MEANINGFUL INTERMEDIATE QUANTITIES ?


Section 4 showed that transformers are a good fit to standard learning algorithms (including those
constructed in Section 3) at the computational level. But these experiments leave open the question
of how these computations are implemented at the algorithmic level. How do transformers arrive at
the solutions in Section 4, and what quantities do they compute along the way? Research on extract-
ing precise algorithmic descriptions of learned models is still in its infancy (Cammarata et al., 2020;
Mu & Andreas, 2020). However, we can gain insight into ICL by inspecting learners’ intermediate
states: asking what information is encoded in these states, and where.
To do so, we identify two intermediate quantities that we expect to be computed by gradient descent
and ridge-regression variants: the moment vector X > Y and the (min-norm) least-square estimated
weight vector wOLS , each calculated after feeding n exemplars. We take a trained in-context learner,
freeze its weights, then train an auxiliary probing model (Alain & Bengio, 2016) to attempt to
recover the target quantities from the learner’s hidden representations. Specifically, the probe model
takes hidden states at a layer H (l) as input, then outputs the prediction for target variable. We define
a probe with position-attention that computes (Appendix E):
α = softmax(sv ) (20)
> (l)
v̂ = FFv (α Wv H ) (21)
We train this probe to minimize the squared error between the predictions and targets v: L(v, v̂) =
|v − v̂|2 . The probe performs two functions simultaneously: its prediction error on held-out rep-
resentations determines the extent to which the target quantity is encoded, while its attention mask,
α identifies the location in which the target quantity is encoded. For the FF term, we can insert
the function approximator of our choosing; by changing this term we can determine the manner in
which the target quantity is encoded—e.g. if FF is a linear model and the probe achieves low error,
then we may infer that the target is encoded linearly.
For each target, we train a separate probe for the value of the target on each prefix of the dataset:
i.e. one probe to decode the value of w computed from a single training example, a second probe
to decode the value for two examples, etc. Results are shown in Fig. 4. For both targets, a 2-
layer MLP probe outperforms a linear probe, meaning that these targets are encoded nonlinearly
(unlike the constructions in Section 3). However, probing also reveals similarities. Both targets are
decoded accurately deep in the network (but inaccurately in the input layer, indicating that probe
success is non-trivial.) Probes attend to the correct timestamps when decoding them. As in both
constructions, X > Y appears to be computed first, becoming predictable by the probe relatively early
in the computation (layer 7); while w becomes predictable later (around layer 12). For comparison,

9
What learning algorithm is in-context learning? Investigations with linear models

we additionally report results on a control task in which the transformer predicts ys generated with
a fixed weight vector w = 1 (so no ICL is required). Probes applied to these models perform
significantly worse at recovering moment matrices (see Appendix E for details).

6 C ONCLUSION

We have presented a set of experiments characterizing the computations underlying in-context learn-
ing of linear functions in transformer sequence models. We showed that these models are capable
in theory of implementing multiple linear regression algorithms, that they empirically implement
this range of algorithms (transitioning between algorithms depending on model capacity and dataset
noise), and finally that they can be probed for intermediate quantities computed by these algorithms.
While our experiments have focused on the linear case, they can be extended to many learning
problems over richer function classes—e.g. to a network whose initial layers perform a non-linear
feature computation. Even more generally, the experimental methodology here could be applied
to larger-scale examples of ICL, especially language models, to determine whether their behaviors
are also described by interpretable learning algorithms. While much work remains to be done, our
results offer initial evidence that the apparently mysterious phenomenon of in-context learning can
be understood with the standard ML toolkit, and that the solutions to learning problems discovered
by machine learning researchers may be discovered by gradient descent as well.

R EFERENCES
Guillaume Alain and Yoshua Bengio. Understanding intermediate layers using linear classifier
probes. arXiv preprint arXiv:1610.01644, 2016.

Marcin Andrychowicz, Misha Denil, Sergio Gomez Colmenarejo, Matthew W. Hoffman, David
Pfau, Tom Schaul, and Nando de Freitas. Learning to learn by gradient descent by gradient
descent. In NIPS, 2016.

Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization. arXiv preprint
arXiv:1607.06450, 2016.

Tom B. Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared Kaplan, Prafulla Dhari-
wal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, Sandhini Agarwal,
Ariel Herbert-Voss, Gretchen Krueger, T. J. Henighan, Rewon Child, Aditya Ramesh, Daniel M.
Ziegler, Jeff Wu, Clemens Winter, Christopher Hesse, Mark Chen, Eric Sigler, Mateusz Litwin,
Scott Gray, Benjamin Chess, Jack Clark, Christopher Berner, Sam McCandlish, Alec Rad-
ford, Ilya Sutskever, and Dario Amodei. Language models are few-shot learners. ArXiv,
abs/2005.14165, 2020.

Nick Cammarata, Shan Carter, Gabriel Goh, Chris Olah, Michael Petrov, Ludwig Schubert, Chelsea
Voss, Ben Egan, and Swee Kiat Lim. Thread: circuits. Distill, 5(3):e24, 2020.

Stephanie CY Chan, Adam Santoro, Andrew K Lampinen, Jane X Wang, Aaditya Singh, Pierre H
Richemond, Jay McClelland, and Felix Hill. Data distributional properties drive emergent few-
shot learning in transformers. arXiv preprint arXiv:2205.05055, 2022.

Chi Chen, Maosong Sun, and Yang Liu. Mask-align: Self-supervised neural word alignment. arXiv
preprint arXiv:2012.07162, 2020.

Yanda Chen, Ruiqi Zhong, Sheng Zha, George Karypis, and He He. Meta-learning via language
model in-context tuning. arXiv preprint arXiv:2110.07814, 2021.

Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam
Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, et al. Palm:
Scaling language modeling with pathways. arXiv preprint arXiv:2204.02311, 2022.

Chelsea Finn, P. Abbeel, and Sergey Levine. Model-agnostic meta-learning for fast adaptation of
deep networks. ArXiv, abs/1703.03400, 2017.

10
What learning algorithm is in-context learning? Investigations with linear models

Shivam Garg, Dimitris Tsipras, Percy Liang, and Gregory Valiant. What can transformers learn
in-context? a case study of simple function classes. ArXiv, abs/2208.01066, 2022.

Dan Hendrycks and Kevin Gimpel. Gaussian error linear units (gelus). arXiv preprint
arXiv:1606.08415, 2016.

Arthur E Hoerl and Robert W Kennard. Ridge regression: Biased estimation for nonorthogonal
problems. Technometrics, 12(1):55–67, 1970.

Kurt Hornik, Maxwell Stinchcombe, and Halbert White. Multilayer feedforward networks are uni-
versal approximators. Neural networks, 2(5):359–366, 1989.

Michael Laskin, Luyu Wang, Junhyuk Oh, Emilio Parisotto, Stephen Spencer, Richie Steigerwald,
DJ Strouse, Steven Hansen, Angelos Filos, Ethan Brooks, et al. In-context reinforcement learning
with algorithm distillation. arXiv preprint arXiv:2210.14215, 2022.

David Marr. Vision: A computational investigation into the human representation and processing of
visual information. MIT press, 2010.

Sewon Min, Mike Lewis, Luke Zettlemoyer, and Hannaneh Hajishirzi. Metaicl: Learning to learn
in context. arXiv preprint arXiv:2110.15943, 2021.

Jesse Mu and Jacob Andreas. Compositional explanations of neurons. Advances in Neural Informa-
tion Processing Systems, 33:17153–17163, 2020.

Catherine Olsson, Nelson Elhage, Neel Nanda, Nicholas Joseph, Nova DasSarma, T. J. Henighan,
Benjamin Mann, Amanda Askell, Yushi Bai, Anna Chen, Tom Conerly, Dawn Drain, Deep Gan-
guli, Zac Hatfield-Dodds, Danny Hernandez, Scott Johnston, Andy Jones, John Kernion, Liane
Lovitt, Kamal Ndousse, Dario Amodei, Tom B. Brown, Jack Clark, Jared Kaplan, Sam McCan-
dlish, and Christopher Olah. In-context learning and induction heads. 2022.

Juergen Schmidhuber, Jieyu Zhao, and Marco A Wiering. Simple principles of metalearning. 1996.

Jack Sherman and Winifred J Morrison. Adjustment of an inverse matrix corresponding to a change
in one element of a given matrix. The Annals of Mathematical Statistics, 21(1):124–127, 1950.

Ashish Vaswani, Noam M. Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez,
Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NIPS, 2017.

Colin Wei, Yining Chen, and Tengyu Ma. Statistically meaningful approximation: a case study on
approximating turing machines with transformers. arXiv preprint arXiv:2107.13163, 2021.

Sang Michael Xie, Aditi Raghunathan, Percy Liang, and Tengyu Ma. An explanation of in-context
learning as implicit bayesian inference. arXiv preprint arXiv:2111.02080, 2021.

Sang Michael Xie, Aditi Raghunathan, Percy Liang, and Tengyu Ma. An explanation of in-context
learning as implicit bayesian inference. ArXiv, abs/2111.02080, 2022.

Chulhee Yun, Srinadh Bhojanapalli, Ankit Singh Rawat, Sashank J Reddi, and Sanjiv Kumar.
Are transformers universal approximators of sequence-to-sequence functions? arXiv preprint
arXiv:1912.10077, 2019.

Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen, Christo-
pher Dewan, Mona Diab, Xian Li, Xi Victoria Lin, Todor Mihaylov, Myle Ott, Sam Shleifer, Kurt
Shuster, Daniel Simig, Punit Singh Koura, Anjali Sridhar, Tianlu Wang, and Luke Zettlemoyer.
Opt: Open pre-trained transformer language models, 2022.

Kaiyang Zhou, Jingkang Yang, Chen Change Loy, and Ziwei Liu. Learning to prompt for vision-
language models. International Journal of Computer Vision, 130(9):2337–2348, 2022.

11
What learning algorithm is in-context learning? Investigations with linear models

A T HEOREM 1
The operations for 1-step SGD with single exemplar can be expressed as following chain (please see
proofs for the Transformer implementation of these operations (Lemma 1) in Appendix C):

• mov(; 1, 0, (1, 1 + d), (1, 1 + d)) (move x)


• aff(; (1, 1 + d), (), (1 + d, 2 + d), W1 = w) (w> x)
• aff(; (1 + d, 2 + d), (0, 1), (2 + d, 3 + d), W1 = I, W2 = −I) (w> x − y)
• mul(; d, 1, 1, (1, 1 + d), (2 + d, 3 + d), (3 + d, 3 + 2d)) (x(w> x − y))
• aff(; (), (), (3 + 2d, 3 + 3d), b = w, ) (write w)
>
• aff(; (3 + d, 3 + 2d), (3 + 2d, 3 + 3d), (3 + 3d, 3 + 4d), W1 = I, W2 = −λ) (x(w x − y) − λw)
• aff(; (3 + 2d, 3 + 3d), (3 + 3d, 3 + 4d), (3 + 2d, 3 + 3d), W1 = I, W2 = −2α, ) (w0 )
• mov(; 2, 1, (3 + 2d, 3 + 3d), (3 + 2d, 3 + 3d)) (move w0 )
>
• mul(; 1, d, 1, (3 + 2d, 3 + 3d), (1, 1 + d), (3 + 3d, 4 + 3d)) (w0 x2 )

This will map:


0 y1 0
 
0 y1 0
 
x1 0 x2   x1 x1 x2 
w > x1 w > x1 w > x2
   
 
w > x1 w > x1 − y w > x2
   
   
x1 w > x1 x1 (w> x1 − y) >
 


 
 7→  x2 w x1 

  
 w w w 

x1 w> x1 − λw x1 (w> x1 − y) − λw >
x2 w x1 − λw
   
   
w − 2α(x1 w> x1 − λw) w0 >
w − 2α(x2 w x1 − λw)
  
   
>
w − 2α(x1 w x1 − λw) w0 w0
   
> > 0> 0>
(w − 2α(x1 w x1 − λw)) x1 w x1 w x2

We can verify the chain of operator step-by-step. In each step, we show only the non-zero rows.

• mov(; 1, 0, (1, 1 + d), (1, 1 + d)) (move x)


   
0 y1 0 0 y1 0
7→
x1 0 x2 x1 x1 x2

• aff(; (1, 1 + d), (), (1 + d, 2 + d), W1 = w) (w> x)


 
  0 y1 0
0 y1 0
7→  x1 x1 x2 
x1 x1 x2
w > x1 w > x1 w > x2

• aff(; (1 + d, 2 + d), (0, 1), (2 + d, 3 + d), W1 = I, W2 = −I) (w> x − y)

0 y1 0
   
0 y1 0
x x1 x2 
 x1 x1 x2  7→  >1

w x1 w > x1 w > x2 
w > x1 w > x1 w > x2
w > x1 >
w x1 − y1 w > x2

• mul(; d, 1, 1, (1, 1 + d), (2 + d, 3 + d), (3 + d, 3 + 2d)) (x(w> x − y))

0 y1 0
 
0 y1 0
 
x1 x1 x2 
 x1 x1 x2  
7→  w> x1 w > x1 w > x2 

w> x w > x1 w > x2  
1  w> x
1 w > x1 − y w > x2 
w > x1 >
w x1 − y1 w > x2
x1 w> x1 x1 (w> x1 − y) x2 w> x1

12
What learning algorithm is in-context learning? Investigations with linear models

• aff(; (), (), (3 + 2d, 3 + 3d), b = w, ) (write w)

0 y1 0
 
0 y1 0
 
 x1 x1 x2 
 x1 x1 x2   >
 w x1 w > x1 w > x2 

 w> x w > x1 w > x2 
1  7→  > > >
 w x1 w x1 − y w x2 
 
 w> x w > x1 − y w > x2 
1 x w> x > >
x1 (w x1 − y) x2 w x1 
x1 w > x1 > >
x1 (w x1 − y) x2 w x1 1 1
w w w

• aff(; (3+d, 3+2d), (3+2d, 3+3d), (3+3d, 3+4d), W1 = I, W2 = −2λ) (x(w> x−y)−2λw)

0 y1 0
 
0 y1 0
 
x1 x1 x2
 x1 x1 x2 
 
w > x1 w > x1 w > x2
 
 >
 w x1 w > x1 w > x2 
  
 7→  w > x1 >
w x1 − y w > x2
 
 >
 w x1 w > x1 − y w > x2 

 x w> x x1 (w> x1 − y) x2 w > x1 
x w> x 1 1
1 1 x1 (w> x1 − y) x2 w> x1   
 w w w 
w w w
x1 w> x1 − λw x1 (w> x1 − y) − λw x2 w> x1 − λw

• aff(; (3 + 2d, 3 + 3d), (3 + 3d, 3 + 4d), (3 + 2d, 3 + 3d), W1 = I, W2 = −2α, ) (w0 )

0 y1 0 0 y1
  
 x1 x1 x2   x1 x1
w > x1 w > x1 w > x2 w > x1 w > x1 w
  
  
> >
w > x1 >
w x1 − y >
w x2  7→  w x1 w x1 − y w
  

 x w> x
 1 1 x1 (w> x1 − y) x2 w > x1 


 x1 w > x1 x1 (w> x1 − y) x2
 w w w  w − 2α(x w> x − λw) w0 w − 2α(x
1 1
x1 w> x1 − λw x1 (w> x1 − y) − λw >
x2 w x1 − λw >
x1 w x1 − λw >
x1 (w x1 − y) − λw x2 w >

• mov(; 2, 1, (3 + 2d, 3 + 3d), (3 + 2d, 3 + 3d)) (move w0 )

0 y1 0
 
 x1 x1 x2 
w > x1 w > x1 w > x2
 
 
> >
w x1 w x1 − y w > x2
 
7→
 
x1 w > x1 x1 (w> x1 − y) x2 w > x1
 
 

 w w w 

w − 2α(x1 w> x1 − λw) w0 >
w − 2α(x2 w x1 − λw)
x1 w> x1 − λw x1 (w> x1 − y) − λw x2 w> x1 − λw
0 y1 0
 
 x1 x1 x2 
w > x1 w > x1 w > x2
 
 
> > >

 w x1 w x1 − y w x2 


 x1 w > x1 x1 (w> x1 − y) >
x2 w x1 

w − 2α(x w> x − λw) w0 >
w − 2α(x2 w x1 − λw)
 1 1 
> >
 x1 w x1 − λw x1 (w x1 − y) − λw x2 w> x1 − λw 
w − 2α(x1 w> x1 − λw) w0 w 0

>
• mul(; 1, d, 1, (3 + 2d, 3 + 3d), (1, 1 + d), (3 + 3d, 4 + 3d)) (w0 x2 )

13
What learning algorithm is in-context learning? Investigations with linear models

0 y1 0
 
 x1 x1 x2 
w> x1 w > x1 w > x2
 
 
 w> x1 w > x1 − y w > x2 
7→
 

 x1 w> x1 x1 (w> x1 − y) x2 w > x1 

w − 2α(x w> x − λw) w 0 >
w − 2α(x2 w x1 − λw)
 1 1 
 x1 w> x1 − λw x1 (w> x1 − y) − λw x2 w> x1 − λw 
> 0 0
w − 2α(x1 w x1 − λw) w w
0 y1 0
 
 x1 x1 x2 

 w > x1 w > x1 w > x2 

> > >

 w x1 w x1 − y w x2 

> > >

 x1 w x1 x1 (w x1 − y) x2 w x1 

 w − 2α(x w> x − λw) w 0 >
w − 2α(x2 w x1 − λw)
 1 1 

 x1 w> x1 − λw x1 (w> x1 − y) − λw x2 w> x1 − λw 

 w − 2α(x1 w> x1 − λw) w 0
w 0 
> > 0> 0>
(w − 2α(x1 w x1 − λw)) x1 w x1 w x2
We obtain the updated prediction in the last hidden unit of the third time-step.

Generalizing to multiple steps of SGD. Since w0 is written in the hidden states, we may repeat
>
this iteration to obtain ŷ3 = w00 x3 where w00 is the one step update w0 − 2α(x2 w0> x2 − y2 x2 +
λw, requiring a total of O(n) layers for a single pass through the dataset where n is the number of
examplers.
As an empirical demonstration of this procedure, the accompanying code release contains a ref-
erence implementation of SGD defined in terms of the base primitive provided in an anymous
links https://icl1.s3.us-east-2.amazonaws.com/theory/{primitives,sgd,ridge}.py (to
preserve the anonymity we did not provide the library dependencies). This implementation predicts
ŷn = wn> xn , where wn is the weight vector resulting from n − 1 consecutive SGD updates on
previous examples. It can be verified there that the procedure requires O(n + d) hidden space. Note
that, it is not O(nd) because we can reuse spaces for the next iteration for the intermediate variables,
an example of this performed in (w0 ) step above highlighted with blue color.

B T HEOREM 2
We provide a similar construction to Theorem 1 (please see proofs for the Transformer implemen-
tation of these operations in Appendix C, specifically for div see Appendix C.6)

• mov(; 1, 0, (1, 1 + d), (1, 1 + d)) (move x1 )


• mul(; d, 1, 1, (1, 1 + d), (0, 1), (1 + d, 1 + 2d)) (x1 y)
• aff(; (), (), (1 + 2d, 1 + 2d + d2 ), b = λI ) (A−1 I
0 = λ)

• mul(; d, d, 1, (1 + 2d, 1 + 2d + d2 ), (1, 1 + d), (1 + 2d + d2 , 1 + 3d + d2 )) (A−1 I


0 u = λ x1 )

• mul(; 1, d, d, (1, 1 + d), (1 + 2d, 1 + 2d + d2 ), (1 + 3d + d2 , 1 + 4d + d2 )) (vA−1 >I


0 = x1 λ )
• mul(; d, 1, d, (1 + 2d + d2 , 1 + 3d + d2 ), (1 + 3d + d2 , 1 + 4d + d2 ), (1 + 4d + d2 , 1 + 4d + 2d2 ))
(A−1 −1 I
0 uvA0 = λ x1 x1 λ )
>I

• mul(; 1, d, 1, (1+3d+d2 , 1+4d+d2 ), (1, 1+d), (1+4d+2d2 , 2+4d+2d2 )) (v > A−1 >I
0 u = x1 λ x1 )

• aff(; (1+4d+2d2 , 2+4d+2d2 ), (), (1+4d+2d2 , 2+4d+2d2 ), W1 = 1, b = 1, ) (1+v > A−1


0 u=
1 + x> I
1 λ x 1 )
• div(; (1 + 4d + d2 , 1 + 4d + 2d2 ), 1 + 4d + 2d2 , (2 + 4d + 2d2 , 2 + 4d + 3d2 )) (right term)
2 2 2 2
• aff(; (1+2d, 1+2d+d ), (2+4d+2d , 2+4d+3d ), (1+2d, 1+2d+d ), W1 = I, W2 = −I)
(A−1
1 )
• mul(; d, d, 1, (1 + 2d, 1 + 2d + d2 ), (1, 1 + d), (2 + 4d + 3d2 , 2 + 5d + 3d2 )) (A−1
1 x1 )

14
What learning algorithm is in-context learning? Investigations with linear models

• mul(; d, 1, 1, (2 + 4d + 3d2 , 2 + 5d + 3d2 ), (0, 1), (2 + 4d + 3d2 , 2 + 5d + 3d2 )) (A−1


1 x1 y1 )

• mov(; 2, 1, (2 + 4d + 3d2 , 2 + 5d + 3d2 ), (2 + 4d + 3d2 , 2 + 5d + 3d2 )) (move w0 )


>
• mul(; d, 1, 1(2 + 4d + 3d2 , 2 + 5d + 3d2 ), (1, 1 + d), (2 + 5d + 3d2 , 3 + 5d + 3d2 )) (w0 x2 )

Note that, in contrast to Appendix A, we need O(d2 ) space to implement matrix multiplications.
Therefore over-all required hidden size is O(d2 )
As Theorem 1, generalizing it to multiple iterations will at least require O(n) layers, as we repeat
the process for the next examplar.

C L EMMA 1

All of the operators mentioned in this lemma share a common computational structure, and can in
fact be implemented as special cases of a “base primitive” we call RAW (for Read-Arithmetic-Write).
This operator may also be useful for future work aimed at implementing other algorithms.
The structure of our proof of Lemma 1 is as follows:

1. Motivation of the base primitive RAW.


2. Formal definition of RAW.
3. Definition of dot, aff, mov in terms of RAW.
4. Implementation of RAW in terms of transformer parameters.
5. Brief discussion of how to parallelize RAW, making it possible to implement mul.
6. Seperate proof for div by utilizing layer norm.

C.1 RAW O PERATOR : I NTUITION

At a high level, all of the primitives in Lemma 1 involve a similar sequence of operations:

1) Operators read some hidden units from the current or previous timestep: dot and aff
read from two subsets of indices in the current hidden state ht 2 , while mov reads from a previous
hidden state ht0 . This selection is straightforwardly implemented using the attention component of
a transformer layer.
We may notate this reading operation as follows:
 
1 X (l)
 hk [r] . (22)
Wa
k∈K(i)
| {z }
Read with Attention

Here r denotes a list of indice to read from, and K denotes a map from current timesteps to target
timesteps. For convenience, we use Numpy-like notation to denote indexing into a vector with
another vector:
Definition C.1 (Bracket). x[.] is Python index notation where the resulting vector, x0 = x[r]:

yj = xrj j = 1, ....|r|

The first step of our proof below shows that the attention output a(l) can compute the expression
above.
2
For notational convenience, we will use h to refer to sequence of hidden states (instead of H in Eq. (1).),
ht0 will be the hidden state at time step t0

15
What learning algorithm is in-context learning? Investigations with linear models

2) Operators perform element-wise arithmetic between the quantity read in step 1 and another
set of entries from the current timestep: This step takes different forms for aff and mul (mov
ignores values at the current timestep altogether).
 
W a
X (l) (l)
 hk [r] W hi [s] (multiplicative form) (23)
|K(i)|
k∈K(i)
| {z }
Read with Attention
 
W a
X (l) (l)
 hk [r] + W hi [s] (additive form) (24)
|K(i)|
k∈K(i)
| {z }
Read with Attention

The second step of the proof below computes these operations inside the MLP component of the
transformer layer.

3) Operators reduce, then write to the current hidden state Once the underlying element-wise
operation calculated, the operator needs to write these values to the some indices in current hidden
state, defined by a list of indices w. Writing might be preceded by a reduction state (e.g. for
computing dot products), which can be expressed generically as a linear operator Wo . The final
form of the computation is thus:
 Elementwise operation

z }|  {
 
(l+1)
 W a
X (l) (l)

hi [w] ← Wo  hk [r] ? W hi [s] (25)
 
 |K(i)| 
 k∈K(i) 
| {z }
Read with Attention

/ w are copied from hl−1 .


Here, ← means that the other indices i ∈

C.2 RAW O PERATOR D EFINITION

We denote this “master operator” as RAW:


Definition C.2. RAW(h; ? , s, r, w, Wo , Wa , W, K) is a function RH×T 7→ RH×T . It is parameter-
ized by an elementwise operator ? ∈ {+, }, three matrices W ∈ Rd×|s| , Wa ∈ Rd×|r| , Wo ∈
|w|×d ∗
R , three index sets s, r, and w, and a timestep map K : Z+ 7→ (Z+ ). Given an input matrix
h, it outputs a matrix with entries:
   
(l+1) Wa X (l) (l)
hi,w = Wo  hk [r] ? W hi [s] i = 1, ..., T ; (26)
|K(i)|
k∈K(i)
(l+1) (l)
hi,j ∈w
/ = hi,j ∈w
/ i = 1, ..., T ; (27)
We additionally require that j ∈ K(i) =⇒ j < i (since self-attention is causal.)

(For simplicity, we did not include a possible bias term in linear projections Wo , Wa , W , we can
always assume the accompanying bias parameters b0 , ba , b when needed)

C.3 R EDUCING L EMMA 1 OPERATORS TO RAW OPERATOR

Given this operator, we can define each primitive in Lemma 1 using a single RAW operator, except
the mul and div. Instead of the matrix multiplication operator mul, we will first show the dot
product dot (a special case of mul), then later in the proof, we will argue that we can parallelize
these dot products in Appendix C.5 to obtain mul. We will show how to implement div separately
in Appendix C.6.

16
What learning algorithm is in-context learning? Investigations with linear models

Lemma 2. We can define mov, aff operator, and the dot product case of mul in Lemma 1 by using
a single RAW operator
dot(h; (i, j), (i0 , j 0 ), (i00 , j 00 )) = mul(h; 1, |i − j|, 1, (i, j), (i0 , j 0 ), (i00 , i00 + 1))
= RAW(h; , W = I, Wa = I, Wo = 1> , s = (i, j), r = (i0 , j 0 ), w = (i00 , i00 + 1), K = {(t, {t})∀t })

aff(h; (i, j), (i0 , j 0 ), (i00 , j 00 ), W1 , W2 , b)


= RAW(h; +, W = W1 , Wa = W2 , Wo = I, b0 = b, s = (i, j), r = (i0 , j 0 ), w = (i00 , i00 + 1), K = {(t, {t})∀t })

mov(h; s, t, (i, j), (i0 , j 0 ))


= RAW(h; +, W = 0, Wa = I, Wo = I, s = (), r = (i0 , j 0 ), w = (i, j), K = {(t, {s})})

Proof. Follows immediately by substituting parameters into Eq. (26).

C.4 I MPLEMENTING RAW

It remains only to show:


Lemma 3. A single transformer layer can implement the RAW operator: there exist settings of trans-
former parameters such that, given an arbitrary hidden matrix h as input, the transformer computes
h0 (Eq. (26)) as output.

Our proof proceeds in stages. We begin by providing specifying initial embedding and positional
embedding layers, constructing inputs to the main transformer layer with necessary positional in-
formation and scratch space. Next, we prove three useful procedures for bypassing (or exploiting)
non-linearities in the feed-forward component of the transformer. Finally, we provide values for re-
maining parameters, showing that we can implement the Elementwise and Reduction steps described
above.

C.4.1 E MBEDDING LAYERS


Embedding Layer for Initialization: Rather than inserting the input matrix h directly into the
transformer layer, we assume (as is standard) the existence of a linear embedding layer. We can
set this layer to pad the input, providing extra scratch space that will be used by later steps of our
implementation.
We define the embedding matrix We as:
 
I (d+1)×(d+1) 0
We = (28)
0 0
Then, the embedded inputs will be
x̃i = We xi = [0, xi , 0H−d−1 ]> (29)
>
ỹi = We yi = [yi , 0H−1 ] (30)

Position Embeddings for Attention Manipulation: Implementing RAW ultimately requires con-
trolling which position attends to which position in each layer. For example, we may wish to have
layers in which each position attends to the previous position only, or in which even positions at-
tends to other even positions. We can utilize position embeddings, pi , to control attention weights.
In a standard transformer, the position embedding matrix is a constant matrix that is added to the
inputs of the transformer after embedding layer (before the first layer), so the actual input to to the
transformer is:
h0i = x̃i + pi (31)
We will use these position embeddings to encode the timestep map K. To do this, we will use 2p
units per layer (p will be defined momentarily). p units will be used to encode attention keys ki , and
the other p will be used to encode queries qi .
We define the position embedding matrix as follows:
(L) (L)
pi = [0d+1 , ki0 , qi0 , . . . , ki , qi , 0H−2pT −1 ]> (32)

17
What learning algorithm is in-context learning? Investigations with linear models

With K encoded in positional embeddings, the transformer matrices WQ and WK are easy to define:
they just need to retrieve the corresponding embedding values:

0 ... 0 ...
   
.. ..
. .
   
   
l p×p p×p l
WK = I 0 WQ =  0p×p I p×p
(33)
   
 

 ... 


 ... 

.. ..
. .

The constructions used in this paper rely on two specific timestep maps K, each of which can be
implemented compactly in terms of k and q:

Case 1: Attend to previous token. This can be constructed by setting:


ki = ei
qi = N ei−1
where N is a sufficiently large number. In this case, the output of the attention mechanism will be:
 
α = softmax (WjQ hi )> (WjK h:i )
 
= softmax q> i [k 1 , . . . , ki ]

= softmax [0, . . . , N, 0])
= [0, . . . , |{z}
1 , 0]
(i−1)

Case 2: Attend to a single token. For simpler patterns, such as attention to a specific token:

{t} i ≥ t
K(i) = (34)
{} i < t
only 1 hidden unit is required. We set:

−N i 6= t
ki =
N i=t
qi = N
from which it can be verified (using the same procedure as in Case 1) that the desired attention
pattern is produced.

Intricacy: How can K be empty? We can also cause K(i) to attend to an empty set by assuming
the softmax has extra (“imaginary”) timestep obtained by prepending a 0 to attention vector pot-hoc
(Chen et al., 2020).
Cumulatively, the parameter matrices defined in this subsection implement the Read with Attention
component of the RAW operator.

C.4.2 H ANDLING & U TILIZING N ONLINEARITIES


The mul operator requires elementwise multiplication of quantities stored in hidden states. While
transformers are often thought of as only straightforwardly implementing affine transformations on
hidden vectors, their nonlinearities in fact allow elementwise multiplication to a high degree of
approximation. We begin by observing the following property of the GeLU activation function in
the MLP layers of the Transformer network:
Lemma 4. The GeLU nonlinearity can be used to perform multiplication: specifically,
r
π
(GeLU(x + y) − GeLU(x) − GeLU(y)) = xy + O(x3 + y 3 ) (35)
2

18
What learning algorithm is in-context learning? Investigations with linear models

Proof. A standard implementation of the GeLU nonlinearity is defined as follows:


r !!
x 2 3

GeLU(x) = 1 + tanh x + 0.044715x . (36)
2 π
Thus
r
x 2 2
GeLU(x) = + x + O(x3 ) (37)
2 π
r
2
GeLU(x + y) − GeLU(x) − GeLU(y) = xy + O(x3 + y 3 ) (38)
π
r
π
=⇒ xy ≈ (GeLU (x + y) − GeLU (x) − GeLU (y)) (39)
2
For small x and y, the third-order term vanishes. By scaling inputs down by a constant before
the GeLU layer, and scaling them up afterwards, models may use the GeLU operator to perform
elementwise multiplication.

We can generalize this proof to other smooth functions as we discussed further in [TODO REF].
Previous work also shows, in practice, Transformers with ReLU activation utilize non-linearities to
get the multiplication in other settings.
When implementing the aff operator, we have the opposite problem: we would like the output of
addition to be transmitted without nonlinearities to the output of the transformer layer. Fortunately,
for large inputs, the GeLU nonlinearity is very close to linear; to bypass it it suffices to add to inputs
a large N :
Lemma 5. The GeLU nonlinearity can be bypassed: specifically,
GeLU(N + x) − N ≈ x N  1 (40)

Proof.
r !!
N 2
N + 0.044715N 3

GeLU(N + x) − N = 1 + tanh −N (41)
2 π
N
≈ (1 + 1) − N (42)
2
=x (43)

For all verions of the RAW operator, it is additionally necessary to bypass the LayerNorm operation.
The following formula will be helpful for this:
Lemma 6. Let N be a large number and λ the LayerNorm function. Then the following approxi-
mation holds:
r
2 X X
N λ([x, N, −N − x, 0]) ≈ [x, 2N, −2N − x, 0] N  1 (44)
L

Proof.
E[x] = 0 (45)
2
1 2N
Var[x] = (N 2 + N 2 + x2 ) ≈ (46)
L L
(47)
Then,
√ √
r r r r
2 X 2 L L X
N λ([x, N, −N − x, 0]) ≈ N[ 2
x, 2L, − 2L − x, 0]
L L 2N 2N 2
X
= [x, 2N, −2N − x, 0]

19
What learning algorithm is in-context learning? Investigations with linear models

By adding a large number N to two padding locations and sum the part of the hidden state that
we are interested to pass through LayerNorm, we make x to the output of LayerNorm pass through.
This addition can be done in the transformer’s
q feed-forward computation (with parameter W F ) prior
to layer norm. This multiplication of L2 N can be done in first layer of MLP back, then linear layer
can output/use x. For convenience, we will henceforth omit the LayerNorm operation when it is not
needed.
We may make each of these operations as precise as desired (or allowed by system precision). With
them defined, we are ready to specify the final components of the RAW operator.

C.4.3 PARAMETERIZING RAW


We want to show a layer of Transformer defined in above, hence parameterized by θ =
{Wf , W1 , W2 , (W Q , W K , W v )m }, can well-approximate the RAW operator defined in Eq. (25).
We will provide step by step constructions and define the parameters in θ. Begin by recalling the
transformer layer definition:
 
α = softmax (WjQ hi )> (WjK H:i ) (48)
bj = α(WjV H:i ) (49)
F
ai = W [b1 , . . . , bm ] (50)
(l+1)
h = FF(a; W1 , W2 ) (51)
(l) (l)
= W1 σ(W2 λ(a + h )) + a + h . (52)

Attention Output We will only use m = 2 attention heads for this construction. We show in
Eq. (32) that we can control attentions to uniformly attend with a pattern by setting key and query
matrices. Assume that the first head parameters W1Q , W1K have been set in the described way to
obtain the pattern function K.
Now we will set remaining attention parameters W1V , W2Q , W2K , W2V and show hat we can make
(l)
the ai + hi term in Eq. (4) to contain the corresponding term in Eq. (25), in some unused indices
t such that:
 
(l) (l) W a
X
(ai + hi )t =  hlk [r] (53)
|K(i)|
k∈K(i)
(l) (l) (l)
(ai + hi )t0 ∈t
/ = (hi )t0 ∈t
/ (54)

Then the term on the RAW operator can be obtained by the first head’s output. In order to achieve
that, we will set Wa as a part of actual attention value network such that W1V is sparse matrix 0
everywhere expect:
(W1V )t[m],r[n] = (Wa )m,n m ∈ 1, ..., |t|, n ∈ 1, ..., |r| (55)

Now our first heads stores the right term in Eq. (53) in the indicies t. However, when we add the
(l)
residual term hi , this will change. To remove the residual term, we will use another head to output
(l)
hi , by setting W2Q , W2K such that K(i) = i, and W2V (similar to Eq. (34)):
(W2V )t[m],r[n] = −1 m ∈ 1, ..., |t|, n ∈ 1, ..., |r| (56)
f H×2H
Then, W ∈ R is zero otherwise:
(W f )t[m],t[m] = 1 m ∈ 1, ..., |t| (57)
f
(W )t[m],t[m]+H = −1 m ∈ 1, ..., |t| (58)
(59)
Q K V f
We already defined (W , W , W )1,2 and W and obtained the first term in the Eq. (25) in (ai +
(l)
hi )t0 ∈t .

20
What learning algorithm is in-context learning? Investigations with linear models

Arithmetic term Now we want to calculate the term inside the parenthesis Eq. (25). We will
calculate it through the MLP layer and store in mi and substract the first term. Let’s denote the
(l)
input to the MLP as xi = (ai + hi ), the output of the first layer ui , the output of the non-linearity
as ai , and the final output as mi . The entries of mi will be:
  
W a
X (l) (l)
(mi )t0 ∈w = Wo  hk [r] ? W hi [s] − xi [w] (60)
|K(i)|
k∈K(i)

(mi )t0 ∈t = −xi [t] (61)


(mi )t0 ∈(t∪w)
/ =0 (62)
We will define the MLP layer to operate the attention term calculated above with a part of the current
hidden state by defining W1 and W2 . Let’s assume we bypass the LayerNorm by using Lemma 6.
Let’s show this seperately for + and operators.

RAW(+, .) If the operator, ? = +, first layer of the MLP will calculate the second term in
Eq. (25) and overwrite the space where the attention output term Eq. (53) is written, and add a large
positive bias term N to by pass GeLU as explained in Lemma 4. We will use an available space t̂
in the xi same size as t.
(ui )t0 ∈t̂ = W hl−1
i [s] + xi [t] + N (63)
(ui )t0 ∈t = −xi [t] + N (64)
(ui )t0 ∈w = −xi [w] + N (65)
(ui )t0 ∈(t∪
/ t̂∪w) = −N (66)
This can be done by setting W1 (weight term of the first layer of the MLP) to zero except the below
indices:
(W1 )t̂[m],s[n] = (W )m,n m ∈ 1, ..., |t̂|, n ∈ 1, . . . , |s| (67)
(W1 )t̂[m],t[n] = +1 m ∈ 1, . . . , |t̂|, n ∈ 1, . . . , |t| (68)
(W1 )t[m],t[m] = −1 m ∈ 1, . . . , |t| (69)
(W1 )w[m],w[m] = −1 m ∈ 1, . . . , |w| (70)
(71)
(72)
and the bias vector b1 to
(b1 )t0 ∈t = N (73)
(b1 )t0 ∈w = N (74)
(b1 )t0 ∈t̂ = N (75)
(b1 )t0 ∈t∪w∪
/ t̂ = −N (76)
(77)
Note the second term is added to make unused indices t ∪ w ∪ t̂ become zero after the gelu, which
outputs zero for large negative values. Since we added a large positive term, we make sure gelu
behaved like a linear layer. Thus we have,
(vi )t0 ∈t̂ = W hli [s] + xi [t] + N (78)
(vi )t0 ∈t = −xi [t] + N (79)
(vi )t0 ∈w = −xi [w] + N (80)
(vi )t0 ∈t∪w∪
/ t̂ = 0 (81)
Now, we need to set W2 , to simulate Wo ∈ R|w|×|t| ,
(W2 )w[m],t̂[n] = (Wo )m,n m ∈ 1, ..., |w|, n ∈ 1, ..., |t| (82)
(W2 )t[m],t[m] = +1 m ∈ 1, . . . , |t| (83)
(W2 )w[m],w[m] = +1 m ∈ 1, . . . , |w| (84)
(85)

21
What learning algorithm is in-context learning? Investigations with linear models

X
(b2 )w[m] = −N (Wo )m,j − N m ∈ 1, . . . , |w| (86)
j

(b2 )t[i] = −N (87)


(b2 )t0 ∈t
/ =0 (88)
(89)
Therefore, mi [w] = Wo xi [t] + W0 W hli [s] − xi [w] equals to what we promised in Eq. (60) for
+ case. If we sum this with the residual xi term back Eq. (53), so the output of this layer can be
written as:
  
(l+1) W a
X (l) (l)
(hi )t0 ∈w = Wo  hk [r] + W hi [s] (90)
|K(i)|
k∈K(i)
(l+1)
(hi )t0 ∈w
/ = (hli )t0 ∈w
/ (91)

RAW( , .) If the operator, ? = , we need to use three extra hidden units the same size as |t|,
let’s name the extra indices as ta , tb , tc , and output w space. The (ui ) will get below entries to be
able to use [], where N is a large number:
(ui )t0 ∈ta = (W hli [s] + xi [t])/N (92)
(ui )t0 ∈tb = xi [t]/N (93)
(ui )t0 ∈tc = W hli [s]/N (94)
(ui )t0 ∈t = −xi [t] + N (95)
(ui )t0 ∈w = −xi [w] + N (96)
(ui )t0 ∈(t∪t
/ a ∪tb ∪tc ∪w)
= −N (97)
(98)

All of this operations are linear, can be done W1 zero except the below entries:
(W1 )ta [m],s[n] = (W )m,n /N m ∈ 1, ..., |ta |, n ∈ 1, ..., |s| (99)
(W1 )ta [m],t[n] = 1/N m ∈ 1, ..., |ta |, n ∈ 1, ..., |t| (100)
(W1 )tb [m],t[m] = 1/N m ∈ 1, ..., |tb |, n ∈ 1, ..., |t| (101)
(W1 )tc [m],s[m] = 1/N m ∈ 1, ..., |tc |, n ∈ 1, ..., |s| (102)
(W1 )w[m],w[m] = −1 m ∈ 1, ..., |w| (103)
(W1 )t[i],t[m] = −1 m ∈ 1, ..., |t| (104)
(105)
and b1 to:
(b1 )t0 ∈(t∪ta ∪tb ∪tc ) = 0 (106)
(b1 )t0 ∈(t∪w) = N (107)
(b1 )t0 ∈(t∪t
/ a ∪tb ∪tc ∪w
) = −N (108)
(109)
The resulting v with the approximations become:
(vi )t0 ∈ta = gelu((W hli [s] + xi [t])/N (110)
(vi )t0 ∈tb = gelu(xi [t]/N ) (111)
(vi )t0 ∈tc = gelu(W hli [s]/N ) (112)
(vi )t0 ∈t = xi [t] + N (113)
(vi )t0 ∈w = xi [w] + N (114)
(vi )t0 ∈(t∪t
/ a ∪tb ∪tc ∪w)
=0 (115)
(116)

22
What learning algorithm is in-context learning? Investigations with linear models

Now, we can use the GeLU trick in Lemma 4, by setting W2


r
π
(W2 )w[m],ta [n] = (Wo )m,n N 2 m ∈ 1, . . . , |w|, n ∈ 1, . . . , |ta | (117)
2
r
π
(W2 )w[m],tb [n] = −(Wo )m,n N 2 m ∈ 1, . . . , |w|, n ∈ 1, . . . , |tb | (118)
2
r
π
(W2 )w[m],tc [n] = −(Wo )m,n N 2 m ∈ 1, . . . , |w|, n ∈ 1, . . . , |tc | (119)
2
(W2 )w[m],w[m] = 1 m ∈ 1, . . . , |w| (120)
(W2 )t[m],t[m] = 1 m ∈ 1, . . . , |t| (121)
(122)
We then set b2 :
(b2 )t0 ∈(t∪w) = N (123)
(b2 )t0 ∈(t∪w)
/ =0 (124)
(125)

With this, mi [w] = Wo xi [t] ∗ W0 W hl−1


i [s] − xi [w], and
  
(l+1) W a
X (l) (l)
(hi )t0 ∈w = Wo  hk [r] W hi [s] (126)
|K(i)|
k∈K(i)
(l+1)
(hi )t0 ∈w
/ = (hli )t0 ∈w
/ (127)

We have used 4|t| space for internal computation of this operation, and finally used |w| space to
write the final result. We show RAW operator is implementable by setting the parameters of a
Transformer.

C.5 PARALLELIZING THE RAW OPERATOR

Lemma 7. With the conditions thatP K is constant, the operators are independent (i.e (ri ∪ si ∪
wi ) ∩ wj6=i = ∅), and there is k (4|tk | + |wk |) available space in the hidden state, then a
Transformer layer can apply k such RAW operation in parallel by setting different regions of
W1 , W2 , Wf and (W V )k matrices.

Proof. From the construction above, it is straightforward to modify the definition of the RAW operator
to perform k operations as all the indices of matrices that we use in Appendix C.4.3 do not overlap
with the given conditions in the lemma.

This makes it possible to construct a Transformer layer not only to implement vector-vector dot
products, but general matrix-matrix products, as required by mul. With this, we show that we can
implement mul by using single layer of a Transformer.

C.6 L AYER N ORM FOR D IVISION

Let say we have the input [c, y, 0]> calculated before the attention output in Eq. (53), and we want
to divide y to c. This trick is very similar to the on in Lemma 6. We can use the following formula:
Lemma 8. using LayerNorm for division. Let N, M to be large numbers, λ LayerNorm function,
the following approximation holds:
r
2 y X y y y
M N λ([N c, , −N c − , 0]) ≈ [M N, , −M N − , 0] (128)
L M M c c

23
What learning algorithm is in-context learning? Investigations with linear models

Proof.
E[x] = 0 (129)
1

1 X 2 X y 2  2N 2 c2

2 2
Var[x] = N c + y + Nc + ≈ (130)
L M M L
(131)
Then,
r r r r r r
2 y X y 2 L L y L LX y
M N λ([N c, , −N c − , 0]) = MN[ , ,− − , 0]
L M M L 2 2 MNc 2 2 MNc
y y
= [M N, , −M N − , 0]
c
|{z} c
wanted result

To get the input to the format used in this Lemma, we can easily use Wf to convert the head outputs.
Then, after the layer norm, we can use W1 to pull the yc back and write it to the attention output. By
this way, we can approximate scalar division in one layer.

Lemma 1 By Lemmas 2, 3, 3, 7 and 8; we constructed the operators in Lemma 1 using single


layer of a Transformer, thus proved Lemma 1

D D ETAILS OF T RANSFORMER A RHITECTURE AND T RAINING


We perform these experiments using the Jax framework on P100 GPUs. The major hyperparameters
used in these experiments are presented in Table 1. The code repository used for reproducing these
experiments will be open sourced at the time of publication. Most of the hyperparameters adapted
from previous work Garg et al. (2022) to be compatible, and we adapted the Transformer architecture
details. We use Adam optimizer with cosine learning rate scheduler with warmup where number of
warmup steps set to be 1/5 of total iterations. We use larned absolute position embeddings.

Parameter Search Range


Number of heads 1, 2, 4, 8 s
Number of layers 1, 2, 12, 16
Hidden size 16, 32, 64, 256, 512, 1024
Batch size 64
Maximum number of epochs 500.000
Initial Learning rate (lri ) 1e-4, 2.5e-4
Weight decay 0, 1e-5
Bias initialization uniform scaling, normal(1.0)
Weight initialization uniform scaling, normal(1.0)
Position embedding initialization uniform scaling, normal(1.0)

Table 1: Hyperparameters used in the ICL. The best parameter for each hyperparameter is high-
lighted.

In the phase shift plots in Fig. 3, we keep the value in the x-axis constant and used the best setting
over the parameters: {number of layers, hidden size, number of heads and learning rate}.

E D ETAILS OF P ROBE
We will use the terms probe model and task model to distinguish probe from ICL. Our probe is
defined as:
α = softmax(sv ) (132)
>
v̂ = FFv (α Wv h) (133)

24
What learning algorithm is in-context learning? Investigations with linear models

X T Y probe error vs layer Layer-8's attention heatmap over time


0.06 Linear Probe 0.6

2 1 0
MLP Probe

Target Index
MSE 0.04 0.4
0.02 0.2

3
0.00
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 0 5 10 15 20 25 30 35 40 45 50 55 60 65 70 75
layer position

wOLS probe errors vs layer Layer-12's attention heatmap over time


0.04 Linear Probe

2 1 0
MLP Probe
0.4

Target Index
MSE

0.02
0.2

3
0.00
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 0 5 10 15 20 25 30 35 40 45 50 55 60 65 70 75
layer position
Figure 5: Detailed error values of the control probe displayed in Fig. 4.

1.0
R2

0.8

0.6
R2

0.4

0.2

0.0
1 2 4 6 8 10 12
#exemplars

Figure 6: R2 of linear weight estimation on d = 8 problem

The position scores sv ∈ RT are learned parameters where T is the max input sequence length
(T = 80 in our experiments). The softmax of position scores attention weights α for each position
and for each target variable. This enables us to learn input-independent, optimal target locations for
each target (displayed on the right side of Fig. 4). We then average hidden states by using these
0
attention weights. A linear projection, Wv ∈ RT ×H , is applied before averaging. FF is either a
linear layer or a 2-layer MLP (hidden size=512) with a GeLU activation function. For each layer,
we train a different probe with different parameters using stochastic gradient descent. H 0 equals to
the 512. The probe is trained using an Adam optimizer with a learning rate of 0.001 (chosen from
among {0.01, 0.001, 0.0001} on validation data).

Control Experiments In Fig. 4, dashed lines show probing results with a task model trained on
a control task, in which w is always the all-ones 1. This problem structurally resembles our main
experiment setup, but does not require in-context learning. During probing, we feed this model data
generated by w sampled form normal distribution as in the original task model. We observe that the
control probe has a significantly higher error rate, showing that the probing accuracy obtained with
actual task model is non-trivial. We present detailed error values of the control probe in Fig. 5.

F L INEARITY OF ICL

In Fig. 1b, we compare implicit linear weight of the ICL against the linear algorithms using ILWD
measure. Note that this measure do not assume predictors to be linear: when the predictors are not
linear, ILWD measures the difference between closest linear predictors (in Eq. (16) sense) to each
algorithm.

25
What learning algorithm is in-context learning? Investigations with linear models

To gain more insight to ICL’s algorithm, we can measure how linear ICL in different regimes of the
linear problem (underdetermined, determined) by using R2 (coefficient of determination) measure.
So, instead of asking what’s the best linear fit in Eq. (16), we can ask how good is the linear fit,
which is the R2 of the estimator. Interestingly, even though our model matches min-norm least
square solution in both metrics in Section 4.3, we show that ICL is becoming gradually linear in
the under-determined regime Fig. 6. This is an important result, enables us to say the in-context
learner’s hypothesis class is not purely linear.

G M ULTIPLICATIVE INTERACTIONS WITH OTHER N ON - LINEARITIES


We can show that for a real-valued and smooth non-linearity f (x), we can apply the same trick in in
the paper body. In particular, we can write Taylor expansion as:
X∞
f (x) = ai xi = a0 + a1 x + a2 x2 + . . . (134)
i=0
which converges for some sufficiently small neighborhood: X ∈ [−, ]. First, assume that the
second order term a2 dominates higher-order terms in this domain such that:
a2 x2  ai>2 xi wherex ∈ X
.
It’s is easy to verify that the following is true:

1
(f (x + y) − f (x) − f (y) + a0 ) = xy + O(x3 + y 3 ) (135)
2a2
So, given the expansion for GeLU in Eq. (37), we can use this generic formula to obtain the multi-
plication approximation:
r
π
(GeLU (x + y) − GeLU (x) − GeLU (y)) ≈ xy (136)
2

We plot this approximation against x2 for [0.1, −0.1] range in Fig. 7a.
In the case of a2 is zero, we cannot get any second order term, and in the case of a2 is negligible
O(x3 + y 3 ) will dominate the Eq. (135), so we cannot obtain a good approximation of xy. In this
case, we can resort to numerical derivatives and utilize the a3 term:
f 0 (x) = a1 + 2a2 x + 3a3 x3 + . . . (137)
2 i
If a3 is not negligible, a3 x  ai>3 x in the same domain, we can use numerical derivatives to get
a multiplication term:
 
1 f (x + y + δ) − f (x + y) f (x + δ) − f (x) f (y + δ) − f (y)
− − + a1 = xy+O(x3 +y 3 )
6a3 δ δ δ
(138)
For example, tanh has no second order term in its Taylor expansion:
x3 2x5
tanh x = x − + + ... (139)
3 15
Using above formula we can obtain the following expression:

1 tanh(x + y + δ) − tanh(x + y) tanh(x + δ) − tanh(x)
− −
2 δ δ
 (140)
tanh(y + δ) − tanh(y)
− + 1 ≈ xy
δ
Similar to our construction in Eq. (110), we can construct a Transformer layer that calculates these
quantities (noting that δ is a small, input-independent scalar).
We plot this approximation against x2 for [0.1, −0.1] range in Fig. 7b. Note that, if we use this
approximation in our constructions we will need more hidden space as there are 6 different tanh
term as opposed to 3 GeLU term in Eq. (110).

26
What learning algorithm is in-context learning? Investigations with linear models

(a) Approximating x2 using GeLU , Eq. (136). (b) Approximating x2 using tanh, Eq. (140), where
δ = 1e−3 .

(c) A piece-wise linear approximation to x2 by using


ReLU, Eq. (141).

Figure 7: Approximations of multiplication via various non-linearities.

Non-smooth non-linearities ReLU is another commonly used non-linearity that is not differen-
tiable. With ReLU, we can only hope to get piece-wise linear approximations. For example, we can
try to approximate x2 with the following function:
0.0375 ∗ ReLU(x) + 0.0375 ∗ ReLU(−x) + ReLU(0.05 ∗ (x − 0.05))+
ReLU(−0.05 ∗ (x + 0.05)) + ReLU(0.025 ∗ (x − 0.025)) + ReLU(−0.025 ∗ (x + 0.025)) ≈ x2
(141)
We plot this approximation against x2 for [0.1, −0.1] range in Fig. 7c.

H E MPIRICAL S CALING A NALYSIS WITH D IMENSIONALITY


In Figs. 3a and 3b, we showed that ICL needs different hidden sizes to enter the “Ridge regression
phase” (orange background) or “OLS phase” (green background) depending on the dimensionality
d of inputs x. However, we cannot reliably read the actual relations between size requirements and
the dimension of the problem from only two dimensions. To better understand size requirements,
we ask the following empirical question for each dimension: how many layer/hidden size/heads are
needed to better fit the least-squares solution than the Ridge(λ = ) regression solution (the green
phase in Figs. 3a and 3b)?
To answer this important question, we experimented with d = {1, 2, 4, 8, 12, 16, 20} and run an
experiment sweep for each dimension over:

• number of layers (L): {1, 2, 4, 8, 12, 16},


• hidden size (H): {16, 32, 64, 256, 512, 1024},
• number of heads (M): {1, 2, 4, 8},
• learning rate: {1e-4, 2.5e-4}.

27
What learning algorithm is in-context learning? Investigations with linear models

8
30
hidden size (H) 6

Layers (L)
25
4
20
2

1 2 4 8 16 20 1 2 4 8 16 20
Problem dimension (d) Problem dimension (d)
(a) hidden size requirements. (b) layer requirements.

1.04
1.02
Heads (M)

1.00
0.98
0.96

1 2 4 8 16 20
Problem dimension (d)
(c) number of head requirements.

Figure 8: Empirical requirements on model parameters to satisfy SPD(Ridge(λ = 0.1), ICL) >
SPD(OLS(λ = 0.1), ICL) when other parameters optimized.

For each feature that affects computational capacity of transformer (L, H, M ), we optimize
other features and find the minimum value for the feature that satisfies SPD(OLS, ICL) <
SPD(Ridge(λ = ), ICL). We plot our experiment with  = 0.1 in Appendix H. We find that
single head is enough for all problem dimensions, while other parameters exhibit a step-function-
like dependence on input size.
Please note that other hyperparameters discussed in Appendix D (e.g weight initialization) were not
optimized for each dimension independently.

28

You might also like