Skip to content

Commit

Permalink
Initial implicit/truncated backward modes (facebookresearch#29)
Browse files Browse the repository at this point in the history
* Initial WIP commit of implicit/truncated backward modes

* spacing

* add numdifftools requirement

* fix mypy and GPU issues

* import BackwardMode as part of the main thesus module

* add ValueError messages

* add comments to backward_modes and add it to examples/README

* Remove error_increase_induces

* move converged_indices from the info back into the optimizaiton loop

* fix gradient scaling for facebookresearch#39

* update backward tests

* add type hints/remove unused track_best_solution

* remove erroneous update
  • Loading branch information
bamos authored Jan 19, 2022
1 parent a6f847d commit a150672
Show file tree
Hide file tree
Showing 7 changed files with 474 additions and 51 deletions.
4 changes: 3 additions & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@ learn the cost weight as a function of pose.
problem, inspired by [Bhardwaj et al. 2020](https://arxiv.org/pdf/1907.09591.pdf).
- tactile_pose_estimation.py: Is an example of how to set up learning models for
tactile pose estimation, as described in [Sodhi et al. 2021](https://arxiv.org/abs/1705.10664)
- backward_modes.py: Shows how to compute derivatives through Theseus solves and switch between backward modes.

These can be run from your root `theseus` directory by doing

python examples/state_estimation_2d.py
python examples/motion_planning_2d.py
python examples/tactile_pose_estimation.py
python examples/backward_modes.py

The motion planning and tactile estimation examples require `hydra` installation which you can obtain
by running.

pip install hydra-core

Any outputs generated by these scripts will be saved under `examples/outputs`. You can
change this directory by passing the CLI option `hydra.run.dir=<your_directory>`
change this directory by passing the CLI option `hydra.run.dir=<your_directory>`
204 changes: 204 additions & 0 deletions examples/backward_modes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# This example illustrates the three backward modes (FULL, IMPLICIT, and TRUNCATED)
# on a problem fitting a quadratic to data.

import torch
import theseus as th

import numpy as np
import numdifftools as nd

from collections import defaultdict
import time

torch.manual_seed(0)


# Sample from a quadratic y = ax^2 + b*noise
def generate_data(num_points=10, a=1.0, b=0.5, noise_factor=0.01):
data_x = torch.rand((1, num_points))
noise = torch.randn((1, num_points)) * noise_factor
data_y = a * data_x.square() + b + noise
return data_x, data_y


num_points = 10
data_x, data_y = generate_data(num_points)
x = th.Variable(data_x.requires_grad_(), name="x")
y = th.Variable(data_y.requires_grad_(), name="y")

# We now attempt to recover the quadratic from the data with
# theseus by formulating it as a non-linear least squares
# optimization problem.
# We write the model as \hat y = \hat a x^2 + \hat b,
# where the parameters \hat a and \hat b are just `a` and `b`
# in the code here.
a = th.Vector(1, name="a")
b = th.Vector(1, name="b")


# The error is y - \hat y
def quad_error_fn(optim_vars, aux_vars):
a, b = optim_vars
x, y = aux_vars
est = a.data * x.data.square() + b.data
err = y.data - est
return err


# We then use Theseus to optimize \hat a and \hat b so that
# y = \hat y for all datapoints
optim_vars = [a, b]
aux_vars = [x, y]
cost_function = th.AutoDiffCostFunction(
optim_vars, # type: ignore
quad_error_fn,
num_points,
aux_vars=aux_vars,
name="quadratic_cost_fn",
)
objective = th.Objective()
objective.add(cost_function)
optimizer = th.GaussNewton(
objective,
max_iterations=15,
step_size=0.5,
)

theseus_inputs = {
"a": 2 * torch.ones((1, 1)).requires_grad_(),
"b": torch.ones((1, 1)).requires_grad_(),
"x": data_x,
"y": data_y,
}
theseus_optim = th.TheseusLayer(optimizer)
updated_inputs, info = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.FULL,
)

# The quadratic \hat y is now fit and we can also use Theseus
# to obtain the adjoint derivatives of \hat a with respect
# to other inputs or hyper-parameters, such as the data itself.
# Here we compute the derivative of \hat a with respect to the data,
# i.e. \partial a / \partial x using the full backward mode.
da_dx = torch.autograd.grad(updated_inputs["a"], data_x, retain_graph=True)[0].squeeze()

print("--- backward_mode=FULL")
print(da_dx.numpy())

# We can also compute this using implicit differentiation by calling
# forward again and changing the backward_mode flag.
updated_inputs, info = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.IMPLICIT,
)

da_dx = torch.autograd.grad(updated_inputs["a"], data_x, retain_graph=True)[0].squeeze()
print("\n--- backward_mode=IMPLICIT")
print(da_dx.numpy())

# We can also use truncated unrolling to compute the derivative:
updated_inputs, info = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.TRUNCATED,
backward_num_iterations=5,
)

da_dx = torch.autograd.grad(updated_inputs["a"], data_x, retain_graph=True)[0].squeeze()

print("\n--- backward_mode=TRUNCATED, backward_num_iterations=5")
print(da_dx.numpy())


# Next we numerically check the derivative
def fit_x(data_x_np):
theseus_inputs["x"] = (
torch.from_numpy(data_x_np).float().clone().requires_grad_().unsqueeze(0)
)
updated_inputs, info = theseus_optim.forward(
theseus_inputs, track_best_solution=True, verbose=False
)
return updated_inputs["a"].item()


data_x_np = data_x.detach().clone().numpy()
dfit_x = nd.Gradient(fit_x)
g = dfit_x(data_x_np)

print("\n--- Numeric derivative")
print(g)

theseus_inputs["x"] = data_x

# Next we run 10 trials of these computations and report the runtime
# of the forward and backward passes.
n_trials = 10
times = defaultdict(list)
for trial in range(n_trials + 1):
start = time.time()
updated_inputs, info = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.FULL,
)
times["fwd"].append(time.time() - start)

start = time.time()
da_dx = torch.autograd.grad(updated_inputs["a"], data_x, retain_graph=True)[
0
].squeeze()
times["bwd"].append(time.time() - start)

updated_inputs, info = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.IMPLICIT,
)
start = time.time()
da_dx = torch.autograd.grad(updated_inputs["a"], data_x, retain_graph=True)[
0
].squeeze()
times["bwd_impl"].append(time.time() - start)

updated_inputs, info = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.TRUNCATED,
backward_num_iterations=5,
)
start = time.time()
da_dx = torch.autograd.grad(updated_inputs["a"], data_x, retain_graph=True)[
0
].squeeze()
times["bwd_trunc"].append(time.time() - start)


print("\n=== Runtimes")
k = "fwd"
print(f"Forward: {np.mean(times[k]):.2e} s +/- {np.std(times[k]):.2e} s")

k = "bwd"
print(f"Backward (FULL): {np.mean(times[k]):.2e} s +/- {np.std(times[k]):.2e} s")

k = "bwd_impl"
print(f"Backward (IMPLICIT) {np.mean(times[k]):.2e} s +/- {np.std(times[k]):.2e} s")

k = "bwd_trunc"
print(
f"Backward (TRUNCATED, 5 steps) {np.mean(times[k]):.2e} s +/- {np.std(times[k]):.2e} s"
)
3 changes: 2 additions & 1 deletion requirements/main.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ scipy>=1.5.3
scikit-sparse>=0.4.5
# torch>=1.7.1 will do separate install instructions for now (CUDA dependent)
pytest>=6.2.1
pybind11>=2.7.1
numdifftools>=0.9.40
pybind11>=2.7.1
1 change: 1 addition & 0 deletions theseus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
NonlinearLeastSquares,
NonlinearOptimizerParams,
NonlinearOptimizerStatus,
BackwardMode,
)
from .theseus_layer import TheseusLayer

Expand Down
1 change: 1 addition & 0 deletions theseus/optimizer/nonlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .levenberg_marquardt import LevenbergMarquardt
from .nonlinear_least_squares import NonlinearLeastSquares
from .nonlinear_optimizer import (
BackwardMode,
NonlinearOptimizer,
NonlinearOptimizerParams,
NonlinearOptimizerStatus,
Expand Down
Loading

0 comments on commit a150672

Please sign in to comment.