Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implicit/truncated backward modes #29

Merged
merged 14 commits into from
Jan 19, 2022
Prev Previous commit
Next Next commit
Remove error_increase_induces
  • Loading branch information
bamos committed Jan 11, 2022
commit cd5a311dda5de0abc0ae76c704acf7f9b7994876
7 changes: 0 additions & 7 deletions theseus/optimizer/nonlinear/nonlinear_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class NonlinearOptimizerStatus(Enum):
class NonlinearOptimizerInfo(OptimizerInfo):
converged_iter: torch.Tensor
converged_indices: torch.Tensor
error_increase_indices: torch.Tensor
best_iter: torch.Tensor
err_history: Optional[torch.Tensor]
last_err: torch.Tensor
Expand Down Expand Up @@ -117,7 +116,6 @@ def _init_info(
last_err = self.objective.error_squared_norm() / 2
best_err = last_err.clone() if track_best_solution else None
converged_indices = torch.zeros_like(last_err).bool()
error_increase_indices = torch.zeros_like(last_err).bool()
if verbose:
err_history = (
torch.ones(self.objective.batch_size, self.params.max_iterations + 1)
Expand All @@ -132,7 +130,6 @@ def _init_info(
converged_indices=converged_indices,
last_err=last_err,
best_err=best_err,
error_increase_indices=error_increase_indices,
status=np.array(
[NonlinearOptimizerStatus.START] * self.objective.batch_size
),
Expand Down Expand Up @@ -165,7 +162,6 @@ def _update_info(
info.best_err = torch.minimum(info.best_err, err)

info.converged_indices = self._check_convergence(err, info.last_err)
info.converged_indices &= ~info.error_increase_indices
info.status[
np.array(info.converged_indices.detach().cpu())
] = NonlinearOptimizerStatus.CONVERGED
Expand Down Expand Up @@ -289,9 +285,6 @@ def _optimize_impl(
**kwargs,
)

if any(grad_loop_info.error_increase_indices):
raise RuntimeError("Grad loop error increased")

# Merge the converged status into the info from the detached loop,
# and for now, don't update the best err tracking or best solution.
M = info.status == NonlinearOptimizerStatus.MAX_ITERATIONS
bamos marked this conversation as resolved.
Show resolved Hide resolved
Expand Down