diff --git a/affordances_theory/AffordancesInContinuousEnvironment.ipynb b/affordances_theory/AffordancesInContinuousEnvironment.ipynb new file mode 100644 index 00000000..651d23e6 --- /dev/null +++ b/affordances_theory/AffordancesInContinuousEnvironment.ipynb @@ -0,0 +1,1746 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ya9j9pyzkyBZ" + }, + "source": [ + "Copyright 2020 The \"What Can I do Here? A Theory of Affordances In Reinforcement Learning\" Authors. All rights reserved.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + " https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "LbWb35G9UHLO", + "outputId": "280cac1e-76e0-4960-a271-24d351f249bc" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Populating the interactive namespace from numpy and matplotlib\n" + ] + } + ], + "source": [ + "%tensorflow_version 2.x\n", + "%pylab inline\n", + "\n", + "# System imports\n", + "import copy\n", + "import dataclasses\n", + "import enum\n", + "import itertools\n", + "import numpy as np\n", + "import operator\n", + "import random\n", + "import time\n", + "from typing import Optional, List, Tuple, Any, Dict, Union, Callable\n", + "\n", + "\n", + "# Library imports.\n", + "from google.colab import files\n", + "from matplotlib import colors\n", + "import matplotlib.animation as animation\n", + "import matplotlib.pylab as plt\n", + "import tensorflow as tf\n", + "\n", + "import tensorflow_probability as tfp" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "UZV6OS_BUklD" + }, + "source": [ + "# Environment Specification" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "Mz4KtBOVUpOV" + }, + "outputs": [], + "source": [ + "#@title Point Class\n", + "@dataclasses.dataclass(order=True, frozen=True)\n", + "class Point:\n", + " \"\"\"A class representing a point in 2D space.\n", + "\n", + " Comes with some convenience functions.\n", + " \"\"\"\n", + " x: float\n", + " y: float\n", + "\n", + " def sum(self):\n", + " return self.x + self.y\n", + "\n", + " def l2norm(self):\n", + " \"\"\"Computes the L2 norm of the point.\"\"\"\n", + " return np.sqrt(self.x * self.x + self.y * self.y)\n", + "\n", + " def __add__(self, other: 'Point'):\n", + " return Point(self.x + other.x, self.y + other.y)\n", + "\n", + " def __sub__(self, other: 'Point'):\n", + " return Point(self.x - other.x, self.y - other.y)\n", + "\n", + " def normal_sample_around(self, scale: float):\n", + " \"\"\"Samples a point around the current point based on some noise.\"\"\"\n", + " new_coords = np.random.normal(dataclasses.astuple(self), scale)\n", + " new_coords = new_coords.astype(np.float32)\n", + " return Point(*new_coords)\n", + "\n", + " def is_close_to(self, other: 'Point', diff: float = 1e-4):\n", + " \"\"\"Determines if one point is close to another.\"\"\"\n", + " point_diff = self - other\n", + " if abs(point_diff.x) \u003c diff and abs(point_diff.y) \u003c diff:\n", + " return True\n", + " else:\n", + " return False\n", + "\n", + "# Test the points.\n", + "z1 = Point(0.4, 0.1)\n", + "assert z1.is_close_to(z1)\n", + "assert z1.is_close_to(Point(0.5, 0.0), 1.0)\n", + "assert not z1.is_close_to(Point(5.0, 0.0), 1.0)\n", + "z2 = Point(0.1, 0.1)\n", + "z3 = z1 - z2\n", + "assert isinstance(z3, Point)\n", + "assert z3.is_close_to(Point(0.3, 0.0))\n", + "assert isinstance(z3.normal_sample_around(0.1), Point)\n", + "\n", + "class Force(Point):\n", + " pass\n", + "\n", + "\n", + "# # Intersection code.\n", + "# See Sedgewick, Robert, and Kevin Wayne. Algorithms. , 2011.\n", + "# Chapter 6.1 on Geometric Primitives\n", + "# https://algs4.cs.princeton.edu/91primitives/\n", + "def _check_counter_clockwise(a: Point, b: Point, c: Point):\n", + " \"\"\"Checks if 3 points are counter clockwise to each other.\"\"\"\n", + " slope_AB_numerator = (b.y - a.y)\n", + " slope_AB_denominator = (b.x - a.x)\n", + " slope_AC_numerator = (c.y - a.y)\n", + " slope_AC_denominator = (c.x - a.x)\n", + " return (slope_AC_numerator * slope_AB_denominator \u003e= \\\n", + " slope_AB_numerator * slope_AC_denominator)\n", + "\n", + "def intersect(segment_1: Tuple[Point, Point], segment_2: Tuple[Point, Point]):\n", + " \"\"\"Checks if two line segments intersect.\"\"\"\n", + " a, b = segment_1\n", + " c, d = segment_2\n", + "\n", + " # Checking if there is an intersection is equivalent to:\n", + " # Exactly one counter clockwise path to D (from A or B) via C.\n", + " AC_ccw_CD = _check_counter_clockwise(a, c, d)\n", + " BC_ccw_CD = _check_counter_clockwise(b, c, d)\n", + " toD_via_C = AC_ccw_CD != BC_ccw_CD\n", + "\n", + " # AND\n", + " # Exactly one counterclockwise path from A (to C or D) via B.\n", + " AB_ccw_BC = _check_counter_clockwise(a, b, c)\n", + " AB_ccw_BD = _check_counter_clockwise(a, b, d)\n", + "\n", + " fromA_via_B = AB_ccw_BC != AB_ccw_BD\n", + "\n", + " return toD_via_C and fromA_via_B\n", + "\n", + "# Some simple tests to ensure everything is working.\n", + "assert not intersect((Point(1, 0), Point(1, 1)), (Point(0,0), Point(0, 1))), \\\n", + " 'Parallel lines detected as intersecting.'\n", + "assert not intersect((Point(0, 0), Point(1, 0)), (Point(0,1), Point(1, 1))), \\\n", + " 'Parallel lines detected as intersecting.'\n", + "assert intersect((Point(3, 5), Point(1, 1)), (Point(2, 2), Point(0, 1))), \\\n", + " 'Lines that intersect not detected.'\n", + "assert not intersect((Point(0, 0), Point(2, 2)), (Point(3, 3), Point(5, 1))), \\\n", + " 'Lines that do not intersect detected as intersecting'\n", + "assert intersect((Point(0, .5), Point(0, -.5)), (Point(.5, 0), Point(-.5, 0.))), \\\n", + " 'Lines that intersect not detected.'" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "IaC8khoBVZ2a" + }, + "outputs": [], + "source": [ + "#@title ContinuousWorld environment.\n", + "\n", + "class ContinuousWorld(object):\n", + " r\"\"\"The ContinuousWorld Environment.\n", + "\n", + " An agent can be anywhere in the grid. The agent provides Forces to move. When\n", + " the agent provides a force, it is applied and the final position is jittered.\n", + "\n", + " When the agent is reset, its location is drawn from a global start position\n", + " given by `drift_between`. This start position is non-stationary and drifts\n", + " toward the target start position as the environment resets with the speed\n", + " `drift_speed`.\n", + "\n", + " For example the start position is (0., 0.). After reseting once, the start\n", + " positon might drift toward (0.5, 0.5). After resetting again it may drift\n", + " again to (0., 0.). This happens smoothly according to the drifting speed.\n", + "\n", + " Walls can be specified in this environment. Detection works by checking if the\n", + " agents action forces it to go in a direction which collides with a wall.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " size: float,\n", + " wall_pairs: Optional[List[Tuple[Point, Point]]] = None,\n", + " drift_between: Optional[List[Tuple[Point, Point]]] = None,\n", + " movement_noise: float = 0.1,\n", + " seed: int = 1,\n", + " drift_speed: float = 0.5,\n", + " reset_noise: Optional[float] = None,\n", + " max_episode_length: int = 10,\n", + " max_action_force: float = 0.5,\n", + " verbose_reset: bool = False\n", + " ):\n", + " \"\"\"Initializes the Continuous World Environment.\n", + "\n", + " Args:\n", + " size: The size of the world.\n", + " wall_pairs: A list of tuple of points representing the start and end\n", + " positions of the wall.\n", + " drift_between: A list of tuple of points representing how the starting\n", + " distrubiton should change. If None, it will drift between the four\n", + " corners of the room.\n", + " movement_noise: The noise around each position after movement.\n", + " seed: The seed for the random number generator.\n", + " drift_speed: How quickly to move in the drift direction.\n", + " reset_noise: The noise around the reset position. Defaults to\n", + " movement_noise if not specified.\n", + " max_episode_length: The maximum length of the episode before resetting.\n", + " max_action_force: If using random_step() this will be the maximum random\n", + " force applied in the x and y direction.\n", + " verbose_reset: Prints out every time the global starting position is\n", + " reset.\n", + " \"\"\"\n", + " self._size = size\n", + " self._wall_pairs = wall_pairs or []\n", + " self._verbose_reset = verbose_reset\n", + "\n", + " # Points to drift the start position between.\n", + " if drift_between is None:\n", + " self._drift_between = (\n", + " Point((1/4) * size, (1/4) * size),\n", + " Point((1/4) * size, (3/4) * size),\n", + " Point((3/4) * size, (1/4) * size),\n", + " Point((3/4) * size, (3/4) * size),\n", + " )\n", + " else:\n", + " self._drift_between = drift_between\n", + "\n", + " self._noise = movement_noise\n", + " self._reset_noise = reset_noise or movement_noise\n", + " self._rng = np.random.RandomState(seed)\n", + " random.seed(seed)\n", + "\n", + " # The current and target starting positions.\n", + " # Internal to this class mu is used to refer to mean \"start position\".\n", + " # Therefore mu = current start position and end_mu is the target start\n", + " # position.\n", + " self._mu, self._end_mu = random.sample(self._drift_between, 2)\n", + " # The speed at which we will move toward the target position.\n", + " self._drift_speed = drift_speed\n", + " self.update_agent_position()\n", + " self._decide_new_target_mu()\n", + " self._max_episode_length = max_episode_length\n", + " self._current_episode_length = 0\n", + " self._terminated = True\n", + " self._max_action_force = max_action_force\n", + " self._recent_mu_updated = False\n", + "\n", + " def _decide_new_target_mu(self):\n", + " \"\"\"Decide a new target direction to move toward.\"\"\"\n", + " # The direction should be toward the \"target ending mu.\"\n", + " (new_end_mu,) = random.sample(self._drift_between, 1)\n", + " while new_end_mu == self._end_mu:\n", + " (new_end_mu,) = random.sample(self._drift_between, 1)\n", + "\n", + " self._end_mu = new_end_mu\n", + " self._decide_drift_direction()\n", + " if self._verbose_reset:\n", + " print(f'Target mu has been updated to: {self._end_mu}')\n", + " self._recent_mu_updated = True\n", + "\n", + " def _decide_drift_direction(self):\n", + " \"\"\"Decide the drifting direction to move in.\"\"\"\n", + " direction = self._end_mu - self._mu\n", + " l2 = direction.l2norm()\n", + " drift_direction = Point(direction.x / l2, direction.y / l2)\n", + " self._drift_direction = Point(\n", + " drift_direction.x * self._drift_speed,\n", + " drift_direction.y * self._drift_speed\n", + " )\n", + "\n", + " def _should_update_target_mu(self) -\u003e bool:\n", + " \"\"\"Decide if the drift direction should change.\"\"\"\n", + " # Condition 1: We are past the edge of the environment.\n", + " if self._past_edge(self._mu.x)[0] or self._past_edge(self._mu.y)[0]:\n", + " return True\n", + "\n", + " # Condition 2: Check if the current mu is close to the end mu.\n", + " return self._mu.is_close_to(self._end_mu, self._drift_speed)\n", + "\n", + " def update_current_start_position(self):\n", + " \"\"\"Update the current mu to drift toward mu_end. Change mu_end if needed.\"\"\"\n", + " if self._should_update_target_mu():\n", + " self._decide_new_target_mu()\n", + " self._decide_drift_direction()\n", + " proposed_mu = self._mu + self._drift_direction\n", + " self._mu = self._wrap_coordinate(proposed_mu)\n", + "\n", + " def _past_edge(self, x: float) -\u003e Tuple[bool, float]:\n", + " \"\"\"Checks if coordinate is beyond the edges.\"\"\"\n", + " if x \u003e= self._size:\n", + " return True, self._size\n", + " elif x \u003c= 0.0:\n", + " return True, 0.0\n", + " else:\n", + " return False, x\n", + "\n", + " def _wrap_coordinate(self, point: Point) -\u003e Point:\n", + " \"\"\"Wraps coordinates that are beyond edges.\"\"\"\n", + " wrapped_coordinates = map(self._past_edge, dataclasses.astuple(point))\n", + " return Point(*map(operator.itemgetter(1), wrapped_coordinates))\n", + "\n", + " def update_agent_position(self):\n", + " self._current_position = self._wrap_coordinate(\n", + " self._mu.normal_sample_around(self._noise))\n", + "\n", + " def set_agent_position(self, new_position: Point):\n", + " self._current_position = self._wrap_coordinate(new_position)\n", + "\n", + " def reset(self) -\u003e Tuple[float, float]:\n", + " \"\"\"Reset the current position of the agent and move the global mu.\"\"\"\n", + " self.update_current_start_position()\n", + " self.update_agent_position()\n", + " self._current_episode_length = 0\n", + " self._terminated = False\n", + " return self._current_position\n", + "\n", + " def get_random_force(self) -\u003e Force:\n", + " return Force(*self._rng.uniform(\n", + " -self._max_action_force, self._max_action_force, 2))\n", + "\n", + " def random_step(self):\n", + " random_action = self.get_random_force()\n", + " to_be_returned = self.step(random_action)\n", + " to_be_returned[-1]['action_taken'] = random_action\n", + " return to_be_returned\n", + "\n", + " @property\n", + " def agent_position(self):\n", + " return dataclasses.astuple(self._current_position)\n", + "\n", + " @property\n", + " def start_position(self):\n", + " return dataclasses.astuple(self._mu)\n", + "\n", + " @property\n", + " def size(self):\n", + " return self._size\n", + "\n", + " @property\n", + " def walls(self):\n", + " return self._wall_pairs\n", + "\n", + " def _check_goes_through_wall(self, start: Point, end: Point):\n", + " if not self._wall_pairs: return False\n", + "\n", + " for pair in self._wall_pairs:\n", + " if intersect((start, end), pair):\n", + " return True\n", + " return False\n", + "\n", + " def step(\n", + " self,\n", + " action: Force\n", + " ) -\u003e Tuple[Tuple[float, float], Optional[float], bool, Dict[str, Any]]:\n", + " \"\"\"Does a step in the environment using the action.\n", + "\n", + " Args:\n", + " action: Force applied by the agent.\n", + "\n", + " Returns:\n", + " Agent position: A tuple of two floats.\n", + " The reward.\n", + " An indicator if the episode terminated.\n", + " A dictionary containing any information about the step.\n", + " \"\"\"\n", + " if self._terminated:\n", + " raise ValueError('Episode is over. Please reset the environment.')\n", + " perturbed_action = action.normal_sample_around(self._noise)\n", + "\n", + " proposed_position = self._wrap_coordinate(\n", + " self._current_position + perturbed_action)\n", + "\n", + " goes_through_wall = self._check_goes_through_wall(\n", + " self._current_position, proposed_position)\n", + "\n", + " if not goes_through_wall:\n", + " self._current_position = proposed_position\n", + "\n", + " self._current_episode_length += 1\n", + "\n", + " if self._current_episode_length \u003e self._max_episode_length:\n", + " self._terminated = True\n", + "\n", + " recent_mu_updated = self._recent_mu_updated\n", + " self._recent_mu_updated = False\n", + " return (\n", + " self._current_position,\n", + " None,\n", + " self._terminated,\n", + " {\n", + " 'goes_through_wall': goes_through_wall,\n", + " 'proposed_position': proposed_position,\n", + " 'recent_start_position_updated': recent_mu_updated\n", + " }\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "O43gq5h_YS2I" + }, + "outputs": [], + "source": [ + "#@title Visualization suite.\n", + "\n", + "def visualize_environment(\n", + " world,\n", + " ax,\n", + " scaling=1.0,\n", + " agent_color='r',\n", + " agent_size=0.2,\n", + " start_color='g',\n", + " draw_agent=True,\n", + " draw_start_mu=True,\n", + " draw_target_mu=True,\n", + " draw_walls=True,\n", + " write_text=True):\n", + " \"\"\"Visualize the continuous grid world.\n", + "\n", + " The agent will be drawn as a circle. The start and target\n", + " locations will be drawn by a cross. Walls will be drawn in\n", + " black.\n", + "\n", + " Args:\n", + " world: The continuous gridworld to visualize.\n", + " ax: The matplotlib axes to draw the gridworld.\n", + " scaling: Scale the plot by this factor.\n", + " agent_color: Color of the agent.\n", + " agent_size: Size of the agent in the world.\n", + " start_color: Color of the start marker.\n", + " draw_agent: Boolean that controls drawing agent.\n", + " draw_start_mu: Boolean that controls drawing starting position.\n", + " draw_target_mu: Boolean that controls drawing ending position.\n", + " draw_walls: Boolean that controls drawing walls.\n", + " write_text: Boolean to write text for each component being drawn.\n", + " \"\"\"\n", + " scaled_size = scaling * world.size\n", + "\n", + " # Draw the outer walls.\n", + " ax.hlines(0, 0, scaled_size)\n", + " ax.hlines(scaled_size, 0, scaled_size)\n", + " ax.vlines(scaled_size, 0, scaled_size)\n", + " ax.vlines(0, 0, scaled_size)\n", + "\n", + " for wall_pair in world.walls:\n", + " ax.plot(\n", + " [p.x * scaling for p in wall_pair],\n", + " [p.y * scaling for p in wall_pair],\n", + " color='k')\n", + "\n", + " if draw_start_mu:\n", + " # Draw the position of the start dist.\n", + " x, y = [p * scaling for p in world.mu_start_position]\n", + " ax.scatter([x], [y], marker='x', c=start_color)\n", + " if write_text: ax.text(x, y, 'Starting position.')\n", + "\n", + " if draw_target_mu:\n", + " # Draw the target position.\n", + " x, y = [p * scaling for p in dataclasses.astuple(world._end_mu)]\n", + " ax.scatter([x], [y], marker='x', c='k')\n", + " if write_text: ax.text(x, y,'Target position.')\n", + "\n", + " if draw_agent:\n", + " # Draw the position of the agent as a circle.\n", + " x, y = [scaling * p for p in world.agent_position]\n", + " agent_circle = plt.Circle((x, y), agent_size, color=agent_color)\n", + " ax.add_artist(agent_circle)\n", + " if write_text: ax.text(x, y, 'Agent position.')\n", + "\n", + " return ax\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "eSfUwRtoY_aN" + }, + "source": [ + "# Affordance specification" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "IVTYi4YVZAfR" + }, + "outputs": [], + "source": [ + "#@title Intent detection and plotting code.\n", + "\n", + "IntentName = enum.IntEnum(\n", + " 'IntentName', 'delta_pos_x delta_neg_x delta_pos_y delta_neg_y')\n", + "\n", + "class IntentStatus(enum.IntEnum):\n", + " complete = 1\n", + " incomplete = 0\n", + "\n", + "@dataclasses.dataclass(eq=False)\n", + "class Intent:\n", + " name: 'IntentName'\n", + " status: 'IntentStatus'\n", + "\n", + "\n", + "PointOrFloatTuple = Union[Point, Tuple[float, float]]\n", + "\n", + "def _get_intent_completed(\n", + " s_t: PointOrFloatTuple,\n", + " a_t: Force,\n", + " s_tp1: PointOrFloatTuple,\n", + " intent_name: IntentName,\n", + " threshold: float = 0.0):\n", + " r\"\"\"Determines if the intent was completed in the transition.\n", + "\n", + " The available intents are based on significant movement on the x-y plane:\n", + "\n", + " Intent is 1 if:\n", + " `s_tp1.{{x,y}} - s_t.{{x,y}} {{\u003e,\u003c}} threshold`\n", + " else: 0.\n", + "\n", + " Args:\n", + " s_t: The current position of the agent.\n", + " a_t: The force for the action.\n", + " s_tp1: The position after executing action of the agent.\n", + " intent_name: The intent that needs to be detected.\n", + " threshold: The significance threshold for the intent to be detected.\n", + " \"\"\"\n", + " if not isinstance(s_t, Point):\n", + " s_t = Point(*s_t)\n", + " if not isinstance(s_tp1, Point):\n", + " s_tp1 = Point(*s_tp1)\n", + " IntentName(intent_name) # Check if valid intent_name.\n", + "\n", + " diff = s_tp1 - s_t # Find the positional difference.\n", + "\n", + " if intent_name == IntentName.delta_pos_x:\n", + " if diff.x \u003e threshold:\n", + " return IntentStatus.complete\n", + " if intent_name == IntentName.delta_pos_y:\n", + " if diff.y \u003e threshold:\n", + " return IntentStatus.complete\n", + " if intent_name == IntentName.delta_neg_x:\n", + " if diff.x \u003c -threshold:\n", + " return IntentStatus.complete\n", + " if intent_name == IntentName.delta_neg_y:\n", + " if diff.y \u003c -threshold:\n", + " return IntentStatus.complete\n", + "\n", + " return IntentStatus.incomplete\n", + "\n", + "# Some simple test cases.\n", + "assert not _get_intent_completed(\n", + " Point(0, 0), None, Point(0.5, 0.0), IntentName.delta_neg_y)\n", + "assert not _get_intent_completed(\n", + " Point(0, 0), None, Point(0.5, 0.0), IntentName.delta_pos_y)\n", + "assert _get_intent_completed(\n", + " Point(0, 0), None, Point(0.5, 0.0), IntentName.delta_pos_x)\n", + "assert not _get_intent_completed(\n", + " Point(0, 0), None, Point(0.5, 0.0), IntentName.delta_neg_x)\n", + "assert _get_intent_completed(\n", + " Point(0, 0), None, Point(0.5, 0.5), IntentName.delta_pos_x)\n", + "assert _get_intent_completed(\n", + " Point(0, 0), None, Point(0.5, 0.5), IntentName.delta_pos_y)\n", + "assert not _get_intent_completed(\n", + " Point(0, 0), None, Point(0.5, 0.5), IntentName.delta_pos_y, 0.6)\n", + "assert not _get_intent_completed(\n", + " Point(0, 0), None, Point(-0.5, -0.5), IntentName.delta_neg_x, 0.6)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "OtWqqse4Zutx" + }, + "outputs": [], + "source": [ + "#@title Data Collection code\n", + "def get_transitions(\n", + " world: ContinuousWorld,\n", + " max_num_transitions: int = 500,\n", + " max_trajectory_length: Optional[int] = None,\n", + " policy: Optional[Callable[[np.ndarray], int]] = None,\n", + " intent_threshold: float = 0.0):\n", + " \"\"\"Samples transitions from an environment.\n", + "\n", + " Args:\n", + " world: The environment to collect trajectories from.\n", + " max_num_transitions: The total number of transitions to sample.\n", + " max_trajectory_length: The maximum length of the trajectory. If None\n", + " trajectories will naturally reset during episode end.\n", + " policy: The data collection policy. If None is given a random policy\n", + " is used. The policy must take a single argument, the one hot\n", + " representation of the state. If using a tensorflow function make sure to\n", + " handle batching within the policy itself.\n", + " intent_threshold: The threshold to use for the intent.\n", + "\n", + " Returns:\n", + " The transitions collected from the environment:\n", + " This is a 4-tuple containing the batch of state, action, state' and intent\n", + " target.\n", + " Human Readable transitions:\n", + " A set containing the unique transitions in the batch and if the intent was\n", + " completed.\n", + " Infos:\n", + " A list containing the info dicts sampled during the batch.\n", + " \"\"\"\n", + " max_trajectory_length = max_trajectory_length or float('inf')\n", + " trajectory = []\n", + " s_t = world.reset()\n", + " trajectory_length = 0\n", + " human_readable = set()\n", + " if policy is None:\n", + " def policy(_):\n", + " return world.get_random_force()\n", + "\n", + " infos = []\n", + "\n", + " for _ in range(max_num_transitions):\n", + " action = policy(s_t)\n", + " s_tp1, _, done, info = world.step(action)\n", + " infos.append(info)\n", + " reward = 0\n", + "\n", + " all_intents = []\n", + " intent_status_only = []\n", + " for intent_name in IntentName:\n", + " intent_status = _get_intent_completed(\n", + " s_t, action, s_tp1, intent_name, intent_threshold)\n", + " all_intents.append((intent_name, intent_status))\n", + " intent_status_only.append(intent_status)\n", + "\n", + " # Human readable vesion:\n", + " human_readable.add((s_t, action, s_tp1, tuple(all_intents)))\n", + "\n", + " # Prepare things for tensorflow:\n", + " s_t_tf = tf.constant(dataclasses.astuple(s_t), dtype=tf.float32)\n", + " s_tp1_tf = tf.constant(dataclasses.astuple(s_tp1), dtype=tf.float32)\n", + " a_t_tf = tf.constant(dataclasses.astuple(action), dtype=tf.float32)\n", + " intent_statuses_tf = tf.constant(intent_status_only)\n", + " trajectory.append((s_t_tf, a_t_tf, s_tp1_tf, reward, intent_statuses_tf))\n", + "\n", + " trajectory_length += 1\n", + " if done or trajectory_length \u003e max_trajectory_length:\n", + " s_t = world.reset()\n", + " trajectory_length = 0\n", + " else:\n", + " s_t = s_tp1\n", + "\n", + " batch = list(map(tf.stack, zip(*trajectory)))\n", + " return batch, human_readable, infos\n", + "\n", + "# Integration test.\n", + "world = ContinuousWorld(\n", + " size=2,\n", + " drift_speed=0.1,\n", + " max_action_force=2.0,\n", + " max_episode_length=100)\n", + "data, _, _ = get_transitions(world, max_num_transitions=2)\n", + "assert data is not None" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "bvYwGLyubcpa" + }, + "outputs": [], + "source": [ + "#@title Probabilistic transition model\n", + "\n", + "hidden_nodes = 32\n", + "input_size = 2\n", + "\n", + "class TransitionModel(tf.keras.Model):\n", + " def __init__(self, hidden_nodes, output_size):\n", + " super().__init__()\n", + " self._net1 = tf.keras.layers.Dense(\n", + " hidden_nodes, activation=tf.keras.activations.relu)\n", + " self._net2 = tf.keras.layers.Dense(\n", + " hidden_nodes, activation=tf.keras.activations.relu)\n", + " # Multiply by 2 for means and variances.\n", + " self._output = tf.keras.layers.Dense(2*output_size)\n", + "\n", + " def __call__(self, st, at):\n", + " net_inputs = tf.concat((st, at), axis=1)\n", + " means_logstd = self._output(self._net2(self._net1(net_inputs)))\n", + " means, logstd = tf.split(means_logstd, 2, axis=1)\n", + " std = tf.exp(logstd)\n", + " return tfp.distributions.Normal(loc=means, scale=std)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "5S6uqTVAdoKD" + }, + "outputs": [], + "source": [ + "#@title Training algorithm.\n", + "\n", + "MACHINE_EPSILON = np.finfo(float).eps.item()\n", + "\n", + "def train_networks(\n", + " world: ContinuousWorld,\n", + " model_network: Optional[tf.keras.Model] = None,\n", + " model_optimizer: Optional[tf.keras.optimizers.Optimizer] = None,\n", + " affordance_network: Optional[tf.keras.Model] = None,\n", + " affordance_optimizer: Optional[tf.keras.optimizers.Optimizer] = None,\n", + " use_affordance_to_mask_model: bool = False,\n", + " affordance_mask_threshold: float = 0.9,\n", + " num_train_steps: int =10,\n", + " fresh_data: bool = True,\n", + " max_num_transitions: int = 1,\n", + " max_trajectory_length: Optional[int] = None,\n", + " optimize_performance: bool = False,\n", + " intent_threshold: float = 1.0,\n", + " debug: bool = False,\n", + " print_losses: bool = False,\n", + " print_every: int = 10):\n", + " \"\"\"Trains an affordance network.\n", + "\n", + " Args:\n", + " world: The gridworld to collect training data from.\n", + " model_network: The network for the transition model.\n", + " model_optimizer: The optimizer for the transition model.\n", + " affordance_network: The affordance network.\n", + " affordance_optimizer: The optimizer for the affordance network.\n", + " use_affordance_to_mask_model: Uses affordances to mask the losses of the\n", + " transition model.\n", + " affordance_mask_threshold: The threshold at which the mask should be\n", + " applied.\n", + " num_train_steps: The total number of training steps.\n", + " fresh_data: Collect fresh data before every training step.\n", + " max_num_transitions: The number of rollout trajectories per training step.\n", + " max_trajectory_length: The maximum length of each trajectory. If None then\n", + " there is no artifically truncated trajectory length.\n", + " optimizer_performance: Use `tf.function` to speed up training steps.\n", + " intent_threshold: The threshold to consider as a signficant completion of\n", + " the intent.\n", + " debug: Debug mode prints out the human readable transitions and disables\n", + " tf.function.\n", + " print_losses: Prints out the losses during training.\n", + " print_every: Indicates how often things should be printed out.\n", + " \"\"\"\n", + " all_aff_losses = []\n", + " all_model_losses = []\n", + "\n", + " # Error checking to make sure the correct combinations of model/affordance\n", + " # nets and optimizers are given or none at all.\n", + " if (affordance_network is None) != (affordance_optimizer is None):\n", + " raise ValueError('Both affordance network and optimizer have to be given.')\n", + " else:\n", + " use_affordances = affordance_network is not None\n", + "\n", + " if (model_network is None) != (model_optimizer is None):\n", + " raise ValueError('Both model network and optimizer have to be given.')\n", + " else:\n", + " use_model = model_network is not None\n", + "\n", + " # At least one of affordance network or model network must be specified.\n", + " if model_network is None and (\n", + " (model_network is None) == (affordance_network is None)):\n", + " raise ValueError(\n", + " 'This code does not do anything without models or affordances.')\n", + "\n", + " # Check if both are specified if use_affordance_to_mask_model is True.\n", + " if use_affordance_to_mask_model and (\n", + " model_network is None and affordance_network is None):\n", + " raise ValueError(\n", + " 'Cannot use_affordance_to_mask model if affordance and model networks'\n", + " ' are not given!')\n", + "\n", + " # User friendly print outs indicate what is happening.\n", + " print(\n", + " f'Using model? {use_model}. Using affordances? {use_affordances}. Using'\n", + " f' affordances to mask model? {use_affordance_to_mask_model}.')\n", + "\n", + " def _train_step_affordances(trajectory):\n", + " \"\"\"Train affordance network.\"\"\"\n", + " # Note: Please make sure you understand the shapes here before editing to\n", + " # prevent accidental broadcast.\n", + " with tf.GradientTape() as tape:\n", + " s_t, a_t, _, _, intent_target = trajectory\n", + " concat_input = tf.concat((s_t, a_t), axis=1)\n", + " preds = affordance_network(concat_input)\n", + "\n", + " intent_target = tf.reshape(intent_target, (-1, 1))\n", + " unshaped_preds = preds\n", + " preds = tf.reshape(preds, (-1, 1))\n", + "\n", + " loss = tf.keras.losses.binary_crossentropy(intent_target, preds)\n", + " total_loss = tf.reduce_mean(loss)\n", + " grads = tape.gradient(total_loss, affordance_network.trainable_variables)\n", + " affordance_optimizer.apply_gradients(\n", + " zip(grads, affordance_network.trainable_variables))\n", + "\n", + " return total_loss, unshaped_preds\n", + "\n", + " def _train_step_model(trajectory, affordances):\n", + " \"\"\"Train model network.\"\"\"\n", + " with tf.GradientTape() as tape:\n", + " s_t, a_t, s_tp1, _, _ = trajectory\n", + " transition_model = model_network(s_t, a_t)\n", + " log_prob = tf.reduce_sum(transition_model.log_prob(s_tp1), -1)\n", + " num_examples = s_t.shape[0]\n", + "\n", + " if use_affordance_to_mask_model:\n", + " # Check if at least one intent is affordable.\n", + " masks_per_intent = tf.math.greater_equal(\n", + " affordances, affordance_mask_threshold)\n", + " masks_per_transition = tf.reduce_any(masks_per_intent, 1)\n", + " # Explicit reshape to prevent accidental broadcasting.\n", + " batch_size = len(s_t)\n", + " log_prob = tf.reshape(log_prob, (batch_size, 1))\n", + " masks_per_transition = tf.reshape(masks_per_transition, (batch_size, 1))\n", + " log_prob = log_prob * tf.cast(masks_per_transition, dtype=tf.float32)\n", + " # num_examples changes if there is masking so take that into account:\n", + " num_examples = tf.reduce_sum(\n", + " tf.cast(masks_per_transition, dtype=tf.float32))\n", + " num_examples = tf.math.maximum(num_examples, tf.constant(1.0))\n", + "\n", + " # Negate log_prob here because we want to maximize this via minimization.\n", + " total_loss = -tf.reduce_sum(log_prob) / num_examples\n", + " grads = tape.gradient(total_loss, model_network.trainable_variables)\n", + " model_optimizer.apply_gradients(\n", + " zip(grads, model_network.trainable_variables))\n", + "\n", + " return total_loss\n", + "\n", + " # Optimize performance using tf.function.\n", + " if optimize_performance and not debug:\n", + " _train_step_affordances = tf.function(_train_step_affordances)\n", + " _train_step_model = tf.function(_train_step_model)\n", + " print('Training step has been optimized.')\n", + "\n", + " initial_data_collected = False\n", + " infos = []\n", + " for i in range(num_train_steps):\n", + " # Step 1: Collect data.\n", + " if not initial_data_collected or fresh_data:\n", + " initial_data_collected = True\n", + " running_time = time.time()\n", + " trajectories, unique_transitions, infos_i = get_transitions(\n", + " world,\n", + " max_num_transitions=max_num_transitions,\n", + " max_trajectory_length=max_trajectory_length,\n", + " intent_threshold=intent_threshold)\n", + " collection_running_time = time.time() - running_time\n", + " if debug: print('unique_transitions:', unique_transitions)\n", + " running_time = time.time()\n", + "\n", + " # Check if the start state was updated:\n", + " infos.append(\n", + " any([info['recent_start_position_updated'] for info in infos_i]))\n", + "\n", + " # Step 2: Train affordance model.\n", + " if use_affordances:\n", + " aff_loss, affordance_predictions = _train_step_affordances(trajectories)\n", + " aff_loss = aff_loss.numpy().item()\n", + " else:\n", + " affordance_predictions = tf.constant(0.0) # Basically a none.\n", + " aff_loss = None\n", + " all_aff_losses.append(aff_loss)\n", + "\n", + " # Step 3: Train transition model and mask predictions if necessary.\n", + " if use_model:\n", + " model_loss = _train_step_model(trajectories, affordance_predictions)\n", + " model_loss = model_loss.numpy().item()\n", + " else:\n", + " model_loss = None\n", + " all_model_losses.append(model_loss)\n", + "\n", + " if debug or print_losses:\n", + " if i % print_every == 0:\n", + " train_loop_time = time.time() - running_time\n", + " print(f'i: {i}, aff_loss: {aff_loss}, model_loss: {model_loss}, '\n", + " f'collection_loop_time: {collection_running_time:.2f}, '\n", + " f'train_loop_time: {train_loop_time:.2f}')\n", + "\n", + " return all_model_losses, all_aff_losses, infos" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "6hXi7Iy0b50_" + }, + "source": [ + "# Plotting utilities" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "gMNIsCLEb3Ie" + }, + "outputs": [], + "source": [ + "#@title Learning curve smoothing\n", + "# From https://github.com/google-research/policy-learning-landscape/blob/master/analysis_tools/data_processing.py#L82\n", + "\n", + "DEFAULT_SMOOTHING_WEIGHT = 0.9\n", + "def apply_linear_smoothing(data, smoothing_weight=DEFAULT_SMOOTHING_WEIGHT):\n", + " \"\"\"Smooth curves using a exponential linear weight.\n", + "\n", + " This smoothing algorithm is the same as the one used in tensorboard.\n", + "\n", + " Args:\n", + " data: The sequence or list containing the data to smooth.\n", + " smoothing_weight: A float representing the weight to place on the moving\n", + " average.\n", + "\n", + " Returns:\n", + " A list containing the smoothed data.\n", + " \"\"\"\n", + " if len(data) == 0: # pylint: disable=g-explicit-length-test\n", + " raise ValueError('No data to smooth.') \n", + " if smoothing_weight \u003c= 0:\n", + " return data\n", + " last = data[0]\n", + " smooth_data = []\n", + " for x in data:\n", + " if not np.isfinite(last):\n", + " smooth_data.append(x)\n", + " else:\n", + " smooth_data.append(last * smoothing_weight + (1 - smoothing_weight) * x)\n", + " last = smooth_data[-1]\n", + " return smooth_data\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "T5cvk1XUb-M0" + }, + "outputs": [], + "source": [ + "#@title Intent plotting code.\n", + "\n", + "def plot_intents(\n", + " world: ContinuousWorld,\n", + " affordance_predictions: np.ndarray,\n", + " eval_action: Tuple[float, float],\n", + " num_world_ticks: int = 3,\n", + " intent_collection: IntentName = IntentName,\n", + " subplot_configuration: Tuple[int, int] = (2, 2),\n", + " figsize: Tuple[int, int] = (5, 5)):\n", + " \"\"\"Plots the intents as a heatmap.\n", + "\n", + " Given the predictions from the affordance network, we plot a heatmap for each\n", + " intent indicating how likely the `eval_action` can be used to complete it.\n", + "\n", + " Args:\n", + " world: The gridworld to use.\n", + " affordance_predictions: Predictions from the affordance classifier. The last\n", + " dimension should be of the same len as intent_collection.\n", + " eval_action: The eval action being used (For plotting the title).\n", + " num_world_ticks: The number of ticks on the axes of the world.\n", + " subplot_configuration: The arrangement of the subplots on the plot.\n", + " figsize: The size of the matplotlib figure.\n", + " \"\"\"\n", + " fig = plt.figure(figsize=figsize)\n", + "\n", + " # Since we are predicting probabilities, normalize between 0 and 1.\n", + " norm = mpl.colors.Normalize(vmin=0.0, vmax=1.0)\n", + "\n", + " # The colorbar axes.\n", + " cax = fig.add_axes([1.0, 0.1, 0.075, 0.8])\n", + "\n", + " for intent in intent_collection:\n", + " ax = fig.add_subplot(*subplot_configuration, intent)\n", + " afford_sliced = affordance_predictions[:, :, intent-1]\n", + " afford_sliced = np.transpose(afford_sliced)\n", + " ax_ = ax.imshow(afford_sliced, origin='lower')\n", + "\n", + " # This code will handle num_world_ticks=0 gracefully.\n", + " ax.set_xticks(np.linspace(0, afford_sliced.shape[0], num_world_ticks))\n", + " ax.set_yticks(np.linspace(0, afford_sliced.shape[0], num_world_ticks))\n", + " ax.set_xticklabels(\n", + " np.linspace(0, world.size, num_world_ticks), fontsize='x-small')\n", + " ax.set_yticklabels(\n", + " np.linspace(0, world.size, num_world_ticks), fontsize='x-small')\n", + "\n", + " ax.set_xlabel('x')\n", + " ax.set_ylabel('y', rotation=0)\n", + " plt.title('Intent: {}'.format(intent.__repr__()[-10:-2]))\n", + " ax_.set_norm(norm)\n", + " if intent == len(intent_collection):\n", + " plt.colorbar(ax_, cax)\n", + " cax.set_ylabel('Probability of intent completion')\n", + "\n", + " plt.suptitle('Evaluating Action: {}'.format(eval_action))\n", + " plt.tight_layout(rect=[0, 0.03, 1, 0.95])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "pHarhdiSdixz" + }, + "source": [ + "# Main Experiment (Training)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Zq9TXAyDfaP_" + }, + "outputs": [], + "source": [ + "# Storing the losses and models in a global list.\n", + "all_losses_global = []\n", + "all_models_global = []\n", + "all_affordance_global = []" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "colab_type": "code", + "id": "OP50CV_YdkLi", + "outputId": "58a4677b-5dcf-4015-d795-2e7a4b9384fc" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Experiments that will be run: [False, True]\n", + "Resetting seed to 0.\n", + "Target mu has been updated to: Point(x=0.5, y=0.5)\n", + "Using model? True. Using affordances? False. Using affordances to mask model? False.\n", + "Training step has been optimized.\n", + "Target mu has been updated to: Point(x=1.5, y=1.5)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:107: RuntimeWarning: invalid value encountered in double_scalars\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "i: 0, aff_loss: None, model_loss: 3.079850912094116, collection_loop_time: 0.20, train_loop_time: 0.28\n", + "i: 1000, aff_loss: None, model_loss: -1.3993345499038696, collection_loop_time: 0.17, train_loop_time: 0.00\n", + "Target mu has been updated to: Point(x=0.5, y=0.5)\n", + "i: 2000, aff_loss: None, model_loss: -1.4746335744857788, collection_loop_time: 0.17, train_loop_time: 0.00\n", + "Target mu has been updated to: Point(x=1.5, y=1.5)\n", + "i: 3000, aff_loss: None, model_loss: -1.926070213317871, collection_loop_time: 0.29, train_loop_time: 0.00\n", + "i: 4000, aff_loss: None, model_loss: -2.129803419113159, collection_loop_time: 0.17, train_loop_time: 0.00\n", + "Target mu has been updated to: Point(x=0.5, y=0.5)\n", + "i: 5000, aff_loss: None, model_loss: -2.212693452835083, collection_loop_time: 0.18, train_loop_time: 0.00\n", + "Target mu has been updated to: Point(x=1.5, y=1.5)\n", + "i: 6000, aff_loss: None, model_loss: -2.104724168777466, collection_loop_time: 0.18, train_loop_time: 0.00\n", + "i: 7000, aff_loss: None, model_loss: -2.153459310531616, collection_loop_time: 0.19, train_loop_time: 0.00\n", + "Target mu has been updated to: Point(x=0.5, y=0.5)\n", + "Resetting seed to 0.\n", + "Target mu has been updated to: Point(x=0.5, y=0.5)\n", + "Using model? True. Using affordances? True. Using affordances to mask model? True.\n", + "Training step has been optimized.\n", + "Target mu has been updated to: Point(x=1.5, y=1.5)\n", + "i: 0, aff_loss: 0.7048008441925049, model_loss: 3.4398090839385986, collection_loop_time: 0.17, train_loop_time: 0.66\n", + "i: 1000, aff_loss: 0.19042730331420898, model_loss: -1.7565302848815918, collection_loop_time: 0.19, train_loop_time: 0.00\n", + "Target mu has been updated to: Point(x=0.5, y=0.5)\n", + "i: 2000, aff_loss: 0.23882852494716644, model_loss: -1.7699742317199707, collection_loop_time: 0.18, train_loop_time: 0.00\n", + "Target mu has been updated to: Point(x=1.5, y=1.5)\n", + "i: 3000, aff_loss: 0.207383394241333, model_loss: -1.983075737953186, collection_loop_time: 0.17, train_loop_time: 0.00\n", + "i: 4000, aff_loss: 0.1946447789669037, model_loss: -2.051856517791748, collection_loop_time: 0.19, train_loop_time: 0.00\n", + "Target mu has been updated to: Point(x=0.5, y=0.5)\n", + "i: 5000, aff_loss: 0.21516098082065582, model_loss: -1.9648913145065308, collection_loop_time: 0.17, train_loop_time: 0.00\n", + "Target mu has been updated to: Point(x=1.5, y=1.5)\n", + "i: 6000, aff_loss: 0.20555077493190765, model_loss: -1.8827911615371704, collection_loop_time: 0.18, train_loop_time: 0.00\n", + "i: 7000, aff_loss: 0.17789226770401, model_loss: -2.0646157264709473, collection_loop_time: 0.28, train_loop_time: 0.00\n", + "Target mu has been updated to: Point(x=0.5, y=0.5)\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUIAAAEvCAYAAAAwx8gYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAdy0lEQVR4nO3dfZQV1Z3u8e9jiwq2y6DgKyAaMQYVWmwxoxBajYo6kajJXIhmNKPDjRMzzk3GpUYvjuSuxJusuXFlNDGMYZxkjOJ7mBscg1GuooPSYCsvxojERNAJCAZFUAR/94+q1kPTL6fp6j6nez+ftc7qOrte9q4+zUNV7VO1FRGYmaVsl0o3wMys0hyEZpY8B6GZJc9BaGbJcxCaWfIchGaWvF0r3YDWDBo0KIYPH17pZphZH7No0aI3ImJwy/KqDMLhw4fT2NhY6WaYWR8j6fetlfvU2MyS5yA0s+Q5CM0seQ5CM0ueg9DMkucgNLPkOQjNLHkdBqGkoZIek7Rc0jJJV7SyjCT9QNIKSc9LGlMy7yJJL+Wvi4reATOzrirnC9VbgW9ExGJJewGLJM2NiOUly5wJjMhfJwA/Ak6QtA9wPVAPRL7u7Ih4s9C9MDPrgg6DMCJeB17Pp9+W9AJwMFAahJOAn0b2uOsFkj4m6UCgAZgbEesBJM0FJgJ3FrkTDQ0NRW7OeokVK1YAcPjhh1e4JVYJ8+bNK2xbnbpGKGk4cCzwdItZBwOvlrxflZe1Vd7atqdKapTUuHbt2rLb1NDQQFNTU9nLW9+xceNGNm7cWOlmWAU0NTUVegBU9r3GkmqB+4C/i4i3CmtBLiJmADMA6uvrOzWQSl1dXaH/O1jv0PwPwZ99eoo+CyzriFBSP7IQvCMi7m9lkdXA0JL3Q/KytsrNzKpGOb3GAn4CvBAR/6eNxWYDf5n3Hn8K2JBfW3wYOF3SQEkDgdPzMjOzqlHOqfFJwJeAJZKaL8Z9ExgGEBG3AnOAs4AVwCbgy/m89ZK+BSzM15ve3HFiZlYtyuk1ng+og2UC+Gob82YCM3eqdWZmPcB3lphZ8hyEZpY8B6GZJc9BaGbJcxCaWfIchGaWPAehmSXPQWhmyXMQmlnyHIRmljwHoZklz0FoZslzEJpZ8hyEZpY8B6GZJc9BaGbJcxCaWfIchGaWPAehmSXPQWhmyXMQmlnyOhzFTtJM4M+BNRFxdCvzrwQuKNneJ4HB+VCerwBvA9uArRFRX1TDzcyKUs4R4e3AxLZmRsT3IqIuIuqAa4D/12Ls4pPz+Q5BM6tKHQZhRDwOlDso+xTgzi61yMyshxV2jVDSALIjx/tKigP4laRFkqYWVZeZWZE6vEbYCZ8FnmxxWjwuIlZL2g+YK+k3+RHmDvKgnAowbNiwAptlZta+InuNJ9PitDgiVuc/1wAPAGPbWjkiZkREfUTUDx48uMBmmZm1r5AglLQ3MAH4RUnZnpL2ap4GTgeWFlGfmVmRyvn6zJ1AAzBI0irgeqAfQETcmi92LvCriHinZNX9gQckNdfz84j4j+KabmZWjA6DMCKmlLHM7WRfsyktWwmM3tmGmZn1FN9ZYmbJcxCaWfIchGaWPAehmSXPQWhmyXMQmlnyHIRmljwHoZklz0FoZslzEJpZ8hyEZpY8B6GZJc9BaGbJcxCaWfIchGaWPAehmSXPQWhmyXMQmlnyHIRmljwHoZklz0FoZslzEJpZ8joMQkkzJa2R1Org7JIaJG2Q1JS/ppXMmyjpRUkrJF1dZMPNzIpSzhHh7cDEDpZ5IiLq8td0AEk1wC3AmcBIYIqkkV1prJlZd+gwCCPicWD9Tmx7LLAiIlZGxBbgLmDSTmzHzKxbFXWN8M8kPSfpIUlH5WUHA6+WLLMqLzMzqyq7FrCNxcAhEbFR0lnAg8CIzm5E0lRgKsCwYcMKaJaZWXm6fEQYEW9FxMZ8eg7QT9IgYDUwtGTRIXlZW9uZERH1EVE/ePDgrjbLzKxsXQ5CSQdIUj49Nt/mOmAhMELSoZJ2AyYDs7tan5lZ0To8NZZ0J9AADJK0Crge6AcQEbcCnwcuk7QV2AxMjogAtkq6HHgYqAFmRsSybtkLM7Mu6DAII2JKB/NvBm5uY94cYM7ONc3MrGf4zhIzS56D0MyS5yA0s+Q5CM0seQ5CM0ueg9DMkucgNLPkOQjNLHkOQjNLnoPQzJLnIDSz5DkIzSx5DkIzS56D0MyS5yA0s+Q5CM0seQ5CM0ueg9DMkucgNLPkOQjNLHkOQjNLnoPQzJLXYRBKmilpjaSlbcy/QNLzkpZIekrS6JJ5r+TlTZIai2y4mVlRyjkivB2Y2M783wETIuIY4FvAjBbzT46Iuoio37kmmpl1r3IGeH9c0vB25j9V8nYBMKTrzTIz6zlFXyO8BHio5H0Av5K0SNLUgusyMytEh0eE5ZJ0MlkQjispHhcRqyXtB8yV9JuIeLyN9acCUwGGDRtWVLPMzDpUyBGhpFHAbcCkiFjXXB4Rq/Ofa4AHgLFtbSMiZkREfUTUDx48uIhmmZmVpctBKGkYcD/wpYj4bUn5npL2ap4GTgda7Xk2M6ukDk+NJd0JNACDJK0Crgf6AUTErcA0YF/gh5IAtuY9xPsDD+RluwI/j4j/6IZ9MDPrknJ6jad0MP9S4NJWylcCo3dcw8ysuvjOEjNLnoPQzJLnIDSz5DkIzSx5DkIzS56D0MyS5yA0s+Q5CM0seQ5CM0ueg9DMkucgNLPkOQjNLHkOQjNLnoPQzJLnIDSz5DkIzSx5DkIzS56D0MyS5yA0s+Q5CM0seQ5CM0ueg9DMkldWEEqaKWmNpFYHaFfmB5JWSHpe0piSeRdJeil/XVRUw83MilLuEeHtwMR25p8JjMhfU4EfAUjah2xA+BOAscD1kgbubGPNzLpDWUEYEY8D69tZZBLw08gsAD4m6UDgDGBuRKyPiDeBubQfqGZmPa6oa4QHA6+WvF+Vl7VVbmZWNaqms0TSVEmNkhrXrl1b6eaYWUKKCsLVwNCS90PysrbKdxARMyKiPiLqBw8eXFCzzMw6VlQQzgb+Mu89/hSwISJeBx4GTpc0MO8kOT0vMzOrGruWs5CkO4EGYJCkVWQ9wf0AIuJWYA5wFrAC2AR8OZ+3XtK3gIX5pqZHRHudLmZmPa6sIIyIKR3MD+CrbcybCczsfNPMzHpG1XSWmJlVioPQzJLnIDSz5DkIzSx5DkIzS56D0MyS5yA0s+Q5CM0seQ5CM0ueg9DMkucgNLPkOQjNLHkOQjNLnoPQzJLnIDSz5DkIzSx5DkIzS56D0MyS5yA0s+Q5CM0seQ5CM0ueg9DMkldWEEqaKOlFSSskXd3K/O9Laspfv5X0p5J520rmzS6y8WZmRehwXGNJNcAtwGnAKmChpNkRsbx5mYj4HyXLfw04tmQTmyOirrgmm5kVq5wjwrHAiohYGRFbgLuASe0sPwW4s4jGmZn1hHKC8GDg1ZL3q/KyHUg6BDgUeLSkeA9JjZIWSPrcTrfUzKybdHhq3EmTgXsjYltJ2SERsVrSYcCjkpZExMstV5Q0FZgKMGzYsIKbZR3avBmWLIGlS+Htt+HddyEC9tgDamth5EgYPRr23LPSLTUrXDlBuBoYWvJ+SF7WmsnAV0sLImJ1/nOlpHlk1w93CMKImAHMAKivr48y2mVd8eabMGsWPPIILFwIr70GAwbAtm2wdWv2AqipgV13zV6bNsEBB8CYMXDqqTB5Muy3X2X3w6wA5QThQmCEpEPJAnAy8MWWC0k6EhgI/GdJ2UBgU0S8J2kQcBLw3SIabjshAhYsgJtugtmzYZddsnBr9tZbO66zbRts2fLR+1WrstfcuXDVVfCZz8A3vgETJoDU/ftg1g06vEYYEVuBy4GHgReAuyNimaTpks4pWXQycFdElB7NfRJolPQc8BhwY2lvs/WQCLjvPvj4x+G00+Dee7NT39IQ7KzNm7Nt/PKX8NnPwtCh8C//ktVl1suUdY0wIuYAc1qUTWvx/h9aWe8p4JgutM+66vXX4eKLYf78rgVfWyJg48bs9bWvwW23wR13wPDhxddl1k18Z0lfFQE/+QkccQQ89lj3hGBL77wDTz8NRx2VnX5/8EH312lWAAdhX/Tuu3DWWXDFFdmR2vvv91zd27ZloXvddTBuXNYDbVblHIR9zYYNWQDNm5cdoVXKO+/A4sVw/PGwZk3l2mFWBgdhX/L221kILl2aHRVW2nvvwcqV8KlPwRtvVLo1Zm1yEPYV770Hp5wCL72UTVeL99+H1at9mmxVzUHYV1x9NSxbVl0h2GzLFnjlFbjsskq3xKxVDsK+4Mkn4cc/zr7bV63eew8eeCD73qFZlXEQ9nYbN8IXvlDdIdhs0ya48EJYt67SLTHbjoOwt7viiuy+4d5i06bsC95mVcRB2Jv94Q/w859XRw9xubZsgV//OnvSjVmVcBD2Zj/4Qe+8e2PLFvje9yrdCrMPOQh7q3ffzTpISp8M01ts2wb33NO7TumtT3MQ9lZ33VXpFnTNLrtkD2gwqwIOwt7qppuyHuMe8CAg4DdFbnTTJvinfwKgqamJOXPmdLBC11x66aUsX549Ae7b3/72dvNOPPHEbq3bqp+DsDd6/31Y3nOPdbwTGEc3jMj1X/8FGzb0SBDedtttjBw5EtgxCJ966qlurduqn4OwN1q2LBtLpAdsBOYDPyEbvrDZB8DfAEeSjfN6FnBvPm8RMAE4DjgDeD0vbwCuIhsW8Qjgid12Y8uCBUybNo1Zs2ZRV1fHrFmztqv/9ttvZ9KkSTQ0NDBixAhuuOGGD+e9+uqrLFy4kKOPPpqbbroJgHfeeYezzz6b0aNHc/TRR3+4vYaGBhobG7n66qvZvHkzdXV1XHDBBQDU1tYCEBFceeWVHH300RxzzDEfrjtv3jwaGhr4/Oc/z5FHHskFF1xA+AG0fUrRgzdZT2hszDocesAvgIlkwbUvWcgdB9wPvAIsB9aQPYr8r4D3ga/l6w0GZgHXAjPz7W0FniF7yu8NmzbxSFMT06dPp7GxkZtvvrnVNjzzzDMsXbqUAQMGcPzxx3P22WcjiT/+8Y+MGTOGhx56iBNOOIEJEyawcuVKDjroIH6Z38GyYcOG7bZ14403cvPNN9PU1LRDPffffz9NTU0899xzvPHGGxx//PF8+tOfBuDZZ59l2bJlHHTQQZx00kk8+eSTjBs3bmd+pVaFfETYGz3+eM88aJXsdHhyPj2Zj06P5wNfIPsDOgA4OS9/EVhKdpRYB/wvsvFfm52X/zwOeCUie2hsB0477TT23Xdf+vfvz3nnncf8+fOZP38+++67LzU1NdTW1nLeeefxxBNPcMwxxzB37lyuuuoqnnjiCfbee++y93X+/PlMmTKFmpoa9t9/fyZMmMDChQsBGDt2LEOGDGGXXXahrq6OV155peztWvXzEWFv1ENfRl5PNkD1ErLOkm35z/a+ARjAUZSM4NXC7vnPGrKjQ37TcReMWgwK1fJ9qSOOOILFixczZ84crrvuOk499VSmTZvW5vLl2n333T+crqmpYWvzKH/WJ/iIsDfqoaPBe4EvAb8nOw1+FTgUeIJsOML7yK4V/hGYl6/zCWAtHwXh+8Cy9irZvJm99tqLt9t5RNfcuXNZv349mzdv5sEHH+Skk05i/PjxrFu3jm3btvHOO+/wwAMPMH78eF577TUGDBjAhRdeyJVXXsnixYt32F6/fv14v5Wndo8fP55Zs2axbds21q5dy+OPP87YsWPba731EQ7C3qiHHrV1J3Bui7Lz8/LzyQa4HglcCIwB9gZ2IwvQq4DRZKfH7fbJbtnCySefzPLly1vtLIHstPT8889n1KhRnH/++dTX1zNmzBj2339/Fi9ezAknnMCll17Ksccey5IlSxg7dix1dXXccMMNXHfddTtsb+rUqYwaNerDzpJm5557LqNGjWL06NGccsopfPe73+WAAw5o93c0bdo0Zs+e3e4yVv1Ujb1f9fX10djYWNayDQ0NQNazl4zDD4eXX650K9gI1ALryHqCnyS7Xtgp++7b7tOrb7/99jY7UpL87A3Y+c9e0qKIqG9ZXtYRoaSJkl6UtELS1a3Mv1jSWklN+evSknkXSXopf13UqVZb60quV1XSn5Md8Y0H/ic7EYJQNftiaeuws0RSDXALWUfgKmChpNmtDNQ+KyIub7HuPsD1QD3ZdfRF+bq+ybQrDjigR79Q3ZZ5RWxk0KB2Z1988cVc7Md2WTcr54hwLLAiIlZGxBay79VOKnP7ZwBzI2J9Hn5zyb6WZl0xYQLs2gc6/CUYP77SrTArKwgPJuswbLYqL2vpfEnPS7pX0tBOrmudMXYsDBhQ6VZ0XW0t+D5fqwJF9Rr/OzA8IkaRHfX9a2c3IGmqpEZJjWvXri2oWX1UfX3vehhrWz74IBv32KzCygnC1cDQkvdD8rIPRcS6iGj+TsdtZDcOlLVuyTZmRER9RNQPHjy4nLana9Ag6MQdE1Xrgw/g4x+vdCvMygrChcAISYdK2o3sTqvtvjgl6cCSt+cAL+TTDwOnSxooaSBwel5mXXXuub37OqEEZ5yRPZfQrMI6/CuMiK3A5WQB9gJwd0QskzRd0jn5Yn8raZmk54C/BS7O110PfIssTBcC0/My66qvfx369at0K3begAFw1VWVboUZUOa9xhExh+yBIaVl00qmrwGuaWPdmXz08BEryic+AcceC731WXpDhsAJJ1S6FWaAb7Hr3a65Bvbaq9Kt6LzaWvjmN7PTY7Mq4CDszc48s3d2mvTrB3/xF5VuhdmHHIS9WU1NNohT//6Vbkn5BgyAn/2sx56wbVYOB2Fvd9JJ8JWv9I4vWO+xB5x3Hpx9dqVbYrYdB2Ff8J3vZPcfV7u994Yf/rDSrTDbgYOwL9h9d7jvvuo+KuzfH+6+u3d27lif5yDsK+rq4J57qvN6Yf/+2WDu+UBIZtXGQdiXnHUW/Nu/VVcY9u+fDeT+xS9WuiVmbXIQ9jXnnQe/+EV1nCb3758F8yWXVLolZu1yEPZFp50GTz6ZPdCgEoHYvz8cdBA88kgWzGZVzkHYV9XVwQsvwN//fRZMPXUXR//+cNllsGKFnzVovYaDsC/r1w9uuAEaG+GYY7Jb27pLbW02qNQTT8A//mN1Xac064CDMAUjR8Kzz2ZfsTnttOzrNkXc2bHbbtl2Pv1puOOO7Aj0uOM6Xs+syvTiB9pZp+yyC5x+evZ67TX453+GH/0I/vSnLMw2bYJWBj3fzq67wp57ZuMqDxgAf/3X2WnwIYf0zD6YdRMHYYoOOgiuvz57rVmTHS0uWpSd1i5Z8lEofvDBR0ePI0dmR3719dnjvw48sON6zHoJB2Hq9tsve1L0GWdUuiVmFeNrhGaWPAehmSXPQWhmyXMQmlnyHIRmljwHoZklr6wglDRR0ouSVki6upX5X5e0XNLzkn4t6ZCSedskNeWv2S3XNTOrtA6/RyipBrgFOA1YBSyUNDsilpcs9ixQHxGbJF0GfBf4b/m8zRFRV3C7zcwKU84R4VhgRUSsjIgtwF3ApNIFIuKxiNiUv10ADCm2mWZm3aecIDwYeLXk/aq8rC2XAA+VvN9DUqOkBZI+txNtNDPrVoXeYifpQqAemFBSfEhErJZ0GPCopCUR8XIr604FpgIMGzasyGaZmbWrnCPC1cDQkvdD8rLtSPoMcC1wTkS811weEavznyuBecCxrVUSETMioj4i6gcPHlz2DpiZdVU5QbgQGCHpUEm7AZOB7Xp/JR0L/JgsBNeUlA+UtHs+PQg4CSjtZDEzq7gOT40jYquky4GHgRpgZkQskzQdaIyI2cD3gFrgHmWPhP9DRJwDfBL4saQPyEL3xha9zWZmFVfWNcKImAPMaVE2rWT6M22s9xRwTFcaaGbW3XxniZklz0FoZslzEJpZ8hyEZpY8B6GZJc9BaGbJcxCaWfIchGaWPAehmSXPQWhmyXMQmlnyHIRmljwHoZklz0FoZslzEJpZ8hyEZpY8B6GZJc9BaGbJcxCaWfIchGaWPAehmSXPQWhmySsrCCVNlPSipBWSrm5l/u6SZuXzn5Y0vGTeNXn5i5LOKK7pZmbF6DAIJdUAtwBnAiOBKZJGtljsEuDNiDgc+D7wv/N1RwKTgaOAicAP8+2ZmVWNcgZ4HwusiIiVAJLuAiYBy0uWmQT8Qz59L3CzJOXld0XEe8DvJK3It/efxTQ/09TURENDQ5GbtF6gqakJwJ99gpqamqirqytse+WcGh8MvFryflVe1uoyEbEV2ADsW+a6AEiaKqlRUuPatWvLaz0wb968Qn8h1nvU1tZSW1tb6WZYBdTV1TFv3rzCtlfOEWGPiIgZwAyA+vr66My6Rf5CzCw95RwRrgaGlrwfkpe1uoykXYG9gXVlrmtmVlHlBOFCYISkQyXtRtb5MbvFMrOBi/LpzwOPRkTk5ZPzXuVDgRHAM8U03cysGB2eGkfEVkmXAw8DNcDMiFgmaTrQGBGzgZ8AP8s7Q9aThSX5cneTdaxsBb4aEdu6aV/MzHaKsgO36lJfXx+NjY2VboaZ9TGSFkVEfcty31liZslzEJpZ8hyEZpY8B6GZJc9BaGbJcxCaWfIchGaWvKr8HqGktcDvO7HKIOCNbmqO66/eulOvP+V939n6D4mIwS0LqzIIO0tSY2tfknT9fbvu1OtPed+Lrt+nxmaWPAehmSWvrwThDNefZN2p15/yvhdaf5+4Rmhm1hV95YjQzGynVX0QVnIo0TLq/rqk5ZKel/RrSYeUzNsmqSl/tXyQbVH1XyxpbUk9l5bMu0jSS/nropbrFlT/90vq/q2kP5XM69L+S5opaY2kpW3Ml6Qf5G17XtKYknlF7HtH9V+Q17tE0lOSRpfMeyUvb5LU6efJlVF3g6QNJb/faSXz2v3MCqr/ypK6l+af9T75vC7te76NoZIey/9tLZN0RSvLFPv5R0TVvsgeBPsycBiwG/AcMLLFMn8D3JpPTwZm5dMj8+V3Bw7Nt1NTcN0nAwPy6cua687fb+yBfb8YuLmVdfcBVuY/B+bTA4uuv8XyXyN7aG9R+/9pYAywtI35ZwEPAQI+BTxd1L6XWf+JzdslG+r26ZJ5rwCDunHfG4D/29XPbGfrb7HsZ8meSF/IvufbOBAYk0/vBfy2lb/9Qj//aj8i/HAo0YjYAjQPJVpqEvCv+fS9wKnS9kOJRsTvgOahRAurOyIei4hN+dsFZGOyFKWcfW/LGcDciFgfEW8Cc8nGle7O+qcAd3ayjjZFxONkTztvyyTgp5FZAHxM0oEUs+8d1h8RT+Xbh4I/+zL2vS1d+ZvZ2foL/dzz+l+PiMX59NvAC+w4+mWhn3+1B2GPDCXahbpLXUL2P1SzPZQNT7pA0uc6UW9n6z8/PzW4V1LzQFld3fdObSO/JHAo8GhJcVf3f2fbV8S+d1bLzz6AX0laJGlqN9X5Z5Kek/SQpKPysh7dd0kDyELmvpLiQvdd2aWuY4GnW8wq9POvmuE8ezNJFwL1wISS4kMiYrWkw4BHJS2JiJcLrvrfgTsj4j1J/53syPiUgusox2Tg3th+PJqe2P+Kk3QyWRCOKykel+/7fsBcSb/Jj7KKspjs97tR0lnAg2QDo/W0zwJPRkTp0WNh+y6plixk/y4i3iqgvW2q9iPCSg4lWtb6kj4DXAucExHvNZdHxOr850pgHtn/ap3RYf0Rsa6kztuA4zrT9q7WX2IyLU6PCtj/nW1fjw0hK2kU2e99UkSsay4v2fc1wAN07pJMhyLirYjYmE/PAfpJGkTPD5/b3ufepX2X1I8sBO+IiPtbWaTYz78rFzW7+0V2xLqS7LSr+eLvUS2W+Srbd5bcnU8fxfadJSvpXGdJOXUfS3ZxekSL8oHA7vn0IOAlOnnRusz6DyyZPhdYEB9dMP5d3o6B+fQ+RdefL3ck2QVyFbn/+brDabvD4Gy2v1j+TFH7Xmb9w8iuO5/YonxPYK+S6aeAiQXXfUDz75ssaP6Q/x7K+sy6Wn8+f2+y64h7dsO+C/gpcFM7yxT6+Xf6F9TTL7Leod+SBc61edl0siMwgD2Ae/I/ymeAw0rWvTZf70XgzG6o+xHgj0BT/pqdl58ILMn/EJcAl3TTvn8HWJbX8xhwZMm6f5X/TlYAX+6O+vP3/wDc2GK9Lu8/2ZHG68D7ZNd5LgG+Anyl5B/LLXnblgD1Be97R/XfBrxZ8tk35uWH5fv9XP7ZXNsNdV9e8rkvoCSMW/vMiq4/X+Ziss7I0vW6vO/5dsaRXWt8vuT3e1Z3fv6+s8TMklft1wjNzLqdg9DMkucgNLPkOQjNLHkOQjNLnoPQzJLnIDSz5DkIzSx5/x/wgbu3KUQYlAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "\u003cFigure size 360x360 with 1 Axes\u003e" + ] + }, + "metadata": { + "needs_background": "light", + "tags": [] + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUIAAAEvCAYAAAAwx8gYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAdy0lEQVR4nO3dfZQV1Z3u8e9jiwq2y6DgKyAaMQYVWmwxoxBajYo6kajJXIhmNKPDjRMzzk3GpUYvjuSuxJusuXFlNDGMYZxkjOJ7mBscg1GuooPSYCsvxojERNAJCAZFUAR/94+q1kPTL6fp6j6nez+ftc7qOrte9q4+zUNV7VO1FRGYmaVsl0o3wMys0hyEZpY8B6GZJc9BaGbJcxCaWfIchGaWvF0r3YDWDBo0KIYPH17pZphZH7No0aI3ImJwy/KqDMLhw4fT2NhY6WaYWR8j6fetlfvU2MyS5yA0s+Q5CM0seQ5CM0ueg9DMkucgNLPkOQjNLHkdBqGkoZIek7Rc0jJJV7SyjCT9QNIKSc9LGlMy7yJJL+Wvi4reATOzrirnC9VbgW9ExGJJewGLJM2NiOUly5wJjMhfJwA/Ak6QtA9wPVAPRL7u7Ih4s9C9MDPrgg6DMCJeB17Pp9+W9AJwMFAahJOAn0b2uOsFkj4m6UCgAZgbEesBJM0FJgJ3FrkTDQ0NRW7OeokVK1YAcPjhh1e4JVYJ8+bNK2xbnbpGKGk4cCzwdItZBwOvlrxflZe1Vd7atqdKapTUuHbt2rLb1NDQQFNTU9nLW9+xceNGNm7cWOlmWAU0NTUVegBU9r3GkmqB+4C/i4i3CmtBLiJmADMA6uvrOzWQSl1dXaH/O1jv0PwPwZ99eoo+CyzriFBSP7IQvCMi7m9lkdXA0JL3Q/KytsrNzKpGOb3GAn4CvBAR/6eNxWYDf5n3Hn8K2JBfW3wYOF3SQEkDgdPzMjOzqlHOqfFJwJeAJZKaL8Z9ExgGEBG3AnOAs4AVwCbgy/m89ZK+BSzM15ve3HFiZlYtyuk1ng+og2UC+Gob82YCM3eqdWZmPcB3lphZ8hyEZpY8B6GZJc9BaGbJcxCaWfIchGaWPAehmSXPQWhmyXMQmlnyHIRmljwHoZklz0FoZslzEJpZ8hyEZpY8B6GZJc9BaGbJcxCaWfIchGaWPAehmSXPQWhmyXMQmlnyOhzFTtJM4M+BNRFxdCvzrwQuKNneJ4HB+VCerwBvA9uArRFRX1TDzcyKUs4R4e3AxLZmRsT3IqIuIuqAa4D/12Ls4pPz+Q5BM6tKHQZhRDwOlDso+xTgzi61yMyshxV2jVDSALIjx/tKigP4laRFkqYWVZeZWZE6vEbYCZ8FnmxxWjwuIlZL2g+YK+k3+RHmDvKgnAowbNiwAptlZta+InuNJ9PitDgiVuc/1wAPAGPbWjkiZkREfUTUDx48uMBmmZm1r5AglLQ3MAH4RUnZnpL2ap4GTgeWFlGfmVmRyvn6zJ1AAzBI0irgeqAfQETcmi92LvCriHinZNX9gQckNdfz84j4j+KabmZWjA6DMCKmlLHM7WRfsyktWwmM3tmGmZn1FN9ZYmbJcxCaWfIchGaWPAehmSXPQWhmyXMQmlnyHIRmljwHoZklz0FoZslzEJpZ8hyEZpY8B6GZJc9BaGbJcxCaWfIchGaWPAehmSXPQWhmyXMQmlnyHIRmljwHoZklz0FoZslzEJpZ8joMQkkzJa2R1Org7JIaJG2Q1JS/ppXMmyjpRUkrJF1dZMPNzIpSzhHh7cDEDpZ5IiLq8td0AEk1wC3AmcBIYIqkkV1prJlZd+gwCCPicWD9Tmx7LLAiIlZGxBbgLmDSTmzHzKxbFXWN8M8kPSfpIUlH5WUHA6+WLLMqLzMzqyq7FrCNxcAhEbFR0lnAg8CIzm5E0lRgKsCwYcMKaJaZWXm6fEQYEW9FxMZ8eg7QT9IgYDUwtGTRIXlZW9uZERH1EVE/ePDgrjbLzKxsXQ5CSQdIUj49Nt/mOmAhMELSoZJ2AyYDs7tan5lZ0To8NZZ0J9AADJK0Crge6AcQEbcCnwcuk7QV2AxMjogAtkq6HHgYqAFmRsSybtkLM7Mu6DAII2JKB/NvBm5uY94cYM7ONc3MrGf4zhIzS56D0MyS5yA0s+Q5CM0seQ5CM0ueg9DMkucgNLPkOQjNLHkOQjNLnoPQzJLnIDSz5DkIzSx5DkIzS56D0MyS5yA0s+Q5CM0seQ5CM0ueg9DMkucgNLPkOQjNLHkOQjNLnoPQzJLXYRBKmilpjaSlbcy/QNLzkpZIekrS6JJ5r+TlTZIai2y4mVlRyjkivB2Y2M783wETIuIY4FvAjBbzT46Iuoio37kmmpl1r3IGeH9c0vB25j9V8nYBMKTrzTIz6zlFXyO8BHio5H0Av5K0SNLUgusyMytEh0eE5ZJ0MlkQjispHhcRqyXtB8yV9JuIeLyN9acCUwGGDRtWVLPMzDpUyBGhpFHAbcCkiFjXXB4Rq/Ofa4AHgLFtbSMiZkREfUTUDx48uIhmmZmVpctBKGkYcD/wpYj4bUn5npL2ap4GTgda7Xk2M6ukDk+NJd0JNACDJK0Crgf6AUTErcA0YF/gh5IAtuY9xPsDD+RluwI/j4j/6IZ9MDPrknJ6jad0MP9S4NJWylcCo3dcw8ysuvjOEjNLnoPQzJLnIDSz5DkIzSx5DkIzS56D0MyS5yA0s+Q5CM0seQ5CM0ueg9DMkucgNLPkOQjNLHkOQjNLnoPQzJLnIDSz5DkIzSx5DkIzS56D0MyS5yA0s+Q5CM0seQ5CM0ueg9DMkldWEEqaKWmNpFYHaFfmB5JWSHpe0piSeRdJeil/XVRUw83MilLuEeHtwMR25p8JjMhfU4EfAUjah2xA+BOAscD1kgbubGPNzLpDWUEYEY8D69tZZBLw08gsAD4m6UDgDGBuRKyPiDeBubQfqGZmPa6oa4QHA6+WvF+Vl7VVbmZWNaqms0TSVEmNkhrXrl1b6eaYWUKKCsLVwNCS90PysrbKdxARMyKiPiLqBw8eXFCzzMw6VlQQzgb+Mu89/hSwISJeBx4GTpc0MO8kOT0vMzOrGruWs5CkO4EGYJCkVWQ9wf0AIuJWYA5wFrAC2AR8OZ+3XtK3gIX5pqZHRHudLmZmPa6sIIyIKR3MD+CrbcybCczsfNPMzHpG1XSWmJlVioPQzJLnIDSz5DkIzSx5DkIzS56D0MyS5yA0s+Q5CM0seQ5CM0ueg9DMkucgNLPkOQjNLHkOQjNLnoPQzJLnIDSz5DkIzSx5DkIzS56D0MyS5yA0s+Q5CM0seQ5CM0ueg9DMkldWEEqaKOlFSSskXd3K/O9Laspfv5X0p5J520rmzS6y8WZmRehwXGNJNcAtwGnAKmChpNkRsbx5mYj4HyXLfw04tmQTmyOirrgmm5kVq5wjwrHAiohYGRFbgLuASe0sPwW4s4jGmZn1hHKC8GDg1ZL3q/KyHUg6BDgUeLSkeA9JjZIWSPrcTrfUzKybdHhq3EmTgXsjYltJ2SERsVrSYcCjkpZExMstV5Q0FZgKMGzYsIKbZR3avBmWLIGlS+Htt+HddyEC9tgDamth5EgYPRr23LPSLTUrXDlBuBoYWvJ+SF7WmsnAV0sLImJ1/nOlpHlk1w93CMKImAHMAKivr48y2mVd8eabMGsWPPIILFwIr70GAwbAtm2wdWv2AqipgV13zV6bNsEBB8CYMXDqqTB5Muy3X2X3w6wA5QThQmCEpEPJAnAy8MWWC0k6EhgI/GdJ2UBgU0S8J2kQcBLw3SIabjshAhYsgJtugtmzYZddsnBr9tZbO66zbRts2fLR+1WrstfcuXDVVfCZz8A3vgETJoDU/ftg1g06vEYYEVuBy4GHgReAuyNimaTpks4pWXQycFdElB7NfRJolPQc8BhwY2lvs/WQCLjvPvj4x+G00+Dee7NT39IQ7KzNm7Nt/PKX8NnPwtCh8C//ktVl1suUdY0wIuYAc1qUTWvx/h9aWe8p4JgutM+66vXX4eKLYf78rgVfWyJg48bs9bWvwW23wR13wPDhxddl1k18Z0lfFQE/+QkccQQ89lj3hGBL77wDTz8NRx2VnX5/8EH312lWAAdhX/Tuu3DWWXDFFdmR2vvv91zd27ZloXvddTBuXNYDbVblHIR9zYYNWQDNm5cdoVXKO+/A4sVw/PGwZk3l2mFWBgdhX/L221kILl2aHRVW2nvvwcqV8KlPwRtvVLo1Zm1yEPYV770Hp5wCL72UTVeL99+H1at9mmxVzUHYV1x9NSxbVl0h2GzLFnjlFbjsskq3xKxVDsK+4Mkn4cc/zr7bV63eew8eeCD73qFZlXEQ9nYbN8IXvlDdIdhs0ya48EJYt67SLTHbjoOwt7viiuy+4d5i06bsC95mVcRB2Jv94Q/w859XRw9xubZsgV//OnvSjVmVcBD2Zj/4Qe+8e2PLFvje9yrdCrMPOQh7q3ffzTpISp8M01ts2wb33NO7TumtT3MQ9lZ33VXpFnTNLrtkD2gwqwIOwt7qppuyHuMe8CAg4DdFbnTTJvinfwKgqamJOXPmdLBC11x66aUsX549Ae7b3/72dvNOPPHEbq3bqp+DsDd6/31Y3nOPdbwTGEc3jMj1X/8FGzb0SBDedtttjBw5EtgxCJ966qlurduqn4OwN1q2LBtLpAdsBOYDPyEbvrDZB8DfAEeSjfN6FnBvPm8RMAE4DjgDeD0vbwCuIhsW8Qjgid12Y8uCBUybNo1Zs2ZRV1fHrFmztqv/9ttvZ9KkSTQ0NDBixAhuuOGGD+e9+uqrLFy4kKOPPpqbbroJgHfeeYezzz6b0aNHc/TRR3+4vYaGBhobG7n66qvZvHkzdXV1XHDBBQDU1tYCEBFceeWVHH300RxzzDEfrjtv3jwaGhr4/Oc/z5FHHskFF1xA+AG0fUrRgzdZT2hszDocesAvgIlkwbUvWcgdB9wPvAIsB9aQPYr8r4D3ga/l6w0GZgHXAjPz7W0FniF7yu8NmzbxSFMT06dPp7GxkZtvvrnVNjzzzDMsXbqUAQMGcPzxx3P22WcjiT/+8Y+MGTOGhx56iBNOOIEJEyawcuVKDjroIH6Z38GyYcOG7bZ14403cvPNN9PU1LRDPffffz9NTU0899xzvPHGGxx//PF8+tOfBuDZZ59l2bJlHHTQQZx00kk8+eSTjBs3bmd+pVaFfETYGz3+eM88aJXsdHhyPj2Zj06P5wNfIPsDOgA4OS9/EVhKdpRYB/wvsvFfm52X/zwOeCUie2hsB0477TT23Xdf+vfvz3nnncf8+fOZP38+++67LzU1NdTW1nLeeefxxBNPcMwxxzB37lyuuuoqnnjiCfbee++y93X+/PlMmTKFmpoa9t9/fyZMmMDChQsBGDt2LEOGDGGXXXahrq6OV155peztWvXzEWFv1ENfRl5PNkD1ErLOkm35z/a+ARjAUZSM4NXC7vnPGrKjQ37TcReMWgwK1fJ9qSOOOILFixczZ84crrvuOk499VSmTZvW5vLl2n333T+crqmpYWvzKH/WJ/iIsDfqoaPBe4EvAb8nOw1+FTgUeIJsOML7yK4V/hGYl6/zCWAtHwXh+8Cy9irZvJm99tqLt9t5RNfcuXNZv349mzdv5sEHH+Skk05i/PjxrFu3jm3btvHOO+/wwAMPMH78eF577TUGDBjAhRdeyJVXXsnixYt32F6/fv14v5Wndo8fP55Zs2axbds21q5dy+OPP87YsWPba731EQ7C3qiHHrV1J3Bui7Lz8/LzyQa4HglcCIwB9gZ2IwvQq4DRZKfH7fbJbtnCySefzPLly1vtLIHstPT8889n1KhRnH/++dTX1zNmzBj2339/Fi9ezAknnMCll17Ksccey5IlSxg7dix1dXXccMMNXHfddTtsb+rUqYwaNerDzpJm5557LqNGjWL06NGccsopfPe73+WAAw5o93c0bdo0Zs+e3e4yVv1Ujb1f9fX10djYWNayDQ0NQNazl4zDD4eXX650K9gI1ALryHqCnyS7Xtgp++7b7tOrb7/99jY7UpL87A3Y+c9e0qKIqG9ZXtYRoaSJkl6UtELS1a3Mv1jSWklN+evSknkXSXopf13UqVZb60quV1XSn5Md8Y0H/ic7EYJQNftiaeuws0RSDXALWUfgKmChpNmtDNQ+KyIub7HuPsD1QD3ZdfRF+bq+ybQrDjigR79Q3ZZ5RWxk0KB2Z1988cVc7Md2WTcr54hwLLAiIlZGxBay79VOKnP7ZwBzI2J9Hn5zyb6WZl0xYQLs2gc6/CUYP77SrTArKwgPJuswbLYqL2vpfEnPS7pX0tBOrmudMXYsDBhQ6VZ0XW0t+D5fqwJF9Rr/OzA8IkaRHfX9a2c3IGmqpEZJjWvXri2oWX1UfX3vehhrWz74IBv32KzCygnC1cDQkvdD8rIPRcS6iGj+TsdtZDcOlLVuyTZmRER9RNQPHjy4nLana9Ag6MQdE1Xrgw/g4x+vdCvMygrChcAISYdK2o3sTqvtvjgl6cCSt+cAL+TTDwOnSxooaSBwel5mXXXuub37OqEEZ5yRPZfQrMI6/CuMiK3A5WQB9gJwd0QskzRd0jn5Yn8raZmk54C/BS7O110PfIssTBcC0/My66qvfx369at0K3begAFw1VWVboUZUOa9xhExh+yBIaVl00qmrwGuaWPdmXz08BEryic+AcceC731WXpDhsAJJ1S6FWaAb7Hr3a65Bvbaq9Kt6LzaWvjmN7PTY7Mq4CDszc48s3d2mvTrB3/xF5VuhdmHHIS9WU1NNohT//6Vbkn5BgyAn/2sx56wbVYOB2Fvd9JJ8JWv9I4vWO+xB5x3Hpx9dqVbYrYdB2Ff8J3vZPcfV7u994Yf/rDSrTDbgYOwL9h9d7jvvuo+KuzfH+6+u3d27lif5yDsK+rq4J57qvN6Yf/+2WDu+UBIZtXGQdiXnHUW/Nu/VVcY9u+fDeT+xS9WuiVmbXIQ9jXnnQe/+EV1nCb3758F8yWXVLolZu1yEPZFp50GTz6ZPdCgEoHYvz8cdBA88kgWzGZVzkHYV9XVwQsvwN//fRZMPXUXR//+cNllsGKFnzVovYaDsC/r1w9uuAEaG+GYY7Jb27pLbW02qNQTT8A//mN1Xac064CDMAUjR8Kzz2ZfsTnttOzrNkXc2bHbbtl2Pv1puOOO7Aj0uOM6Xs+syvTiB9pZp+yyC5x+evZ67TX453+GH/0I/vSnLMw2bYJWBj3fzq67wp57ZuMqDxgAf/3X2WnwIYf0zD6YdRMHYYoOOgiuvz57rVmTHS0uWpSd1i5Z8lEofvDBR0ePI0dmR3719dnjvw48sON6zHoJB2Hq9tsve1L0GWdUuiVmFeNrhGaWPAehmSXPQWhmyXMQmlnyHIRmljwHoZklr6wglDRR0ouSVki6upX5X5e0XNLzkn4t6ZCSedskNeWv2S3XNTOrtA6/RyipBrgFOA1YBSyUNDsilpcs9ixQHxGbJF0GfBf4b/m8zRFRV3C7zcwKU84R4VhgRUSsjIgtwF3ApNIFIuKxiNiUv10ADCm2mWZm3aecIDwYeLXk/aq8rC2XAA+VvN9DUqOkBZI+txNtNDPrVoXeYifpQqAemFBSfEhErJZ0GPCopCUR8XIr604FpgIMGzasyGaZmbWrnCPC1cDQkvdD8rLtSPoMcC1wTkS811weEavznyuBecCxrVUSETMioj4i6gcPHlz2DpiZdVU5QbgQGCHpUEm7AZOB7Xp/JR0L/JgsBNeUlA+UtHs+PQg4CSjtZDEzq7gOT40jYquky4GHgRpgZkQskzQdaIyI2cD3gFrgHmWPhP9DRJwDfBL4saQPyEL3xha9zWZmFVfWNcKImAPMaVE2rWT6M22s9xRwTFcaaGbW3XxniZklz0FoZslzEJpZ8hyEZpY8B6GZJc9BaGbJcxCaWfIchGaWPAehmSXPQWhmyXMQmlnyHIRmljwHoZklz0FoZslzEJpZ8hyEZpY8B6GZJc9BaGbJcxCaWfIchGaWPAehmSXPQWhmySsrCCVNlPSipBWSrm5l/u6SZuXzn5Y0vGTeNXn5i5LOKK7pZmbF6DAIJdUAtwBnAiOBKZJGtljsEuDNiDgc+D7wv/N1RwKTgaOAicAP8+2ZmVWNcgZ4HwusiIiVAJLuAiYBy0uWmQT8Qz59L3CzJOXld0XEe8DvJK3It/efxTQ/09TURENDQ5GbtF6gqakJwJ99gpqamqirqytse+WcGh8MvFryflVe1uoyEbEV2ADsW+a6AEiaKqlRUuPatWvLaz0wb968Qn8h1nvU1tZSW1tb6WZYBdTV1TFv3rzCtlfOEWGPiIgZwAyA+vr66My6Rf5CzCw95RwRrgaGlrwfkpe1uoykXYG9gXVlrmtmVlHlBOFCYISkQyXtRtb5MbvFMrOBi/LpzwOPRkTk5ZPzXuVDgRHAM8U03cysGB2eGkfEVkmXAw8DNcDMiFgmaTrQGBGzgZ8AP8s7Q9aThSX5cneTdaxsBb4aEdu6aV/MzHaKsgO36lJfXx+NjY2VboaZ9TGSFkVEfcty31liZslzEJpZ8hyEZpY8B6GZJc9BaGbJcxCaWfIchGaWvKr8HqGktcDvO7HKIOCNbmqO66/eulOvP+V939n6D4mIwS0LqzIIO0tSY2tfknT9fbvu1OtPed+Lrt+nxmaWPAehmSWvrwThDNefZN2p15/yvhdaf5+4Rmhm1hV95YjQzGynVX0QVnIo0TLq/rqk5ZKel/RrSYeUzNsmqSl/tXyQbVH1XyxpbUk9l5bMu0jSS/nropbrFlT/90vq/q2kP5XM69L+S5opaY2kpW3Ml6Qf5G17XtKYknlF7HtH9V+Q17tE0lOSRpfMeyUvb5LU6efJlVF3g6QNJb/faSXz2v3MCqr/ypK6l+af9T75vC7te76NoZIey/9tLZN0RSvLFPv5R0TVvsgeBPsycBiwG/AcMLLFMn8D3JpPTwZm5dMj8+V3Bw7Nt1NTcN0nAwPy6cua687fb+yBfb8YuLmVdfcBVuY/B+bTA4uuv8XyXyN7aG9R+/9pYAywtI35ZwEPAQI+BTxd1L6XWf+JzdslG+r26ZJ5rwCDunHfG4D/29XPbGfrb7HsZ8meSF/IvufbOBAYk0/vBfy2lb/9Qj//aj8i/HAo0YjYAjQPJVpqEvCv+fS9wKnS9kOJRsTvgOahRAurOyIei4hN+dsFZGOyFKWcfW/LGcDciFgfEW8Cc8nGle7O+qcAd3ayjjZFxONkTztvyyTgp5FZAHxM0oEUs+8d1h8RT+Xbh4I/+zL2vS1d+ZvZ2foL/dzz+l+PiMX59NvAC+w4+mWhn3+1B2GPDCXahbpLXUL2P1SzPZQNT7pA0uc6UW9n6z8/PzW4V1LzQFld3fdObSO/JHAo8GhJcVf3f2fbV8S+d1bLzz6AX0laJGlqN9X5Z5Kek/SQpKPysh7dd0kDyELmvpLiQvdd2aWuY4GnW8wq9POvmuE8ezNJFwL1wISS4kMiYrWkw4BHJS2JiJcLrvrfgTsj4j1J/53syPiUgusox2Tg3th+PJqe2P+Kk3QyWRCOKykel+/7fsBcSb/Jj7KKspjs97tR0lnAg2QDo/W0zwJPRkTp0WNh+y6plixk/y4i3iqgvW2q9iPCSg4lWtb6kj4DXAucExHvNZdHxOr850pgHtn/ap3RYf0Rsa6kztuA4zrT9q7WX2IyLU6PCtj/nW1fjw0hK2kU2e99UkSsay4v2fc1wAN07pJMhyLirYjYmE/PAfpJGkTPD5/b3ufepX2X1I8sBO+IiPtbWaTYz78rFzW7+0V2xLqS7LSr+eLvUS2W+Srbd5bcnU8fxfadJSvpXGdJOXUfS3ZxekSL8oHA7vn0IOAlOnnRusz6DyyZPhdYEB9dMP5d3o6B+fQ+RdefL3ck2QVyFbn/+brDabvD4Gy2v1j+TFH7Xmb9w8iuO5/YonxPYK+S6aeAiQXXfUDz75ssaP6Q/x7K+sy6Wn8+f2+y64h7dsO+C/gpcFM7yxT6+Xf6F9TTL7Leod+SBc61edl0siMwgD2Ae/I/ymeAw0rWvTZf70XgzG6o+xHgj0BT/pqdl58ILMn/EJcAl3TTvn8HWJbX8xhwZMm6f5X/TlYAX+6O+vP3/wDc2GK9Lu8/2ZHG68D7ZNd5LgG+Anyl5B/LLXnblgD1Be97R/XfBrxZ8tk35uWH5fv9XP7ZXNsNdV9e8rkvoCSMW/vMiq4/X+Ziss7I0vW6vO/5dsaRXWt8vuT3e1Z3fv6+s8TMklft1wjNzLqdg9DMkucgNLPkOQjNLHkOQjNLnoPQzJLnIDSz5DkIzSx5/x/wgbu3KUQYlAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "\u003cFigure size 360x360 with 1 Axes\u003e" + ] + }, + "metadata": { + "needs_background": "light", + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "#@title Train affordance and transition model.\n", + "# Trains num_repeats affordance and model networks.\n", + "#@markdown Experiments to run.\n", + "run_model = True #@param {type:\"boolean\"}\n", + "run_model_with_affordances = True #@param {type:\"boolean\"}\n", + "num_repeats = 1#@param {type:\"integer\"}\n", + "\n", + "#@markdown Training arguments\n", + "# use_affordance_to_mask_model = True #@param {type:\"boolean\"}\n", + "optimize_performance = True #@param {type:\"boolean\"}\n", + "model_learning_rate = 1e-2 #@param {type:\"number\"}\n", + "affordance_learning_rate = 1e-1 #@param {type:\"number\"}\n", + "max_num_transitions = 1000 #@param {type:\"integer\"}\n", + "num_train_steps = 8000 #@param {type:\"integer\"}\n", + "affordance_mask_threshold = 0.5 #@param {type:\"number\"}\n", + "seed = 0 #@param {type:\"integer\"}\n", + "intent_threshold = 0.05 #@param {type:\"number\"}\n", + "\n", + "#@markdown Environment arguments\n", + "drift_speed = 0.001 #@param {type:\"number\"}\n", + "max_action_force = 0.5 #@param {type:\"number\"}\n", + "movement_noise = 0.1 #@param {type:\"number\"}\n", + "max_episode_length = 100000 #@param {type:\"integer\"}\n", + "\n", + "input_size = 2\n", + "action_size = 2\n", + "intent_size = len(IntentName)\n", + "hidden_nodes = 32\n", + "world_size = 2\n", + "\n", + "affordance_mask_params = []\n", + "if run_model:\n", + " affordance_mask_params.append(False)\n", + "if run_model_with_affordances:\n", + " affordance_mask_params.append(True)\n", + "\n", + "print(f'Experiments that will be run: {affordance_mask_params}')\n", + "\n", + "for repeat_number in range(num_repeats):\n", + " all_losses = {}\n", + " model_networks = {}\n", + " affordance_networks = {}\n", + " new_seed = seed + repeat_number\n", + " for use_affordance_to_mask_model in affordance_mask_params:\n", + " print(f'Resetting seed to {new_seed}.')\n", + " np.random.seed(new_seed)\n", + " random.seed(new_seed)\n", + " tf.random.set_seed(new_seed)\n", + "\n", + " affordance_network = tf.keras.Sequential([\n", + " tf.keras.layers.Dense(\n", + " hidden_nodes, activation=tf.keras.activations.relu),\n", + " tf.keras.layers.Dense(\n", + " hidden_nodes, activation=tf.keras.activations.relu),\n", + " tf.keras.layers.Dense(\n", + " intent_size, activation=tf.keras.activations.sigmoid),\n", + " ])\n", + "\n", + " affordance_sgd = tf.keras.optimizers.Adam(\n", + " learning_rate=affordance_learning_rate)\n", + " model_sgd = tf.keras.optimizers.Adam(learning_rate=model_learning_rate)\n", + " model_network = TransitionModel(hidden_nodes, input_size)\n", + "\n", + " # Store models for later use.\n", + " model_networks[use_affordance_to_mask_model] = model_network\n", + " affordance_networks[use_affordance_to_mask_model] = affordance_network\n", + "\n", + " world = ContinuousWorld(\n", + " size=world_size,\n", + " # Slow drift speed to make the transition from L -\u003e R slow.\n", + " drift_speed=drift_speed,\n", + " drift_between=(\n", + " # Drift between the two sides around the wall.\n", + " Point((1 / 4) * world_size, (1 / 4) * world_size),\n", + " Point((3 / 4) * world_size, (3 / 4) * world_size),\n", + " ),\n", + " max_action_force=max_action_force,\n", + " max_episode_length=max_episode_length,\n", + " movement_noise=movement_noise,\n", + " wall_pairs=[\n", + " (Point(1.0, 0.0), Point(1.0, 2.0)),\n", + " ],\n", + " verbose_reset=True)\n", + "\n", + " fig = plt.figure(figsize=(5, 5))\n", + " ax = fig.add_subplot(1, 1, 1)\n", + "\n", + " visualize_environment(\n", + " world, ax, scaling=1.0, draw_start_mu=False, draw_target_mu=False)\n", + "\n", + " def _use_affordance_or_none(model):\n", + " if use_affordance_to_mask_model:\n", + " return model\n", + " else:\n", + " return None\n", + "\n", + " model_loss, aff_loss, infos = train_networks(\n", + " world,\n", + " model_network=model_network,\n", + " model_optimizer=model_sgd,\n", + " affordance_network=_use_affordance_or_none(affordance_network),\n", + " affordance_optimizer=_use_affordance_or_none(affordance_sgd),\n", + " print_losses=True,\n", + " fresh_data=True,\n", + " affordance_mask_threshold=affordance_mask_threshold,\n", + " use_affordance_to_mask_model=use_affordance_to_mask_model,\n", + " max_num_transitions=max_num_transitions,\n", + " max_trajectory_length=None,\n", + " optimize_performance=optimize_performance,\n", + " num_train_steps=num_train_steps,\n", + " intent_threshold=intent_threshold,\n", + " print_every=1000)\n", + "\n", + " all_losses[use_affordance_to_mask_model] = (model_loss, aff_loss, infos)\n", + "\n", + " all_models_global.append(model_networks)\n", + " all_affordance_global.append(affordance_networks)\n", + " all_losses_global.append(all_losses)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "90KLuKumiRcB" + }, + "outputs": [], + "source": [ + "#@title Save weights\n", + "\n", + "for seed, model_networks in enumerate(all_models_global):\n", + " model_networks[True].save_weights(\n", + " f'./affordances/seed_{seed}_model_networks_True/keras.weights')\n", + " model_networks[False].save_weights(\n", + " f'./affordances/seed_{seed}_model_networks_False/keras.weights')\n", + "for seed, affordance_networks in enumerate(all_models_global):\n", + " affordance_networks[True].save_weights(\n", + " f'./affordances/seed_{seed}_affordance_networks_true/keras.weights')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "g1iO3h6Ffucu" + }, + "source": [ + "# Visualizations\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "OQtxS_ovhJxt" + }, + "source": [ + "## Learning curves" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "06NPUpXDfyFf" + }, + "outputs": [], + "source": [ + "#@title Code to collect results from a list of lists into a single array.\n", + "def _collect_results(\n", + " all_losses_g,\n", + " using_affordances,\n", + " save_to_disk=False,\n", + " smooth_weight=0.99,\n", + " skip_first=10):\n", + " \"\"\"Collects results from the list of losses.\"\"\"\n", + " smoothed_curves = []\n", + " for seed, trace in enumerate(all_losses_g):\n", + " if save_to_disk:\n", + " np.save(f'./affordances/curve_seed_{seed}_{using_affordances}.npy',\n", + " np.array(trace[using_affordances][0]))\n", + " # Smooth the curves for plotting.\n", + " smoothed_curves.append(\n", + " apply_linear_smoothing(\n", + " trace[using_affordances][0][skip_first:], smooth_weight))\n", + " all_curves_stacked = np.stack(smoothed_curves)\n", + " mean_curve = np.mean(all_curves_stacked, 0)\n", + " std_curve = np.std(all_curves_stacked, 0)\n", + " return mean_curve, std_curve" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "R4elUiamgc-I" + }, + "outputs": [], + "source": [ + "#@title Plot averaged learning curves.\n", + "\n", + "mean_curve_aff, std_curve_aff = _collect_results(\n", + " all_losses_global, True)\n", + "mean_curve_normal, std_curve_normal = _collect_results(\n", + " all_losses_global, False)\n", + "colors = ['r', 'k']\n", + "plt.plot(\n", + " mean_curve_aff, color=colors[0], linewidth=4, label='With Affordance')\n", + "plt.plot(\n", + " mean_curve_normal, color=colors[1], linewidth=4, label='without Affordance')\n", + "\n", + "plt.fill_between(\n", + " range(len(mean_curve_aff)),\n", + " mean_curve_aff+std_curve_aff,\n", + " mean_curve_aff-std_curve_aff,\n", + " alpha=0.25,\n", + " color=colors[0])\n", + "\n", + "\n", + "plt.fill_between(\n", + " range(len(mean_curve_normal)),\n", + " mean_curve_normal+std_curve_normal/np.sqrt(num_repeats),\n", + " mean_curve_normal-std_curve_normal/np.sqrt(num_repeats),\n", + " alpha=0.25,\n", + " color=colors[1])\n", + "\n", + "plt.ylim([-2.2, -1.2])\n", + "plt.xticks(fontsize=15)\n", + "plt.yticks(fontsize=15)\n", + "plt.xticks([0, 2500, 5000, 7500],[0, 2500, 5000, 7500], fontsize=15)\n", + "plt.legend(fontsize=15)\n", + "plt.xlabel('Updates', fontsize=20)\n", + "plt.ylabel(r'$-\\log \\hat{P}(s^\\prime|s,a)$',fontsize=20)\n", + "plt.tight_layout()\n", + "plt.savefig('./affordances/model_learning_avg_plot.pdf')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "OSAm0ukchLqr" + }, + "source": [ + "## Intent heatmap plots" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 399 + }, + "colab_type": "code", + "id": "1NeTHiZkhXKj", + "outputId": "2428756a-8685-4312-f3f3-1bf4be01188b" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:56: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZ4AAAFZCAYAAACokUkDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3debwcVZn/8c+3+ya5SUgCAUESdgGRYUQQcV9QccB9lBEQQRwUEUHUGUdxHEEUUcYNAX8YAUVFxUHBqKC4geLG4oKAIoggmyCQBEhI7r3dz++Pc/qmbqe6b/VeXfd559WvdNd66vbpeuqcOnWOzAznnHOuX0qDToBzzrmZxQOPc865vvLA45xzrq888DjnnOsrDzzOOef6ygOPc865vvLAk1OSLpf0xh5t+72Szu7FtrtJ0qWSXt+H/Vwuaa2kn/Z6X70k6QhJj0gySTsOOj3ONeKBp0OSbpP0aPzB115nDDpdNZKeJ+nO5DQz+7CZ9SSoxX1K0q2SbmxhnRMlfTk5zcz2N7Pzup/CVMeY2XMS6Vks6SJJqyXdLum1jVaU9C5J10t6WNJfJb2rlR1Lem3cx2pJF0ta3GC5nSV9S9I/JD0o6fuSHl+bb2bnmNlGrezbuUHwwNMdLzOzjRKvYwadoAF7DrA5sIOkpww6MW06ExgDtgAOAf6fpH9qsKyAw4BNgP2AYyQdlGUncZufBQ6N+1oDfKbB4hsDy4HHx2WvAr6VZT/O5YkHnh6RNEfSSkm7JaY9JpaONpe0iaTvxKvXFfH9Vg22NaU0IGm7WJ0yEj+/QdIf4xX3rZLeHKfPBy4FliRKY0uS20ts6/WS/ibpfkn/ndjXXEnnxTT+UdJ/1ZegUryecEK8JL5PHss/SfpBvGK/N1b77Qe8FzgwpvH3cdnJ6kZJJUnviyWD+yR9UdKiLMfQqvh3ezXwP2b2iJldSTjhH5q2vJmdama/MbMJM7spHvszM+7uEODbZvZTM3sE+B/gVZIWpOznqliqedDMxoFPAo+XtGnrR+nc4Hjg6REzWwd8Ezg4Mfk1wBVmdh/hb/95YFtgG+BRoN0quvuAlwILgTcAn5S0p5mtBvYH7k6Uxu5usI1nEa6kXwC8X9IT4vQTgO2AHYB9gdc1S4ikecABwPnxdZCk2XHeAuCHwPeAJcCOwI/M7HvAh4ELYhp3T9n04fG1T0zLRmz490o9BknPkrSyWbrr7AxMmNmfE9N+DzQq8UySJODZwA0Z9/VPcdsAmNlfCCWtnTOs+xzg72b2QMZ9OZcLHni64+JYuqm93hSnfwVIVrm8Nk7DzB4ws2+Y2Rozexg4GXhuOzs3s++a2V8suAK4jHDya8UHzOxRM/s94URYO/m/Bviwma0wszuBT0+znVcB62IavgvMAl4S572UcKL8uJmtNbOHzezXGdN3CPAJM7s1lgyOJwS1kemOwcyuNLONM+4HQlB7qG7aKmCDUkiKE1l/UZF1X6ta3VcsHZ8JvDPjfpzLjZHpF3EZvNLMfpgy/SfAPElPBe4FngRcBJMlg08S7glsEpdfIKlsZpVWdi5pf0LJZGfCSW8e8IcWj+HvifdrCCdECCWTOxLzku/TvB74uplNABOSvhGnXQRsDfylxXTVLAFuT3y+nZB/t0hMa3QMrXqEUHpMWgg83GwlSccQ7vU8O5Z4e7IvSY8hBPbPmNlXM+7HudzwEk8PxQDydUJ128HAd2LpBuA/CNVCTzWzhYRqEwg3quutJgSTmsfW3kiaA3wD+BiwRbyyvySxnU67H78HSN572rrRgvEq/PnA6yT9XdLfCdVuL5a0GSFo7dBg9enSeTehWrJmG2CCENC77c/AiKSdEtN2p0n1maR/B94DvCCWDLO6gfWlSyTtAMyJaUjbzyaEoLPczE5uYT/O5YYHnt77CnAgoaroK4npCwj3dVbG5rMnNNnG74DnSNom3lA/PjFvNuFE9Q9CCWN/4EWJ+fcCm9ZuxLfh68DxsTHEUqBZi71DCSfMxxNKd08ilMLuJAZeYEtJb4+NLxbE0mAtndtJapQnvwq8Q9L2kjZi/T2hiTaPq6F4b+ybwEmS5kt6JvAK4Etpy0s6JKZnXzO7NWX+5ZJObLC784GXSXp2bNRwEvDNxAVKcjsLge8DPzez97RxaM7lggee7vi2pj7Hc1FtRryHsZpQVXRpYp1PAXOB+4FfEW64pzKzHwAXANcB1xJO4LV5DwNvIwSIFYT7SMsT8/9EOGnfGu8/LWnx2E4iBI6/EhoGXEi4h5Pm9YTqn78nX8BZwOtjWvcFXkaoFruZ0FgA4P/i/w9I+k3Kts8lnPh/GtOyFjg2ywHEk/ojWZZNOJrw/dxH+Pu9xcxuaLC9DwGbAlcn8sBZiflbAz9P20nc5lGEAHQf4YLk6ETaL5X03vjxX4GnAG+oy2/btHhszg2UfCA41wpJbwEOMrO2GkLkkaTLgKcD15jZPtMt3+K2tyLc83pGN7fbYF9vINw3HAV2TSt9OZcHHnhcU5K2JNyX+SWwE6Gl2hlm9qmBJsw5N7S8VZubzmzCk/XbAyuBr9H4yXrnnJuWl3icc871lTcucM4511ceeJxzzvWVBx7nnHN95YHHOedcX3ngcc4511ceeJxzzvWVBx7nnHN95YHHOedcX3ngcc4511dDEXgk3SbphRmXvVzSG7u4b5O0Y7e252YOz7fOpRuKwOOcc644hi7wSDpc0pWSPiZphaS/xsHPkHQy8GzgjDhOyRlx+i6SfiDpQUk3SXpNYntfkHSmpO9KeljSryU9Ls77aVzs93F7B3aY7p9LOkPSKkl/kvSCxPwlkpbHNN4i6U2JeXtLukbSQ5LulfSJafZ1YPy7LIyf948jgj6m3fS7zgx5vk1Nd5y/SNI5ku6RdJekD0kqx3llSR+XdH9c75hYEmvYObGkxZLulPSy+Hmj+Hs4rN1jcCDpXEn3Sbq+wXxJ+nT8W18nac+eJsjMcv8CbgNeGN8fDowDbwLKwFsIwyLXOjy9HHhjYt35hCGX30DojXsPwuBru8b5XwAeAPaO888HvpZY34Ad69KzEnhWi8dwOGGo5ncAswijkq4CFsf5PyX0+jxKGLnzH8Dz47xfAofG9xsBT8uwv/PjsW0a/z4vTcz7DvCeQX+vRX8VKN82S/dFhN7L5wObA1cBb47zjgJuJAydvglhIEEDRqbZ54sIAwVuDnwOuDAx77XAdYP+boftBTwH2BO4vsH8FxMGqhTwNODXPU3PoP8gGf9o9T/gWxLz5sXM/Nj4uf4HfCDws7rtfRY4Ib7/AnB23Rfwp8TnDX7AbR7D4ckfbJx2FWG46K2BCrAgMe8U4Avx/U+BDwCbtbC/jYG/AX8APjvo73AmvgqUb1PTDWxBGI12bmL+wcBP4vsfE4NQ/PzCLIEnLnt6zLt3AZsO+rsswgvYrkng+SxwcOLzTcCWvUrLsI7H8/faGzNbIwlCSSDNtsBTJa1MTBshDKO8wfaANU221am7LH6r0e2EIbGXAA9aGBo6OW+v+P4IwhDUf5L0V+ADZvYdmjCzlZL+D3gn8OpuHYDryLDm20bpXkwovd8Tp0Govr8jvl+SeE/d++ksA44BPmxmD7SX7MH6l33m2wMPVvqyr2uvW3cDYTj4mmVmtqyFTSxl6vdzZ5x2TxeSt4FhDTzN1A8wdAdwhZntO4jE1FkqSYngsw2wnFASWixpQSL4bEO42sPMbgYOllQCXgVcKGlTM1vdaEeSngT8O/BV4NPAfj05Itctec63jdxBKPFsZmYTKfPvIVSz1WydZaPxHtEy4IvA0ZI+b2a3dJrYfnvgwQpXfX+bvuyrvOXNa81sr+mXzIeha1yQwb2EoZprvgPsLOlQSbPi6ymSntDm9jqxOfC2mIZ/A54AXGJmdwC/AE6RNCrpiYRSzpcBJL1O0mPMrEqopweoNtqJpNG47nsJ9wiWSjq6S8fgeiPP+TaVmd0DXAZ8XNJCSSVJj5P03LjI14HjJC2VtDHw7oybfi8hEP878L/AF2sNFoaJAdU+/euCu5h6YbBVnNYTRQw8pwEHxBY4n44liBcBBxFKFn8HPgrMybi9E4HzJK2stSqKLYWe3Ubafg3sRLhJfDJwQKIa4WBCHezdhBu2J5jZD+O8/YAbJD0Sj+8gM3u0yX5OAe4ws/9nZuuA1wEfkrRTTP+lkt7bRvpd7+Q53zZzGGF49BuBFcCFwJZx3ucIgek64LfAJYQGNg3rnyQ9mVA9fJiZVQjHbMB74vxDJN3Q5WPoEaNi1b68umA5cFhs3fY0YFW8sOgJH/q6TyQdTrh5/KxBp8W5QYjNsM8ys20HnZZ+2HP3Ofbz7y3py77mLbnt2mZVbZK+CjwP2IxQGj6BcH8OMztL4SbdGYSL3DXAG8zsml6lt4j3eJxzOSBpLrAPodSzBeFkd9FAE9VnXaoG65iZHTzNfAPe2qfkFLKqbWAknRWrM+pfZ/VgX4c02NeQVEO4vOhhvhXhMYAVhKq2PwLvj/tM218vqgJdDnVU1SbplcBLgIXAOWZ2WZy+D6H9/wjwLjO7u/OkOtcdnm9dP+yx+2y74tLH9mVfi5be0bSqLW86qmozs4uBiyVtAnyMUKSG8MTywcCuhNZZH+xkP851k+db1y/VDVrJO+jePZ73AWcmPsvMqpJuZ2o7/jBTOhI4EmD+/PlP3mWXXbqUDNdL11577f1mVqT+3lrKt+B5dxgNKt8aUPHAk6qjwBNbQnwEuNTMfpOYVY0PO25DeAJ2ivhE7TKAXZ84287+5r08XJ3NuI1QUuObceVpbtSVm3zJzbYLsNZmsbY6i5vWLeHBiflcu3IbHnh0HnffvZjSqhHm3lti9AFj9MEqc+8fo7x6nNLK1WjdOLZmDUhodJSxx23Ove9Yy9OX3MZLNvk9S0dWNt3vdJodUy+UZIxbiXEr85E7Xszq9y1h1j0r+d6fTyWekIdeu/kWpubdhVpsm/xmezQyAmrxdml3msC2sLts+UglTb/Q0Ll2YPnWSzzpOi3xHEvof2mRwtgfzzSzQwk/zLMJzfWyPjTWU1UrNQw+Vct40hhAHqow9UTQ60BUtboTTzGb2w8+36rU1+CjkjIHH9cdBlSK+fvpWKf3eD5N6I6l5qw4/ceEDgIzqZqoxAZ2aUGgFjAqDRrh1UpC9Sfp9fOt4banpIMSFUQVUbXwwlj/AhQeRwYzVP9DjpnMTFStNLm9biljXd1eIxXEGGXGqmVUNVTJR5PQbulWvu1YllJSF4NTsjTjQag/ivXL6Z5cPcdTQalX9LWA0ajEUgtIjarikifrjkoMA/6tZgk63SgRVa0UX7XA6yepgUkGpy4HofrgU8xqtsExzO/xNJCrwNPraqTptl+iShmjhFFSeFGy8LSTgBJY/L/pdmTT3lPqlW5VzY1ZGTOF0k7BSjxDq1kJqcOg5EGnBwwqHndS5SrwDFKttFRSlbKqlLDJU7gpBhxgcqJSfqgSFqeXlY8cVwtErQSgUN1YoopCjPUSz/SSJ+5BVGP14p5Ro0DX54YRrng88EyjFl8m48jk/yknlzitpFhqGuIa3kq8z1X2mDM8aoEiQ2Cor2azqm1Y6rFqevBptQVfxjQVTe12sNuQB54mSo1KLcncZBaucMuNGzb0o0FAM+1Wt4V7POYlnlaVNJhST561EqwKE6Q08N9+XnngmYbF2zzhw0CT0leVrE3M3YY86DhiicezQioPPE1s8EyLc663ClPaCbzEk25oAs+gWom1JK3BwRAzD7zOtS10meO/oTRDE3hce9q5v1MhPPwKxHtYQxD0ncshrzVJ54GnRTlpJe2cyzkv8TRWmMAzXQei/aQBPkCa1OkDuX615vqqYPd3DDXs5mum87+Kc865vipMiWc60139JzspTW1KnLZ6gZ9vmfI3KPBxupwoWGmnxmsN0s2YwONaVzX5o9fOtcnv8TTmgSeLtAv+gp+Q/QfjXKfkD2I34IFnOsmistc4OecyCn21eeBJ44GngWoXrvjLA2zZ1kmLttr4Rx5nneuM1xykK0Tg6VVT6pl84k2O1mreuMD1UkEbFph5VVsjQxF48vBMzBQ5PxF3c0A95fxYncuzbtScFNFQBJ5ho4J0b+BNQZ1rX2jV5iWeNP5X6ZGGY/kMCf/BtGnQ/eAPW7XVsKXXdYWXeDLaYARS55xryu/xNOKBp9sGXDvVrfs7FVMYFqE2wqpzriXenLqxGRF4Oj4Z+3nXue6bAdVsFb9PmmpGBJ6ODFHQ6WZrNudcZ7x36sY88LRgyNsLZFZBVCl5qzbnOlT1ezypPPA451wPeHPqxjzwJNRuBFbRjL/ar1hpfeOCGVAX7/psBuQpQ36PpwEPPE3M9ODjnOuMt2pL53+VOt3o1K+k/g/F7Q0LcsCbnTuXiZd4XKoKJT+Put6YAdVsEGqp/QHSdLkPPHnoIFRGtmbVGkzVXE9LO4P/8zs3pOSdhDaQ+8CTR7JEr83ee7NzLoXhJZ5GPPBMpy6uzIRhAioWnuGxWuOKGXDMrk9mSDVbjTenTjf0gaffN/Gdcy4L88cyGspN4ClTzX/Twxxe+Pfq/k6F0voSjysU81YjfeMlnnS5CTzOuYKbYdVshneZ00jhA48/39Ke2g9mJtzTcq431JXnAovIw3ET01Y1Ffyk7PXTzrleKHyJp13tnnRNQrK+DH3dy9LclCs1vycwHGZYVVbeeVVbYx54ssj6AGlB5L6Rhxs+MzQoelVbuqEOPN6UurdmUKx1ruvM5CWeBpr+VSSdJOntic8nSzqu98nqn7QueYbhaeNeN5qY/MEM6X2smZB3Xf5VrNSXVxaS9pN0k6RbJL0nZf42kn4i6beSrpP04q7/QaLpUnwucFhMVAk4CPhyrxKTSzP4BvuQP8fjeTcvZmg1mxHH9urDazqSysCZwP7ArsDBknatW+x9wNfNbA/C7+Uz3f2LrNe0qs3MbpP0gKQ9gC2A35rZA71KTL1+dhDqReKphr1uetB51zlQnmpP9gZuMbNbASR9DXgFcGNiGQMWxveLgLt7lZgs93jOBg4HHku4ipzZhrPmaabyvOsGJrRq69sF3GaSrkl8XmZmyxKflwJ3JD7fCTy1bhsnApdJOhaYD7ywFwmFbIHnIuAkYBbw2l4lJC+alXxabSHdq/swvb6/UyHcFF3fSejQVpX0L+8Wtcm5VUG5uWofOn3sMud+M9urw20cDHzBzD4u6enAlyTtZtb9E8C0fxUzGwN+Qqj7qyTnSdpB0jmSLqybvpuk8+Nrt+4mObuenaCH9Ib7TDPMebcwhveipWjuArZOfN4qTks6Avg6gJn9EhgFNutFYqYNPPHG7NOAc+rnmdmtZnZEymrHAW8FjgaO7TSRbr1+dQFUsVKIr0McZD3vukGq9U7dj1cGVwM7Sdpe0mxC44Hldcv8DXgBgKQnEALPP7r4J5nUtKottnr4DnCRmd3cwnYXmdnKuI0FKds9EjgSYMul5RY221+TX2iL9bT96LXANdePvDvKvG4ktbi8tJObh7HNbELSMcD3gTJwrpndIOkk4BozWw78B/A5Se8g3KI63Kw3V57TtWq7Edihje2ukrSIkPiHU7a7DFgGsOsTZ+fmLJ1syVV/FbFBLEl+H1alvvA47MGnioa6OXU/8u5CLR7uL9n1lBlUcvQbMrNLgEvqpr0/8f5G4Jn9SEtHPRdI2hQ4GdhD0vHArmZ2KHAacHpc7NTOkuhc93nedf3gHe2m6yjwxOcijkqZfj3x4b2hN0OvaSvJVm0FbLE1I/KuGyjDu8xpJBd9tZVkM/YE34p+NiyYVMCg41y/DPuD2L2Si8CTpp+9FrgNZemGw7mGvGFBvx8gHSq5DTxDqVSsTDbELamdy4HiVrVJeiahp4NtCXFEgJlZpgY9HnhakeVE3KPY0+8hvHPUx5QbNl7amQnOAd4BXAtUpll2Ax54UjQrHs+kGsBhbk7tXB4UuMp6lZld2u7KQxt4ujEIXNb7SEP+SE5bvG7auc7k7TmeLvuJpP8Fvgmsq000s99kWXloA08/TLnibyH4aMgjVe1p62HvMse5QSvqPR7W92yd7JjUgOdnWdkDT490c1juft/fgam96vao1wznCq3WV1sRmdk+nazvgaeB9UMCDDYdA1XQH43rIW9YMEVR7/HEbqVOAJ4TJ10BnGRmq7KsX9hyYLvy9sDXIEo74Pd4XBs86ExRe44nJ71Td9u5hL4MXxNfDwGfz7pyYUs8PT1hJzdd9R+bcy5dge/xPM7MXp34/AFJv8u6cmEDT18U+N5H1Ya7d2rnBm5wpZF+eFTSs8zsSph8oPTRrCt74OmDMpa7KrxmKlZa37igwMHVOde2twDnxXs9Ah4EDs+68lAGnm62GHONechxrn1GcRsXmNnvgN0lLYyfH2pl/aEMPN3QUiek8Qzc78dzBtWwALxxgWuRNyxIVbTfkaTXmdmXJb2zbjoAZvaJLNuZsYHHZVCwH41z/VTQ3qnnx/83GBaeFipJchl48jIkQnIomqZDX9fpxrDXgyztQF1rHG+555rx0k5DRQs8ZvbZ+PaHZvbz5LzYwCCTwrb165kZdOPD2xU4175azwUFfY7n9IzTUuWyxNOpTksLye5i2q1u6kapZ1AqaP1NUY8+zrWtaI0LJD0deAbwmLr7PAuBctbtFDLwDFKhStYec9yw0wArdax4VW3AbGAjQuxI3ud5CDgg60Y88EQVNGXws6JdqbSqwN25u27y+zsNFbFxgZldAVwh6QtmdrukeWa2ptXtDPweT7G+lu4YdMOCGu+5wDnXwBJJNwJ/ApC0u6TPZF3ZSzwuVYH7mJrRrJqPi5q+GGQ1W1S0Ek/Cp4B/AZYDmNnvJT2n+SrreeDJoo+/1TyUdqpWCj+YwSfF5ZlXszVV5PF4AMzsjtqDo1El67oeeFox3YlYxc1kzrnWFbi6+g5JzwBM0izgOOCPWVfOXeAZ9MOjDauYpjxMmq0oUBriPuWqCEyoaj4CqXNtKnAjpaOA04ClwF3AZcBbs66cu8DjnHNFYMVsTg2Amd0PHNLu+h54mmi3mNxuVsvD/Z2aismfHXWN+f2dTIpW1SbpdJrcdDCzt2XZjgeeHinnpL+5dlQoeas25zpWyMYF13RjI0MXeLoxFk8n95Gy3t8phOL9aJzrq6KVeMzsvOTnOB6PmdnDrWzHL2tzIk/VbFDcumnXBV7NNuNJ2kvSH4DrgOsl/V7Sk7OuP3Qlnn5Jnnhn4ik4tGrDOwkdFh4McqeIXeYknAscbWY/A5D0LODzwBOzrOyBxznnesEKfd1WqQUdADO7UtJE1pU98KQo8FVKJpXaGB/F/dG4dnnJqiUFfo7nCkmfBb5KOFMcCFwuaU8AM/tNs5U98DjnXA+EmurCBp7d4/8n1E3fg3Doz2+2cuECT89u0ltdi7ZGF36l1jNa3hoWQCz1FfdH41wfFLI5NQBmtk8n6xcu8LjumPKDqXr1inPtKOo9HkkbA4cB25GII/4AaRdYH+5z5LG0MynHSXMD4Pd3WlbgqrZLgF8Bf6Bx/U9DQxV4uvHwaD+UNNxn7Kr3XOBcx8wKHXhGzeyd7a48484uXe39usmgWqUhLy74czzOuSa+JOlNkraUtLj2yrryUJV4BsbPvW6m82q2thS1cQEwBvwv8N+sP0MasEOWlT3w1KlkqWJqYfjgMkalQVv+PN/f8VZtznWuwBUG/wHsGIdHaFmuAk+n1WDdPJE3vFIpbkaaosBXas71TYHv8dwCrGl35VwFnkGrzrxbXqkylfqcyzMNPg8bKnLgWQ38TtJPgHW1id6cul/aLEvnuZoN8C5z3Hp+f6dtBf4JXRxfbfHAk5GMQuci51yXFbg5tZmdJ2k2sHOcdJOZjWddf/Dl0Ryq3d9o+apfxclkkw/PFvjuaGF4iSS/rE+vDCTtJ+kmSbdIek+DZV4j6UZJN0j6SpNtPQ+4GTgT+AzwZ0nPyZaSDks8kubHnY4Bl5vZ+XH6bsDxcbFTzOz6TvbTD2WMUY2zaCTcL9ty3kPMKU+wbmyEdaNzWK1ZjM8vMbZAzN1oPrNWV5i9Yg6lRycor1qNzZnF2GYbsXqLWWw5uobFI6uZrUrDFm15NaUKcLiSnlmR8q1rwKr5uM+TkxKPpDIhSOwL3AlcLWm5md2YWGYnQv5/ppmtkLR5k01+HHiRmd0U192Z0FN1psHgOq1qexVwoZl9W9IFwPlx+nHAWwmx+FTgzVk2VrVSRy3bKqiteyclVZmndYzaOE8avZ0KJfaadytVSjywdCNWV+dw86NbcM/aRfz1ocXc9+BCJlbOZu4985n1MGx09wLG5ouHdoCxzSc47DE38qTR25nVJPC0m9ZeSaalrCrlUjUEngKV4hK6mm+HiWIntlb3SIDa6Nw297wkmLQ3cIuZ3Qog6WvAK4AbE8u8CTjTzFYAmNl9TbY3qxZ04rJ/ljQra2I6DTxbEfrqAagkpi8ys5UAkhbUryTpSOBIgC2XljtMQmuaBbeSQru2WVSYXx6jghjVOGNWZtzKzClNsGZiFo+OzWLleImJeWU0ISbmiMocqMwzSvMmmF9ax6gmhq60M4O0lW/j9Mm8O8q8XqYxGz+55lqOaqqXAnckPt8JPLVumZ0BJP0cKAMnmtn3GmzvGklnA1+Onw8BrsmamE4Dz52EH/HvmHq/aJWkRYQrx4frVzKzZcAyAEn/2GPbO1cDbT2I1D9/zLzkEeG/zcj9MTVyB/ALAP4Wp0hfBnj8gBLUbW3lW9gw7/7QLlzN+LB+zwmVKZ+GOO+mGki+7fN4PJtJSp74l8W82ooRYCfgeYTfx08l/XPtYqzOWwi1A7Xm0z8jVF9n3lEnvgmcIeklwLclfcnMDgVOA06Py5zabANm9hhJ15jZXh2mJVeKekyDTkOXdJxvwfPusBhYvjX62fvH/dN8Z3cBWyc+bxWnJd0J/Dq2TvurpD8TAtHVKdsbAU4zs0/A5D2kOVkT21HgMbPVwBsSk86P068njNXgXO54vnX9kqOqtquBnSRtTwg4BwGvrVvmYuBg4POSNiNUvd3aYHs/Al4IPBI/zwUuA56RJTH+HI9zzvVKTgKPmU1IOgb4PuH+zblmdoOkk4BrzJ83vksAACAASURBVGx5nPciSTcSKl/fZWYPNNjkqJnVgg5m9oikzDc98xJ4Wq2LHAZ+TDNDEf8mRTumAR1PvrrMMbNLCAO4Jae9P/HegHfG13RWS9rTzH4DIOnJwKNZ05KLwNPGTbDc82OaGYr4NynaMQ30eHJS4umBtwP/J+luwoMXjwUOzLpyLgKPc84VTrG7zLla0i6sbzGY7y5zJM2XdJ6kz0k6JDF9N0nnx9du/U5XJyTtIOkcSRfWTR/mY3pl/I4ukPSixPR94vd3vqQlg0xjv3nezT/Pt/1jZuNmdn18ZQ46MJi+2mpPjb8JeHlieu2p8aOBYweQrraZ2a1mdkTKrGE+povjd3QUU4vQRxFahJ3C5CNLM4bn3ZzLXb7NUV9teTKIqra2nxofQkU4pvcR+niqkZlVJd1O+C5nEs+7wyMn+baYVW2dGkSJp/bUeP3+V0laJGkhDZ4aH0JDe0wKPgpcWmu5ElUllYBtCN/lTOJ5N+dyl28LWuKR9KMs0xoZRImnK0+N54mkTYGTgT0kHQ/sOuzHRKheeSGwSNKOhB5rDyU0TT0bmAW8e4DpGwTPu/mXr3w7hNVgzUgaBeYRuujZhPVFuoWE/uCybcdy9Gitc84VxZzttrItT8g0EnTHbv/3d1/bj26OJB1HaEq9hNADQi3wPAR8zszOyLIdb07tnHM9UrTrejM7DThN0rFmdvq0KzTggcc553qlYIGnxsxOl/QMYDsSccTMvphlfQ88zjnXKwV9gFTSl4DHEYYWqbXwNMADj3POuZ7Yi9AQpa0ynQce55zrERW0qg24ntA/2z3trOyBxznnemFIexXIaDPgRklXAetqE83s5Y1XWc8DTxdJegpwDrA3YcyLq4AD4wBjzuWW591eUGHv8QAndrKyB54uij22Lgc+RBiR78v+w3XDwPNujxS0xGNmV0jaFtjJzH4YB4ErZ13fA0/3nUQYZnYt0J+nx5zrDs+73VbQwCPpTcCRwGJC67alwFnAC7KsP4i+2opuU2AjYAEwOuC0ONcKz7vdVtC+2gg9lz+T0GMBZnYzsHnWlT3wdN9ngf8Bzgc+OuC0ONcKz7vdZIR7PP149d86MxurfZA0Qgsh0KvaukjSYcC4mX1FUhn4haTnm9mPB50255rxvNsbBW5OfYWk9wJzJe1LGLfp21lX9k5CnXOuB+Zss7Ut+a+392Vftx37n33pJLQmDjFxBPAiQkeh3zezz2Vd30s8zjnXK8W9rj82dhg6GWwkHRenTcvv8TjnnGvV61OmHZ51ZS/xOOdcjxTtHo+kg4HXAtvH575qFgAPZt2OBx7nnOuV4vVc8AtC/2ybAR9PTH8YuC7rRjzwOOdcLxSwrzYzux24HXh6J9vxezzOOdcrBX2AVNKrJN0saZWkhyQ9LOmhrOt7icc553qkaPd4Ek4FXmZmf2xnZQ88zjnXK8UNPPe2G3TAA49zzrnWXSPpAuBipo7H880sK3vgcc65XiluiWchsIbQc0GNAR54nHNuUGTFvcdjZm/oZH0PPM451ysFe45H0n+Z2amSTielPGdmmcZx8sDjnHO9UrwST61BwTWdbGQoAo+k24A3mtkPMyx7OWHY3rO7tG8jDO96Sze252YOz7euaFVtZvbt+P95nWxnKAKPc84NpYIFnm4Zup4LJB0u6UpJH5O0QtJfJe0f550MPBs4Q9Ijks6I03eR9ANJD0q6SdJrEtv7gqQzJX03Pn37a0mPi/N+Ghf7fdzegR2m++eSzohP+/5J0gsS85dIWh7TeEsc07w2b29J18QnhO+V9Ilp9vVdScfWTbtO0r+2m37XmSHPt6npjvMXSTpH0j2S7pL0oTiQHJLKkj4u6f643jGSLI5W2Wh//ybp2rpp75T0rXaPYWBsfQODXr+GzdAFnuipwE2EjupOBc6RJDP7b+BnwDFmtpGZHSNpPvAD4CuEMcEPAj4jadfE9g4CPgBsAtwCnAxgZs+J83eP27sAQNJKSc9qM91/iek+AfimpMVx3teAO4ElwAHAhyU9P847DTjNzBYCjwO+Ps1+zgNeV/sgaXdgKfBdSZ+R9Jk20u46N8z5doN0x3lfACaAHYE9CM1r3xjnvQnYH3gSsCfwygz7Wk7o+fgJiWmHAl+U9CxJK9tI/+AUrMscSR+N//9bJ9sZ1sBzu5l9zswqhJPslsAWDZZ9KXCbmX3ezCbM7LfAN4DkH+4iM7vKzCYI480/qdnOzWxjM7uyjXTfB3zKzMbjyeAm4CWStgaeCbzbzNaa2e+As4HD4nrjwI6SNjOzR8zsV9PsZzmws6Sd4udDgQvMbMzMjjazo9tIu+vcsObb1HRL2gJ4MfB2M1ttZvcBnyQERIDXEC6Y7jSzFcBHptuRma0DLiBeOEn6J2A74DtmdqWZbdxG+l33vDhedBzfyUaGNfD8vfbGzNbEtxs1WHZb4Knxam9lvGI6BHhs2vYID0U12lan7rKpY43fTijhLAEeNLOH6+Ytje+PAHYG/iTpakkvbbYTM1tL/PEqDFF7MPClLh2Da9+w5ttG6d4WmAXck0jjZwklNAj5+o7EdpLvmzkPeG08wR0KfD0GpOFTsBIP8D1gBfBEJToHlXcSusHXcAdwhZntO4jE1Fkaq1ZqadyGUDq5G1gsaUEi+GwD3AVgZjcDB8cg8irgQkmbmtnqJvs6jxBsrgTWmNkve3A8rnvynG8buYPQXcpmsdRV7x5gq8TnrbNs1Mx+JWmMcN/rtfE1lIbx/kszZvYu4F2SvmVmr2h3O8Na4mnmXmCHxOfvEKqdDpU0K76eUleH3Mr2OrE58LaYhn8DngBcYmZ3EAZYOkXSqKQnEko5XwaQ9DpJjzGzKlCr464221EMNFXCYE1e2sm/POfbVGZ2D3AZ8HFJCyWVJD1O0nPjIl8HjpO0VNLGwLtb2PwXgTOA8TarB10PmdkrJG0h6aXx9ZhW1i9i4DkNOCC2wPl0LEG8iFDvfDeh2uCjwJyM2zsROC9WJbwGILYUenYbafs1sBNwP+FG8AFm9kCcdzChLvtu4CLghMTzH/sBN0h6JB7fQWb2aIb9fRH4Z2IAi2k/S9JZbaTd9Vae820zhwGzgRsJVTAXEu4BAXyOEJiuA34LXEJoiFDJsN0vAbsxNe8+O/4GhkfxqtqAycYFVxHuOb4GuErSAZnXn3rLwfWKpMMJDxO206qo3X0eBhzZz30610hshn2WmW2bYdm5hMY4e8aq5qEzumRr2+7N7+zLvm468Z3XmtlefdkZIOn3wL6xQQmxxPNDM9s9y/pFLPE4QNI84Ghg2aDT4mYmSXMlvVjSiKSlhEcILsq4+luAq4c16EwqaIkHKNWCTvQALcQTDzxdFKuxHkl5db1qS9IhDfZ1g6R/Af5BqOf/Srf37Yqlh/lWhOeMVhCq2v4IvD/uM21/j8TqtNuA44D/6HD/g1fcwPM9Sd9XeMD4cOC7hKrUTDqqapP0SuAlhLEZzjGzy+L0fYDDCa3m3mVmd7e9E+e6zPOt64e5S7a27d7Yn6q2P32wv1VtAJJeBdSq8X9mZllLs501pzazi4GLJW0CfIxwIxHgKMLN8l0JrbM+2Ml+nOsmz7fOdS6ONppp4Ld63XqO533AmYnPMrOqpNuZ2o4/zJSOBI4EmD9//pN32WWXLiXD9dK11157v5m11Gwy51rKt+B5dxgNNN96261UHQWe+GTxR4BLzew3iVnV+LDjNoT+x6Yws2XEm94Ltdg2+W183EADvuVkTR6NaaVKUonBn9o9pmZpGYAfVL5OPCEPvXbzLaTk3d/t2OvktqbTPFzLuw3yrUodDmzWzu+hhX1KacteO5h8a/l6gFTSfoRm+2XgbDNL7cJI0qsJzeKfYmYdjbvTSKclnmOBFwKLJO0IPNPMDiX8MM8mdKcx/UNjKmXO0Fbt4TepUuMfrtRa8Kltr1U5CzgF1Z18m0fN8nBHmx3GgJMDOQk8Cj2GnwnsS7ioulrScjO7sW65BYSGHb+eZnsvA74bH2pvWaf3eD4NfDox6aw4/cfAj7NsQ+Uy5YUpXUwlAkyyAYQAqnXHmhYQ6pdJ2Vbq/hJD1U4JcladWpJptO9prhgb8oDTN93It0kdn5S7ZDK/Ngo+tbzZ6DcgTebbpsfUat7udgApTbP/6Y6zn3KQhGhv4BYzuxVA0teAVxAe/E36IOFB5XdNs70DgU9J+gZwrpn9qZXEDL6vtpEyWrwJAFYSmgwA6/+fzIq1aVMCQoP3ycBTlwGnBJ/64GIWptV+uFULy1dr86rxc1jGassmfrCaPRvK5Q1+RBsEvbTgON2PpUFAbbiP1G3k59fgeqTbJZ9Wgs2gAk3W6X3Ux6q2zSQlq8WWxWrhmqVM7aT1TsJwF5Mk7QlsbWbfldQ08JjZ6yQtJDTG+YLCiLefB75a19lxqoEHnnVLy9z84YUbXqSYMNP6CwZTmGdgtVJJVevXMSXatde9h8nPmjIvzJcBVVAVSmNCFSivFapCeR2UJmDkUaM0BqOrKsx5YJzymjHKK1bD2Di2eg2MjKD5c6kums+Du2/M+HyojIpqOWy39sJq7239ewNVmPp5g/kWp9eWsynbmZwWl13/2aAS/oqT24yBM3y2cOwW39d/dsXS7Ds1g0bn6mQQmy4IJS9spglCZtZZNVkOgktT/fsJ3d9Jc+p4b/MThMcJMjGzhyRdCMwF3g78K6ED0U+b2enN1h144FkybyUf2vNblOM3VIk5f9xC0iomqpSoWIlxK0+ZN25lKoiq1c8vUzVRoUTVNPl5wsphe1ZiwkpUTExUy1QRY5UyY9URVqydy9jECI88Oofx8TITj47AeInyw2XKa2HsvhHmzy4x+6ER5lShtHYMVSowexbVhfMY22weq3aEsY2r2LwKGjGsohAkKyHoqaIY6EKQI/4vQBOAQWmiFJdNBK7EZ6pQqhC3Y1OCmyrpwao0GZDiOpVaEEoPWm6I9bLqNlHCn1bVOg8+1er0pZ48GtzDnWnuYmrv4FvFaTULCH3jXR6/i8cCyyW9PK2BgaRXEILUjoQ+Ifc2s/tijyk3AvkOPCWMUY0zW1P7DawFoIqFDFdNdLJQm1Zbpjr5Of4fg9XUZUMQWv9+/TaqVqKCGK+OsKoyl4lqiZXj83i0Mot7H13A6rHZ3L9iAWMPzULVEUoTojpSpjQ2ysiaMuWJEHjGF89jzeazYJdH2GnTFWw6upq55XHG475qQa5q618TMVhOVEP6zDRlmUo1cUy1+ZP/h9JfbZna8VWrtf9LUz7bZAkxliTrS5GJEmKtVJn3C0oXZGp0k6UEmzWo5CX4mOU6k+aoVdvVwE6SticEnINIDDdhZqsII8wCIOly4D+btGp7FfBJM/tpcqKZrZF0xHSJGXjgUYPAk1igoUqzmdOoBaSkKiVWV+cwbuXJ//8yd3MeHJvPH0w8yEaMP1pibE0JVWD2vBA0So/MxkZHGN9ohHWLxBOX3M0zNvkLS2atYH5p3Qb7qtb1VJScX39M1SnzksF3/XKNtlepm54MvOnb15TpJVUnS6LwHtwQG3S1aa+DD+Q6AA2amU1IOgb4PqE59blmdoOkk4BrzGx5i5v8e33QkfRRM3u3mf1oupUHHnjyrP7EnUaD/kE75/IrR6cHM7uEuv7UzOz9DZZ93jSb25cNHznYP2VaKg88LUjGGE15b3nKX865nMhRVVtXSHoLodf7x0m6LjFrAfDzrNvxwJPR5D2PRk8je8nHOVeveKeFrwCXAqcwtf79YTN7MOtGPPBkUO3gXpJzbobKV6u2bjEzu03SW+tnSFqcNfh44HHOuR4QTdtGDauvAC8FriWE1eQhGrBDlo144JlGNdF6rP7qxRsWOOeaKtgpwsxeGv/fvpPteODpkXLRcpxzrmUFbFywZ7P5db29N+SBJ6PJ5/OS3e2kKWDZ2jnXpoIFHuDjTeYZ8PwsG/HA06niZSznnEtlZvt0YzseeDJI9hKQRalo5WvnXHsKdiqQ9Hwz+7GkV6XNj8NhT8sDj3PO9UKjZ/6G23MJY1a9LGWeAR54us1M64dVgMT/xctdzrkuKNipwcxOiP+/oZPteODJyOqr2wqWoZxz3VfAEg8AkjYFTgCeRTgbXgmcZGYPZFl/CAe56I36Hp5rsnQU6pxzqaxPr/77GvAP4NXAAfH9BVlX9hKPc871SFFLPMCWZvbBxOcPSTow68p+OZ/B5Jg1KZnIey9wzqXqV2lnMKegyyQdJKkUX68hjPWTiZd4WtSwd2rwh0edc1MV7LpU0sOs76Pt7cCX46wS8Ajwn1m244EnodmIpvWNCwpchHbOuVRmtqAb2/HA0wkPPs65BkSxL1AlbQLsBIzWptUPh92IBx7nnOuVggYeSW8EjgO2An4HPA34JRn7avPGBa1IPjyawTD3UF1SddBJcG7oyawvrwE4DngKcHvsv20PYGXWlb3Ek0E1y8Ojfp52ziUVcwTSmrVmtlYSkuaY2Z8kPT7ryh54plELOg2DT5UpXeaYvGmbcy4o8D2eOyVtDFwM/EDSCuD2rCt74HHOuV4paOAxs3+Nb0+U9BNgEfC9rOt74GlBfVVqga9mnBsuOa1pKPI5Io5GWuur7edmNpZ13Vw0LmjUT1ouFTgjOee6rKA9F0h6P3AesCmwGfB5Se/Lur6XeNow3VVMCfNWYc65IjsE2N3M1gJI+gihWfWHsqzsgSeDqveF45xrVTEHgqu5m/Dg6Nr4eQ5wV9aVPfBkVJ9/CpyhnHPdUrDzhKTTCUe1CrhB0g/i532Bq7JuxwNPK+qbVDvnXAMF7TLnmvj/tcBFiemXt7IRDzx16jsKnRwSIRJMvYrxYRGcc40U7PxgZufV3kuaDewcP95kZuNZt+OBpxNdyFQVG6IWfc65lhSwxAOApOcRWrXdRrge31rS672T0A7VD3ldPyyCc841Vewucz4OvMjMbgKQtDPwVeDJWVb2wJPBBt3lbLhAfxLinBsqBX6qYlYt6ACY2Z8lzcq6sgeeVliD97VJXmvmnEsq7jXptZLOZv0IpIewvuHBtDzwtKrYxWfnnMviKOCtwNvi558Bn8m6sgeeJupbtGXit4Kcc1ERGxdIKgO/N7NdgE+0s42mZ1ZJJ0l6e+LzyZKOa2dHw84bFwwXz7tu4IzQ8rUfr34ellkFuEnSNu1uY7oSz7nAN4FPSSoBBwF7t7uzYTSlu5yU3qnzdkXjzbMnzfi86wYvb+eHLtqE0HPBVcDq2kQze3mWlZsGHjO7TdIDkvYAtgB+a2YPdJLaYTSlVVvMSAXOUIXgedflQnHPE//TycpZ7vGcDRwOPJZwFTnjedAZGp533cAUscscSaOEhgU7An8AzjGziVa3k6Ve5iJgP+ApwPdb3cFMoYJ1jVEQnnfd4PTr/k5/zz3nAXsRgs7+hAdJWzZticfMxuLQpivjTaVJknYA/htYZGYHJKbvBhwfP55iZte3kzjnOjFT8q75A8y5VbQSD7Crmf0zgKRzaKFH6qRpSzzxxuzTgHPq55nZrWZ2RMpqxxHaeB8NHNtOwvLGDO+desh43nUDV7wRSCc7Am2niq1muubUuwK3AD8ys5tb2O4iM1tpZquABSnbPVLSNZKuWfVgJWX1/Eq9gqkr6pYKdpkzjKOp9iPvjrOuW8l1bljsLumh+HoYeGLtvaSHsm5kulZtNwI7tJG4VZIWEWLxwynbXQYsA9jpn+fm/iy9QV9ttasMr+LIrX7k3YVa7BnANVWwa1DMrNyN7XTUc4GkTYGTgT0kHU+o/zsUOA04PS52amdJzDlvVDCUPO+6nvOL04Y6CjzxuYijUqZfDxzWybbzohJLO2ZqqS61zPBVT80kMyHvuhzwuJPK+2pzzrkeKVpVW7fkpn+V+iGnc8szknMuqxw9xyNpP0k3SbpF0ntS5r9T0o2SrpP0I0nbdv3vEeUm8ORZfS/VoY+2Bl/2kMRP51zv1fpz7PVr2nSEHqXPJDz0uStwcGz5mfRbYC8zeyJwIT28x+mBJ0X9sNeNbPCF5zjoTDuKqnOuu/r1DE+2As/ewC3x+bUx4GvAK6Yk1+wnZrYmfvwVsFXrB52N3+NphVezuWFg3rAlD0Jfbbk5aSwF7kh8vhN4apPljwAu7VViPPBk5OPxOOdybDNJyaGnl8Vnzlom6XWE/tie25WUpfDA06r6C5jcXNA453Knf4XP+81srybz7wK2TnzeKk6bQtILCX0YPtfMetY1hweeaaQOf+3BxjmXQY6q2q4GdpK0PSHgHAS8NrlAHLvqs8B+ZnZfLxMzIwNPq6N0To5CmjIgXHifm8zlnMuL/nfg2ZCZTUg6hjA8SBk418xukHQScI2ZLQf+F9gI+D9JAH/LOqJoq2Zk4HHOud7L/oxNP5jZJcAlddPen3j/wn6lxQNPRrX8M6UJdX7ylHMuh7zngnQeeJxzrldyVOLJEw887fC85JybjsEQDmXVFx542uRFaOfctLzEk2pou8zpV6eiVRNVU8MHSNOaSw7jiJ3OOdcvXuJpVaMeDHp8ZZO1/zjnXI54gSeVB54u8551nHM1OXqANFc88LRhJmUmrzZ0rgMz6FzRCg88rajLQzMpADnnWmT0s6+2oeKBJ4PQuKBuoscc51wTwvzitAEPPC1KbUYdM5fJb/A45xI88KTywNMjJX/QxznngSeVB55pVEnpkTqZl6qesVyByJvtd43f42nIc1mOpY4F5JxzQ85LPE3UekeomvwBHedcy7xxQToPPK1KDo/geco514wHnlQeeJxzrifyNRBcnnjgaYFidZs3WHPOTcvwwNOABx7nnOsVb9WWygNPRpPDIvgFjHMuI29ckM4DzzQqTcbicc65pjzwpPLAk1CZ7rkZ2/D9Bvd7PEY55yA+QOqBJ40/odiOyaAT3qguc5W99YFzzjXkJZ4Gql695pzriDenbsQDT0a17OOFGedcZh54UnngyaDqN26cc+3wwJPKA08rEnnISz7Ouaa8cUFDHniiSoNSTVoP0Q2DjheMnHOTDMyfIE3jgaeOD0XgnOsar2pL5YEnq7SeC6qsz1iewZxzSV7V1pAHniY2KP3MsDxUnmkH7Fy3+QVpKq9XysCf6XGuQyX/Dbn1vMTTDu/u3DmXhZ8nUnngmUa7pZ2y94fu3AznPRc04oGngUpdLeRk/knLR4nM5bVyzjkgNi7wC9A0Hnha4VHFucEpDeEtaS/xpPLAk4E3LnDOtcUDT6qOAo+k+cBngDHgcjM7P07fDTg+LnaKmV3faBuGGLMyAJWM90WqbTbGa9Q7AYSxeGrbrZgoU6WkKnPL48yfNcbIrApjc6pU54jK7PCqziqhiTKlkTI2qxynQSnWx1UQa6uzm6Yh2WQ7Wb1XqQt29cdcabBeMkjWVxfW9pVMw9T91+/TKBWwO4Zu5Nsky9OzGmlPyrdy8rMqqLTBMSmtVVorT+VXs/1mrVaXnba/SmVqmpT3vGn+HE8DnZZ4XgVcaGbflnQBcH6cfhzwVkIt56nAmxttYMJKPFjZiFGNU1LGwNNC7wL1J9+GyyVO2FVKjJbGmWUVtpq7gvkj63hkbA5/lzFWmUdpvAylEiNrR5g1u0T50VEmFszh0cUlxhYZc8vjVCixsjKfMtXUNNSXohoFifXz04NF/bz6+fWD2yU7PG02rxY8Z2nqj70gOs63k6oVyMMJsFtX1mZgGx6TdZwNup+Pcn9KNzDvMidVp4FnK+AP8X0yZy0ys5UAkhbUryTpSOBIgMVL5lCxEhWVMuekrMEENiw5NJNWkpqlSniVK4yMVFlXNqwM1TJYKb7KwkrCymzwZFSF0rRBBloLJvXzN5jXZrCZMi9eBBS0Z+628m2cPpl3R5kXJhaxOqWIxzQIXuJJ1WnguZPwI/4dU0+5qyQtIoSSh+tXMrNlwDIASf84cpcrVwP3d5iWwftR+O+W9wKwGUU4puiEEIAeP+h0dElb+RY2zLs/tAuLkXenKlTepTj5tjA6DTzfBM6Q9BLg25K+ZGaHAqcBp8dlTm22ATN7jKRrzGyvDtOSK0U9pkGnoUs6zrfgeXdYDDTfeskxVUeBx8xWA29ITDo/Tr8eOKyTbTvXK55vXV+Y+XM8DXhzauec6xUv8aTKS+BZNugE9IAf08xQxL9J0Y5pYMdjXuJJJfOI7JxzXbeovKk9bfQlfdnXZWu+dO0w3ZfLS4nHOeeKxQeCa6jvnR9Jmi/pPEmfk3RIYvpuks6Pr936na5OSNpB0jmSLqybPszH9Mr4HV0g6UWJ6fvE7+98SUsGmcZ+87ybf7nLt1btz2vIDKLXvdpT428CXp6YXntq/Gjg2AGkq21mdquZHZEya5iP6eL4HR0FHJiYdRShRdgpQNoxF5nn3ZzLU741QndK/XgNm0EEnq2AO+L7DZ4aN7NVQOpT40OoCMf0PuDMxGdZ6AfkdsJ3OZN43h0enm/rSNpP0k2SbpH0npT5c2JJ8RZJv5a0Xa/SMojAU3tqvH7/qyQtkrSQBk+ND6GhPSYFHwUuNbPfJGZVJZWAbQjf5UzieTfncpVvzXJT1SapTAjE+wO7AgdL2rVusSOAFWa2I/BJ4KNd/otMGkTjgq48NZ4nkjYFTgb2kHQ8sOuwHxOheuWFwCJJOwLPjMe0DDgbmAW8e4DpGwTPu/mXq3ybo2qwvYFbzOxWAElfA14B3JhY5hXAifH9hYS8LutB02dvTu2ccz0g6XuEfu/6YRRYm/i8LPYrWEvLAcB+ZvbG+PlQ4KlmdkximevjMnfGz3+Jy3S93z5vTu2ccz1gZvsNOg15NYRjyTrnnGvRXcDWic9bxWmpy0gaARYBD/QiMR54nHOu+K4GdpK0vaTZwEHA8rpllgOvj+8PAH7ci/s74FVtzjlXeGY2IekY4PtAGTjXzG6QdBJwjZktB84BviTpFuBBQnDqCW9c4Jxzrq+8qs0551xfeeDpIklPkXSdpNHYOBBt/gAAAMxJREFUr9cNw9TPlZu5PO+6fvKqti6T9CFCm/q5wJ1mdsqAk+RcJp53Xb944Omy2GLkasLDXM8ws8o0qziXC553Xb94VVv3bQpsROhYcXTAaXGuFZ53XV94iafLJC0HvgZsD2yZ7JLCuTzzvOv6xZ/j6SJJhwHjZvaV2BvsLyQ938x+POi0OdeM513XT17icc4511d+j8c551xfeeBxzjnXVx54nHPO9ZUHHuecc33lgcc551xfeeBxzjnXVx54nHPO9ZUHHuecc331/wESNYpMPvYMPwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "\u003cFigure size 360x360 with 5 Axes\u003e" + ] + }, + "metadata": { + "needs_background": "light", + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "#@title Evaluating action and plotting intent heatmaps.\n", + "#@markdown What is the action?\n", + "action_x_dir = 0.2 #@param {type:\"number\"}\n", + "action_y_dir = 0.2 #@param {type:\"number\"}\n", + "network_seed = 0 #@param {type:\"integer\"}\n", + "\n", + "# Cover the x-y grid.\n", + "xs = np.linspace(0, world.size)\n", + "ys = np.linspace(0, world.size)\n", + "xy_coords = tf.constant(list(itertools.product(xs, ys)), dtype=tf.float32)\n", + "\n", + "eval_action = [action_x_dir, action_y_dir]\n", + "fixed_action = tf.constant([eval_action], dtype=tf.float32)\n", + "fixed_action = tf.repeat(fixed_action, 2500, axis=0)\n", + "\n", + "concat_matrix = tf.concat((xy_coords, fixed_action), axis=1)\n", + "affordance_network = all_affordance_global[network_seed][True]\n", + "afford_predictions = affordance_network(concat_matrix)\n", + "affordance_predictions = tf.reshape(\n", + " afford_predictions,\n", + " (len(xs), len(ys), intent_size)).numpy()\n", + "\n", + "plot_intents(world, affordance_predictions, eval_action)\n", + "\n", + "plt.savefig(\n", + " f'intent_eval_FX{action_x_dir}_FY{action_y_dir}.pdf', bbox_inches='tight')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "mOK_e3oqhrjI" + }, + "source": [ + "## Model Predictions" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "61Gb_tXKh3je" + }, + "outputs": [], + "source": [ + "#@title Round annotation plotting code.\n", + "ROUND_BOX = dict(boxstyle='round', facecolor='wheat', alpha=1.0)\n", + "\n", + "def add_annotation(\n", + " ax,\n", + " start: Tuple[float, float],\n", + " end: Tuple[float, float],\n", + " connectionstyle, text):\n", + " x1, y1 = start\n", + " x2, y2 = end\n", + "\n", + " # ax.plot([x1, x2], [y1, y2], \".\")\n", + " ax.annotate(\n", + " \"\",\n", + " xy=(x1, y1),\n", + " xycoords='data',\n", + " xytext=(x2 + 0.25, y2),\n", + " textcoords='data',\n", + " size=30.0,\n", + " arrowprops=dict(arrowstyle=\"-\u003e\", color=\"0.0\",\n", + " shrinkA=5, shrinkB=5,\n", + " patchA=None, patchB=None,\n", + " connectionstyle=connectionstyle,),)\n", + "\n", + " ax.text(*end, text, size=15,\n", + " #transform=ax.transAxes,\n", + " ha=\"left\", va=\"top\", bbox=ROUND_BOX)\n", + "\n", + "connection_styles = [\n", + " \"arc3,rad=-0.3\",\n", + " \"arc3,rad=0.3\",\n", + " \"arc3,rad=0.0\",\n", + " \"arc3,rad=0.5\",\n", + " \"arc3,rad=-0.5\"\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "both", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 689 + }, + "colab_type": "code", + "id": "ILHEEbHShwCb", + "outputId": "9bac0134-d9d1-46da-9283-975ede52250e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Figures show the predicted position of the transition distribution.\n", + "Gray circle shows what would have been predicted but was masked by affordance model. \n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAATwAAAE/CAYAAADbkX+oAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAaaElEQVR4nO3de3xU9Z3/8deHEEAIci+FRAgUjYRbBKzIbSN1oVAFZeuK8NMHLRTLurUUVqvYQluR9dGtNMKyIgpaEZH9Ae0Df+KqP9cIlcUQu1mQcJGrBsL9JmAgl+/+MSfpBJKQgSFh+L6fj0ceZM45c853mPDinDOTOeacQ0TEB3VqewAiIjVFwRMRbyh4IuINBU9EvKHgiYg3FDwR8YaC5wEzG2BmW6uYn2xmzszq1uS4qsNCXjGzY2aWFUybaGYHzOyUmbWIwjYyzWz85Y9WrnYKXgwysyfN7J3zpn1eybRRzrk1zrmUsOm7zezOGhrrq2Y24zJW0R/4WyDJOfdtM4sHZgGDnXMJzrkjURmoeEHBi02rgb5mFgdgZm2AeOCW86Z1CpaNZe2B3c6508Ht1kADYFOkKwr2FvUz7zE9+bFpPaHApQW3BwAfAlvPm7bDObfPzNLNLA/AzBYB7YC3gkPCx8PWO8bMvjCzw2b2VOlEM6tvZhlmti/4yjCz+sG8sWb25/DBBYfHncxsAjAGeDzY1lsVPRgze97MvjSzk2b2qZkNCKaPA14Gbg/uvyR4jADHzew/g+X6mtl6MzsR/Nk3bN2ZZvaMmX0MnAE6mtnfmtmWYPl/BSxs+W+Z2X+a2ZHg72GxmTUNm7/bzP7JzDYE919qZg3C5o8ws5zgsewws+8G05uY2QIzyzezvWY2I+w/p05m9lGwvsNmtrSivyeJAuecvmLwi1DgfhZ8/6/AD4Fnzpu2MPg+HcgLu+9u4M6w28mAA14CrgN6AGeBzsH83wDrgG8ArYC1wNPBvLHAn88bmwM6Bd+/Csy4yGP5P0ALoC4wBdgPNKho/WFjrRvcbg4cAx4M7v9AcLtFMD8T+ALoEsxvBXwFfJ/Qfxo/A4qA8cHynQgdQtcPll0NZJz3d5cFtA22vRn4cTDv28CJ4P51gETg5mDeH4EXgUbB32MW8HAwbwnwVHCfBkD/2v75ula/tIcXuz4CBgbfDwDWBF/h0z6KcJ2/ds597Zz7H+B/CIUPQntpv3HOHXTOHQJ+TSgwUeGce905d8Q5V+Sce45QbFIudr/A94DPnXOLgvsvAbYAd4ct86pzbpNzrggYCmxyzi1zzhUCGYQCWzqW7c65951zZ4PHOgv4m/O2Ods5t885dxR4i7/uVY8j9J/M+865EufcXufcFjNrDQwDJjnnTjvnDgK/B0YF9yskdOje1jlX4Jz7M3JFKHixazXQ38yaA62cc58T2vPqG0zrSuTn7/aHfX8GSAi+bwvsCZu3J5gWFcEh4ubgkO440ARoWc27nz+20vElht3+8rzly2670C5W2W0za21mbwaHnSeB1ysYS2V/TzcAOyoYY3tCe5P5ZnY8eIwvEtrTA3ic0GF1lpltMrMfVvZg5fIoeLHrvwiF4UfAxwDOuZPAvmDaPufcrkruG+lH5Owj9I+2VLtgGsBpoGHpDDP7ZiTbCs7XPQ78PdDMOdeU0GGhVXW/KsZWOr69lYwhn1CYSrdv4beBmcHy3Zxz1xM63K7uWL4EvlXJ9LNAS+dc0+DreudcFwDn3H7n3I+cc22Bh4F/M7NO1dymREDBi1HOua+BbGAyoUPZUn8OplW1d3cA6BjB5pYAvzCzVmbWEphGaM8HQoe+XcwsLTh5/6sIt9WY0Dm0Q0BdM5sGXB/B2FYBN5nZaDOra2b3A6nA/6tk+beD8Y600PsOHwXCI90YOAWcMLNE4LEIxrIA+IGZfcfM6phZopnd7JzLB94DnjOz64N53zKzvwEws/vMLClYxzFCwS2JYLtSTQpebPuI0GFR+DmfNcG0qoL3z4QCdtzM/qka25lBKK4bgI3AX4JpOOe2EXpR4/8Dn583FghFIDXY1p8qWPe7wH8A2wgdihZQ/hC0Si70Pry7CL3YcYTQ3uJdzrnDlSx/GLgPeDZY/kaCPeTAr4GehPYy3wZWRDCWLOAHhM7PnSD0/JTufT4E1ANyCUVtGdAmmHcr8ImZnQJWAj91zu2s7nal+ix0CkNE5NqnPTwR8YaCJyLeUPBExBsKnoh4Q8ETEW/U2ueftWzZ0iUnJ9fW5kXkGvXpp58eds61qmherQUvOTmZ7Ozs2tq8iFyjzOz8XzUso0NaEfGGgici3lDwRMQbV91FW0QqUlhYSF5eHgUFBbU9FLlKNGjQgKSkJOLj46t9HwVPYkJeXh6NGzcmOTmZ0Cc6ic+ccxw5coS8vDw6dOhQ7fvpkFZiQkFBAS1atFDsBAAzo0WLFhHv8St4EjMUOwl3KT8PFw2emd1gZh+aWW7w8dM/rWAZM7PZZrY9uJpTz4hHInKVi4uLIy0tja5du3Lfffdx5syZS17X2LFjWbZsGQDjx48nNze30mUzMzNZu3ZtxNtITk7m8OELPxZw5syZEa8rUvv27eP73/8+ADk5Oaxataps3sqVK3n22Wev+BgqUp1zeEXAFOfcX8ysMfCpmb3vnAt/hoYS+iDFG4HbgBeCP0WuiHGvro/q+haMvfWiy1x33XXk5OQAMGbMGObNm8fkyZPL5hcVFVG3buSnxV9++eUq52dmZpKQkEDfvn2rXK66Zs6cydSpUy+YXnZlrzqXf+DXtm3bsqDn5OSQnZ3NsGHDABg+fDjDhw+/7G1cios+O8HHU+cH339lZpsJXSAlPHgjgNeCC6KsM7OmZtYmuG/UpKenR3N1EiO2b9/OK6+8Uu4f4qnTp6K6ja1bt150Gedc2XI33XQT2dnZvPbaa8yePZvrr7+enTt3smrVKp577jmysrI4d+4co0ePZtSoUTjnePrpp1m7di1t2rQhPj6evXv3snXrVh588EEef/xxunXrxpo1a/j9739PcXExzZo1Y8aMGcydO5c6deqwYMECfvGLX9CxY0emT59Ofn7on9fUqVPp2bMnx44dY8qUKRw8eJC0tDQKCwvZvn07R44cKXsMzz33HF9//TWdO3emU6dOTJo0ifHjx9OjRw82bdrEiy++yEsvvcTGjRs5e/YsgwcP5tFHHwVg0KBB3HPPPWRmZlJYWMjzzz9Px44dycrKKttrNDMWLVrE8ePHmThxIsuXL2fq1KkUFBTwwQcfMGHCBAoKCvjss8+YNm0aeXl5PPXUUxw7dozmzZszc+ZM2rZtyxNPPEFCQgI7duxg//79/Pa3vy3bY7wcEf13ZGbJwC3AJ+fNSqT8x3LnBdPKBS+4MPMEgHbt2kU00PT0dHJyckhLS7v4wnJNOXXqFCUlV88lHoqKili9ejUDBgwAIDc3l7feeoukpCSWLl1K48aNWbZsGefOneOBBx6gf//+5Obmsnv3bt5++20OHz7MXXfdxciRI8ut9+jRo/zyl7/k9ddfJykpiePHj9O0aVPuv/9+GjZsyLhx4wCYMmUKY8eOpVevXuzbt4/x48ezatUq5s6dS69evXjkkUfIzMws28MKN2XKFBYvXsyf/hT6tP28vDz27NnDs88+W/Zva9KkSTRt2pTi4mLGjh3L1q1bSUkJXTWzWbNmrFixgjfeeIOFCxcyY8YMFi5cyLRp0+jZsyenT5+mfv36ZdurV68eP/nJT8oCB7BixV8/NX/GjBncc8893HvvvSxfvpxnnnmGuXPnArB//34WLFiAc47hw4fXbPDMLAFYTujamicvZWPOufnAfIDevXtH/NnyaWlpZGZmXsqmJYalp6dTr169sn90AAn/dUk/gpUKX3dlCgoKuP/++wEYMGAAU6dOZe3atdx222185zvfAWDDhg1s2LCh7Of09OnTFBcXs2PHDsaNG0dqaioAd955J4mJiaSkpNCwYUOSk5PJz89n0KBBZesq1bJlSxISEsrGmJWVRV5eXrlxJSYmsnHjRlasWEHHjh1JSUnhySefpFOnTrRsWf4qk2ZWtq769evTvn37sscFMG/ePObPn09RURH5+fmcOXOGlJQU4uPjmThxIomJiQwbNoyPP/6YlJQUhgwZQkZGBmPGjGHkyJEkJSXRqFGjsuesTZs27N27t2yb4bc3btzIe++9R3x8PI899hizZs0iJSWFJk2a0K9fP+rUqUNKSgoHDhyo9nNZlWoFz8ziCcVusXOuooua7KX8pe6SKH+ZPJGYF34OL1yjRo3KvnfOMWfOHIYMGVJumfCT9perpKSEdevW0aBBg6isL3z8u3bt4ne/+x3r16+nWbNmjB07ttxbP0r33uLi4igqKgLgiSee4Hvf+x6rVq2iX79+vPvuu1EZW/gbiqN17Z3qvEprhK48tdk5N6uSxVYCDwWv1vYBTkT7/J1ILBgyZAgvvPAChYWFAGzbto3Tp08zcOBAli5dSnFxMfn5+Xz44YcX3LdPnz6sXr2aXbtClxM+evQoAI0bN+arr74qW27w4MHMmTOn7HZphAcOHMgbb7wBwDvvvMOxY8cqHGN8fHzZ+M538uRJGjVqRJMmTThw4ADvvPPORR/zjh076NatGz//+c+59dZb2bJlS7n5548/XN++fXnzzTcBWLx4cdlpgiulOi/H9AMeBAaZWU7wNczMfmxmPw6WWQXsBLYDLwH/cGWGK3J1Gz9+PKmpqfTs2ZOuXbvy8MMPU1RUxL333suNN95IamoqDz30ELfffvsF923VqhXz589n5MiR9OjRo+ww8+677+aPf/wjaWlprFmzhtmzZ5OdnU337t1JTU1l3rx5AEyfPp3Vq1fTpUsXVqxYUel58gkTJtC9e3fGjBlzwbwePXpwyy23cPPNNzN69Gj69et30ceckZFB165d6d69O/Hx8QwdOrTc/DvuuIPc3FzS0tJYunRpuXlz5szhlVdeoXv37ixatIjnn3/+otu7HLV2mcbevXu7SD4Pr/QVWp3D8096ejrTp0/njjvuqO2hSA0qfUW8qvOrmzdvpnPnzuWmmdmnzrneFS2v37QQEW8oeCLiDQVPRLyh4ElMKP21J5FSl/LzoOBJTMjPz+fIkSOKngB//Ty8SN/vpw8AlZiwZMkS+vbty6FDh2p7KFJD9u/fD1DprxWWfuJxJBQ8iQmnTp2K6JNtJfZNnDgRiO5b0XRIKyLeUPBExBsKnoh4Q8ETEW8oeCLiDQVPRLyh4ImINxQ8EfGGgici3lDwRMQbCp6IeEPBExFvKHgi4g0FT0S8oeCJiDcUPBHxhoInIt5Q8ETEGwqeiHhDwRMRbyh4IuINBU9EvKHgiYg3FDwR8YaCJyLeUPBExBsKnoh4Q8ETEW8oeCLiDQVPRLyh4ImINxQ8EfGGgici3lDwRMQbCp6IeEPBExFvKHgi4g0FT0S8oeCJiDcUPBHxhoInIt5Q8ETEGwqeiHhDwRMRbyh4IuINBU9EvKHgiYg3FDwR8YaCJyLeUPBExBsKnoh4Q8ETEW8oeCLiDQVPRLyh4ImINxQ8EfGGgici3lDwRMQbCp6IeEPBExFvKHgi4g0FT0S8oeCJiDcUPBHxhoInIt5Q8ETEGwqeiHhDwRMRbyh4IuINBU9EvKHgiYg3FDwR8YaCJyLeUPBExBsKnoh4Q8ETEW8oeCLiDQVPRLyh4ImINxQ8EfGGgici3lDwRMQbCp6IeEPBExFvKHgi4g0FT0S8oeCJiDcUPBHxhoInIt5Q8ETEGwqeiHhDwRMRbyh4IuINBU9EvFG3tgcgseuzvSd4b9N+Dpw8yw3Nr2NotzZ8q1VCbQ9LpFIKnlySdzft5/9mf4lzoduHT51lQ94JxvXvwG0dW9Tu4EQqoUNaidhH2w7x7+v/GrtSxSWOl9bsZGPeidoZmMhFKHgSkdNni1j+aV6l852DN7K+oKi4pAZHJVI9Cp5E5IMtBzl9tqjKZQ6eLGDdzqM1NCKR6lPwpNqKikv4cMvBai37Xu7+KzwakcgpeFJtu4+c4eTXhdVadu+xrzl6+twVHpFIZBQ8qbZdh09HuPypKzQSkUuj4Em17TwUWcB2HooskCJXmoIn1bb7SGQB2xnhHqHIlabgSTnnis/hzn+DHVBQWMzBk2fLTXNAiav87SdfHj0T7eGJXBYFT8qcKz7H8CXDmfzu5AuiV1BYXO62A3IPbSJr3/pKo1dQqPfiydVFwZMy8XXi6dyyMxmfZFwQvXNhbyQujd3OY7toXC8Bs4p/jJxzegOyXFUuGjwzW2hmB83ss0rmm5nNNrPtZrbBzHpGf5hSE8yMWUNmMem2SRVGD8rHrmOzDqS26oLVznBrxIkTJ3h00mRWrFjBmTM6RI911dnDexX4bhXzhwI3Bl8TgBcuf1hSWyqLXr24OhHHzsyoGxfbBxGbN2/m5VcX8fCTM2nxjdYMvfselixZwsmTJ2t7aHIJLvppKc651WaWXMUiI4DXXGhXYJ2ZNTWzNs65/CiNUWpYafQAMj7JAOCfB/1LxHt2DeJjO3alGrVoTaN7f02DMydYv/0Tsp+ezQ/HT+D2/gN46IG/Z8SIETRr1qy2hynVEI2Ph0oEvgy7nRdMU/BiWLno/eF5MoZm0HzIPxKX0Jyd7GInuy66jsJDe5g75g9RHVNtaQTENWxCQvfB0H0w1509zYbtWTz2uwX8aMLDDBx0Jx/8x9u1Nj6pnhr9PDwzm0DosJd27drV5KblEpRGL2NlBiRCUfE+mjfpUO371yuIo2PPyz+lu23bNgBuuummy15XpPbs2cPJM+e9HaekmHMHd8GhHRTs3843E2+gz629anxsErloBG8vcEPY7aRg2gWcc/OB+QC9e/e+8M1eclVxzjH53cnQGvgRFBbn0aJx82q/UPEPd4ylV/vmlz2O9PR0ADIzMy97XZFat24dd48ZjysppuCLjRTvXEfB5+to3fobjBl1H/ffN4PU1NRa3fuU6ovGSZaVwEPBq7V9gBM6fxf7SmOX8UkGk26bRMm0Ekb3uJOdx3aRe2gT1fnfqkPLa+Pj3k/m7+bQi2NpvmkZP7unL/+dtZYdWzbxm1/9ii5duih2MeSie3hmtgRIB1qaWR4wHYgHcM7NA1YBw4DtwBngB1dqsFIzzo/drCGzMDNeGPEbtn7xClsPhc7fVbWnl9jsOpo3qldzg75CevfuzR8Wvsztt99O+/bta3s4cpmq8yrtAxeZ74BHojYiqVWVxQ4gvm4cj90xlJnvv8/OY1VHb3DqN2tw1FdO3bp1GTVqVG0PQ6JEF/GRMlXFrtSdqa35YHMPgEqj943rG9Cn4+WfuxOJNgVPyhSWFLL58OZKYwfQsF5d/q5XEl8FH/P+1blTOFdS9utlZjD62+1i/g3Hcm1S8KRMvbh6rHxgJfF14qs8ET/wplYUFBazNBtKSkqoE8Quro4xrn8HuiU1qakhi0REwZNy6sVV74WGwV2+Sdum1/Fe7gEOniwgqZkuxC1XPwVPLlnXxCZ0TdTenMQOnWgREW8oeCLiDQVPRLyh4ImINxQ8EfGGgici3lDwRMQbCp6IeEPBExFvKHgi4g0FT0S8oeCJiDcUPBHxhoInIt5Q8ETEGwqeiHhDwRMRbyh4IuINBU9EvKHgiYg3FDwR8YaCJyLeUPBExBsKnoh4Q8ETEW8oeCLiDQVPRLyh4ImINxQ8EfGGgici3lDwRMQbCp6IeEPBExFvKHgi4g0FT0S8oeCJiDcUPBHxhoInIt5Q8ETEGwqeiHhDwRMRbyh4IuINBU9EvKHgiYg3FDwR8YaCJyLeUPBExBsKnoh4Q8ETEW8oeCLiDQVPRLyh4ImINxQ8EfGGgici3lDwRMQbCp6IeEPBExFvKHgi4g0FT0S8oeCJiDcUPBHxhoInIt5Q8ETEGwqeiHhDwRMRbyh4IuINBU9EvKHgiYg3FDwR8YaCJyLeUPBExBsKnoh4Q8ETEW8oeCLiDQVPRLyh4ImINxQ8EfGGgici3lDwRMQbCp6IeEPBExFvKHgi4g0FT0S8oeCJiDcUPBHxhoInIt5Q8ETEGwqeiHhDwRMRbyh4IuINBU9EvKHgiYg3FDwR8YaCJyLeUPBExBsKnoh4Q8ETEW8oeCLiDQVPRLyh4ImINxQ8EfGGgici3lDwRMQbCp6IeEPBExFvKHgi4g0FT0S8oeCJiDcUPBHxhoInIt6IWvDM7LtmttXMtpvZE9Far4hItEQleGYWB8wFhgKpwANmlhqNdYuIREvdKK3n28B259xOADN7ExgB5EZp/QDk5OSQnp4ezVVKDMjJyQHQc++ZnJwc0tLSorrOaB3SJgJfht3OC6aVY2YTzCzbzLIPHToU0QYyMzOj/uAlNiQkJJCQkFDbw5AalpaWRmZmZlTXGa09vGpxzs0H5gP07t3bRXr/aD94EfFLtPbw9gI3hN1OCqaJiFw1ohW89cCNZtbBzOoBo4CVUVq3iEhUROWQ1jlXZGb/CLwLxAELnXOborFuEZFoido5POfcKmBVtNYnIhJt+k0LEfGGgici3lDwRMQbCp6IeEPBExFvKHgi4g0FT0S8Yc5F/Cut0dmw2SFgT4R3awkcvgLDkaufnns/Xcrz3t4516qiGbUWvEthZtnOud61PQ6peXru/RTt512HtCLiDQVPRLwRa8GbX9sDkFqj595PUX3eY+ocnojI5Yi1PTwRkUsWE8HTJSD9YGYLzeygmX1WyXwzs9nBz8EGM+tZ02OU6DOzG8zsQzPLNbNNZvbTCpaJynN/1QdPl4D0yqvAd6uYPxS4MfiaALxQA2OSK68ImOKcSwX6AI9U8G88Ks/9VR88wi4B6Zw7B5ReAlKuMc651cDRKhYZAbzmQtYBTc2sTc2MTq4U51y+c+4vwfdfAZu58KqHUXnuYyF41boEpHhBPwvXODNLBm4BPjlvVlSe+1gInoh4wMwSgOXAJOfcySuxjVgIni4BKaX0s3CNMrN4QrFb7JxbUcEiUXnuYyF4ugSklFoJPBS8YtcHOOGcy6/tQcnlMTMDFgCbnXOzKlksKs991K5adqXoEpD+MLMlQDrQ0szygOlAPIBzbh6hq+INA7YDZ4Af1M5IJcr6AQ8CG80sJ5g2FWgH0X3u9ZsWIuKNWDikFRGJCgVPRLyh4ImINxQ8EfGGgici3lDwRMQbCp6IeEPBExFv/C+4GIGhTOMrBgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "\u003cFigure size 360x360 with 1 Axes\u003e" + ] + }, + "metadata": { + "needs_background": "light", + "tags": [] + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAATwAAAE/CAYAAADbkX+oAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3dd3hUVf7H8fdJSEgVkCYQBKQpqZBkpSgBkR4RWLEsuFRd/LEiggUroqyrqwsslkVdgQWxLKyuKKiIBAEXhABBmihISZCSUIZUMkm+vz9mMpuQQhImpNzviyfPk7n3zj3nzp18OOe2Y0QEpZSyAo+qroBSSl0pGnhKKcvQwFNKWYYGnlLKMjTwlFKWoYGnlLIMDTzlYozZY4zpVQXldjTGJBhjUo0xk40xvsaYz4wxNmPMMjesv7UxRowxddxRX1Vz6RegljHGCNBeRA4UmPYc0E5ERpX2XhEJruTqleQxIE5EIgCMMfcCTYGGIpJTRXVStZC28FR10ArYc9HrnyoSdtqKU6XRwLMYY0wjY8znxphzxpgzxpgNxhgP57zDxphbnb8/Z4z5lzFmsbOruccYE1VgPV2MMTuc85YZYz4yxswqocy2xpi1xpjTxpgUY8xSY0x957y1QG/gdWNMmjHmA+BZ4C7n6/HGGA9jzNPGmCPGmFPOOtVzvj+/uzreGHMUWGuM8TTGvOos6xdg8EX1GWuM2ees+y/GmD8UmNfLGJNkjJnmLOu4MWZsgfm+xpi/OutiM8ZsNMb4Oud1Ncb81/nZ7ix4eMAYM8ZZVqox5pAxZuRl7UhVMSKiP7XoBxAc3deC054D3nP+/mdgPuDl/LkZMM55h4FbC7wnCxgEeDrft9k5zxs4AjzkXMdwIBuYVUKd2gF9gbpAY2A9MLfA/HXAhOLq63w9DjgAXAcEAB8DS5zzWju3eTHgD/gCE4EfgZbA1UCcc5k6zvcMBtoCBogBMoAuznm9gBzgeee2DXLOb+Cc/4azvi2cn0t353a1AE47l/dwbu9p5/b6A+eBjs51NAOCq/q7YsUfbeFZjx3HH1wrEbGLyAZx/hUWY6OIrBKRXGAJEO6c3hXH8d95znV8DGwpqUAROSAiX4vIBRFJBmbjCJqyGgnMFpFfRCQNeAK4+6Lu63Miki4imcCdOAI1UUTO4AjrgvVZKSIHxeFbYDWO4M9nB553btsqIA3o6GwJjwMeEpFjIpIrIv8VkQvAKGCV8/PKE5GvgXgcAQiQB4QYY3xF5LiIFOzCqytEA6/2ycXRMinIC8cfMcArOFpLq51drOmlrOtEgd8zAB9nyDQHjl0UlIklrcQY09QY86Ex5pgx5jzwHtCobJsDzvKOFHh9BEfgNi2h/OYXvS74XowxA40xm51d+nM4QqlgfU5L4eOHGThalo0AH+BgMXVsBYxwdmfPOdd7E9BMRNKBu3C0PI8bY1YaY66/5FYrt9PAq32O4ujmFdQG5x+9iKSKyDQRuQ4YAkw1xvQpZxnHgRbGGFNgWstSln8RR5cyVESuwtEaMqUsf7FfcQRKvmtxdDtPFphWMHyPX1Sfa/N/McbUBf4NvAo0FZH6wKoy1icFRze/bTHzEnF0s+sX+PEXkZcAROQrEemLo3X9I/BOGcpTbqaBV/t8BDxtjAlyHuy/FbgNWA5gjIk1xrRzhpUNR4swr5xlbHK+74/GmDrGmNuB35SyfCCObqHNGNMCeLSc5X0APGyMaWOMCcARoB9JyWdx/wVMdn4GDYCCrVhvHMfckoEcY8xAoF9ZKiEiecACYLYxprnz5Eg3Z4i+B9xmjOnvnO7jPAES5Gzh3m6M8QcuOD+L8n7myg008Gqf54H/AhuBs8BfgJEists5vz2wBscf3SbgTRGJK08BIpKN40TFeOAcjhbb5zj+mIszE+iCI2BX4jjpUB4LcBxDXA8cwtHKerCU5d8BvgJ2AtsLliciqcBkHKF4FvgdsKIcdXkE2AVsBc4ALwMeIpII3A48iSNME3EEu4fzZyqOluoZHMcvHyhHmcpNjJR4vFqpsjPGfA/MF5GFVV0XpUqiLTxVIcaYGGPMNc4u7WggDPiyquulVGn0qnRVUR1xdAv9gV+AO0TkeNVWSanSaZdWKWUZ2qVVSlmGBp5SyjKq7Bheo0aNpHXr1lVVvFKqltq2bVuKiDQubl6VBV7r1q2Jj4+vquKVUrWUMeZISfO0S6uUsgwNPKWUZWjgKaUsQy88tgC73U5SUhJZWVlVXRWl3MbHx4egoCC8vC5+GlrJNPAsICkpicDAQFq3bk3hJzopVTOJCKdPnyYpKYk2bdqU+X3apbWArKwsGjZsqGGnag1jDA0bNix3r0UDzyI07FRtU5Hv9CUDzxjT0hgTZ4zZ6xy56qFiljHGmHnGmAPGmB+MMV3KXROllKpkZWnh5QDTRKQTjsFbJhljOl20zEAcD5ZsD9wP/N2ttVRudfDgQXbv3u22n4MHixviobDMzExiYmLIzc3l8OHDGGN4+umnXfNTUlLw8vLij3/8Y7m3Z926dcTGxpb7fQCHDx8mJCQEgF27djFmzJgSy6hXrx4RERHccMMNzJw5s0Ll5WvdujUpKSkAdO/evdRlFy1axK+//lqu9Rfcrounv//+++VaV0WsWLGCl156CYD//Oc/7N271zXv2WefZc2aNZVeh+Jc8qSF85E/x52/pxpj9uEYkm5vgcVuBxY7B3XZbIypb4xp5u7HBfXq1cudq7OMGTNm4OHxv//bDh8+jK+vr9vWn5mZSU5O6WNmL126lJtuuokDBw6QlJREUFAQH3/8Mffeey8AH3zwAe3atePs2bPs37+/0Hvzj9P4+PgUu+6jR4+SlpZW5H1lkZSURHZ2Nvv378fb25uffvqJuLg4mjdvXqSMzp0789Zbb5GRkcGwYcMIDQ0lODjYtUxOTg516pTtPKDdbufAgQOcPn2ahQsXllr3N998k4CAAEJDQyu0XQV9//33LFiwgMjIyCLvKU/9L6Vjx4507NiR/fv3889//pNevXrh6ekJwMiRjiF5y7K/Onbs6Jb65CvX1hljWgOdge8vmtWCwqNEJTmnFQo8Y8z9OFqAXHvttZRHr169SEhIICIiolzvU9XDZ599xquvvup67evry3XXXceuXbsIDQ1l1apVDBgwgFOnTgGwdu1a5s+fj91uJzAwkD/96U+0bNmSLVu28OKLLwKOYzhLliwpVM6uXbt49tln+dvf/sb58+d56aWXyMjIoEGDBvz5z3+mSZMm7N69m6eeegqAHj16FHp/7969WbVqFRMmTChxW/z8/AgODubo0aOsXbuWxMREEhMTadasGU8//TQzZszg+HHHV//JJ5+kS5cunD17lmnTpnHq1Kki3+EuXbqwfft2AN555x1WrFiBh4cHPXv2JDg4mD179vDoo4/i4+PDhx9+yIEDB8q9Xflmz57NwYMHGTp0KEOHDuWqq67i66+/JiMjg9zcXN566y0mTZrE+fPnsdvtTJkyhT59+pCUlMT9999PZGQkO3bsoEmTJrz55pv4+PiwePFiPvroIzw9PWnXrh2zZ8/m448/Zvfu3cTGxhIXF8fWrVuZP38+8+bN480336RXr14MGDCATZs28Ze//IWcnBxCQ0N57rnn8Pb25pZbbmHw4MF89913eHh4sGzZMq6/3g0DvZV1AFscw9RtA4YXM+9z4KYCr78BokpbX2RkpJRHTEyMxMTElOs9ymHv3r2FXu/atUsOHTrktp9du3aVWv6FCxekadOmrteHDh2S4OBg+fTTT2XatGly9OhRueWWW2ThwoUyadIkERE5c+aM5OXliYjICy+8IGPGjBERkdjYWNm4caOIiKSmpordbpe4uDgZPHiwfPfdd9KlSxc5cuSIZGdnS7du3eTUqVMiIvLhhx/K2LFjRUQkNDRUvv32WxEReeSRRyQ4ONhVt40bN0psbGyRbcgvQ0QkJSVFWrVqJbt375YZM2ZIly5dJCMjQ0RE7rnnHtmwYYOIiBw5ckSuv/56ERF58MEHZebMmSIi8vnnnwsgycnJIiLi7+8vIiKrVq2Sbt26SXp6uoiInD59WkQc3/2tW7eKiFR4u4rbDhGRhQsXSosWLVxl2e12sdlsIiKSnJwsbdu2lby8PDl06JB4enrKjh07RERkxIgRsmTJEhERadasmWRlZYmIyNmzZ13rzd+Xo0ePlmXLlrnKzH+dmZkpQUFBsn//fhERuffee2XOnDkiItKqVSt56qmn5Mcff5Q33nhDxo8fX2RbRIp+t0VEgHgpIXfK1MIzxnjhGNpuqTgGXb7YMQoPixfknKYUKSkp1K9fv8j0AQMG8Mwzz9C0aVPuuuuuQvOSkpK46667OH78OGlpaQQFBQGOlsvUqVMZOXIkw4cPd03ft28f999/P6tXr6Z58+au44t9+/YFIDc3l2bNmnHu3DnOnTtHz549Abj33nv54osvXOU2adKkxONlGzZsoHPnznh4eDB9+nSCg4NZtmwZQ4YMcR0iWLNmTaHjVefPnyctLY3169fz8ceOP53BgwfToEGDIutfs2YNY8eOxc/PD4Crr766yDL79++v0HaVpm/fvq6yRIQnn3yS9evX4+HhwbFjxzh50jEaZps2bVyt08jISA4fPgxAWFgYI0eOdLUay2r//v20adOGDh06ADB69GjeeOMNpkyZAkC/fv1cZeV/dpfrkoHnHM7vXWCfiMwuYbEVOIbs+xC4EbCJPu5bOfn6+hZ7vZS3tzeRkZH89a9/Ze/evaxY8b/Bwx588EGmTp3KkCFDWLx4Ma+//joA06dPZ/DgwaxatYoePXrw1VdfAdCsWTOysrLYsWMHzZs3R0QIDg5m06ZNhco8d+5cqXXNysoq8fjmzTffzOeff15kur+/v+v3vLw8Nm/eXOLxxstV0e0qTcH6L126lOTkZLZt24aXlxetW7d27bu6deu6lvP09CQzMxOAlStXsn79ej777DP+9Kc/sWvXrgrXpSBvb29XWZc6RlxWZTlL2wO4F7jFGJPg/BlkjJlojJnoXGYVjnENDuAYIu//3FI7VSs0aNCA3NzcYkNv2rRpvPzyy0VaMzabjRYtWgCOs3z5Dh48SGhoKI8//jjR0dH8+OOPANSvX5+VK1fyxBNPsG7dOjp27EhycrIrGOx2O3v27KF+/frUr1+fjRs3Ao4/8IJ++umnYs9ullW/fv147bXXXK8TEhIA6Nmzp+vs6BdffMHZs2eLvLdv374sXLiQjIwMAM6cOQNAYGAgqampABXernwF11Ucm81GkyZN8PLyIi4ujiNHSnzSEuAI+MTERHr37s3LL7+MzWYjLS2tTGV27NiRw4cPc+DAAQCWLFlCTExMqeVdrrKcpd3IJUZld/abJ7mrUqpy+fr6FvlSXu76LqVfv35s3LiRW2+9tdD04ODgQmc68z333HOMGDGCBg0aEBERQVJSEgBz584lLi4ODw8PgoODGThwoOuPv2nTpnz++ecMHDiQBQsWsHz5ciZPnozNZiMnJ4cpU6YQHBzMwoULGTduHMYYV7cpX1xcHIMHD67oR8G8efOYNGkSYWFh5OTk0LNnT+bPn8+MGTO45557CA4Opnv37sWetBswYAAJCQlERUXh7e3NoEGDePHFFxkzZgwTJ07E19eXTZs2VWi78oWFheHp6Ul4eDhjxowp0rUeOXIkt912G6GhoURFRV3yREFubi6jRo3CZrMhIkyePLnI4Yu7776b++67j3nz5rF8+XLXdB8fHxYuXMiIESPIyckhOjqaiRMnXlyEW1XZID5RUVFSngeA5l+Ssm7dusqpUC22b98+brjhhiqtw/bt25kzZ06Rs6plkX/5grsvUbjYhQsXiImJYePGjW67PENVXFn2e3HfbWPMNhGJKm55vbVMXRFdunShd+/e5ObmVnVVSnT06FFeeuklDbtaTPesumLGjRtX1VUoVfv27Wnfvn1VV0NVIm3hWURVHbpQqrJU5DutgWcBPj4+nD59WkNP1RrifB5eeS//0S6tBQQFBZGUlERycnJVV6VCTpw4ATgugVDWcan9nv/E4/LQwLMALy+vcj0Vtrp54IEHAD1DbzWVsd+1S6uUsgwNPKWUZWjgKaUsQwNPKWUZGnhKKcvQwFNKWYYGnlLKMjTwlFKWoYGnlLIMDTyllGVo4CmlLEMDTyllGRp4SinL0MBTSlmGBp5SyjI08JRSlqGBp5SyDA08pZRlaOAppSxDA08pZRkaeEopy9DAU0pZhgaeUsoyNPCUUpahgaeUsgwNPKWUZWjgKaUsQwNPKWUZGnhKKcvQwFNKWYYGnlLKMjTwlFKWoYGnlLIMDTyllGVo4CmlLEMDTyllGRp4SinL0MBTSlmGBp5SyjI08JRSlqGBp5SyDA08pZRlaOAppSxDA08pZRkaeEopy9DAU0pZhgaeUsoyNPCUUpahgaeUsgwNPKWUZWjgKaUsQwNPKWUZGnhKKcvQwFNKWYYGnlLKMjTwlFKWoYGnlLIMDTyllGVo4CmlLEMDTyllGRp4SinL0MBTSlmGBp5SyjI08JRSlqGBp5SyDA08pZRlaOAppSxDA08pZRkaeEopy9DAU0pZhgaeUsoyNPCUUpahgaeUsgwNPKWUZWjgKaUsQwNPKWUZGnhKKcvQwFNKWYYGnlLKMjTwlFKWoYGnlLIMDTyllGVo4CmlLEMDTyllGRp4SinL0MBTSlmGBp5SyjLqVHUFlAIQEXJycrhw4QJZWVlkZGSQk5ODiJCVlQXAiRMn8PPzw9vbG29vbzw89P9rVT4aeKpKZWdnY7PZSElJITc31zXdy8vLFWgiAsCZM2dITk7GGANAvXr1aNiwIb6+vq5pSpVGA09VifT0dE6fPo3NZsMYg6+vL56ensUumx9mfn5+rmkiQlpaGufOnaNu3bo0btyYevXqaatPlUoDT11ROTk5nDx5kjNnzuDl5UVAQECFWmf5IQlgt9tJSkoiJSWFoKAg13SlLqaBp66YtLQ0EhMTyc3NrXDQFcfLywsvLy8uXLjAzz//zDXXXEOjRo20taeK0MBTlU5ESE5O5sSJE/j6+lZaC6xu3bp4eXlx8uRJUlNTadWqFXXq6Fdc/Y/+F6gqlYhw4sQJTpw4QUBAAF5eXpVanoeHB4GBgWRlZXHo0CHsdnullqdqFg08VWlEhFOnTpGcnExgYOAV7WL6+/tjt9s5cuQIOTk5V6xcVb1p4KlKc+7cOU6ePOnW43Xl4efnR3Z2NomJia5LW5S1aeCpSpGdnc2xY8fw9/ev0pMHfn5+pKamcvbs2Sqrg6o+NPCU24kIx48fx8PDo8Rr664kf39/fv31V7Kzs6u6KqqKaeApt7PZbNhstkIXClclT09PPD09OXbsmHZtLU4DT7lVfuuuuoRdPl9fX9LS0sjMzKzqqqgqpIGn3Co9PR273V4tr3+rU6cOZ86cqepqqCqkgafcKiUlBW9v76quRrF8fHw4d+6cXptnYRp4ym2ys7NJTU2lbt26VV2VYuVfGmOz2aq4JqqqaOApt8k/PladH9Xk7e1NampqVVdDVRENPFVIdm52mc9kigjZuf+71CMzM7NaHrsrqE6dOmRkZOjZWovSwFMu2bnZDPlgCFO/mnrJQBARpn41lSEfDHGFXlpaWrUPPA8PD0dQ6zV5lqSBp1y8PLy4odENzP1+bqmhlx92c7+fyw2NbsDLw4u8vDyysrKqfeABGngWdsnAM8YsMMacMsbsLmG+McbMM8YcMMb8YIzp4v5qqivBGMPs/rOZcuOUEkOvYNhNuXEKs/vPxhjjejx7dT5+l69gfS/FZrMxecpUPv74YzIyMiq5ZqqylaWFtwgYUMr8gUB758/9wN8vv1qqqpQWeiWFXf68miQvL69My+3bt49/LFrCH554kYZNmjLwtqF88MEHnD9/vpJrqCrDJfsfIrLeGNO6lEVuBxaL4xu/2RhT3xjTTESOu6mO6grLDz2Aud/PBWB2/9klhl1NVNbAA/Bv2BT/YTPxybCx9cD3xL8wj3ET7qfbTTfz+3vu5Pbbb6dBgwaVWFvlLu444NICSCzwOsk5TQOvBisUev/8G3MHznXNm+v8V5CXlxcdOnQgPT3d7XU5ftzxVerdu7db1hcQEEBSUhLnzp0r83v8AU+/egSE9YOwfvheSOeHA1t49NV3ue/+P9Dzllv55suVbqmfqjxX9AizMeZ+HN1err322itZtKqA/NCbu2Ku47+wPOjSrPhDtMYYmjVrxoULF9xej/wLhdu3b++W9dWtW5errrrKNd5taY4cOcL5jMLbJHm5ZJ86BMkHyTpxgGtatKRrdKRb6qYqlzsC7xjQssDrIOe0IkTkbeBtgKioqJp10MeC8o/Z0RS4zzGt5409S+zO/vTTT3h4eLj9TO3dd98NwNtvv+2W9aWmptK+fXt8fHwuuezmzZu5beQEJC+XrKO7yP1lM1k/b6Zp0yaMvHsEd42YRadOnWp0995K3PHNXAH80RjzIXAjYNPjdzVfcSco8l8DxYaev78/qamp1frSFBHBGFOu+33PHz9M+ltjaHVtK0bdcyd3jpjtttamurIu+c00xnwA9AIaGWOSgBmAF4CIzAdWAYOAA0AGMLayKquujJLOxhZ3IqNg6Pn5+VX7Jwvn5OTg4+NT5qcwR0VF8c8F/6Bbt260atWqkmunKltZztLec4n5AkxyW41UlSrt0pNLhZ6Pj0+1vzwlOzubhg0blnn5OnXquLrUquarvn0PdcWVFnb5Sgs9Hx8ffHx8sNvtlT4cY0WICHl5edSvX7+qq6KqiAaecrHn2dmXsu+S19kVDL19Kfuw59nx9vTGGEPjxo1JTEysloGXnZ2Nv79/mU5WqNpJA0+5eHt6s+KeFXh5eF3yrGN+6OWHXb788Wfz8vKqdLSy4mRnZ9O8efOqroaqQtXrG6mqXH5LrSyMMYXCDhwD5jRu3Lja3XeanZ2Nt7c3/v7+VV0VVYU08JTbNWzYEG9v72rzRBIRISsri6CgoGrX6lRXlu595Xaenp4EBQWRlZVVrntWK0t6ejqNGzfW1p3SwFOVw8/PjyZNmlTKvbXlceHCBby8vGjSpEmV1kNVDxp4qtI0adKEq666irS0tCopPzs7m5ycHFq1aoWnp2eV1EFVLxp4qtJ4eHgQFBSEv7//FQ+97OxssrOzadOmjV6Golw08FSl8vT05NprryUwMJDU1NQrckwvMzOTnJwc2rZti5+fX6WXp2oODTxV6Tw9PWnZsiVNmzYlLS2tUh4hBZCbm8v58+fx8fGhXbt2+Pr6Vko5qubSC4/VFeHh4UGTJk0IDAwkKSmJ1NRU/Pz83HJsTUTIzMwkLy+PoKAgGjRooI9rUsXSwFNXlK+vL23btuXMmTMkJydjt9upW7cu3t5lv+A5X05ODpmZmRhjqF+/Po0bN6Zu3bqVVHNVG2jgqSvOw8ODRo0acfXVV5Oenk5ycjJpaWkYYzDG4OXlRZ06dYpcJJyTk4PdbneNOObp6UmzZs2oV69etbx3V1U/Gniqynh4eBAYGEhgYCDZ2dlcuHCBzMxM0tPTycjIcJ3gKDikYr169fD398fb25u6devqnROqXDTwVLXg7e2Nt7c3gYGBrmkigoi47pBo165dVVVP1RIaeKrayu/iKuUu2h9QSlmGBp5SyjI08JRSlqGBp5SyDA08pZRlaOAppSxDA08pZRkaeEopy9DAU0pZhgaeUsoyNPCUUpahgaeUsgwNPKWUZWjgKaUsQwNPKWUZGnhKKcvQwFNKWYYGnlLKMjTwlFKWoYGnlLIMDTyllGVo4CmlLEMDTyllGRp4SinL0MBTSlmGBp5SyjI08JRSlqGBp5SyDA08pZRlaOAppSxDA08pZRkaeEopy6jRgde6dWtSUlJKXWbRokX8+uuvFS4jISGBVatWles9vXr1Ij4+vsJlFhQfH8/kyZMBuHDhArfeeisRERF89NFHTJgwgb17917WOtetW8d///tft9RVqequTlVXoLItWrSIkJAQmjdvXqH3JyQkEB8fz6BBg9xcs7KJiooiKioKgB07drjqBHDXXXdd9jrXrVtHQEAA3bt3d0NtlareanQLL9/hw4e54YYbuO+++wgODqZfv35kZmayfPly4uPjGTlyJBEREWRmZrJt2zZiYmKIjIykf//+HD9+HHC0yh5//HF+85vf0KFDBzZs2EB2djbPPvssH330katVVVBubi6PPPIIISEhhIWF8dprrxWp2wMPPEBUVBTBwcHMmDHDNX369Ol06tSJsLAwHnnkEQCWLVtGSEgI4eHh9OzZE3AEUmxsLKdOnWLUqFFs3bqViIgIDh48WKgl+eWXX9KlSxfCw8Pp06cPAFu2bKFbt2507tyZ7t27s3///kLrPHz4MPPnz2fOnDlERESwYcMGDh8+zC233EJYWBh9+vTh6NGjAIwZM4bJkyfTvXt3rrvuOpYvX+7OXajUlSEiVfITGRkp5RETEyMxMTGFprVq1UqSk5Pl0KFD4unpKTt27BARkREjRsiSJUtc79u6dauIiGRnZ0u3bt3k1KlTIiLy4YcfytixY13LTZ06VUREVq5cKX369BERkYULF8qkSZOKrdObb74pv/3tb8Vut4uIyOnTp4uUmT8tJydHYmJiZOfOnZKSkiIdOnSQvLw8ERE5e/asiIiEhIRIUlJSoWlxcXEyePDgIr8XLOfUqVMSFBQkv/zyS6EybTabq25ff/21DB8+vMh6ZsyYIa+88oprnbGxsbJo0SIREXn33Xfl9ttvFxGR0aNHyx133CG5ubmyZ88eadu2bbGfSWUobt+r2q+i+x2IlxJyp9Z0adu0aUNERAQAkZGRHD58uMgy+/fvZ/fu3fTt2xdwtNCaNWvmmj98+PBS33+xNWvWMHHiROrUcXyMV199dZFl/vWvf/H222+Tk5PD8ePH2bt3L506dcLHx4fx48cTGxtLbGwsAD169GDMmDHceeedrrqUxebNm+nZsydt2rQpVA+bzcbo0aP5+eefMcZgt9svua5Nmzbx8ccfA3Dvvffy2GOPueYNHToUDw8POnXqxMmTJ8tcP6Wqi1oTeHXr1nX97unpSWZmZm6hBQEAABFpSURBVJFlRITg4GA2bdpU6jo8PT3Jycm57DodOnSIV199la1bt9KgQQPGjBlDVlYWderUYcuWLXzzzTcsX76c119/nbVr1zJ//ny+//57Vq5cSWRkJNu2bbus8p955hl69+7NJ598wuHDh+nVq9dlra/gZ+z4j1SpmqVWHMMrTWBgIKmpqQB07NiR5ORkV+DZ7Xb27NlT5vdfrG/fvrz11luucDxz5kyh+efPn8ff35969epx8uRJvvjiCwDS0tKw2WwMGjSIOXPmsHPnTgAOHjzIjTfeyPPPP0/jxo1JTEws0zZ27dqV9evXc+jQoUL1sNlstGjRAnCcvCnL9nXv3p0PP/wQgKVLl3LzzTeXqQ5K1QS1PvDGjBnDxIkTiYiIIDc3l+XLl/P4448THh5ORETEJS/J6N27N3v37i32pMWECRO49tprCQsLIzw8nPfff7/Q/PDwcDp37sz111/P7373O3r06AFAamoqsbGxhIWFcdNNNzF79mwAHn30UUJDQwkJCaF79+6Eh4eXaRsbN27M22+/zfDhwwkPD3edvX3sscd44okn6Ny5c4kt1ttuu41PPvnEddLitddeY+HChYSFhbFkyRL+9re/lakOStUEpqq6JlFRUVKea9Xyu2Pr1q2rnAqpakv3vTVVdL8bY7aJSFRx82p9C08ppfJp4CmlLEMDTyllGRp4SinL0MCroHXr1lGvXj0iIiKIiIjg1ltvLXXZ/IuLlVJVp9ZceFwVbr75Zj7//POqroZSqoxqdAsvICCAhx9+mODgYPr06UNycjLgeJpI165dCQsLY9iwYZw9exaAefPmuW7Yv/vuu0tc70MPPcTzzz8PwFdffUXPnj3Jy8u7ZH1Kulm/oG+//dbVKuzcubProt9XXnmF6OhowsLCCj1kQCnlRiXdZFvZP+54eAAg7733noiIzJw503WTf2hoqKxbt05ERJ555hl56KGHRESkWbNmkpWVJSL/uzm/OOnp6dKpUydZu3atdOjQQQ4cOCAiIp9++qk888wzIuK4Af+qq66S8PBwCQ8Pl1mzZpXpZv3Y2FjZuHGjiIikpqaK3W6Xr776Su677z7Jy8uT3NxcGTx4sHz77bfl+nxqM314gDXpwwMu4uHh4bqrYNSoUQwfPhybzca5c+eIiYkBYPTo0YwYMQKAsLAwRo4cydChQxk6dGiJ6/Xz8+Odd96hZ8+ezJkzh7Zt2wIwZMgQhgwZ4lru4i5tYmLiJW/W79GjB1OnTmXkyJEMHz6coKAgVq9ezerVq+ncuTPguPXs559/dj0iSinlHjW6S3sxY0yp81euXMmkSZPYvn070dHRpT4gYNeuXTRs2LBcT0vOv1l/9+7dfPbZZ2RlZRVZZvr06fzjH/8gMzOTHj168OOPPyIiPPHEEyQkJJCQkMCBAwcYP358mctVSpVNjQ68vLw814Mo33//fW666Sbq1atHgwYN2LBhAwBLliwhJiaGvLw8EhMT6d27Ny+//DI2m420tLRi13vkyBH++te/smPHDr744gu+//77MtWnLDfrHzx4kNDQUB5//HGio6P58ccf6d+/PwsWLHDV59ixY5w6dao8H4VSqgxqdJfW39+fLVu2MGvWLJo0aeK6uf+f//wnEydOJCMjg+uuu46FCxeSm5vLqFGjsNlsiAiTJ0+mfv36RdYpIowfP55XX32V5s2b8+677zJmzBi2bt3K6tWriY+Pd53QuNhjjz3G6NGjmTVrFoMHDy52mblz5xIXF4eHhwfBwcEMHDiQunXrsm/fPrp16wY4Tsa89957NGnSxE2flFIKavjDAwICAkpspanaQx8eYE368ACllLoMNbpLe7mtu4ULFxZ53luPHj144403Lmu9SqnqqUYH3uUaO3YsY8eOrepqKKWuEO3SXqbWrVsTGhrqunuitCcol2XgcKVU5bF0C89d4uLiaNSoUVVXQyl1CTW6hRcQEMBTTz1FeHg4Xbt2dQ0dmJyczG9/+1uio6OJjo7mu+++c03v27cvwcHBTJgwgVatWpXY4tq6dSthYWFkZWWRnp5OcHAwu3fvLlO9hg4dSmRkJMHBwbz99ttF5qenpzN48GDCw8MJCQlxXU5T0iDhSin3qNGBl56eTteuXdm5cyc9e/bknXfeARw3/z/88MNs3bqVf//730yYMAGAmTNncsstt7Bnzx7uuOMOjh49WuK6o6OjGTJkCE8//TSPPfYYo0aNIiQkBMA1/m2+3r17ExERwY033gjAggUL2LZtG/Hx8cybN4/Tp08XWv7LL7+kefPm7Ny5k927dzNgwADsdjsPPvggy5cvZ9u2bYwbN46nnnrKbZ+VUqqGd2m9vb1dz5mLjIzk66+/BhwDZO/du9e13Pnz50lLS2Pjxo188sknAAwYMIAGDRqUuv5nn32W6OhofHx8mDdvnmt6QkJCoeUu7tLOmzfPVU5iYiI///wzDRs2dM0PDQ1l2rRpPP7448TGxnLzzTeze/fuUgcJV0pdvhodeF5eXq77ZwsOnp2Xl8fmzZvx8fG5rPWfPn2atLQ07HY7WVlZ+Pv7X/I969atY82aNWzatAk/Pz969epV5J7aDh06sH37dlatWsXTTz9Nnz59GDZsWKmDhCulLl+N7tKWpF+/frz22muu1/ktsh49evCvf/0LgNWrV7uek1eSP/zhD7zwwguMHDmSxx9/vExl22w2GjRogJ+fHz/++CObN28ussyvv/6Kn58fo0aN4tFHH2X79u0VGiRcKVU+tTLw5s2bR3x8PGFhYXTq1In58+cDMGPGDFavXk1ISAjLli3jmmuuITAwsNh1LF68GC8vL373u98xffp0tm7dytq1a4Gix/AKGjBgADk5Odxwww1Mnz6drl27Fllm165d/OY3vyEiIoKZM2fy9NNP4+3tXe5BwpVS5VOj76UtrwsXLuDp6UmdOnXYtGkTDzzwQJHjcar60Xtpraky7qWt0cfwyuvo0aPceeed5OXl4e3t7Tqrq5SyBksFXvv27dmxY0ehaadPn6ZPnz5Flv3mm28KnVlVStV8lgq84jRs2FC7tUpZRK08aaGUUsWxfAuvojw9PQkNDXW9/s9//kPr1q2LXVYfVKpU9aCBV0G+vr7aFVaqhqnRXdrKGIg7Ly+P9u3bu9aVl5dHu3btXK9LkpaWRp8+fejSpQuhoaF8+umnRZY5fvw4PXv2JCIigpCQENdAQ6tXr6Zbt2506dKFESNGaGtQqUpSowMvPT2dqKgo9uzZQ0xMDDNnzgTg97//PS+//DI//PADoaGhrukvvfQSO3bs4IcffnBdjHwxDw8PRo0axdKlSwHHfbnh4eHY7XYGDRrkWi4zM9P1DLxhw4bh4+PDJ598wvbt24mLi2PatGlcfI3j+++/T//+/UlISGDnzp1ERESQkpLCrFmzWLNmDdu3bycqKorZs2dXxsellOXV6C5tZQ3EPW7cOG6//XamTJnCggULGDt2LM2bN2fVqlWuZS7u0trtdp588knWr1+Ph4cHx44d4+TJk1xzzTWuZaKjoxk3bhx2u52hQ4cSERHBt99+y969e+nRowcA2dnZrtHLlFLuVaNbeBdz10DcLVu2pGnTpqxdu5YtW7YwcODAS5a9dOlSkpOT2bZtGwkJCTRt2rTIQwN69uzJ+vXradGiBWPGjGHx4sWICH379nUNwr13717efffdsm+0UqrManTgVdZA3AATJkxg1KhRjBgxAk9Pz0vWxWaz0aRJE7y8vIiLi+PIkSNFljly5AhNmzblvvvuY8KECWzfvp2uXbvy3XffceDAAcDRTf/pp58q8nEopS6hRndpK2Mg7nxDhgwpNMjPr7/+yoQJEwp1awsaOXIkt912G6GhoURFRXH99dcXWWbdunW88soreHl5ERAQwOLFi2ncuDGLFi3innvu4cKFCwDMmjWLDh06XO7Ho5S6SI1+eEBlXt8WHx/Pww8/7GopqqqjDw+wJn14wBXy0ksv8fe//911plYpVTvU6MCrzIG4p0+fflnrVkpVPzU68C6XDsStlLXU6LO0VW3MmDG0adPGdQFywYF+ils2/4yyUqpqWLqF5w6vvPIKd9xxR1VXQylVBjW6hVeZA3E/++yzzJ071/X6qaeeKnK8rzjPP/880dHRhISEcP/99xe5vQxg+vTprnt6H3nkkVLrrJRyIxGpkp/IyEgpj5iYGImJiSk0DZAVK1aIiMijjz4qL7zwgoiI3HPPPbJhwwYRETly5Ihcf/31IiIyadIkefHFF0VE5IsvvhBAkpOTiy3v0KFD0rlzZxERyc3Nleuuu05SUlJk/PjxsnXrVhERGT16tLRu3VrCw8MlPDxcfvjhBzl9+rRrHaNGjXLVb/To0bJs2TJJSUmRDh06SF5enoiInD17ttQ6q+L3var9KrrfgXgpIXdqdJe2Mgfibt26NQ0bNmTHjh2cPHmSzp0707BhQ/7xj38UWu7iLu2///1v/vKXv5CRkcGZM2cIDg7mtttuc82vV68ePj4+jB8/ntjYWFf9S6pzQEBART8epdRFanTgVfZA3BMmTGDRokWcOHGCcePGXXL5rKws/u///o/4+HhatmzJc889V+R+2jp16rBlyxa++eYbli9fzuuvv87atWvdVmelVMlq9DG8krhrIO5hw4bx5ZdfsnXrVvr373/JcvPDrVGjRqSlpRV7VjYtLQ2bzcagQYOYM2cOO3fuLLXOSin3qZWB546BuMHRZe7duzd33nmn6wECEyZMoKRb4urXr899991HSEgI/fv3Jzo6usgyqampxMbGEhYWxk033eR69l1JdVZKuU+Nvpe2vMo7EHdeXh5dunRh2bJltG/fvsLlqsuj99Jak95Le5nKMxD33r17iY2NZdiwYRp2StUSlgq88g7E/csvv1ypqimlrgBLBV5xdCBupayjVp60UEqp4mjgKaUsQwNPKWUZGnhKKctwW+AZYwYYY/YbYw4YY/RxwUqpasctgWeM8QTeAAYCnYB7jDGd3LFupZRyF3ddlvIb4ICI/AJgjPkQuB3YW+q7yikhIcF19bWyjvzLhnTfW0tCQgIRERFuXae7urQtgMQCr5Oc0woxxtxvjIk3xsQnJyeXq4B169a5feNVzRAQEKCPybKgiIgIt99OeEUvPBaRt4G3wXEvbXnfr/dSKqUuh7taeMeAlgVeBzmnKaVUteGuwNsKtDfGtDHGeAN3AyvctG6llHILt3RpRSTHGPNH4CvAE1ggInvcsW6llHIXtx3DE5FVwCp3rU8ppdxN77RQSlmGBp5SyjI08JRSlqGBp5SyDA08pZRlaOAppSxDA08pZRlVNi6tMSYZOFLOtzUCUiqhOqr6031vTRXZ761EpHFxM6os8CrCGBNf0gC7qnbTfW9N7t7v2qVVSlmGBp5SyjJqWuC9XdUVUFVG9701uXW/16hjeEopdTlqWgtPKaUqrEYEng4BaQ3GmAXGmFPGmN0lzDfGmHnO78EPxpguV7qOyv2MMS2NMXHGmL3GmD3GmIeKWcYt+77aB54OAWkpi4ABpcwfCLR3/twP/P0K1ElVvhxgmoh0AroCk4r5G3fLvq/2gUeBISBFJBvIHwJS1TIish44U8oitwOLxWEzUN8Y0+zK1E5VFhE5LiLbnb+nAvsoOuqhW/Z9TQi8Mg0BqSxBvwu1nDGmNdAZ+P6iWW7Z9zUh8JRSFmCMCQD+DUwRkfOVUUZNCDwdAlLl0+9CLWWM8cIRdktF5ONiFnHLvq8JgadDQKp8K4DfO8/YdQVsInK8qiulLo8xxgDvAvtEZHYJi7ll37tt1LLKokNAWocx5gOgF9DIGJMEzAC8AERkPo5R8QYBB4AMYGzV1FS5WQ/gXmCXMSbBOe1J4Fpw777XOy2UUpZRE7q0SinlFhp4SinL0MBTSlmGBp5SyjI08JRSlqGBp5SyDA08pZRlaOAppSzj/wEtGlA8tgNcyQAAAABJRU5ErkJggg==\n", + "text/plain": [ + "\u003cFigure size 360x360 with 1 Axes\u003e" + ] + }, + "metadata": { + "needs_background": "light", + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "#@title Plotting Model predictions.\n", + "\n", + "#@markdown Use affordance based model?\n", + "affordance_mask_threshold = 0.5 #@param {type:\"number\"}\n", + "network_seed = 0 #@param {type:\"integer\"}\n", + "\n", + "#@markdown What is the action?\n", + "action_x_dir = +0.5 #@param {type:\"number\"}\n", + "action_y_dir = 0.0 #@param {type:\"number\"}\n", + "\n", + "#@markdown Where is the agent?\n", + "agent_x = 0.75 #@param {type:\"number\"}\n", + "agent_y = 1.0 #@param {type:\"number\"}\n", + "\n", + "action = tf.constant([[action_x_dir, action_y_dir]])\n", + "pos = tf.constant([[agent_x, agent_y]])\n", + "\n", + "affordance_networks = all_affordance_global[network_seed]\n", + "model_networks = all_models_global[network_seed]\n", + "\n", + "scale_scale = 2.0\n", + "\n", + "for i, use_affordance_to_mask_model in enumerate([False, True]):\n", + " fig = plt.figure(figsize=(5, 5))\n", + " ax = fig.add_subplot(1, 1, 1)\n", + " transition_dist = model_networks[use_affordance_to_mask_model](pos, action)\n", + " transition_loc = tuple(transition_dist.loc[0].numpy())\n", + " transition_scale = tuple(transition_dist.scale[0].numpy() * scale_scale)\n", + "\n", + " if use_affordance_to_mask_model:\n", + " aff_network = affordance_networks[use_affordance_to_mask_model]\n", + " AF = aff_network(tf.concat([pos, action], axis=1))\n", + " intents_completable = (AF \u003e affordance_mask_threshold)[0].numpy()\n", + "\n", + " visualize_environment(\n", + " world,\n", + " ax,\n", + " scaling=1.0,\n", + " draw_start_mu=False,\n", + " draw_target_mu=False,\n", + " draw_agent=False,\n", + " agent_size=0.1,\n", + " write_text=False)\n", + " ax.scatter([agent_x], [agent_y], s=150.0, c='green', marker='x')\n", + " ax.arrow(agent_x, agent_y, action_x_dir, action_y_dir, head_width=0.05)\n", + "\n", + " if use_affordance_to_mask_model and not np.any(intents_completable):\n", + " color = 'gray'\n", + " alpha = 0.25\n", + " ellipse_text = '(Masked) '\n", + " else:\n", + " color = None\n", + " alpha = 0.7\n", + " ellipse_text = ''\n", + "\n", + " elipse = mpl.patches.Ellipse(\n", + " transition_loc, *transition_scale, alpha=alpha, color=color)\n", + " ax.add_artist(elipse)\n", + "\n", + " if use_affordance_to_mask_model:\n", + " string_built = ' Intent classificaiton\\n'\n", + " for a in list(zip(IntentName, intents_completable)):\n", + " string_built += ' ' + str(a[0])[-5:] + ':' + str(a[1])\n", + " string_built += '\\n'\n", + " ax.text(\n", + " 0,\n", + " 0,\n", + " string_built,\n", + " )\n", + "\n", + " ax.set_xticks([0.0, 1.0, 2.0])\n", + " ax.set_xticklabels([0, 1.0, 2.0])\n", + "\n", + " ax.set_yticks([0.0, 1.0, 2.0])\n", + " ax.set_yticklabels([0, 1.0, 2.0])\n", + "\n", + " if use_affordance_to_mask_model:\n", + " title = 'Using affordances'\n", + " else:\n", + " title = 'Without affordances'\n", + " ax.set_title(title)\n", + " ax.legend([elipse], [ellipse_text + 'Predicted transition'])\n", + " file_name = (f'./empirical_demo{movement_noise}_P{agent_x}_{agent_y}_'\n", + " f'F{action_x_dir}_{action_x_dir}.pdf')\n", + " fig.savefig(file_name)\n", + "\n", + "print(\n", + " 'Figures show the predicted position of the transition distribution.'\n", + " '\\nGray circle shows what would have been predicted but was masked by '\n", + " 'affordance model. ')" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "5GHbTbNClfL-" + }, + "outputs": [], + "source": [ + "" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "AffordancesInContinuousEnvironment.ipynb", + "provenance": [ + { + "file_id": "1W86NFSHwhnx-UEmAY_mhJJzUxPXY3JC4", + "timestamp": 1591715576521 + } + ] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/affordances_theory/README.md b/affordances_theory/README.md new file mode 100644 index 00000000..77eb8fda --- /dev/null +++ b/affordances_theory/README.md @@ -0,0 +1,5 @@ +# Code for "What can I do here? A theory of affordances in reinforcement Learning. + +This iPython notebook accompanies the paper "What can I do here? A theory of +affordances in reinforcmenet learning" and covers the experiments in Section 8. + diff --git a/affordances_theory/requirements.txt b/affordances_theory/requirements.txt new file mode 100644 index 00000000..089a23ce --- /dev/null +++ b/affordances_theory/requirements.txt @@ -0,0 +1,4 @@ +tensorflow==2.1.0 +tensorflow_probability=0.7.0 +matplotlib==3.1.2 +numpy==1.17.5