Skip to content

Commit

Permalink
Added a fallback for implicit mode in case Gauss-Newton fails in the …
Browse files Browse the repository at this point in the history
…last step. (facebookresearch#579)
  • Loading branch information
luisenp authored Jul 18, 2023
1 parent abc8201 commit e910656
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
29 changes: 29 additions & 0 deletions tests/theseus_tests/optimizer/nonlinear/test_backwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,32 @@ def error_fn(optim_vars, aux_vars):
# Equality should hold exactly even in floating point
# because of how the derivatives cancel
assert da_dx.item() == 1.5


def test_implicit_fallback_linear_solver():
# Create a singular system that can only be solved if damping added
x = th.Vector(2, name="x")
t = th.Vector(2, name="t")

o = th.Objective()
w = th.DiagonalCostWeight(torch.FloatTensor([1, 0]).view(1, 2))
o.add(th.Difference(x, t, w, name="cost"))
opt = th.TheseusLayer(th.LevenbergMarquardt(o, max_iterations=5))

input_dict = {"x": torch.ones(1, 2), "t": torch.zeros(1, 2)}

# __strict_implicit_final_gn__ = True shows that this problem leads to errors
with pytest.raises(RuntimeError):
opt.forward(
input_dict,
optimizer_kwargs={
"damping": 0.1,
"backward_mode": "implicit",
"__strict_implicit_final_gn__": True,
},
)
# No error is raised by default
opt.forward(
input_dict,
optimizer_kwargs={"damping": 0.1, "backward_mode": "implicit"},
)
9 changes: 7 additions & 2 deletions theseus/optimizer/nonlinear/nonlinear_least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,15 @@ def _optimize_loop(
# Well, technically full Newton, this is hard to implement and GN
# is working well so far.
#
# We also need to detach the hessian when computing
# As shown above, we also need to detach the hessian when computing
# linearization above, as higher order terms introduce errors
# in the derivative if the fixed point is not accurate enough.
delta = self.linear_solver.solve()
try:
delta = self.linear_solver.solve()
except RuntimeError as e: # fallback to regular step if GN fails
if kwargs.get("__strict_implicit_final_gn__", False):
raise e
delta = self.compute_delta(**kwargs)
else:
delta = self.compute_delta(**kwargs)
except RuntimeError as run_err:
Expand Down

0 comments on commit e910656

Please sign in to comment.