diff --git a/ch18/ch18.ipynb b/ch18/ch18.ipynb index 7e12ec2f..cfb34307 100644 --- a/ch18/ch18.ipynb +++ b/ch18/ch18.ipynb @@ -387,12 +387,28 @@ "metadata": {}, "source": [ "```python\n", - "## Script: gridworld_env.py\n", + "# coding: utf-8\n", + "\n", + "# Python Machine Learning 3rd Edition by\n", + "# Sebastian Raschka (https://sebastianraschka.com) & Vahid Mirjalili](http://vahidmirjalili.com)\n", + "# Packt Publishing Ltd. 2019\n", + "#\n", + "# Code Repository: https://github.com/rasbt/python-machine-learning-book-3rd-edition\n", + "#\n", + "# Code License: MIT License (https://github.com/rasbt/python-machine-learning-book-3rd-edition/blob/master/LICENSE.txt)\n", + "\n", + "############################################################################\n", + "# Chapter 18: Reinforcement Learning\n", + "############################################################################\n", + "\n", + "# Script: gridworld_env.py\n", "\n", "import numpy as np\n", "from gym.envs.toy_text import discrete\n", "from collections import defaultdict\n", "import time\n", + "import pickle\n", + "import os\n", "\n", "from gym.envs.classic_control import rendering\n", "\n", @@ -410,32 +426,29 @@ " xl, xr = xc - half_size, xc + half_size\n", " yt, yb = xc - half_size, xc + half_size\n", " return [(xl, yt), (xr, yt), (xr, yb), (xl, yb)]\n", - " elif loc=='interior_triangle':\n", + " elif loc == 'interior_triangle':\n", " x1, y1 = xc, yc + CELL_SIZE//3\n", " x2, y2 = xc + CELL_SIZE//3, yc - CELL_SIZE//3\n", " x3, y3 = xc - CELL_SIZE//3, yc - CELL_SIZE//3\n", " return [(x1, y1), (x2, y2), (x3, y3)]\n", "\n", + "\n", "def draw_object(coords_list):\n", - " if len(coords_list) == 1: # -> circle\n", + " if len(coords_list) == 1: # -> circle\n", " obj = rendering.make_circle(int(0.45*CELL_SIZE))\n", " obj_transform = rendering.Transform()\n", " obj.add_attr(obj_transform)\n", " obj_transform.set_translation(*coords_list[0])\n", - " obj.set_color(0.2, 0.2, 0.2) # -> black\n", - " elif len(coords_list) == 3: # -> triangle\n", + " obj.set_color(0.2, 0.2, 0.2) # -> black\n", + " elif len(coords_list) == 3: # -> triangle\n", " obj = rendering.FilledPolygon(coords_list)\n", - " obj.set_color(0.9, 0.6, 0.2) # -> yellow\n", - " elif len(coords_list) > 3: # -> polygon\n", + " obj.set_color(0.9, 0.6, 0.2) # -> yellow\n", + " elif len(coords_list) > 3: # -> polygon\n", " obj = rendering.FilledPolygon(coords_list)\n", - " obj.set_color(0.4, 0.4, 0.8) # -> blue\n", + " obj.set_color(0.4, 0.4, 0.8) # -> blue\n", " return obj\n", "\n", "\n", - "import pickle\n", - "import os\n", - "\n", - "\n", "class GridWorldEnv(discrete.DiscreteEnv):\n", " def __init__(self, num_rows=4, num_cols=6, delay=0.05):\n", " self.num_rows = num_rows\n", @@ -451,7 +464,7 @@ " self.action_defs = {0: move_up, 1: move_right,\n", " 2: move_down, 3: move_left}\n", "\n", - " ## Number of states/actions\n", + " # Number of states/actions\n", " nS = num_cols * num_rows\n", " nA = len(self.action_defs)\n", " self.grid2state_dict = {(s // num_cols, s % num_cols): s\n", @@ -459,10 +472,10 @@ " self.state2grid_dict = {s: (s // num_cols, s % num_cols)\n", " for s in range(nS)}\n", "\n", - " ## Gold state\n", + " # Gold state\n", " gold_cell = (num_rows // 2, num_cols - 2)\n", "\n", - " ## Trap states\n", + " # Trap states\n", " trap_cells = [((gold_cell[0] + 1), gold_cell[1]),\n", " (gold_cell[0], gold_cell[1] - 1),\n", " ((gold_cell[0] - 1), gold_cell[1])]\n", @@ -473,7 +486,7 @@ " self.terminal_states = [gold_state] + trap_states\n", " print(self.terminal_states)\n", "\n", - " ## Build the transition probability\n", + " # Build the transition probability\n", " P = defaultdict(dict)\n", " for s in range(nS):\n", " row, col = self.state2grid_dict[s]\n", @@ -482,7 +495,7 @@ " action = self.action_defs[a]\n", " next_s = self.grid2state_dict[action(row, col)]\n", "\n", - " ## Terminal state\n", + " # Terminal state\n", " if self.is_terminal(next_s):\n", " r = (1.0 if next_s == self.terminal_states[0]\n", " else -1.0)\n", @@ -495,7 +508,7 @@ " done = False\n", " P[s][a] = [(1.0, next_s, r, done)]\n", "\n", - " ## Initial state distribution\n", + " # Initial state distribution\n", " isd = np.zeros(nS)\n", " isd[0] = 1.0\n", "\n", @@ -516,7 +529,7 @@ "\n", " all_objects = []\n", "\n", - " ## List of border points' coordinates\n", + " # List of border points' coordinates\n", " bp_list = [\n", " (CELL_SIZE - MARGIN, CELL_SIZE - MARGIN),\n", " (screen_width - CELL_SIZE + MARGIN, CELL_SIZE - MARGIN),\n", @@ -528,7 +541,7 @@ " border.set_linewidth(5)\n", " all_objects.append(border)\n", "\n", - " ## Vertical lines\n", + " # Vertical lines\n", " for col in range(self.num_cols + 1):\n", " x1, y1 = (col + 1) * CELL_SIZE, CELL_SIZE\n", " x2, y2 = (col + 1) * CELL_SIZE, \\\n", @@ -536,7 +549,7 @@ " line = rendering.PolyLine([(x1, y1), (x2, y2)], False)\n", " all_objects.append(line)\n", "\n", - " ## Horizontal lines\n", + " # Horizontal lines\n", " for row in range(self.num_rows + 1):\n", " x1, y1 = CELL_SIZE, (row + 1) * CELL_SIZE\n", " x2, y2 = (self.num_cols + 1) * CELL_SIZE, \\\n", @@ -544,17 +557,17 @@ " line = rendering.PolyLine([(x1, y1), (x2, y2)], False)\n", " all_objects.append(line)\n", "\n", - " ## Traps: --> circles\n", + " # Traps: --> circles\n", " for cell in trap_cells:\n", " trap_coords = get_coords(*cell, loc='center')\n", " all_objects.append(draw_object([trap_coords]))\n", "\n", - " ## Gold: --> triangle\n", + " # Gold: --> triangle\n", " gold_coords = get_coords(*gold_cell,\n", " loc='interior_triangle')\n", " all_objects.append(draw_object(gold_coords))\n", "\n", - " ## Agent --> square or robot\n", + " # Agent --> square or robot\n", " if (os.path.exists('robot-coordinates.pkl') and CELL_SIZE == 100):\n", " agent_coord\n", " s = pickle.load(\n", @@ -651,7 +664,21 @@ "metadata": {}, "source": [ "```python\n", - "## Script: agent.py\n", + "# coding: utf-8\n", + "\n", + "# Python Machine Learning 3rd Edition by\n", + "# Sebastian Raschka (https://sebastianraschka.com) & Vahid Mirjalili](http://vahidmirjalili.com)\n", + "# Packt Publishing Ltd. 2019\n", + "#\n", + "# Code Repository: https://github.com/rasbt/python-machine-learning-book-3rd-edition\n", + "#\n", + "# Code License: MIT License (https://github.com/rasbt/python-machine-learning-book-3rd-edition/blob/master/LICENSE.txt)\n", + "\n", + "############################################################################\n", + "# Chapter 18: Reinforcement Learning\n", + "############################################################################\n", + "\n", + "# Script: agent.py\n", "\n", "from collections import defaultdict\n", "import numpy as np\n", @@ -672,7 +699,7 @@ " self.epsilon_min = epsilon_min\n", " self.epsilon_decay = epsilon_decay\n", "\n", - " ## Define the q_table\n", + " # Define the q_table\n", " self.q_table = defaultdict(lambda: np.zeros(self.env.nA))\n", "\n", " def choose_action(self, state):\n", @@ -694,15 +721,16 @@ " else:\n", " q_target = r + self.gamma*np.max(self.q_table[next_s])\n", "\n", - " ## Update the q_table\n", + " # Update the q_table\n", " self.q_table[s][a] += self.lr * (q_target - q_val)\n", "\n", - " ## Adjust the epislon\n", + " # Adjust the epislon\n", " self._adjust_epsilon()\n", "\n", " def _adjust_epsilon(self):\n", " if self.epsilon > self.epsilon_min:\n", " self.epsilon *= self.epsilon_decay\n", + "\n", "```" ] }, @@ -711,10 +739,24 @@ "metadata": {}, "source": [ "```python\n", - "## Script: qlearning.py\n", + "# coding: utf-8\n", "\n", - "#from gridworld_env import GridWorldEnv\n", - "#from agent import Agent\n", + "# Python Machine Learning 3rd Edition by\n", + "# Sebastian Raschka (https://sebastianraschka.com) & Vahid Mirjalili](http://vahidmirjalili.com)\n", + "# Packt Publishing Ltd. 2019\n", + "#\n", + "# Code Repository: https://github.com/rasbt/python-machine-learning-book-3rd-edition\n", + "#\n", + "# Code License: MIT License (https://github.com/rasbt/python-machine-learning-book-3rd-edition/blob/master/LICENSE.txt)\n", + "\n", + "############################################################################\n", + "# Chapter 18: Reinforcement Learning\n", + "############################################################################\n", + "\n", + "# Script: qlearning.py\n", + "\n", + "from gridworld_env import GridWorldEnv\n", + "from agent import Agent\n", "from collections import namedtuple\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", @@ -724,6 +766,7 @@ "Transition = namedtuple(\n", " 'Transition', ('state', 'action', 'reward', 'next_state', 'done'))\n", "\n", + "\n", "def run_qlearning(agent, env, num_episodes=50):\n", " history = []\n", " for episode in range(num_episodes):\n", @@ -747,6 +790,7 @@ "\n", " return history\n", "\n", + "\n", "def plot_learning_history(history):\n", " fig = plt.figure(1, figsize=(14, 10))\n", " ax = fig.add_subplot(2, 1, 1)\n", @@ -776,6 +820,7 @@ "\n", " plot_learning_history(history)\n", "\n", + "\n", "```" ] }, @@ -915,16 +960,33 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "\n", "```python\n", + "\n", + "# coding: utf-8\n", + "\n", + "# Python Machine Learning 3rd Edition by\n", + "# Sebastian Raschka (https://sebastianraschka.com) & Vahid Mirjalili](http://vahidmirjalili.com)\n", + "# Packt Publishing Ltd. 2019\n", + "#\n", + "# Code Repository: https://github.com/rasbt/python-machine-learning-book-3rd-edition\n", + "#\n", + "# Code License: MIT License (https://github.com/rasbt/python-machine-learning-book-3rd-edition/blob/master/LICENSE.txt)\n", + "\n", + "############################################################################\n", + "# Chapter 18: Reinforcement Learning\n", + "############################################################################\n", + "\n", + "# Script: carpole/main.py\n", + "\n", "import gym\n", "import numpy as np\n", "import tensorflow as tf\n", "import random\n", + "import matplotlib.pyplot as plt\n", "from collections import namedtuple\n", "from collections import deque\n", "\n", - "import matplotlib.pyplot as plt\n", - "\n", "np.random.seed(1)\n", "tf.random.set_seed(1)\n", "\n", @@ -933,7 +995,6 @@ " 'next_state', 'done'))\n", "\n", "\n", - "\n", "class DQNAgent:\n", " def __init__(\n", " self, env, discount_factor=0.95,\n", @@ -956,18 +1017,18 @@ " def _build_nn_model(self, n_layers=3):\n", " self.model = tf.keras.Sequential()\n", "\n", - " ## Hidden layers\n", + " # Hidden layers\n", " for n in range(n_layers - 1):\n", " self.model.add(tf.keras.layers.Dense(\n", " units=32, activation='relu'))\n", " self.model.add(tf.keras.layers.Dense(\n", " units=32, activation='relu'))\n", "\n", - " ## Last layer\n", + " # Last layer\n", " self.model.add(tf.keras.layers.Dense(\n", " units=self.action_size))\n", "\n", - " ## Build & compile model\n", + " # Build & compile model\n", " self.model.build(input_shape=(None, self.state_size))\n", " self.model.compile(\n", " loss='mse',\n", @@ -992,7 +1053,7 @@ " target = (r +\n", " self.gamma * np.amax(\n", " self.model.predict(next_s)[0]\n", - " )\n", + " )\n", " )\n", " target_all = self.model.predict(s)[0]\n", " target_all[a] = target\n", @@ -1013,19 +1074,20 @@ " history = self._learn(samples)\n", " return history.history['loss'][0]\n", "\n", - " def plot_learning_history(history):\n", - " fig = plt.figure(1, figsize=(14, 5))\n", - " ax = fig.add_subplot(1, 1, 1)\n", - " episodes = np.arange(len(history[0])) + 1\n", - " plt.plot(episodes, history[0], lw=4,\n", - " marker='o', markersize=10)\n", - " ax.tick_params(axis='both', which='major', labelsize=15)\n", - " plt.xlabel('Episodes', size=20)\n", - " plt.ylabel('# Total Rewards', size=20)\n", - " plt.show()\n", + "\n", + "def plot_learning_history(history):\n", + " fig = plt.figure(1, figsize=(14, 5))\n", + " ax = fig.add_subplot(1, 1, 1)\n", + " episodes = np.arange(len(history[0])) + 1\n", + " plt.plot(episodes, history[0], lw=4,\n", + " marker='o', markersize=10)\n", + " ax.tick_params(axis='both', which='major', labelsize=15)\n", + " plt.xlabel('Episodes', size=20)\n", + " plt.ylabel('# Total Rewards', size=20)\n", + " plt.show()\n", "\n", "\n", - "## General settings\n", + "# General settings\n", "EPISODES = 200\n", "batch_size = 32\n", "init_replay_memory_size = 500\n", @@ -1036,7 +1098,7 @@ " state = env.reset()\n", " state = np.reshape(state, [1, agent.state_size])\n", "\n", - " ## Filling up the replay-memory\n", + " # Filling up the replay-memory\n", " for i in range(init_replay_memory_size):\n", " action = agent.choose_action(state)\n", " next_state, reward, done, _ = env.step(action)\n", @@ -1051,12 +1113,9 @@ "\n", " total_rewards, losses = [], []\n", " for e in range(EPISODES):\n", - " \n", " state = env.reset()\n", - " \n", " if e % 10 == 0:\n", " env.render()\n", - " \n", " state = np.reshape(state, [1, agent.state_size])\n", " for i in range(500):\n", " action = agent.choose_action(state)\n", @@ -1068,17 +1127,14 @@ " state = next_state\n", " if e % 10 == 0:\n", " env.render()\n", - " \n", " if done:\n", " total_rewards.append(i)\n", " print('Episode: %d/%d, Total reward: %d'\n", " % (e, EPISODES, i))\n", " break\n", - " \n", " loss = agent.replay(batch_size)\n", " losses.append(loss)\n", - " \n", - " plot_learning_history((total_rewards, losses))\n", + " plot_learning_history(total_rewards)\n", "\n", "```" ]