Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Latest commit

 

History

History
178 lines (148 loc) · 6.21 KB

README.md

File metadata and controls

178 lines (148 loc) · 6.21 KB

The Differentiable Cross-Entropy Method

This repository is by Brandon Amos and Denis Yarats and contains the PyTorch library and source code to reproduce the experiments in our ICML 2020 paper on The Differentiable Cross-Entropy Method. This repository depends on the Limited Multi-Label Projection Layer. Our code provides an implementation of the vanilla cross-entropy method for optimization and our differentiable extension. The core library source code is in dcem/; our experiments are in exp/, including the regression notebook and the action embedding notebook that produced most of the plots in our paper; basic usage examples of our code that are not published in our paper are in examples.ipynb; our slides are available here in pptx and pdf formats; and the full LaTeX source code for our paper is in paper/.

Setup

Once you have PyTorch setup, you can install our core code as a package with pip:

pip install git+git://github.com/facebookresearch/dcem.git

This should automatically install the Limited Multi-Label Projection Layer dependency.

Basic usage

Our core cross-entropy method implementation with the differentiable extension is available in dcem. We provide a lightweight wrapper for using CEM and DCEM in the control setting in dcem_ctrl. These can be imported as:

from dcem import dcem, dcem_ctrl

The interface for DCEM is:

dcem(
    f, # Objective to optimize
    nx, # Number of dimensions to optimize over
    n_batch, # Number of elements in the batch
    init_mu, # Initial mean
    init_sigma, # Initial variance
    n_sample, # Number of samples CEM uses in each iteration
    n_elite, # Number of elite CEM candidates in each iteration
    n_iter, # Number of CEM iterations
    temp, # DCEM temperature parameter, set to None for vanilla CEM
    iter_cb, # Iteration callback
)

And our control interface is:

dcem_ctrl(
    obs=obs, # Initial state
    plan_horizon, # Planning horizon for the control problem
    init_mu, # Initial control sequence mean, warm-starting can be done here
    init_sigma, # Initial variance around the control sequence
    n_sample, # Number of samples CEM uses in each iteration
    n_elite, # Number of elite CEM candidates in each iteration
    n_iter, # Number of CEM iterations
    n_ctrl, # Number of control dimensions
    lb, # Lower-bound of the control signal
    ub, # Upper-bound of the control signal
    temp, # DCEM temperature parameter, set to None for vanilla CEM
    rollout_cost, # Function that returns the cost of rollout out a control sequence
    iter_cb, # CEM iteration callback
)

Simple examples

examples.ipynb provides a light introduction for using our interface for simple optimization and control problems.

2d optimization

We first show how to use DCEM to optimize a 2-dimensional objective:

Next we parameterize that objective and show how DCEM can update the objective to move the minimum to a desired location:

Pendulum control

We show how to use CEM to solve a pendulum control problem, which can be made differentiable by setting a non-zero temperature for the soft top-k operation.

Reproducing our experimental results

We provide the source code for our cartpole and regression experiments in the exps directory. We do not have plans to open source our PlaNet and PPO experiment. One starting point is to use an existing PyTorch PlaNet implementation such as cross32768/PlaNet_PyTorch with a PyTorch PPO implementation such as ikostrikov/pytorch-a2c-ppo-acktr-gai or SAC implementation such as denisyarats/pytorch_sac.

1D energy-based regression

The base experimental code for our 1D energy-based regression experiment is in regression.py. Once running this, the results can be analyzed with regression-analysis.ipynb, which will produce:

Embedding actions in the cartpole

The base experimental code for our cartpole action embedding experiment is in cartpole_emb.py. Once running this, the results can be analyzed with cartpole_emb-analysis.ipynb, which will produce:

Citations

If you find this repository helpful in your publications, please consider citing our paper.

@inproceedings{amos2020differentiable,
  title={{The Differentiable Cross-Entropy Method}},
  author={Brandon Amos and Denis Yarats},
  booktitle={ICML},
  year={2020}
}

Licensing

This repository is licensed under the CC BY-NC 4.0 License.