forked from rasbt/python-machine-learning-book-3rd-edition
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathqlearning.py
80 lines (65 loc) · 2.58 KB
/
qlearning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# coding: utf-8
# Python Machine Learning 3rd Edition by
# Sebastian Raschka (https://sebastianraschka.com) & Vahid Mirjalili](http://vahidmirjalili.com)
# Packt Publishing Ltd. 2019
#
# Code Repository: https://github.com/rasbt/python-machine-learning-book-3rd-edition
#
# Code License: MIT License (https://github.com/rasbt/python-machine-learning-book-3rd-edition/blob/master/LICENSE.txt)
#################################################################################
# Chapter 18 - Reinforcement Learning for Decision Making in Complex Environments
#################################################################################
# Script: qlearning.py
from gridworld_env import GridWorldEnv
from agent import Agent
from collections import namedtuple
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(1)
Transition = namedtuple(
'Transition', ('state', 'action', 'reward', 'next_state', 'done'))
def run_qlearning(agent, env, num_episodes=50):
history = []
for episode in range(num_episodes):
state = env.reset()
env.render(mode='human')
final_reward, n_moves = 0.0, 0
while True:
action = agent.choose_action(state)
next_s, reward, done, _ = env.step(action)
agent._learn(Transition(state, action, reward,
next_s, done))
env.render(mode='human', done=done)
state = next_s
n_moves += 1
if done:
break
final_reward = reward
history.append((n_moves, final_reward))
print('Episode %d: Reward %.1f #Moves %d'
% (episode, final_reward, n_moves))
return history
def plot_learning_history(history):
fig = plt.figure(1, figsize=(14, 10))
ax = fig.add_subplot(2, 1, 1)
episodes = np.arange(len(history))
moves = np.array([h[0] for h in history])
plt.plot(episodes, moves, lw=4,
marker="o", markersize=10)
ax.tick_params(axis='both', which='major', labelsize=15)
plt.xlabel('Episodes', size=20)
plt.ylabel('# moves', size=20)
ax = fig.add_subplot(2, 1, 2)
rewards = np.array([h[1] for h in history])
plt.step(episodes, rewards, lw=4)
ax.tick_params(axis='both', which='major', labelsize=15)
plt.xlabel('Episodes', size=20)
plt.ylabel('Final rewards', size=20)
plt.savefig('q-learning-history.png', dpi=300)
plt.show()
if __name__ == '__main__':
env = GridWorldEnv(num_rows=5, num_cols=6)
agent = Agent(env)
history = run_qlearning(agent, env)
env.close()
plot_learning_history(history)