Skip to content

Commit

Permalink
Release of option_keyboard code.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 310175317
  • Loading branch information
shaobohou authored and diegolascasas committed May 7, 2020
1 parent 6f14cb5 commit 391bc47
Show file tree
Hide file tree
Showing 15 changed files with 1,342 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ https://deepmind.com/research/publications/

## Projects

* [The Option Keyboard: Combining Skills in Reinforcement Learning](option_keyboard), NeurIPS 2019
* [VISR - Fast Task Inference with Variational Intrinsic Successor Features](visr), ICLR 2020
* [Unveiling the predictive power of static structure in glassy systems](glassy_dynamics), Nature Physics 2020
* [Multi-Object Representation Learning with Iterative Variational Inference (IODINE)](iodine)
Expand Down
58 changes: 58 additions & 0 deletions option_keyboard/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# The Option Keyboard: Combining Skills in Reinforcement Learning

This directory contains an implementation of the Option Keyboard framework.

From the [abstract](http://papers.nips.cc/paper/9463-the-option-keyboard-combining-skills-in-reinforcement-learning):

> The ability to combine known skills to create new ones may be crucial in the
solution of complex reinforcement learning problems that unfold over extended
periods. We argue that a robust way of combining skills is to define and manipulate
them in the space of pseudo-rewards (or “cumulants”). Based on this premise, we
propose a framework for combining skills using the formalism of options. We show
that every deterministic option can be unambiguously represented as a cumulant
defined in an extended domain. Building on this insight and on previous results
on transfer learning, we show how to approximate options whose cumulants are
linear combinations of the cumulants of known options. This means that, once we
have learned options associated with a set of cumulants, we can instantaneously
synthesise options induced by any linear combination of them, without any learning
involved. We describe how this framework provides a hierarchical interface to the
environment whose abstract actions correspond to combinations of basic skills.
We demonstrate the practical benefits of our approach in a resource management
problem and a navigation task involving a quadrupedal simulated robot.

If you use the code here please cite this paper

> Andre Barreto, Diana Borsa, Shaobo Hou, Gheorghe Comanici, Eser Aygün, Philippe Hamel, Daniel Toyama, Jonathan hunt, Shibl Mourad, David Silver, Doina Precup. *The Option Keyboard: Combining Skills in Reinforcement Learning*. Neurips 2019. [\[paper\]](https://papers.nips.cc/paper/9463-the-option-keyboard-combining-skills-in-reinforcement-learning).
## Running the code

### Setup
```
python3 -m venv ok_venv
source ok_venv/bin/activate
pip install -r option_keyboard/requirements.txt
```

### Scavenger Task
All agents are trained on a simple grid-world resource collection task. There
are two types of collectible objects in the world: if the agent collects the
object that is less abundant of the two then it receives a reward of -1,
otherwise it receives a reward of +1 when it collects the object. See section
5.1 in the paper for more details.

### Train the DQN baseline
```
python3 -m option_keyboard.run_dqn
```
This trains a DQN agent on the scavenger task.

### Train the Option Keyboard and agent
```
python3 -m option_keyboard.run_ok
```
This first trains an Option Keyboard on the cumulants in the task environment.
Then it trains a DQN agent on the true task reward using high level abstract
actions provided by the keyboard.

## Disclaimer
This is not an official Google or DeepMind product.
58 changes: 58 additions & 0 deletions option_keyboard/auto_reset_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Auto-resetting environment base class.
The environment API states that stepping an environment after a LAST timestep
should return the first timestep of a new episode.
However, environment authors sometimes don't spot this part or find it awkward
to implement. This module contains a class that helps implement the reset
behaviour.
"""

import abc
import dm_env


class Base(dm_env.Environment):
"""This class implements the required `step()` and `reset()` methods.
It instead requires users to implement `_step()` and `_reset()`. This class
handles the reset behaviour automatically when it detects a LAST timestep.
"""

def __init__(self):
self._reset_next_step = True

@abc.abstractmethod
def _reset(self):
"""Returns a `timestep` namedtuple as per the regular `reset()` method."""

@abc.abstractmethod
def _step(self, action):
"""Returns a `timestep` namedtuple as per the regular `step()` method."""

def reset(self):
self._reset_next_step = False
return self._reset()

def step(self, action):
if self._reset_next_step:
return self.reset()
timestep = self._step(action)
self._reset_next_step = timestep.last()
return timestep
41 changes: 41 additions & 0 deletions option_keyboard/configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Environment configurations."""


def get_task_config():
return dict(
arena_size=11,
num_channels=2,
max_num_steps=50, # 5o for the actual task.
num_init_objects=10,
object_priors=[0.5, 0.5],
egocentric=True,
rewarder="BalancedCollectionRewarder",
)


def get_pretrain_config():
return dict(
arena_size=11,
num_channels=2,
max_num_steps=40, # 40 for pretraining.
num_init_objects=10,
object_priors=[0.5, 0.5],
egocentric=True,
default_w=(1, 1),
)
160 changes: 160 additions & 0 deletions option_keyboard/dqn_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DQN agent."""

import numpy as np
import sonnet as snt
import tensorflow.compat.v1 as tf


class Agent():
"""A DQN Agent."""

def __init__(
self,
obs_spec,
action_spec,
network_kwargs,
epsilon,
additional_discount,
batch_size,
optimizer_name,
optimizer_kwargs,
):
"""A simple DQN agent.
Args:
obs_spec: The observation spec.
action_spec: The action spec.
network_kwargs: Keyword arguments for snt.nets.MLP
epsilon: Exploration probability.
additional_discount: Discount on returns used by the agent.
batch_size: Size of update batch.
optimizer_name: Name of an optimizer from tf.train
optimizer_kwargs: Keyword arguments for the optimizer.
"""

self._epsilon = epsilon
self._additional_discount = additional_discount
self._batch_size = batch_size

self._n_actions = action_spec.num_values
self._network = ValueNet(self._n_actions, network_kwargs=network_kwargs)

self._replay = []

obs_spec = self._extract_observation(obs_spec)

# Placeholders for policy
o = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype)
q = self._network(tf.expand_dims(o, axis=0))

# Placeholders for update.
o_tm1 = tf.placeholder(shape=(None,) + obs_spec.shape, dtype=obs_spec.dtype)
a_tm1 = tf.placeholder(shape=(None,), dtype=tf.int32)
r_t = tf.placeholder(shape=(None,), dtype=tf.float32)
d_t = tf.placeholder(shape=(None,), dtype=tf.float32)
o_t = tf.placeholder(shape=(None,) + obs_spec.shape, dtype=obs_spec.dtype)

# Compute values over all options.
q_tm1 = self._network(o_tm1)
q_t = self._network(o_t)

a_t = tf.cast(tf.argmax(q_t, axis=-1), tf.int32)
qa_tm1 = _batched_index(q_tm1, a_tm1)
qa_t = _batched_index(q_t, a_t)

# TD error
g = additional_discount * d_t
td_error = tf.stop_gradient(r_t + g * qa_t) - qa_tm1
loss = tf.reduce_sum(tf.square(td_error) / 2)

with tf.variable_scope("optimizer"):
self._optimizer = getattr(tf.train, optimizer_name)(**optimizer_kwargs)
train_op = self._optimizer.minimize(loss)

# Make session and callables.
session = tf.Session()
self._update_fn = session.make_callable(train_op,
[o_tm1, a_tm1, r_t, d_t, o_t])
self._value_fn = session.make_callable(q, [o])
session.run(tf.global_variables_initializer())

def _extract_observation(self, obs):
return obs["arena"]

def step(self, timestep, is_training=False):
"""Select actions according to epsilon-greedy policy."""

if is_training and np.random.rand() < self._epsilon:
return np.random.randint(self._n_actions)

q_values = self._value_fn(
self._extract_observation(timestep.observation))
return int(np.argmax(q_values))

def update(self, step_tm1, action, step_t):
"""Takes in a transition from the environment."""

transition = [
self._extract_observation(step_tm1.observation),
action,
step_t.reward,
step_t.discount,
self._extract_observation(step_t.observation),
]
self._replay.append(transition)

if len(self._replay) == self._batch_size:
batch = list(zip(*self._replay))
self._update_fn(*batch)
self._replay = [] # Just a queue.


class ValueNet(snt.AbstractModule):
"""Value Network."""

def __init__(self,
n_actions,
network_kwargs,
name="value_network"):
"""Construct a value network sonnet module.
Args:
n_actions: Number of actions.
network_kwargs: Network arguments.
name: Name
"""
super(ValueNet, self).__init__(name=name)
self._n_actions = n_actions
self._network_kwargs = network_kwargs

def _build(self, observation):
flat_obs = snt.BatchFlatten()(observation)
net = snt.nets.MLP(**self._network_kwargs)(flat_obs)
net = snt.Linear(output_size=self._n_actions)(net)

return net

@property
def num_actions(self):
return self._n_actions


def _batched_index(values, indices):
one_hot_indices = tf.one_hot(indices, values.shape[-1], dtype=values.dtype)
return tf.reduce_sum(values * one_hot_indices, axis=-1)
Loading

0 comments on commit 391bc47

Please sign in to comment.