Skip to content

Commit

Permalink
Merge pull request rasbt#34 from rasbt/ch18-sr
Browse files Browse the repository at this point in the history
ch18 fixes
  • Loading branch information
rasbt authored Nov 3, 2019
2 parents b678cd0 + 553ccaf commit d246520
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 46 deletions.
52 changes: 35 additions & 17 deletions ch18/cartpole/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,24 @@
# coding: utf-8

# Python Machine Learning 3rd Edition by
# Sebastian Raschka (https://sebastianraschka.com) & Vahid Mirjalili](http://vahidmirjalili.com)
# Packt Publishing Ltd. 2019
#
# Code Repository: https://github.com/rasbt/python-machine-learning-book-3rd-edition
#
# Code License: MIT License (https://github.com/rasbt/python-machine-learning-book-3rd-edition/blob/master/LICENSE.txt)

############################################################################
# Chapter 18: Reinforcement Learning
############################################################################

# Script: carpole/main.py

import gym
import numpy as np
import tensorflow as tf
import random
import matplotlib.pyplot as plt
from collections import namedtuple
from collections import deque

Expand Down Expand Up @@ -35,18 +52,18 @@ def __init__(
def _build_nn_model(self, n_layers=3):
self.model = tf.keras.Sequential()

## Hidden layers
# Hidden layers
for n in range(n_layers - 1):
self.model.add(tf.keras.layers.Dense(
units=32, activation='relu'))
self.model.add(tf.keras.layers.Dense(
units=32, activation='relu'))

## Last layer
# Last layer
self.model.add(tf.keras.layers.Dense(
units=self.action_size))

## Build & compile model
# Build & compile model
self.model.build(input_shape=(None, self.state_size))
self.model.compile(
loss='mse',
Expand All @@ -71,7 +88,7 @@ def _learn(self, batch_samples):
target = (r +
self.gamma * np.amax(
self.model.predict(next_s)[0]
)
)
)
target_all = self.model.predict(s)[0]
target_all[a] = target
Expand All @@ -92,19 +109,20 @@ def replay(self, batch_size):
history = self._learn(samples)
return history.history['loss'][0]

def plot_learning_history(history):
fig = plt.figure(1, figsize=(14, 5))
ax = fig.add_subplot(1, 1, 1)
episodes = np.arange(len(history[0])) + 1
plt.plot(episodes, history[0], lw=4,
marker='o', markersize=10)
ax.tick_params(axis='both', which='major', labelsize=15)
plt.xlabel('Episodes', size=20)
plt.ylabel('# Total Rewards', size=20)
plt.show()

def plot_learning_history(history):
fig = plt.figure(1, figsize=(14, 5))
ax = fig.add_subplot(1, 1, 1)
episodes = np.arange(len(history[0])) + 1
plt.plot(episodes, history[0], lw=4,
marker='o', markersize=10)
ax.tick_params(axis='both', which='major', labelsize=15)
plt.xlabel('Episodes', size=20)
plt.ylabel('# Total Rewards', size=20)
plt.show()


## General settings
# General settings
EPISODES = 200
batch_size = 32
init_replay_memory_size = 500
Expand All @@ -115,7 +133,7 @@ def plot_learning_history(history):
state = env.reset()
state = np.reshape(state, [1, agent.state_size])

## Filling up the replay-memory
# Filling up the replay-memory
for i in range(init_replay_memory_size):
action = agent.choose_action(state)
next_state, reward, done, _ = env.step(action)
Expand Down Expand Up @@ -151,4 +169,4 @@ def plot_learning_history(history):
break
loss = agent.replay(batch_size)
losses.append(loss)
plot_learning_history((total_rewards, losses))
plot_learning_history(total_rewards)
22 changes: 18 additions & 4 deletions ch18/gridworld/agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
## Script: agent.py
# coding: utf-8

# Python Machine Learning 3rd Edition by
# Sebastian Raschka (https://sebastianraschka.com) & Vahid Mirjalili](http://vahidmirjalili.com)
# Packt Publishing Ltd. 2019
#
# Code Repository: https://github.com/rasbt/python-machine-learning-book-3rd-edition
#
# Code License: MIT License (https://github.com/rasbt/python-machine-learning-book-3rd-edition/blob/master/LICENSE.txt)

############################################################################
# Chapter 18: Reinforcement Learning
############################################################################

# Script: agent.py

from collections import defaultdict
import numpy as np
Expand All @@ -19,7 +33,7 @@ def __init__(
self.epsilon_min = epsilon_min
self.epsilon_decay = epsilon_decay

## Define the q_table
# Define the q_table
self.q_table = defaultdict(lambda: np.zeros(self.env.nA))

def choose_action(self, state):
Expand All @@ -41,10 +55,10 @@ def _learn(self, transition):
else:
q_target = r + self.gamma*np.max(self.q_table[next_s])

## Update the q_table
# Update the q_table
self.q_table[s][a] += self.lr * (q_target - q_val)

## Adjust the epislon
# Adjust the epislon
self._adjust_epsilon()

def _adjust_epsilon(self):
Expand Down
63 changes: 39 additions & 24 deletions ch18/gridworld/gridworld_env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
# coding: utf-8

# Python Machine Learning 3rd Edition by
# Sebastian Raschka (https://sebastianraschka.com) & Vahid Mirjalili](http://vahidmirjalili.com)
# Packt Publishing Ltd. 2019
#
# Code Repository: https://github.com/rasbt/python-machine-learning-book-3rd-edition
#
# Code License: MIT License (https://github.com/rasbt/python-machine-learning-book-3rd-edition/blob/master/LICENSE.txt)

############################################################################
# Chapter 18: Reinforcement Learning
############################################################################

# Script: gridworld_env.py

import numpy as np
from gym.envs.toy_text import discrete
from collections import defaultdict
import time
import pickle
import os

from gym.envs.classic_control import rendering

Expand All @@ -19,32 +37,29 @@ def get_coords(row, col, loc='center'):
xl, xr = xc - half_size, xc + half_size
yt, yb = xc - half_size, xc + half_size
return [(xl, yt), (xr, yt), (xr, yb), (xl, yb)]
elif loc=='interior_triangle':
elif loc == 'interior_triangle':
x1, y1 = xc, yc + CELL_SIZE//3
x2, y2 = xc + CELL_SIZE//3, yc - CELL_SIZE//3
x3, y3 = xc - CELL_SIZE//3, yc - CELL_SIZE//3
return [(x1, y1), (x2, y2), (x3, y3)]


def draw_object(coords_list):
if len(coords_list) == 1: # -> circle
if len(coords_list) == 1: # -> circle
obj = rendering.make_circle(int(0.45*CELL_SIZE))
obj_transform = rendering.Transform()
obj.add_attr(obj_transform)
obj_transform.set_translation(*coords_list[0])
obj.set_color(0.2, 0.2, 0.2) # -> black
elif len(coords_list) == 3: # -> triangle
obj.set_color(0.2, 0.2, 0.2) # -> black
elif len(coords_list) == 3: # -> triangle
obj = rendering.FilledPolygon(coords_list)
obj.set_color(0.9, 0.6, 0.2) # -> yellow
elif len(coords_list) > 3: # -> polygon
obj.set_color(0.9, 0.6, 0.2) # -> yellow
elif len(coords_list) > 3: # -> polygon
obj = rendering.FilledPolygon(coords_list)
obj.set_color(0.4, 0.4, 0.8) # -> blue
obj.set_color(0.4, 0.4, 0.8) # -> blue
return obj


import pickle
import os


class GridWorldEnv(discrete.DiscreteEnv):
def __init__(self, num_rows=4, num_cols=6, delay=0.05):
self.num_rows = num_rows
Expand All @@ -60,18 +75,18 @@ def __init__(self, num_rows=4, num_cols=6, delay=0.05):
self.action_defs = {0: move_up, 1: move_right,
2: move_down, 3: move_left}

## Number of states/actions
# Number of states/actions
nS = num_cols * num_rows
nA = len(self.action_defs)
self.grid2state_dict = {(s // num_cols, s % num_cols): s
for s in range(nS)}
self.state2grid_dict = {s: (s // num_cols, s % num_cols)
for s in range(nS)}

## Gold state
# Gold state
gold_cell = (num_rows // 2, num_cols - 2)

## Trap states
# Trap states
trap_cells = [((gold_cell[0] + 1), gold_cell[1]),
(gold_cell[0], gold_cell[1] - 1),
((gold_cell[0] - 1), gold_cell[1])]
Expand All @@ -82,7 +97,7 @@ def __init__(self, num_rows=4, num_cols=6, delay=0.05):
self.terminal_states = [gold_state] + trap_states
print(self.terminal_states)

## Build the transition probability
# Build the transition probability
P = defaultdict(dict)
for s in range(nS):
row, col = self.state2grid_dict[s]
Expand All @@ -91,7 +106,7 @@ def __init__(self, num_rows=4, num_cols=6, delay=0.05):
action = self.action_defs[a]
next_s = self.grid2state_dict[action(row, col)]

## Terminal state
# Terminal state
if self.is_terminal(next_s):
r = (1.0 if next_s == self.terminal_states[0]
else -1.0)
Expand All @@ -104,7 +119,7 @@ def __init__(self, num_rows=4, num_cols=6, delay=0.05):
done = False
P[s][a] = [(1.0, next_s, r, done)]

## Initial state distribution
# Initial state distribution
isd = np.zeros(nS)
isd[0] = 1.0

Expand All @@ -125,7 +140,7 @@ def _build_display(self, gold_cell, trap_cells):

all_objects = []

## List of border points' coordinates
# List of border points' coordinates
bp_list = [
(CELL_SIZE - MARGIN, CELL_SIZE - MARGIN),
(screen_width - CELL_SIZE + MARGIN, CELL_SIZE - MARGIN),
Expand All @@ -137,33 +152,33 @@ def _build_display(self, gold_cell, trap_cells):
border.set_linewidth(5)
all_objects.append(border)

## Vertical lines
# Vertical lines
for col in range(self.num_cols + 1):
x1, y1 = (col + 1) * CELL_SIZE, CELL_SIZE
x2, y2 = (col + 1) * CELL_SIZE, \
(self.num_rows + 1) * CELL_SIZE
line = rendering.PolyLine([(x1, y1), (x2, y2)], False)
all_objects.append(line)

## Horizontal lines
# Horizontal lines
for row in range(self.num_rows + 1):
x1, y1 = CELL_SIZE, (row + 1) * CELL_SIZE
x2, y2 = (self.num_cols + 1) * CELL_SIZE, \
(row + 1) * CELL_SIZE
line = rendering.PolyLine([(x1, y1), (x2, y2)], False)
all_objects.append(line)

## Traps: --> circles
# Traps: --> circles
for cell in trap_cells:
trap_coords = get_coords(*cell, loc='center')
all_objects.append(draw_object([trap_coords]))

## Gold: --> triangle
# Gold: --> triangle
gold_coords = get_coords(*gold_cell,
loc='interior_triangle')
all_objects.append(draw_object(gold_coords))

## Agent --> square or robot
# Agent --> square or robot
if (os.path.exists('robot-coordinates.pkl') and CELL_SIZE == 100):
agent_coord
s = pickle.load(
Expand Down Expand Up @@ -215,4 +230,4 @@ def close(self):
if res[2]:
break

env.close()
env.close()
18 changes: 17 additions & 1 deletion ch18/gridworld/qlearning.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
## Script: qlearning.py
# coding: utf-8

# Python Machine Learning 3rd Edition by
# Sebastian Raschka (https://sebastianraschka.com) & Vahid Mirjalili](http://vahidmirjalili.com)
# Packt Publishing Ltd. 2019
#
# Code Repository: https://github.com/rasbt/python-machine-learning-book-3rd-edition
#
# Code License: MIT License (https://github.com/rasbt/python-machine-learning-book-3rd-edition/blob/master/LICENSE.txt)

############################################################################
# Chapter 18: Reinforcement Learning
############################################################################

# Script: qlearning.py

from gridworld_env import GridWorldEnv
from agent import Agent
Expand All @@ -11,6 +25,7 @@
Transition = namedtuple(
'Transition', ('state', 'action', 'reward', 'next_state', 'done'))


def run_qlearning(agent, env, num_episodes=50):
history = []
for episode in range(num_episodes):
Expand All @@ -34,6 +49,7 @@ def run_qlearning(agent, env, num_episodes=50):

return history


def plot_learning_history(history):
fig = plt.figure(1, figsize=(14, 10))
ax = fig.add_subplot(2, 1, 1)
Expand Down

0 comments on commit d246520

Please sign in to comment.