Skip to content

Commit

Permalink
ch18 fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Nov 3, 2019
1 parent 13369d9 commit 553ccaf
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 29 deletions.
22 changes: 18 additions & 4 deletions ch18/gridworld/agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
## Script: agent.py
# coding: utf-8

# Python Machine Learning 3rd Edition by
# Sebastian Raschka (https://sebastianraschka.com) & Vahid Mirjalili](http://vahidmirjalili.com)
# Packt Publishing Ltd. 2019
#
# Code Repository: https://github.com/rasbt/python-machine-learning-book-3rd-edition
#
# Code License: MIT License (https://github.com/rasbt/python-machine-learning-book-3rd-edition/blob/master/LICENSE.txt)

############################################################################
# Chapter 18: Reinforcement Learning
############################################################################

# Script: agent.py

from collections import defaultdict
import numpy as np
Expand All @@ -19,7 +33,7 @@ def __init__(
self.epsilon_min = epsilon_min
self.epsilon_decay = epsilon_decay

## Define the q_table
# Define the q_table
self.q_table = defaultdict(lambda: np.zeros(self.env.nA))

def choose_action(self, state):
Expand All @@ -41,10 +55,10 @@ def _learn(self, transition):
else:
q_target = r + self.gamma*np.max(self.q_table[next_s])

## Update the q_table
# Update the q_table
self.q_table[s][a] += self.lr * (q_target - q_val)

## Adjust the epislon
# Adjust the epislon
self._adjust_epsilon()

def _adjust_epsilon(self):
Expand Down
63 changes: 39 additions & 24 deletions ch18/gridworld/gridworld_env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
# coding: utf-8

# Python Machine Learning 3rd Edition by
# Sebastian Raschka (https://sebastianraschka.com) & Vahid Mirjalili](http://vahidmirjalili.com)
# Packt Publishing Ltd. 2019
#
# Code Repository: https://github.com/rasbt/python-machine-learning-book-3rd-edition
#
# Code License: MIT License (https://github.com/rasbt/python-machine-learning-book-3rd-edition/blob/master/LICENSE.txt)

############################################################################
# Chapter 18: Reinforcement Learning
############################################################################

# Script: gridworld_env.py

import numpy as np
from gym.envs.toy_text import discrete
from collections import defaultdict
import time
import pickle
import os

from gym.envs.classic_control import rendering

Expand All @@ -19,32 +37,29 @@ def get_coords(row, col, loc='center'):
xl, xr = xc - half_size, xc + half_size
yt, yb = xc - half_size, xc + half_size
return [(xl, yt), (xr, yt), (xr, yb), (xl, yb)]
elif loc=='interior_triangle':
elif loc == 'interior_triangle':
x1, y1 = xc, yc + CELL_SIZE//3
x2, y2 = xc + CELL_SIZE//3, yc - CELL_SIZE//3
x3, y3 = xc - CELL_SIZE//3, yc - CELL_SIZE//3
return [(x1, y1), (x2, y2), (x3, y3)]


def draw_object(coords_list):
if len(coords_list) == 1: # -> circle
if len(coords_list) == 1: # -> circle
obj = rendering.make_circle(int(0.45*CELL_SIZE))
obj_transform = rendering.Transform()
obj.add_attr(obj_transform)
obj_transform.set_translation(*coords_list[0])
obj.set_color(0.2, 0.2, 0.2) # -> black
elif len(coords_list) == 3: # -> triangle
obj.set_color(0.2, 0.2, 0.2) # -> black
elif len(coords_list) == 3: # -> triangle
obj = rendering.FilledPolygon(coords_list)
obj.set_color(0.9, 0.6, 0.2) # -> yellow
elif len(coords_list) > 3: # -> polygon
obj.set_color(0.9, 0.6, 0.2) # -> yellow
elif len(coords_list) > 3: # -> polygon
obj = rendering.FilledPolygon(coords_list)
obj.set_color(0.4, 0.4, 0.8) # -> blue
obj.set_color(0.4, 0.4, 0.8) # -> blue
return obj


import pickle
import os


class GridWorldEnv(discrete.DiscreteEnv):
def __init__(self, num_rows=4, num_cols=6, delay=0.05):
self.num_rows = num_rows
Expand All @@ -60,18 +75,18 @@ def __init__(self, num_rows=4, num_cols=6, delay=0.05):
self.action_defs = {0: move_up, 1: move_right,
2: move_down, 3: move_left}

## Number of states/actions
# Number of states/actions
nS = num_cols * num_rows
nA = len(self.action_defs)
self.grid2state_dict = {(s // num_cols, s % num_cols): s
for s in range(nS)}
self.state2grid_dict = {s: (s // num_cols, s % num_cols)
for s in range(nS)}

## Gold state
# Gold state
gold_cell = (num_rows // 2, num_cols - 2)

## Trap states
# Trap states
trap_cells = [((gold_cell[0] + 1), gold_cell[1]),
(gold_cell[0], gold_cell[1] - 1),
((gold_cell[0] - 1), gold_cell[1])]
Expand All @@ -82,7 +97,7 @@ def __init__(self, num_rows=4, num_cols=6, delay=0.05):
self.terminal_states = [gold_state] + trap_states
print(self.terminal_states)

## Build the transition probability
# Build the transition probability
P = defaultdict(dict)
for s in range(nS):
row, col = self.state2grid_dict[s]
Expand All @@ -91,7 +106,7 @@ def __init__(self, num_rows=4, num_cols=6, delay=0.05):
action = self.action_defs[a]
next_s = self.grid2state_dict[action(row, col)]

## Terminal state
# Terminal state
if self.is_terminal(next_s):
r = (1.0 if next_s == self.terminal_states[0]
else -1.0)
Expand All @@ -104,7 +119,7 @@ def __init__(self, num_rows=4, num_cols=6, delay=0.05):
done = False
P[s][a] = [(1.0, next_s, r, done)]

## Initial state distribution
# Initial state distribution
isd = np.zeros(nS)
isd[0] = 1.0

Expand All @@ -125,7 +140,7 @@ def _build_display(self, gold_cell, trap_cells):

all_objects = []

## List of border points' coordinates
# List of border points' coordinates
bp_list = [
(CELL_SIZE - MARGIN, CELL_SIZE - MARGIN),
(screen_width - CELL_SIZE + MARGIN, CELL_SIZE - MARGIN),
Expand All @@ -137,33 +152,33 @@ def _build_display(self, gold_cell, trap_cells):
border.set_linewidth(5)
all_objects.append(border)

## Vertical lines
# Vertical lines
for col in range(self.num_cols + 1):
x1, y1 = (col + 1) * CELL_SIZE, CELL_SIZE
x2, y2 = (col + 1) * CELL_SIZE, \
(self.num_rows + 1) * CELL_SIZE
line = rendering.PolyLine([(x1, y1), (x2, y2)], False)
all_objects.append(line)

## Horizontal lines
# Horizontal lines
for row in range(self.num_rows + 1):
x1, y1 = CELL_SIZE, (row + 1) * CELL_SIZE
x2, y2 = (self.num_cols + 1) * CELL_SIZE, \
(row + 1) * CELL_SIZE
line = rendering.PolyLine([(x1, y1), (x2, y2)], False)
all_objects.append(line)

## Traps: --> circles
# Traps: --> circles
for cell in trap_cells:
trap_coords = get_coords(*cell, loc='center')
all_objects.append(draw_object([trap_coords]))

## Gold: --> triangle
# Gold: --> triangle
gold_coords = get_coords(*gold_cell,
loc='interior_triangle')
all_objects.append(draw_object(gold_coords))

## Agent --> square or robot
# Agent --> square or robot
if (os.path.exists('robot-coordinates.pkl') and CELL_SIZE == 100):
agent_coord
s = pickle.load(
Expand Down Expand Up @@ -215,4 +230,4 @@ def close(self):
if res[2]:
break

env.close()
env.close()
18 changes: 17 additions & 1 deletion ch18/gridworld/qlearning.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
## Script: qlearning.py
# coding: utf-8

# Python Machine Learning 3rd Edition by
# Sebastian Raschka (https://sebastianraschka.com) & Vahid Mirjalili](http://vahidmirjalili.com)
# Packt Publishing Ltd. 2019
#
# Code Repository: https://github.com/rasbt/python-machine-learning-book-3rd-edition
#
# Code License: MIT License (https://github.com/rasbt/python-machine-learning-book-3rd-edition/blob/master/LICENSE.txt)

############################################################################
# Chapter 18: Reinforcement Learning
############################################################################

# Script: qlearning.py

from gridworld_env import GridWorldEnv
from agent import Agent
Expand All @@ -11,6 +25,7 @@
Transition = namedtuple(
'Transition', ('state', 'action', 'reward', 'next_state', 'done'))


def run_qlearning(agent, env, num_episodes=50):
history = []
for episode in range(num_episodes):
Expand All @@ -34,6 +49,7 @@ def run_qlearning(agent, env, num_episodes=50):

return history


def plot_learning_history(history):
fig = plt.figure(1, figsize=(14, 10))
ax = fig.add_subplot(2, 1, 1)
Expand Down

0 comments on commit 553ccaf

Please sign in to comment.