Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implicit/truncated backward modes #29

Merged
merged 14 commits into from
Jan 19, 2022
Prev Previous commit
Next Next commit
move converged_indices from the info back into the optimizaiton loop
  • Loading branch information
bamos committed Jan 11, 2022
commit 58c5610dde7ae320c02e1f3813bd5fff644f2a51
23 changes: 12 additions & 11 deletions theseus/optimizer/nonlinear/nonlinear_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class NonlinearOptimizerStatus(Enum):
@dataclass
class NonlinearOptimizerInfo(OptimizerInfo):
converged_iter: torch.Tensor
converged_indices: torch.Tensor
best_iter: torch.Tensor
err_history: Optional[torch.Tensor]
last_err: torch.Tensor
Expand Down Expand Up @@ -115,7 +114,6 @@ def _init_info(
with torch.no_grad():
last_err = self.objective.error_squared_norm() / 2
best_err = last_err.clone() if track_best_solution else None
converged_indices = torch.zeros_like(last_err).bool()
if verbose:
err_history = (
torch.ones(self.objective.batch_size, self.params.max_iterations + 1)
Expand All @@ -127,7 +125,6 @@ def _init_info(
err_history = None
return NonlinearOptimizerInfo(
best_solution=self._maybe_init_best_solution(do_init=track_best_solution),
converged_indices=converged_indices,
last_err=last_err,
best_err=best_err,
status=np.array(
Expand All @@ -143,8 +140,9 @@ def _update_info(
info: NonlinearOptimizerInfo,
current_iter: int,
err: torch.Tensor,
converged_indices: torch.Tensor,
):
info.converged_iter += 1 - info.converged_indices.long()
info.converged_iter += 1 - converged_indices.long()
if info.err_history is not None:
assert err.grad_fn is None
info.err_history[:, current_iter + 1] = err.clone().cpu()
Expand All @@ -161,9 +159,9 @@ def _update_info(

info.best_err = torch.minimum(info.best_err, err)

info.converged_indices = self._check_convergence(err, info.last_err)
converged_indices = self._check_convergence(err, info.last_err)
info.status[
np.array(info.converged_indices.detach().cpu())
np.array(converged_indices.detach().cpu())
] = NonlinearOptimizerStatus.CONVERGED

# loop for the iterative optimizer
Expand All @@ -177,7 +175,7 @@ def _optimize_loop(
force_update,
**kwargs,
):

converged_indices = torch.zeros_like(info.last_err).bool()
for it_ in range(start_iter, start_iter + num_iter):
# do optimizer step
self.linear_solver.linearization.linearize()
Expand All @@ -200,19 +198,23 @@ def _optimize_loop(
return info

self.retract_and_update_variables(
delta, info.converged_indices, force_update=force_update
delta, converged_indices, force_update=force_update
)

# check for convergence
with torch.no_grad():
err = self.objective.error_squared_norm() / 2
self._update_info(info, it_, err)
self._update_info(info, it_, err, converged_indices)
if verbose:
print(
f"Nonlinear optimizer. Iteration: {it_+1}. "
f"Error: {err.mean().item()}"
)
if info.converged_indices.all():
converged_indices = self._check_convergence(err, info.last_err)
info.status[
converged_indices.cpu().numpy()
] = NonlinearOptimizerStatus.CONVERGED
if converged_indices.all():
break # nothing else will happen at this point
info.last_err = err

Expand Down Expand Up @@ -293,7 +295,6 @@ def _optimize_impl(
| (grad_loop_info.status[M] == NonlinearOptimizerStatus.CONVERGED)
)
info.status[M] = grad_loop_info.status[M]
info.converged_indices[M] = grad_loop_info.converged_indices[M]
info.converged_iter[M] = (
info.converged_iter[M] + grad_loop_info.converged_iter[M]
)
Expand Down