Time Series Forecasting Using Deep Learning - MATLAB & Simulink
Time Series Forecasting Using Deep Learning - MATLAB & Simulink
To forecast the values of future time steps of a sequence, you can train a
sequence-to-sequence regression LSTM network, where the responses are the training sequences with values shifted
by one time step. That is, at each time step of the input sequence, the LSTM network learns to predict the value of the
next time step.
To forecast the values of multiple time steps in the future, use the predictAndUpdateState function to predict time
steps one at a time and update the network state at each prediction.
This example uses the data set chickenpox_dataset. The example trains an LSTM network to forecast the number of
chickenpox cases given the number of cases in previous months.
data = chickenpox_dataset;
data = [data{:}];
figure
plot(data)
xlabel("Month")
ylabel("Cases")
title("Monthy Cases of Chickenpox")
Partition the training and test data. Train on the first 90% of the sequence and test on the last 10%.
To forecast the values of future time steps of a sequence, specify the responses to be the training sequences with
values shifted by one time step. That is, at each time step of the input sequence, the LSTM network learns to predict
the value of the next time step.
https://www.mathworks.com/help/nnet/examples/time-series-forecasting-using-deep-learning.html 1/6
7/25/2018 Time Series Forecasting Using Deep Learning - MATLAB & Simulink
numTimeStepsTrain = floor(0.9*numel(data));
XTrain = data(1:numTimeStepsTrain);
YTrain = data(2:numTimeStepsTrain+1);
XTest = data(numTimeStepsTrain+1:end-1);
YTest = data(numTimeStepsTrain+2:end);
Standardize Data
For a better fit and to prevent the training from diverging, standardize the training data to have zero mean and unit
variance. Standardize the test data using the same parameters as the training data.
mu = mean(XTrain);
sig = std(XTrain);
inputSize = 1;
numResponses = 1;
numHiddenUnits = 200;
layers = [ ...
sequenceInputLayer(inputSize)
lstmLayer(numHiddenUnits)
fullyConnectedLayer(numResponses)
regressionLayer];
Specify the training options. Set the solver to 'adam' and train for 250 epochs. To prevent the gradients from exploding,
set the gradient threshold to 1. Specify the initial learn rate 0.005, and drop the learn rate after 125 epochs by
multiplying by a factor of 0.2.
net = trainNetwork(XTrain,YTrain,layers,opts);
https://www.mathworks.com/help/nnet/examples/time-series-forecasting-using-deep-learning.html 2/6
7/25/2018 Time Series Forecasting Using Deep Learning - MATLAB & Simulink
To initialize the network state, first predict on the training data XTrain. Next, make the first prediction using the last time
step of the training response YTrain(end). Loop over the remaining predictions and input the previous prediction to
predictAndUpdateState.
net = predictAndUpdateState(net,XTrain);
[net,YPred] = predictAndUpdateState(net,YTrain(end));
numTimeStepsTest = numel(XTest);
for i = 2:numTimeStepsTest
[net,YPred(1,i)] = predictAndUpdateState(net,YPred(i-1));
end
The training progress plot reports the root-mean-square error (RMSE) calculated from the standardized data. Calculate
the RMSE from the unstandardized predictions.
rmse = sqrt(mean((YPred-YTest).^2))
rmse = single
211.1609
Plot the training time series with the forecasted values.
https://www.mathworks.com/help/nnet/examples/time-series-forecasting-using-deep-learning.html 3/6
7/25/2018 Time Series Forecasting Using Deep Learning - MATLAB & Simulink
figure
plot(data(1:numTimeStepsTrain))
hold on
idx = numTimeStepsTrain:(numTimeStepsTrain+numTimeStepsTest);
plot(idx,[data(numTimeStepsTrain) YPred],'.-')
hold off
xlabel("Month")
ylabel("Cases")
title("Forecast")
legend(["Observed" "Forecast"])
figure
subplot(2,1,1)
plot(YTest)
hold on
plot(YPred,'.-')
hold off
legend(["Observed" "Forecast"])
ylabel("Cases")
title("Forecast")
subplot(2,1,2)
stem(YPred - YTest)
xlabel("Month")
ylabel("Error")
title("RMSE = " + rmse)
https://www.mathworks.com/help/nnet/examples/time-series-forecasting-using-deep-learning.html 4/6
7/25/2018 Time Series Forecasting Using Deep Learning - MATLAB & Simulink
First, initialize the network state. To make predictions on a new sequence, reset the network state using resetState.
Resetting the network state prevents previous predictions from affecting the predictions on the new data. Reset the
network state, and then initialize the network state by predicting on the training data.
net = resetState(net);
net = predictAndUpdateState(net,XTrain);
Predict on each time step. For each prediction, predict the next time step using the observed value of the previous time
step.
YPred = [];
numTimeStepsTest = numel(XTest);
for i = 1:numTimeStepsTest
[net,YPred(1,i)] = predictAndUpdateState(net,XTest(i));
end
rmse = sqrt(mean((YPred-YTest).^2))
rmse = 116.8079
Compare the forecasted values with the test data.
figure
subplot(2,1,1)
plot(YTest)
hold on
https://www.mathworks.com/help/nnet/examples/time-series-forecasting-using-deep-learning.html 5/6
7/25/2018 Time Series Forecasting Using Deep Learning - MATLAB & Simulink
plot(YPred,'.-')
hold off
legend(["Observed" "Predicted"])
ylabel("Cases")
title("Forecast with Updates")
subplot(2,1,2)
stem(YPred - YTest)
xlabel("Month")
ylabel("Error")
title("RMSE = " + rmse)
Here, the predictions are more accurate when updating the network state with the observed values instead of the
predicted values.
See Also
lstmLayer | sequenceInputLayer | trainNetwork | trainingOptions
Related Topics
• Sequence Classification Using Deep Learning
• Sequence-to-Sequence Classification Using Deep Learning
• Sequence-to-Sequence Regression Using Deep Learning
• Long Short-Term Memory Networks
• Deep Learning in MATLAB
https://www.mathworks.com/help/nnet/examples/time-series-forecasting-using-deep-learning.html 6/6