Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implicit/truncated backward modes #29

Merged
merged 14 commits into from
Jan 19, 2022
Prev Previous commit
Next Next commit
add comments to backward_modes and add it to examples/README
  • Loading branch information
bamos committed Jan 11, 2022
commit 656587de9dd1dd4fe49c838d42a6bd062f23e4f7
4 changes: 3 additions & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@ learn the cost weight as a function of pose.
problem, inspired by [Bhardwaj et al. 2020](https://arxiv.org/pdf/1907.09591.pdf).
- tactile_pose_estimation.py: Is an example of how to set up learning models for
tactile pose estimation, as described in [Sodhi et al. 2021](https://arxiv.org/abs/1705.10664)
- backward_modes.py: Shows how to compute derivatives through Theseus solves and switch between backward modes.

These can be run from your root `theseus` directory by doing

python examples/state_estimation_2d.py
python examples/motion_planning_2d.py
python examples/tactile_pose_estimation.py
python examples/backward_modes.py

The motion planning and tactile estimation examples require `hydra` installation which you can obtain
by running.

pip install hydra-core

Any outputs generated by these scripts will be saved under `examples/outputs`. You can
change this directory by passing the CLI option `hydra.run.dir=<your_directory>`
change this directory by passing the CLI option `hydra.run.dir=<your_directory>`
29 changes: 23 additions & 6 deletions examples/backward_modes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
#!/usr/bin/env python3
#
# This example illustrates the three backward modes (FULL, IMPLICIT, and TRUNCATED)
# on a problem fitting a quadratic to data.
#
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# This example illustrates the three backward modes (FULL, IMPLICIT, and TRUNCATED)
# on a problem fitting a quadratic to data.

import torch
import theseus as th
Expand All @@ -20,6 +19,7 @@
torch.manual_seed(0)


# Sample from a quadratic y = ax^2 + b*noise
def generate_data(num_points=10, a=1.0, b=0.5, noise_factor=0.01):
data_x = torch.rand((1, num_points))
noise = torch.randn((1, num_points)) * noise_factor
Expand All @@ -29,13 +29,20 @@ def generate_data(num_points=10, a=1.0, b=0.5, noise_factor=0.01):

num_points = 10
data_x, data_y = generate_data(num_points)

x = th.Variable(data_x.requires_grad_(), name="x")
y = th.Variable(data_y.requires_grad_(), name="y")

# We now attempt to recover the quadratic from the data with
# theseus by formulating it as a non-linear least squares
# optimization problem.
# We write the model as \hat y = \hat a x^2 + \hat b,
# where the parameters \hat a and \hat b are just `a` and `b`
# in the code here.
a = th.Vector(1, name="a")
b = th.Vector(1, name="b")


# The error is y - \hat y
def quad_error_fn(optim_vars, aux_vars):
a, b = optim_vars
x, y = aux_vars
Expand All @@ -44,6 +51,8 @@ def quad_error_fn(optim_vars, aux_vars):
return err


bamos marked this conversation as resolved.
Show resolved Hide resolved
# We then use Theseus to optimize \hat a and \hat b so that
# y = \hat y for all datapoints
optim_vars = [a, b]
aux_vars = [x, y]
cost_function = th.AutoDiffCostFunction(
Expand Down Expand Up @@ -75,12 +84,18 @@ def quad_error_fn(optim_vars, aux_vars):
backward_mode=th.BackwardMode.FULL,
)

# The quadratic \hat y is now fit and we can also use Theseus
# to obtain the adjoint derivatives of \hat a with respect
# to other inputs or hyper-parameters, such as the data itself.
# Here we compute the derivative of \hat a with respect to the data,
# i.e. \partial a / \partial x using the full backward mode.
da_dx = torch.autograd.grad(updated_inputs["a"], data_x, retain_graph=True)[0].squeeze()

print("--- backward_mode=FULL")
print(da_dx.numpy())


# We can also compute this using implicit differentiation by calling
# forward again and changing the backward_mode flag.
updated_inputs, info = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
Expand All @@ -92,6 +107,7 @@ def quad_error_fn(optim_vars, aux_vars):
print("\n--- backward_mode=IMPLICIT")
print(da_dx.numpy())

# We can also use truncated unrolling to compute the derivative:
updated_inputs, info = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
Expand All @@ -106,6 +122,7 @@ def quad_error_fn(optim_vars, aux_vars):
print(da_dx.numpy())


# Next we numerically check the derivative
def fit_x(data_x_np):
theseus_inputs["x"] = (
torch.from_numpy(data_x_np).float().clone().requires_grad_().unsqueeze(0)
Expand Down