Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed TheseusLayer.forward() to receive optimizer_kwargs as a single dict #45

Merged
merged 4 commits into from
Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
rev: v0.910
hooks:
- id: mypy
additional_dependencies: [torch==1.9.0, tokenize-rt==3.2.0, types-PyYAML]
additional_dependencies: [torch==1.9.0, tokenize-rt==3.2.0, types-PyYAML, types-mock]
args: [--no-strict-optional, --ignore-missing-imports]
exclude: setup.py

Expand Down
56 changes: 34 additions & 22 deletions examples/backward_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,11 @@ def quad_error_fn(optim_vars, aux_vars):
theseus_optim = th.TheseusLayer(optimizer)
updated_inputs, info = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.FULL,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.FULL,
},
)

# The quadratic \hat y is now fit and we can also use Theseus
Expand All @@ -98,9 +100,11 @@ def quad_error_fn(optim_vars, aux_vars):
# forward again and changing the backward_mode flag.
updated_inputs, info = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.IMPLICIT,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.IMPLICIT,
},
)

da_dx = torch.autograd.grad(updated_inputs["a"], data_x, retain_graph=True)[0].squeeze()
Expand All @@ -110,10 +114,12 @@ def quad_error_fn(optim_vars, aux_vars):
# We can also use truncated unrolling to compute the derivative:
updated_inputs, info = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.TRUNCATED,
backward_num_iterations=5,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.TRUNCATED,
"backward_num_iterations": 5,
},
)

da_dx = torch.autograd.grad(updated_inputs["a"], data_x, retain_graph=True)[0].squeeze()
Expand All @@ -127,8 +133,8 @@ def fit_x(data_x_np):
theseus_inputs["x"] = (
torch.from_numpy(data_x_np).float().clone().requires_grad_().unsqueeze(0)
)
updated_inputs, info = theseus_optim.forward(
theseus_inputs, track_best_solution=True, verbose=False
updated_inputs, _ = theseus_optim.forward(
theseus_inputs, optimizer_kwargs={"track_best_solution": True, "verbose": False}
)
return updated_inputs["a"].item()

Expand All @@ -150,9 +156,11 @@ def fit_x(data_x_np):
start = time.time()
updated_inputs, info = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.FULL,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.FULL,
},
)
times["fwd"].append(time.time() - start)

Expand All @@ -164,9 +172,11 @@ def fit_x(data_x_np):

updated_inputs, info = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.IMPLICIT,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.IMPLICIT,
},
)
start = time.time()
da_dx = torch.autograd.grad(updated_inputs["a"], data_x, retain_graph=True)[
Expand All @@ -176,10 +186,12 @@ def fit_x(data_x_np):

updated_inputs, info = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.TRUNCATED,
backward_num_iterations=5,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.TRUNCATED,
"backward_num_iterations": 5,
},
)
start = time.time()
da_dx = torch.autograd.grad(updated_inputs["a"], data_x, retain_graph=True)[
Expand Down
10 changes: 7 additions & 3 deletions examples/motion_planning_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,13 @@ def run_learning_loop(cfg):

_, info = motion_planner.layer.forward(
planner_inputs,
track_best_solution=True,
verbose=cfg.verbose,
**cfg.optim_params.kwargs,
optimizer_kwargs={
**{
"track_best_solution": True,
"verbose": cfg.verbose,
},
**cfg.optim_params.kwargs,
},
)
if cfg.do_learning and cfg.include_imitation_loss:
solution_trajectory = motion_planner.get_trajectory()
Expand Down
8 changes: 5 additions & 3 deletions examples/state_estimation_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,12 @@ def cost_weights_model():
print("Initial error:", objective.error_squared_norm().mean().item())

for i in range(inner_loop_iters):
theseus_inputs, info = state_estimator.forward(
theseus_inputs, _ = state_estimator.forward(
theseus_inputs,
track_best_solution=True,
verbose=epoch % 10 == 0,
optimizer_kwargs={
"track_best_solution": True,
"verbose": epoch % 10 == 0,
},
)
theseus_inputs = run_model(
mode_,
Expand Down
4 changes: 3 additions & 1 deletion examples/tactile_pose_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,9 @@ def run_learning_loop(cfg):
(sdf_tensor.data).repeat(batch_size, 1, 1).to(device)
)

theseus_inputs, _ = theseus_layer.forward(theseus_inputs, verbose=True)
theseus_inputs, _ = theseus_layer.forward(
theseus_inputs, optimizer_kwargs={"verbose": True}
)

obj_poses_opt = theg.get_tactile_poses_from_values(
batch_size=batch_size,
Expand Down
1 change: 1 addition & 0 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ nox==2020.8.22
pre-commit>=2.9.2
isort>=5.6.4
types-PyYAML==5.4.3
types-mock>=4.0.8
3 changes: 2 additions & 1 deletion requirements/main.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ scikit-sparse>=0.4.5
# torch>=1.7.1 will do separate install instructions for now (CUDA dependent)
pytest>=6.2.1
numdifftools>=0.9.40
pybind11>=2.7.1
pybind11>=2.7.1
mock>=4.0.3
1 change: 1 addition & 0 deletions theseus/core/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ def _get_batch_size(batch_sizes: Sequence[int]) -> int:
return max_bs
raise ValueError("Provided data tensors must be broadcastable.")

input_data = input_data or {}
for var_name, data in input_data.items():
if data.ndim < 2:
raise ValueError(
Expand Down
40 changes: 22 additions & 18 deletions theseus/optimizer/nonlinear/tests/test_backwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ def fit_x(data_x_np):
theseus_inputs["x"] = (
torch.from_numpy(data_x_np).float().clone().requires_grad_().unsqueeze(0)
)
updated_inputs, info = theseus_optim.forward(
theseus_inputs, track_best_solution=True, verbose=False
updated_inputs, _ = theseus_optim.forward(
theseus_inputs,
optimizer_kwargs={"track_best_solution": True, "verbose": False},
)
return updated_inputs["a"].item()

Expand All @@ -79,39 +80,42 @@ def fit_x(data_x_np):
da_dx_numeric = torch.from_numpy(dfit_x(data_x_np)).float()

theseus_inputs["x"] = data_x
updated_inputs, info = theseus_optim.forward(
updated_inputs, _ = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.FULL,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.FULL,
},
)
da_dx_full = torch.autograd.grad(updated_inputs["a"], data_x, retain_graph=True)[
0
].squeeze()
assert torch.allclose(da_dx_numeric, da_dx_full, atol=1e-3)

updated_inputs, info = theseus_optim.forward(
updated_inputs, _ = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.IMPLICIT,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.IMPLICIT,
},
)
da_dx_implicit = torch.autograd.grad(
updated_inputs["a"], data_x, retain_graph=True
)[0].squeeze()
assert torch.allclose(da_dx_numeric, da_dx_implicit, atol=1e-4)

updated_inputs, info = theseus_optim.forward(
updated_inputs, _ = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.TRUNCATED,
backward_num_iterations=5,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.TRUNCATED,
"backward_num_iterations": 5,
},
)
da_dx_truncated = torch.autograd.grad(
updated_inputs["a"], data_x, retain_graph=True
)[0].squeeze()
assert torch.allclose(da_dx_numeric, da_dx_truncated, atol=1e-4)


test_backwards()
91 changes: 84 additions & 7 deletions theseus/tests/test_theseus_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import math

import mock
import pytest # noqa: F401
import torch
import torch.nn as nn
Expand Down Expand Up @@ -214,7 +215,7 @@ def _run_optimizer_test(
with torch.no_grad():
input_values = {"coefficients": torch.ones(batch_size, 2, device=device) * 0.75}
target_vars, _ = layer_ref.forward(
input_values, verbose=verbose, **optimizer_kwargs
input_values, optimizer_kwargs={**optimizer_kwargs, **{"verbose": verbose}}
)

# Now create another that starts with a random cost weight and use backpropagation to
Expand Down Expand Up @@ -275,7 +276,9 @@ def cost_weight_fn():
}

with torch.no_grad():
pred_vars, info = layer_to_learn.forward(input_values, **optimizer_kwargs)
pred_vars, info = layer_to_learn.forward(
input_values, optimizer_kwargs=optimizer_kwargs
)
loss0 = F.mse_loss(
pred_vars["coefficients"], target_vars["coefficients"]
).item()
Expand All @@ -294,7 +297,7 @@ def cost_weight_fn():
cost_weight_param_name: cost_weight_fn(),
}
pred_vars, info = layer_to_learn.forward(
input_values, verbose=verbose, **optimizer_kwargs
input_values, optimizer_kwargs={**optimizer_kwargs, **{"verbose": verbose}}
)
assert not (
(info.status == th.NonlinearOptimizerStatus.START)
Expand Down Expand Up @@ -433,14 +436,14 @@ def test_send_to_device():
xs = torch.linspace(0, 10, num_points).repeat(batch_size, 1)
ys = model(xs, torch.ones(batch_size, 2))

objective = create_qf_theseus_layer(xs, ys)
layer = create_qf_theseus_layer(xs, ys)
input_values = {"coefficients": torch.ones(batch_size, 2, device=device) * 0.5}
with torch.no_grad():
if device != "cpu":
with pytest.raises(RuntimeError):
objective.forward(input_values)
objective.to(device)
output_values, _ = objective.forward(input_values)
layer.forward(input_values)
layer.to(device)
output_values, _ = layer.forward(input_values)
for k, v in output_values.items():
assert v.device == input_values[k].device

Expand Down Expand Up @@ -470,3 +473,77 @@ def _do_check(layer_, optimizer_):
optimizer = th.GaussNewton(objective, th.CholeskyDenseSolver)
objective.erase(cost_functions[0].name)
_do_check(layer, optimizer)


def test_pass_optimizer_kwargs():
# Create the dataset to fit, model(x) is the true data generation process
batch_size = 16
num_points = 10
xs = torch.linspace(0, 10, num_points).repeat(batch_size, 1)
ys = model(xs, torch.ones(batch_size, 2))

layer = create_qf_theseus_layer(
xs,
ys,
nonlinear_optimizer_cls=th.GaussNewton,
linear_solver_cls=th.CholmodSparseSolver,
)
layer.to("cpu")
input_values = {"coefficients": torch.ones(batch_size, 2) * 0.5}
for tbs in [True, False]:
_, info = layer.forward(
input_values, optimizer_kwargs={"track_best_solution": tbs}
)
if tbs:
assert (
isinstance(info.best_solution, dict)
and "coefficients" in info.best_solution
)
else:
assert info.best_solution is None

# Pass invalid backward mode to trigger exception
with pytest.raises(ValueError):
layer.forward(input_values, optimizer_kwargs={"backward_mode": -1})

# Now test that compute_delta() args passed correctly
# Path compute_delta() to receive args we control
def _mock_compute_delta(cls, fake_arg=None, **kwargs):
if fake_arg is not None:
raise ValueError
return layer.optimizer.linear_solver.solve()

with mock.patch.object(th.GaussNewton, "compute_delta", _mock_compute_delta):
layer_2 = create_qf_theseus_layer(xs, ys)
layer_2.forward(input_values)
# If fake_arg is passed correctly, the mock of compute_delta will trigger
with pytest.raises(ValueError):
layer_2.forward(input_values, {"fake_arg": True})


def test_no_layer_kwargs():
# Create the dataset to fit, model(x) is the true data generation process
batch_size = 16
num_points = 10
xs = torch.linspace(0, 10, num_points).repeat(batch_size, 1)
ys = model(xs, torch.ones(batch_size, 2))

layer = create_qf_theseus_layer(
xs,
ys,
nonlinear_optimizer_cls=th.GaussNewton,
linear_solver_cls=th.CholmodSparseSolver,
)
layer.to("cpu")
input_values = {"coefficients": torch.ones(batch_size, 2) * 0.5}

# Trying a few variations of aux_vars. In general, no kwargs should be accepted
# beyong input_data and optimization_kwargs, but I'm not sure how to test for this
with pytest.raises(TypeError):
layer.forward(input_values, aux_vars=None)

with pytest.raises(TypeError):
layer.forward(input_values, aux_variables=None)

with pytest.raises(TypeError):
layer.forward(input_values, auxiliary_vars=None)
Loading