Artificial Intelligence II (CS4442 & CS9542)
Overfitting, Cross-Validation, and Regularization
Boyu Wang
Department of Computer Science
University of Western Ontario
Motivation examples: polynomial regression
I As the degree of the polynomial increases, there is more degrees of
freedom, and the (training) error approaches to zero.
Figure credit: Christopher Bishop
1
Motivation examples: polynomial regression
I Minimizing the training/empirical loss does NOT indicate a good
test/generalization performance.
I Overfitting: Very low training error, very high test error.
Figure credit: Christopher Bishop
2
Overfitting – general phenomenon
I Too simple (e.g., small M) → underfitting
I Too complex (e.g., large M) → overfitting
Figure credit: Ian Goodfellow
3
Overfitting
I Training loss and test loss are different
I Larger the hypothesis class, easier to find a hypothesis that fits
the training data
- but may have large test error (overfitting)
I Prevent overfitting:
- Large data set
- Throw away useless hypothesis class (model selection)
- Control model complexity (regularization)
4
Larger data set
I Overfitting is mostly due to sparseness of data.
I Same model complexity: more data ⇒ less overfitting. With more data,
more complex (i.e. more flexible) models can be used.
Figure credit: Christopher Bishop
5
Model selection
I How to choose the optimal model complexity/hyper-parameter
(e.g., choose the best degree for polynomial regression)
I Cannot be done by training data alone
6
Model selection
I How to choose the optimal model complexity/hyper-parameter
(e.g., choose the best degree for polynomial regression)
I Cannot be done by training data alone
I We can use our prior knowledge or expertise (e.g., somehow we
know that the degree should not exceed 4)
I Create held-out data to approximate the test error (i.e., mimic
the test data)
I called validation data set
6
Model selection: cross-validation
For each order of polynomial M
1. Randomly split the training data into K groups, and following procedure
K times:
i. Leave out the k -th group from the training set as a validation set
ii. Use the other other K − 1 to find best parameter vector wk
iii. Measure the error of wk on the validation set; call this Jk
1
PK
2. Compute the average errors: J = K k =1 Jk
Choose the order of polynomial M with the lowest error J.
7
Model selection: cross-validation
Figure: K -fold cross-validation for the case of K = 4
Figure credit: Christopher Bishop
8
General learning procedure
Given a training set and a test set
1. Use cross-validation to choose the hyper-parameter/hypothesis
class.
2. Once the hyper-parameter is selected, use the entire training set
to find the best model parameters w.
3. Evaluate the performance of w on the test set.
These sets must be disjoint! – you should never touch the test data
before you evaluate your model.
9
Summary of cross-validation
I Can also used for selecting other hyper-parameters for
model/algorithm (e.g., number of hidden layers of neural
networks, learning rate of gradient descent, or even different
machine learning models)
I Very straightforward to implement algorithm
I Provides a great estimate of the true error of a model
I Leave-one-out cross-validation: number of groups = number of
training instances
I Computationally expensive; even worse when there are more
hyper-parameters
10
Regularization
I Intuition: complicated hypotheses lead to overfitting
I Idea: penalize the model complexity (e.g., large values of wj ):
L(w) = J(w) + λR(w)
where J(w): training loss, R(w): regularization
function/regularizer, and λ ≥ 0: regularization parameter to
control the tradeoff between data fitting and model complexity.
11
`2 -norm regularization for linear regression
Objective function:
m n 2 λ X n
1 XX
L(w) = w0 + wj · xi,j − yi + wj2
2 2
i=1 j=1 j=1
I No regularization on w0 !
Equivalently, we have
1 λ
L(w) = ||Xw − y||22 + w > Îw
2 2
where w = [w0 , w1 , . . . , wn ]>
0 0 ··· 0
0 1 ··· 0
Î = .
.. .. ..
.. . . .
0 0 ··· 1
12
`2 -norm regularization for linear regression
Objective function:
1 λ
L(w) = ||Xw − y||22 + w > Îw
2 2
1 > >
= w (X X + λÎ)w − w > X > y − y > Xw + y > y
2
Optimal solution (by solving ∇L(w) = 0):
w = (X > X + λÎ)−1 X > y
13
More on `2 -norm regularization
1 λ
arg min ||Xw − y||22 + w > Îw = (X > X + λÎ)−1 X > y
w 2 2
I `2 -norm regularization pushes the parameters towards to 0.
I λ = 0 ⇒ same as in the regular linear regression
I λ→∞⇒w →0
I 0 < λ < ∞ ⇒ magnitude of the weights will be smaller than in the
regular linear regression
Figure credit: Christopher Bishop 14
Another view of `2 -norm regularization
I From the optimization theory1 , we know that
min J(w) + λR(w)
w
is equivalent to
min J(w)
w
such that R(w) ≤ η
for some η ≥ 0.
I Hence, `2 -regularized linear regression can be re-formulated as (we
only consider wj , j > 0 here)
min ||Xw − y ||22
w
such that ||w||22 ≤ η
1
e.g., Boyd and Lieven. Convex Optimization. 2004.
15
Visualizing `2 -norm regularization (2 features)
Figure: w ∗ = (X > X + λI)−1 Xy
Figure credit: Christopher Bishop
16
`1 -norm regularization
I Instead of using `2 -norm, we use `1 -norm to control the model
complexity:
m n n
1 XX 2 X
min w0 + wj · xi,j − yi + λ |wj |
w 2
i=1 j=1 j=1
which is equivalent to
m n
1 XX 2
min w0 + wj · xi,j − yi
w 2
i=1 j=1
n
X
such that |wj | ≤ η
j=1
I Also called LASSO (least absolute shrinkage and selection operator).
I No analytical solution anymore!
17
Visualizing `1 -norm regularization (2 features)
I If λ is large enough , the circle is very likely to intersect the diamond at
one of the corners.
I This makes `1 -norm regularization much more likely to make some
weights exactly 0.
I In other words, we essentially perform feature selection!
18
Comparison of `2 and `1
Figure credit: Bishop; Hastie, Tibshirani & Friedman 19
Summary of regularization
I Both are commonly used approaches to avoid overfitting.
I Both push the weights towards 0.
I `2 produces small, but non-zero weights, while `1 is likely to make some
weights exactly 0.
I `1 optimization is computationally more expensive than `2 .
I Choose appropriate λ: cross-validation is often used.
20