Skip to content

Commit

Permalink
Moved back BackwardMode, _merge_infos, and _split_backward_iters to b…
Browse files Browse the repository at this point in the history
…ase class.
  • Loading branch information
luisenp committed Feb 2, 2023
1 parent d2fc18b commit cf9ae43
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 104 deletions.
3 changes: 2 additions & 1 deletion theseus/optimizer/nonlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from .dogleg import Dogleg
from .gauss_newton import GaussNewton
from .levenberg_marquardt import LevenbergMarquardt
from .nonlinear_least_squares import BackwardMode, NonlinearLeastSquares
from .nonlinear_least_squares import NonlinearLeastSquares
from .nonlinear_optimizer import (
BackwardMode,
NonlinearOptimizer,
NonlinearOptimizerInfo,
NonlinearOptimizerParams,
Expand Down
102 changes: 1 addition & 101 deletions theseus/optimizer/nonlinear/nonlinear_least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@

import abc
import warnings
from enum import Enum
from typing import Any, Callable, Dict, NoReturn, Optional, Tuple, Type, Union

import numpy as np
import torch

from theseus.core import Objective
Expand All @@ -19,36 +17,13 @@
LUCudaSparseSolver,
)
from .nonlinear_optimizer import (
BackwardMode,
NonlinearOptimizer,
NonlinearOptimizerInfo,
NonlinearOptimizerStatus,
)


class BackwardMode(Enum):
UNROLL = 0
IMPLICIT = 1
TRUNCATED = 2
DLM = 3

@staticmethod
def resolve(key: Union[str, "BackwardMode"]) -> "BackwardMode":
if isinstance(key, BackwardMode):
return key

if not isinstance(key, str):
raise ValueError("Backward mode must be th.BackwardMode or string.")

try:
backward_mode = BackwardMode[key.upper()]
except KeyError:
raise ValueError(
f"Unrecognized backward mode f{key}."
f"Valid choices are unroll, implicit, truncated, dlm."
)
return backward_mode


EndIterCallbackType = Callable[
["NonlinearOptimizer", NonlinearOptimizerInfo, torch.Tensor, int], NoReturn
]
Expand Down Expand Up @@ -118,58 +93,6 @@ def __init__(
self.ordering = self.linear_solver.linearization.ordering
self._tmp_optim_vars = tuple(v.copy(new_name=v.name) for v in self.ordering)

# Modifies the (no grad) info in place to add data of grad loop info
def _merge_infos(
self,
grad_loop_info: NonlinearOptimizerInfo,
num_no_grad_iters: int,
num_grad_iters: int,
info: NonlinearOptimizerInfo,
):
total_iters = num_no_grad_iters + num_grad_iters
# we add + 1 to all indices to account for the initial values
info_idx = slice(num_no_grad_iters + 1, total_iters + 1)
grad_info_idx = slice(1, num_grad_iters + 1)
# Concatenate error histories
if info.err_history is not None:
info.err_history[:, info_idx] = grad_loop_info.err_history[:, grad_info_idx]
if info.state_history is not None:
for var in self.objective.optim_vars.values():
info.state_history[var.name][
..., info_idx
] = grad_loop_info.state_history[var.name][..., grad_info_idx]

# Merge best solution and best error
if info.best_solution is not None:
best_solution = {}
best_err_no_grad = info.best_err
best_err_grad = grad_loop_info.best_err
idx_no_grad = (best_err_no_grad < best_err_grad).cpu().view(-1, 1)
best_err = torch.minimum(best_err_no_grad, best_err_grad)
for var_name in info.best_solution:
sol_no_grad = info.best_solution[var_name]
sol_grad = grad_loop_info.best_solution[var_name]
best_solution[var_name] = torch.where(
idx_no_grad, sol_no_grad, sol_grad
)
info.best_solution = best_solution
info.best_err = best_err

# Merge the converged status into the info from the detached loop,
M = info.status == NonlinearOptimizerStatus.MAX_ITERATIONS
assert np.all(
(grad_loop_info.status[M] == NonlinearOptimizerStatus.MAX_ITERATIONS)
| (grad_loop_info.status[M] == NonlinearOptimizerStatus.CONVERGED)
)
info.status[M] = grad_loop_info.status[M]
info.converged_iter[M] = (
info.converged_iter[M] + grad_loop_info.converged_iter[M]
)
# If didn't coverge in either loop, remove misleading converged_iter value
info.converged_iter[
M & (grad_loop_info.status == NonlinearOptimizerStatus.MAX_ITERATIONS)
] = -1

def _error_metric(
self,
input_tensors: Optional[Dict[str, torch.Tensor]] = None,
Expand Down Expand Up @@ -295,29 +218,6 @@ def _optimize_loop(
] = NonlinearOptimizerStatus.MAX_ITERATIONS
return iters_done

# Returns how many iterations to do with and without autograd
def _split_backward_iters(self, **kwargs) -> Tuple[int, int]:
if kwargs["backward_mode"] == BackwardMode.TRUNCATED:
if "backward_num_iterations" not in kwargs:
raise ValueError("backward_num_iterations expected but not received.")
if kwargs["backward_num_iterations"] > self.params.max_iterations:
warnings.warn(
f"Input backward_num_iterations "
f"(={kwargs['backward_num_iterations']}) > "
f"max_iterations (={self.params.max_iterations}). "
f"Using backward_num_iterations=max_iterations."
)
backward_num_iters = min(
kwargs["backward_num_iterations"], self.params.max_iterations
)
else:
backward_num_iters = {
BackwardMode.UNROLL: self.params.max_iterations,
BackwardMode.DLM: self.params.max_iterations,
BackwardMode.IMPLICIT: 1,
}[kwargs["backward_mode"]]
return backward_num_iters, self.params.max_iterations - backward_num_iters

# `track_best_solution` keeps a **detached** copy (as in no gradient info)
# of the best variables found, but it is optional to avoid unnecessary copying
# if this is not needed
Expand Down
104 changes: 102 additions & 2 deletions theseus/optimizer/nonlinear/nonlinear_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

import abc
import math
import warnings
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Optional
from typing import Dict, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -16,6 +17,30 @@
from theseus.optimizer import Optimizer, OptimizerInfo


class BackwardMode(Enum):
UNROLL = 0
IMPLICIT = 1
TRUNCATED = 2
DLM = 3

@staticmethod
def resolve(key: Union[str, "BackwardMode"]) -> "BackwardMode":
if isinstance(key, BackwardMode):
return key

if not isinstance(key, str):
raise ValueError("Backward mode must be th.BackwardMode or string.")

try:
backward_mode = BackwardMode[key.upper()]
except KeyError:
raise ValueError(
f"Unrecognized backward mode f{key}."
f"Valid choices are unroll, implicit, truncated, dlm."
)
return backward_mode


@dataclass
class NonlinearOptimizerParams:
abs_err_tolerance: float
Expand Down Expand Up @@ -52,7 +77,7 @@ class NonlinearOptimizerInfo(OptimizerInfo):
# Base class for all nonlinear optimizers.
# Provides methods useful for bookkeeping during optimization.
# Subclasses need to provide `error_metric()`, which computes the error to
# to optimized (e.g., sum of squared costs for NLLS), in addition to
# to be optimized (e.g., sum of squared costs for NLLS), in addition to
# `_optimize_impl()` as defined by the base `Optimizer` class.
class NonlinearOptimizer(Optimizer, abc.ABC):
_MAX_ALL_REJECT_ATTEMPTS = 3
Expand Down Expand Up @@ -192,3 +217,78 @@ def _update_info(
@abc.abstractmethod
def _optimize_impl(self, **kwargs) -> OptimizerInfo:
pass

# Modifies the (no grad) info in place to add data of grad loop info
def _merge_infos(
self,
grad_loop_info: NonlinearOptimizerInfo,
num_no_grad_iters: int,
num_grad_iters: int,
info: NonlinearOptimizerInfo,
):
total_iters = num_no_grad_iters + num_grad_iters
# we add + 1 to all indices to account for the initial values
info_idx = slice(num_no_grad_iters + 1, total_iters + 1)
grad_info_idx = slice(1, num_grad_iters + 1)
# Concatenate error histories
if info.err_history is not None:
info.err_history[:, info_idx] = grad_loop_info.err_history[:, grad_info_idx]
if info.state_history is not None:
for var in self.objective.optim_vars.values():
info.state_history[var.name][
..., info_idx
] = grad_loop_info.state_history[var.name][..., grad_info_idx]

# Merge best solution and best error
if info.best_solution is not None:
best_solution = {}
best_err_no_grad = info.best_err
best_err_grad = grad_loop_info.best_err
idx_no_grad = (best_err_no_grad < best_err_grad).cpu().view(-1, 1)
best_err = torch.minimum(best_err_no_grad, best_err_grad)
for var_name in info.best_solution:
sol_no_grad = info.best_solution[var_name]
sol_grad = grad_loop_info.best_solution[var_name]
best_solution[var_name] = torch.where(
idx_no_grad, sol_no_grad, sol_grad
)
info.best_solution = best_solution
info.best_err = best_err

# Merge the converged status into the info from the detached loop,
M = info.status == NonlinearOptimizerStatus.MAX_ITERATIONS
assert np.all(
(grad_loop_info.status[M] == NonlinearOptimizerStatus.MAX_ITERATIONS)
| (grad_loop_info.status[M] == NonlinearOptimizerStatus.CONVERGED)
)
info.status[M] = grad_loop_info.status[M]
info.converged_iter[M] = (
info.converged_iter[M] + grad_loop_info.converged_iter[M]
)
# If didn't coverge in either loop, remove misleading converged_iter value
info.converged_iter[
M & (grad_loop_info.status == NonlinearOptimizerStatus.MAX_ITERATIONS)
] = -1

# Returns how many iterations to do with and without autograd
def _split_backward_iters(self, **kwargs) -> Tuple[int, int]:
if kwargs["backward_mode"] == BackwardMode.TRUNCATED:
if "backward_num_iterations" not in kwargs:
raise ValueError("backward_num_iterations expected but not received.")
if kwargs["backward_num_iterations"] > self.params.max_iterations:
warnings.warn(
f"Input backward_num_iterations "
f"(={kwargs['backward_num_iterations']}) > "
f"max_iterations (={self.params.max_iterations}). "
f"Using backward_num_iterations=max_iterations."
)
backward_num_iters = min(
kwargs["backward_num_iterations"], self.params.max_iterations
)
else:
backward_num_iters = {
BackwardMode.UNROLL: self.params.max_iterations,
BackwardMode.DLM: self.params.max_iterations,
BackwardMode.IMPLICIT: 1,
}[kwargs["backward_mode"]]
return backward_num_iters, self.params.max_iterations - backward_num_iters

0 comments on commit cf9ae43

Please sign in to comment.