Skip to content

Commit

Permalink
Moved linear solver usage from NonlinearOptim class to NLLS subclass.
Browse files Browse the repository at this point in the history
  • Loading branch information
luisenp committed Jan 19, 2023
1 parent 3004b6d commit ff1dcb1
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 22 deletions.
14 changes: 10 additions & 4 deletions theseus/optimizer/nonlinear/nonlinear_least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,26 @@ def __init__(
step_size: float = 1.0,
**kwargs,
):
linear_solver_cls = linear_solver_cls or CholeskyDenseSolver
super().__init__(
objective,
linear_solver_cls=linear_solver_cls,
vectorize=vectorize,
linearization_cls=linearization_cls,
linearization_kwargs=linearization_kwargs,
linear_solver_kwargs=linear_solver_kwargs,
abs_err_tolerance=abs_err_tolerance,
rel_err_tolerance=rel_err_tolerance,
max_iterations=max_iterations,
step_size=step_size,
**kwargs,
)
linear_solver_cls = linear_solver_cls or CholeskyDenseSolver
linear_solver_kwargs = linear_solver_kwargs or {}
self.linear_solver = linear_solver_cls(
objective,
linearization_cls=linearization_cls,
linearization_kwargs=linearization_kwargs,
**linear_solver_kwargs,
)
self.ordering = self.linear_solver.linearization.ordering
self._tmp_optim_vars = tuple(v.copy(new_name=v.name) for v in self.ordering)

# Modifies the (no grad) info in place to add data of grad loop info
def _merge_infos(
Expand Down
22 changes: 4 additions & 18 deletions theseus/optimizer/nonlinear/nonlinear_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
import math
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional, Type
from typing import Dict, Optional

import numpy as np
import torch

from theseus.core import Objective
from theseus.optimizer import Linearization, Optimizer, OptimizerInfo
from theseus.optimizer.linear import LinearSolver
from theseus.optimizer import Optimizer, OptimizerInfo


@dataclass
Expand Down Expand Up @@ -61,31 +60,18 @@ class NonlinearOptimizer(Optimizer, abc.ABC):
def __init__(
self,
objective: Objective,
linear_solver_cls: Type[LinearSolver],
*args,
vectorize: bool = False,
linearization_cls: Optional[Type[Linearization]] = None,
linearization_kwargs: Optional[Dict[str, Any]] = None,
linear_solver_kwargs: Optional[Dict[str, Any]] = None,
abs_err_tolerance: float = 1e-8,
rel_err_tolerance: float = 1e-5,
max_iterations: int = 20,
step_size: float = 1.0,
**kwargs,
):
super().__init__(objective, vectorize=vectorize, **kwargs)
linear_solver_kwargs = linear_solver_kwargs or {}
self.linear_solver = linear_solver_cls(
objective,
linearization_cls=linearization_cls,
linearization_kwargs=linearization_kwargs,
**linear_solver_kwargs,
)
self.ordering = self.linear_solver.linearization.ordering
self.params = NonlinearOptimizerParams(
abs_err_tolerance, rel_err_tolerance, max_iterations, step_size
)
self._tmp_optim_vars = tuple(v.copy(new_name=v.name) for v in self.ordering)

def set_params(self, **kwargs):
self.params.update(kwargs)
Expand Down Expand Up @@ -115,7 +101,7 @@ def _maybe_init_best_solution(
if not do_init:
return None
solution_dict = {}
for var in self.ordering:
for var in self.objective.optim_vars.values():
solution_dict[var.name] = var.tensor.detach().clone().cpu()
return solution_dict

Expand Down Expand Up @@ -191,7 +177,7 @@ def _update_info(
assert info.best_err is not None
good_indices = err < info.best_err
info.best_iter[good_indices] = current_iter
for var in self.ordering:
for var in self.objective.optim_vars.values():
info.best_solution[var.name][good_indices] = (
var.tensor.detach().clone()[good_indices].cpu()
)
Expand Down

0 comments on commit ff1dcb1

Please sign in to comment.