forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
15 changed files
with
1,342 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# The Option Keyboard: Combining Skills in Reinforcement Learning | ||
|
||
This directory contains an implementation of the Option Keyboard framework. | ||
|
||
From the [abstract](http://papers.nips.cc/paper/9463-the-option-keyboard-combining-skills-in-reinforcement-learning): | ||
|
||
> The ability to combine known skills to create new ones may be crucial in the | ||
solution of complex reinforcement learning problems that unfold over extended | ||
periods. We argue that a robust way of combining skills is to define and manipulate | ||
them in the space of pseudo-rewards (or “cumulants”). Based on this premise, we | ||
propose a framework for combining skills using the formalism of options. We show | ||
that every deterministic option can be unambiguously represented as a cumulant | ||
defined in an extended domain. Building on this insight and on previous results | ||
on transfer learning, we show how to approximate options whose cumulants are | ||
linear combinations of the cumulants of known options. This means that, once we | ||
have learned options associated with a set of cumulants, we can instantaneously | ||
synthesise options induced by any linear combination of them, without any learning | ||
involved. We describe how this framework provides a hierarchical interface to the | ||
environment whose abstract actions correspond to combinations of basic skills. | ||
We demonstrate the practical benefits of our approach in a resource management | ||
problem and a navigation task involving a quadrupedal simulated robot. | ||
|
||
If you use the code here please cite this paper | ||
|
||
> Andre Barreto, Diana Borsa, Shaobo Hou, Gheorghe Comanici, Eser Aygün, Philippe Hamel, Daniel Toyama, Jonathan hunt, Shibl Mourad, David Silver, Doina Precup. *The Option Keyboard: Combining Skills in Reinforcement Learning*. Neurips 2019. [\[paper\]](https://papers.nips.cc/paper/9463-the-option-keyboard-combining-skills-in-reinforcement-learning). | ||
## Running the code | ||
|
||
### Setup | ||
``` | ||
python3 -m venv ok_venv | ||
source ok_venv/bin/activate | ||
pip install -r option_keyboard/requirements.txt | ||
``` | ||
|
||
### Scavenger Task | ||
All agents are trained on a simple grid-world resource collection task. There | ||
are two types of collectible objects in the world: if the agent collects the | ||
object that is less abundant of the two then it receives a reward of -1, | ||
otherwise it receives a reward of +1 when it collects the object. See section | ||
5.1 in the paper for more details. | ||
|
||
### Train the DQN baseline | ||
``` | ||
python3 -m option_keyboard.run_dqn | ||
``` | ||
This trains a DQN agent on the scavenger task. | ||
|
||
### Train the Option Keyboard and agent | ||
``` | ||
python3 -m option_keyboard.run_ok | ||
``` | ||
This first trains an Option Keyboard on the cumulants in the task environment. | ||
Then it trains a DQN agent on the true task reward using high level abstract | ||
actions provided by the keyboard. | ||
|
||
## Disclaimer | ||
This is not an official Google or DeepMind product. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Lint as: python3 | ||
# pylint: disable=g-bad-file-header | ||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
"""Auto-resetting environment base class. | ||
The environment API states that stepping an environment after a LAST timestep | ||
should return the first timestep of a new episode. | ||
However, environment authors sometimes don't spot this part or find it awkward | ||
to implement. This module contains a class that helps implement the reset | ||
behaviour. | ||
""" | ||
|
||
import abc | ||
import dm_env | ||
|
||
|
||
class Base(dm_env.Environment): | ||
"""This class implements the required `step()` and `reset()` methods. | ||
It instead requires users to implement `_step()` and `_reset()`. This class | ||
handles the reset behaviour automatically when it detects a LAST timestep. | ||
""" | ||
|
||
def __init__(self): | ||
self._reset_next_step = True | ||
|
||
@abc.abstractmethod | ||
def _reset(self): | ||
"""Returns a `timestep` namedtuple as per the regular `reset()` method.""" | ||
|
||
@abc.abstractmethod | ||
def _step(self, action): | ||
"""Returns a `timestep` namedtuple as per the regular `step()` method.""" | ||
|
||
def reset(self): | ||
self._reset_next_step = False | ||
return self._reset() | ||
|
||
def step(self, action): | ||
if self._reset_next_step: | ||
return self.reset() | ||
timestep = self._step(action) | ||
self._reset_next_step = timestep.last() | ||
return timestep |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Lint as: python3 | ||
# pylint: disable=g-bad-file-header | ||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
"""Environment configurations.""" | ||
|
||
|
||
def get_task_config(): | ||
return dict( | ||
arena_size=11, | ||
num_channels=2, | ||
max_num_steps=50, # 5o for the actual task. | ||
num_init_objects=10, | ||
object_priors=[0.5, 0.5], | ||
egocentric=True, | ||
rewarder="BalancedCollectionRewarder", | ||
) | ||
|
||
|
||
def get_pretrain_config(): | ||
return dict( | ||
arena_size=11, | ||
num_channels=2, | ||
max_num_steps=40, # 40 for pretraining. | ||
num_init_objects=10, | ||
object_priors=[0.5, 0.5], | ||
egocentric=True, | ||
default_w=(1, 1), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# Lint as: python3 | ||
# pylint: disable=g-bad-file-header | ||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
"""DQN agent.""" | ||
|
||
import numpy as np | ||
import sonnet as snt | ||
import tensorflow.compat.v1 as tf | ||
|
||
|
||
class Agent(): | ||
"""A DQN Agent.""" | ||
|
||
def __init__( | ||
self, | ||
obs_spec, | ||
action_spec, | ||
network_kwargs, | ||
epsilon, | ||
additional_discount, | ||
batch_size, | ||
optimizer_name, | ||
optimizer_kwargs, | ||
): | ||
"""A simple DQN agent. | ||
Args: | ||
obs_spec: The observation spec. | ||
action_spec: The action spec. | ||
network_kwargs: Keyword arguments for snt.nets.MLP | ||
epsilon: Exploration probability. | ||
additional_discount: Discount on returns used by the agent. | ||
batch_size: Size of update batch. | ||
optimizer_name: Name of an optimizer from tf.train | ||
optimizer_kwargs: Keyword arguments for the optimizer. | ||
""" | ||
|
||
self._epsilon = epsilon | ||
self._additional_discount = additional_discount | ||
self._batch_size = batch_size | ||
|
||
self._n_actions = action_spec.num_values | ||
self._network = ValueNet(self._n_actions, network_kwargs=network_kwargs) | ||
|
||
self._replay = [] | ||
|
||
obs_spec = self._extract_observation(obs_spec) | ||
|
||
# Placeholders for policy | ||
o = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype) | ||
q = self._network(tf.expand_dims(o, axis=0)) | ||
|
||
# Placeholders for update. | ||
o_tm1 = tf.placeholder(shape=(None,) + obs_spec.shape, dtype=obs_spec.dtype) | ||
a_tm1 = tf.placeholder(shape=(None,), dtype=tf.int32) | ||
r_t = tf.placeholder(shape=(None,), dtype=tf.float32) | ||
d_t = tf.placeholder(shape=(None,), dtype=tf.float32) | ||
o_t = tf.placeholder(shape=(None,) + obs_spec.shape, dtype=obs_spec.dtype) | ||
|
||
# Compute values over all options. | ||
q_tm1 = self._network(o_tm1) | ||
q_t = self._network(o_t) | ||
|
||
a_t = tf.cast(tf.argmax(q_t, axis=-1), tf.int32) | ||
qa_tm1 = _batched_index(q_tm1, a_tm1) | ||
qa_t = _batched_index(q_t, a_t) | ||
|
||
# TD error | ||
g = additional_discount * d_t | ||
td_error = tf.stop_gradient(r_t + g * qa_t) - qa_tm1 | ||
loss = tf.reduce_sum(tf.square(td_error) / 2) | ||
|
||
with tf.variable_scope("optimizer"): | ||
self._optimizer = getattr(tf.train, optimizer_name)(**optimizer_kwargs) | ||
train_op = self._optimizer.minimize(loss) | ||
|
||
# Make session and callables. | ||
session = tf.Session() | ||
self._update_fn = session.make_callable(train_op, | ||
[o_tm1, a_tm1, r_t, d_t, o_t]) | ||
self._value_fn = session.make_callable(q, [o]) | ||
session.run(tf.global_variables_initializer()) | ||
|
||
def _extract_observation(self, obs): | ||
return obs["arena"] | ||
|
||
def step(self, timestep, is_training=False): | ||
"""Select actions according to epsilon-greedy policy.""" | ||
|
||
if is_training and np.random.rand() < self._epsilon: | ||
return np.random.randint(self._n_actions) | ||
|
||
q_values = self._value_fn( | ||
self._extract_observation(timestep.observation)) | ||
return int(np.argmax(q_values)) | ||
|
||
def update(self, step_tm1, action, step_t): | ||
"""Takes in a transition from the environment.""" | ||
|
||
transition = [ | ||
self._extract_observation(step_tm1.observation), | ||
action, | ||
step_t.reward, | ||
step_t.discount, | ||
self._extract_observation(step_t.observation), | ||
] | ||
self._replay.append(transition) | ||
|
||
if len(self._replay) == self._batch_size: | ||
batch = list(zip(*self._replay)) | ||
self._update_fn(*batch) | ||
self._replay = [] # Just a queue. | ||
|
||
|
||
class ValueNet(snt.AbstractModule): | ||
"""Value Network.""" | ||
|
||
def __init__(self, | ||
n_actions, | ||
network_kwargs, | ||
name="value_network"): | ||
"""Construct a value network sonnet module. | ||
Args: | ||
n_actions: Number of actions. | ||
network_kwargs: Network arguments. | ||
name: Name | ||
""" | ||
super(ValueNet, self).__init__(name=name) | ||
self._n_actions = n_actions | ||
self._network_kwargs = network_kwargs | ||
|
||
def _build(self, observation): | ||
flat_obs = snt.BatchFlatten()(observation) | ||
net = snt.nets.MLP(**self._network_kwargs)(flat_obs) | ||
net = snt.Linear(output_size=self._n_actions)(net) | ||
|
||
return net | ||
|
||
@property | ||
def num_actions(self): | ||
return self._n_actions | ||
|
||
|
||
def _batched_index(values, indices): | ||
one_hot_indices = tf.one_hot(indices, values.shape[-1], dtype=values.dtype) | ||
return tf.reduce_sum(values * one_hot_indices, axis=-1) |
Oops, something went wrong.