Skip to content

Commit 553ccaf

Browse files
committed
ch18 fixes
1 parent 13369d9 commit 553ccaf

File tree

3 files changed

+74
-29
lines changed

3 files changed

+74
-29
lines changed

ch18/gridworld/agent.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,18 @@
1-
## Script: agent.py
1+
# coding: utf-8
2+
3+
# Python Machine Learning 3rd Edition by
4+
# Sebastian Raschka (https://sebastianraschka.com) & Vahid Mirjalili](http://vahidmirjalili.com)
5+
# Packt Publishing Ltd. 2019
6+
#
7+
# Code Repository: https://github.com/rasbt/python-machine-learning-book-3rd-edition
8+
#
9+
# Code License: MIT License (https://github.com/rasbt/python-machine-learning-book-3rd-edition/blob/master/LICENSE.txt)
10+
11+
############################################################################
12+
# Chapter 18: Reinforcement Learning
13+
############################################################################
14+
15+
# Script: agent.py
216

317
from collections import defaultdict
418
import numpy as np
@@ -19,7 +33,7 @@ def __init__(
1933
self.epsilon_min = epsilon_min
2034
self.epsilon_decay = epsilon_decay
2135

22-
## Define the q_table
36+
# Define the q_table
2337
self.q_table = defaultdict(lambda: np.zeros(self.env.nA))
2438

2539
def choose_action(self, state):
@@ -41,10 +55,10 @@ def _learn(self, transition):
4155
else:
4256
q_target = r + self.gamma*np.max(self.q_table[next_s])
4357

44-
## Update the q_table
58+
# Update the q_table
4559
self.q_table[s][a] += self.lr * (q_target - q_val)
4660

47-
## Adjust the epislon
61+
# Adjust the epislon
4862
self._adjust_epsilon()
4963

5064
def _adjust_epsilon(self):

ch18/gridworld/gridworld_env.py

+39-24
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,25 @@
1+
# coding: utf-8
2+
3+
# Python Machine Learning 3rd Edition by
4+
# Sebastian Raschka (https://sebastianraschka.com) & Vahid Mirjalili](http://vahidmirjalili.com)
5+
# Packt Publishing Ltd. 2019
6+
#
7+
# Code Repository: https://github.com/rasbt/python-machine-learning-book-3rd-edition
8+
#
9+
# Code License: MIT License (https://github.com/rasbt/python-machine-learning-book-3rd-edition/blob/master/LICENSE.txt)
10+
11+
############################################################################
12+
# Chapter 18: Reinforcement Learning
13+
############################################################################
14+
15+
# Script: gridworld_env.py
16+
117
import numpy as np
218
from gym.envs.toy_text import discrete
319
from collections import defaultdict
420
import time
21+
import pickle
22+
import os
523

624
from gym.envs.classic_control import rendering
725

@@ -19,32 +37,29 @@ def get_coords(row, col, loc='center'):
1937
xl, xr = xc - half_size, xc + half_size
2038
yt, yb = xc - half_size, xc + half_size
2139
return [(xl, yt), (xr, yt), (xr, yb), (xl, yb)]
22-
elif loc=='interior_triangle':
40+
elif loc == 'interior_triangle':
2341
x1, y1 = xc, yc + CELL_SIZE//3
2442
x2, y2 = xc + CELL_SIZE//3, yc - CELL_SIZE//3
2543
x3, y3 = xc - CELL_SIZE//3, yc - CELL_SIZE//3
2644
return [(x1, y1), (x2, y2), (x3, y3)]
2745

46+
2847
def draw_object(coords_list):
29-
if len(coords_list) == 1: # -> circle
48+
if len(coords_list) == 1: # -> circle
3049
obj = rendering.make_circle(int(0.45*CELL_SIZE))
3150
obj_transform = rendering.Transform()
3251
obj.add_attr(obj_transform)
3352
obj_transform.set_translation(*coords_list[0])
34-
obj.set_color(0.2, 0.2, 0.2) # -> black
35-
elif len(coords_list) == 3: # -> triangle
53+
obj.set_color(0.2, 0.2, 0.2) # -> black
54+
elif len(coords_list) == 3: # -> triangle
3655
obj = rendering.FilledPolygon(coords_list)
37-
obj.set_color(0.9, 0.6, 0.2) # -> yellow
38-
elif len(coords_list) > 3: # -> polygon
56+
obj.set_color(0.9, 0.6, 0.2) # -> yellow
57+
elif len(coords_list) > 3: # -> polygon
3958
obj = rendering.FilledPolygon(coords_list)
40-
obj.set_color(0.4, 0.4, 0.8) # -> blue
59+
obj.set_color(0.4, 0.4, 0.8) # -> blue
4160
return obj
4261

4362

44-
import pickle
45-
import os
46-
47-
4863
class GridWorldEnv(discrete.DiscreteEnv):
4964
def __init__(self, num_rows=4, num_cols=6, delay=0.05):
5065
self.num_rows = num_rows
@@ -60,18 +75,18 @@ def __init__(self, num_rows=4, num_cols=6, delay=0.05):
6075
self.action_defs = {0: move_up, 1: move_right,
6176
2: move_down, 3: move_left}
6277

63-
## Number of states/actions
78+
# Number of states/actions
6479
nS = num_cols * num_rows
6580
nA = len(self.action_defs)
6681
self.grid2state_dict = {(s // num_cols, s % num_cols): s
6782
for s in range(nS)}
6883
self.state2grid_dict = {s: (s // num_cols, s % num_cols)
6984
for s in range(nS)}
7085

71-
## Gold state
86+
# Gold state
7287
gold_cell = (num_rows // 2, num_cols - 2)
7388

74-
## Trap states
89+
# Trap states
7590
trap_cells = [((gold_cell[0] + 1), gold_cell[1]),
7691
(gold_cell[0], gold_cell[1] - 1),
7792
((gold_cell[0] - 1), gold_cell[1])]
@@ -82,7 +97,7 @@ def __init__(self, num_rows=4, num_cols=6, delay=0.05):
8297
self.terminal_states = [gold_state] + trap_states
8398
print(self.terminal_states)
8499

85-
## Build the transition probability
100+
# Build the transition probability
86101
P = defaultdict(dict)
87102
for s in range(nS):
88103
row, col = self.state2grid_dict[s]
@@ -91,7 +106,7 @@ def __init__(self, num_rows=4, num_cols=6, delay=0.05):
91106
action = self.action_defs[a]
92107
next_s = self.grid2state_dict[action(row, col)]
93108

94-
## Terminal state
109+
# Terminal state
95110
if self.is_terminal(next_s):
96111
r = (1.0 if next_s == self.terminal_states[0]
97112
else -1.0)
@@ -104,7 +119,7 @@ def __init__(self, num_rows=4, num_cols=6, delay=0.05):
104119
done = False
105120
P[s][a] = [(1.0, next_s, r, done)]
106121

107-
## Initial state distribution
122+
# Initial state distribution
108123
isd = np.zeros(nS)
109124
isd[0] = 1.0
110125

@@ -125,7 +140,7 @@ def _build_display(self, gold_cell, trap_cells):
125140

126141
all_objects = []
127142

128-
## List of border points' coordinates
143+
# List of border points' coordinates
129144
bp_list = [
130145
(CELL_SIZE - MARGIN, CELL_SIZE - MARGIN),
131146
(screen_width - CELL_SIZE + MARGIN, CELL_SIZE - MARGIN),
@@ -137,33 +152,33 @@ def _build_display(self, gold_cell, trap_cells):
137152
border.set_linewidth(5)
138153
all_objects.append(border)
139154

140-
## Vertical lines
155+
# Vertical lines
141156
for col in range(self.num_cols + 1):
142157
x1, y1 = (col + 1) * CELL_SIZE, CELL_SIZE
143158
x2, y2 = (col + 1) * CELL_SIZE, \
144159
(self.num_rows + 1) * CELL_SIZE
145160
line = rendering.PolyLine([(x1, y1), (x2, y2)], False)
146161
all_objects.append(line)
147162

148-
## Horizontal lines
163+
# Horizontal lines
149164
for row in range(self.num_rows + 1):
150165
x1, y1 = CELL_SIZE, (row + 1) * CELL_SIZE
151166
x2, y2 = (self.num_cols + 1) * CELL_SIZE, \
152167
(row + 1) * CELL_SIZE
153168
line = rendering.PolyLine([(x1, y1), (x2, y2)], False)
154169
all_objects.append(line)
155170

156-
## Traps: --> circles
171+
# Traps: --> circles
157172
for cell in trap_cells:
158173
trap_coords = get_coords(*cell, loc='center')
159174
all_objects.append(draw_object([trap_coords]))
160175

161-
## Gold: --> triangle
176+
# Gold: --> triangle
162177
gold_coords = get_coords(*gold_cell,
163178
loc='interior_triangle')
164179
all_objects.append(draw_object(gold_coords))
165180

166-
## Agent --> square or robot
181+
# Agent --> square or robot
167182
if (os.path.exists('robot-coordinates.pkl') and CELL_SIZE == 100):
168183
agent_coord
169184
s = pickle.load(
@@ -215,4 +230,4 @@ def close(self):
215230
if res[2]:
216231
break
217232

218-
env.close()
233+
env.close()

ch18/gridworld/qlearning.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,18 @@
1-
## Script: qlearning.py
1+
# coding: utf-8
2+
3+
# Python Machine Learning 3rd Edition by
4+
# Sebastian Raschka (https://sebastianraschka.com) & Vahid Mirjalili](http://vahidmirjalili.com)
5+
# Packt Publishing Ltd. 2019
6+
#
7+
# Code Repository: https://github.com/rasbt/python-machine-learning-book-3rd-edition
8+
#
9+
# Code License: MIT License (https://github.com/rasbt/python-machine-learning-book-3rd-edition/blob/master/LICENSE.txt)
10+
11+
############################################################################
12+
# Chapter 18: Reinforcement Learning
13+
############################################################################
14+
15+
# Script: qlearning.py
216

317
from gridworld_env import GridWorldEnv
418
from agent import Agent
@@ -11,6 +25,7 @@
1125
Transition = namedtuple(
1226
'Transition', ('state', 'action', 'reward', 'next_state', 'done'))
1327

28+
1429
def run_qlearning(agent, env, num_episodes=50):
1530
history = []
1631
for episode in range(num_episodes):
@@ -34,6 +49,7 @@ def run_qlearning(agent, env, num_episodes=50):
3449

3550
return history
3651

52+
3753
def plot_learning_history(history):
3854
fig = plt.figure(1, figsize=(14, 10))
3955
ax = fig.add_subplot(2, 1, 1)

0 commit comments

Comments
 (0)