Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Differentiable CEM solver #329

Merged
merged 35 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
31c277d
first implementation of dcem solver
dishank-b Oct 13, 2022
245d80c
dcem optimizer working
dishank-b Oct 14, 2022
11962c6
minor changes in dcem, added tests
dishank-b Oct 17, 2022
d069ee6
online calculatino of n_batch
dishank-b Oct 18, 2022
24d2fa9
initializing LML layer
dishank-b Oct 20, 2022
5632588
dcem working backwards tutorial 2
dishank-b Oct 28, 2022
655ba19
better vectorization in solver
dishank-b Oct 31, 2022
63a04cf
vectoriztion in solve method for itr loop in optimizer class
dishank-b Oct 31, 2022
fd93941
forward pass working perfectly with current set of hyperparams with b…
dishank-b Nov 3, 2022
976fb08
dcem backward unit test passed for one setting
dishank-b Nov 3, 2022
7547504
DCEM backward unit test working, not tested with leo, insanely slow w…
dishank-b Nov 4, 2022
3a22e09
refactoring, removed DcemSolver in favour of solve method in DCEM opt…
dishank-b Nov 4, 2022
89c8b39
correcting circle ci errors
dishank-b Nov 7, 2022
9bf03c4
corrected lml url for requirements.txt
dishank-b Nov 7, 2022
a1064a2
corrected reuirements.txt for lml
dishank-b Nov 7, 2022
c21dc69
removing -e from requirements
dishank-b Nov 8, 2022
c809e94
changing setup.py to install lml
dishank-b Nov 8, 2022
2bbf2db
changing setup.py to add lml
dishank-b Nov 8, 2022
4f3788c
commented dcem_test
dishank-b Nov 8, 2022
0f69df6
unit test working with both gpu, cpu with even less 10-2 error thres …
dishank-b Nov 9, 2022
0592f6f
testing with lml_eps=10-4
dishank-b Nov 10, 2022
7d68639
Revert "testing with lml_eps=10-4"
dishank-b Nov 10, 2022
044e881
reverting the common.py file
dishank-b Nov 10, 2022
769d483
dcem working, name changed from DCem to DCEM
dishank-b Mar 6, 2023
126ee47
removed _all_solve function and chnaged _solve name to _CEM_step
dishank-b Mar 6, 2023
f2ccf4b
changed dcem objective to use error_metric and edit __init files
dishank-b Mar 6, 2023
74a2a5d
dcem working, added dcem tutorial
dishank-b Mar 6, 2023
679dc3b
add lml as third party
dishank-b Mar 6, 2023
fab68be
or black pre-commit hook
dishank-b Mar 6, 2023
f4b345a
removeing abs in loss function since model chnaged test_theseus layer
dishank-b Mar 7, 2023
97c94b6
changes in test_theseus to make it compatible with DCEM
dishank-b Mar 9, 2023
bc32139
minor changes:styling, nits, typehinting, etc.
dishank-b Mar 17, 2023
50ea767
reverted minor changes, corrected test_theseus_layer argument logic f…
dishank-b Mar 20, 2023
ee75e95
using scatter for indexes with temp=None in dcem
dishank-b Mar 21, 2023
947e026
final changes, removing half-complete changes before merge
dishank-b Mar 22, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ types-PyYAML==5.4.3
mock>=4.0.3
types-mock>=4.0.8
Sphinx==5.0.2
sphinx-rtd-theme==1.0.0
sphinx-rtd-theme==1.0.0
1 change: 1 addition & 0 deletions requirements/main.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ scikit-sparse>=0.4.5
pytest>=6.2.1
pybind11>=2.7.1
functorch==0.2.1 # > 0.2.1 will install torch1.13, which breaks CUDA 10.2
semantic-version==2.10.0
27 changes: 21 additions & 6 deletions tests/test_theseus_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ def error_fn(optim_vars, aux_vars):
linear_solver_cls=linear_solver_cls,
max_iterations=max_iterations,
)
assert isinstance(optimizer.linear_solver, linear_solver_cls)

if hasattr(optimizer, "linear_solver"):
assert isinstance(optimizer.linear_solver, linear_solver_cls)
assert not objective.vectorized

if force_vectorization:
Expand Down Expand Up @@ -203,7 +205,7 @@ def _run_optimizer_test(
print(
f"testing for optimizer {nonlinear_optimizer_cls.__name__}, "
f"cost weight modeled as {cost_weight_model}, "
f"linear solver {linear_solver_cls.__name__} "
f"linear solver {linear_solver_cls.__name__ if linear_solver_cls is not None else None} "
f"learning method {learning_method}"
)

Expand Down Expand Up @@ -236,7 +238,9 @@ def _run_optimizer_test(
max_iterations=max_iterations,
)
layer_ref.to(device)
initial_coefficients = torch.ones(batch_size, 2, device=device) * 0.75
initial_coefficients = torch.ones(batch_size, 2, device=device) * torch.tensor(
[0.75, 7], device=device
)
with torch.no_grad():
input_values = {"coefficients": initial_coefficients}
target_vars, _ = layer_ref.forward(
Expand Down Expand Up @@ -306,6 +310,7 @@ def cost_weight_fn():
pred_vars, info = layer_to_learn.forward(
input_values, optimizer_kwargs=optimizer_kwargs
)

loss0 = F.mse_loss(
pred_vars["coefficients"], target_vars["coefficients"]
).item()
Expand Down Expand Up @@ -335,6 +340,7 @@ def cost_weight_fn():
},
},
)

assert not (
(info.status == th.NonlinearOptimizerStatus.START)
| (info.status == th.NonlinearOptimizerStatus.FAIL)
Expand Down Expand Up @@ -378,7 +384,7 @@ def cost_weight_fn():
optimizer.step()

loss_ratio = mse_loss.item() / loss0
print("Loss: ", mse_loss.item(), ". Loss ratio: ", loss_ratio)
print("Iteration: ", i, "Loss: ", mse_loss.item(), ". Loss ratio: ", loss_ratio)
if loss_ratio < loss_ratio_target:
solved = True
break
Expand All @@ -404,7 +410,7 @@ def _solver_can_be_run(lin_solver_cls):


@pytest.mark.parametrize(
"nonlinear_optim_cls", [th.Dogleg, th.GaussNewton, th.LevenbergMarquardt]
"nonlinear_optim_cls", [th.Dogleg, th.GaussNewton, th.LevenbergMarquardt, th.DCEM]
)
@pytest.mark.parametrize(
"lin_solver_cls",
Expand Down Expand Up @@ -436,15 +442,24 @@ def test_backward(
and learning_method not in "leo",
},
th.Dogleg: {},
th.DCEM: {},
}[nonlinear_optim_cls]
if learning_method == "leo":
if lin_solver_cls not in [th.CholeskyDenseSolver, th.LUDenseSolver]:
# other solvers don't support sampling from system's covariance
return
if nonlinear_optim_cls == th.Dogleg:
return # LEO not working with Dogleg
if nonlinear_optim_cls == th.DCEM:
return
if nonlinear_optim_cls == th.Dogleg and lin_solver_cls != th.CholeskyDenseSolver:
return
if nonlinear_optim_cls == th.DCEM:
if lin_solver_cls != th.CholeskyDenseSolver:
return
else:
lin_solver_cls = None

# test both vectorization on/off
force_vectorization = torch.rand(1).item() > 0.5
_run_optimizer_test(
Expand All @@ -455,7 +470,7 @@ def test_backward(
use_learnable_error=use_learnable_error,
force_vectorization=force_vectorization,
learning_method=learning_method,
max_iterations=10,
max_iterations=10 if nonlinear_optim_cls != th.DCEM else 50,
lr=1.0
if nonlinear_optim_cls == th.Dogleg and not torch.cuda.is_available()
else 0.075,
Expand Down
1 change: 1 addition & 0 deletions theseus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
)
from .optimizer.nonlinear import ( # usort: skip
BackwardMode,
DCEM,
Dogleg,
GaussNewton,
LevenbergMarquardt,
Expand Down
2 changes: 2 additions & 0 deletions theseus/optimizer/nonlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .dcem import DCEM
dishank-b marked this conversation as resolved.
Show resolved Hide resolved
from .dogleg import Dogleg
from .gauss_newton import GaussNewton
from .levenberg_marquardt import LevenbergMarquardt
Expand Down
243 changes: 243 additions & 0 deletions theseus/optimizer/nonlinear/dcem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
from typing import Optional, Union, List, Dict

import numpy as np
import torch
from torch.distributions import Normal

from theseus.third_party.lml import LML
from theseus.core.objective import Objective
from theseus.optimizer import OptimizerInfo
from theseus.optimizer.variable_ordering import VariableOrdering

from .nonlinear_optimizer import (
NonlinearOptimizer,
BackwardMode,
NonlinearOptimizerInfo,
NonlinearOptimizerStatus,
EndIterCallbackType,
)


class DCEM(NonlinearOptimizer):
"""
DCEM optimizer for nonlinear optimization using sampling based techniques.
The optimizer can be really sensitive to hypermeter tuning. Here are few tuning
hints:
1. If have to lower the max_iterations, then increase the n_sample.
2. The higher the n_sample, the slowly with variance of samples will decrease.
3. The higher the n_sample, more the chances of optimum being in the elite set.
4. The higher the n_elite, the slower is convergence, but more accurate it might
be, but would need more iterations. n_elite= 5 is good enough for most cases.
"""

def __init__(
self,
objective: Objective,
vectorize: bool = False,
max_iterations: int = 50,
n_sample: int = 100,
n_elite: int = 5,
temp: float = 1.0,
init_sigma: Union[float, torch.Tensor] = 1.0,
lb: float = None,
ub: float = None,
lml_verbose: bool = False,
lml_eps: float = 1e-3,
normalize: bool = True,
abs_err_tolerance: float = 1e-6,
rel_err_tolerance: float = 1e-4,
**kwargs,
) -> None:
super().__init__(
objective,
vectorize=vectorize,
abs_err_tolerance=abs_err_tolerance,
rel_err_tolerance=rel_err_tolerance,
max_iterations=max_iterations,
**kwargs,
)

self.objective = objective
self.ordering = VariableOrdering(objective)
self.n_samples = n_sample
self.n_elite = n_elite
self.lb = lb
self.ub = ub
self.temp = temp
self.normalize = normalize
self._tot_dof = sum([x.dof() for x in self.ordering])
self.lml_eps = lml_eps
self.lml_verbose = lml_verbose
self.init_sigma = init_sigma

dishank-b marked this conversation as resolved.
Show resolved Hide resolved
def _mu_vec_to_dict(self, mu: torch.Tensor) -> Dict[str, torch.Tensor]:
idx = 0
mu_dic = {}
for var in self.ordering:
mu_dic[var.name] = mu[:, slice(idx, idx + var.dof())]
idx += var.dof()
return mu_dic

def reset_sigma(self, init_sigma: Union[float, torch.Tensor]) -> None:
self.sigma = (
torch.ones(
(self.objective.batch_size, self._tot_dof), device=self.objective.device
)
* init_sigma
)

def _CEM_step(self):
dishank-b marked this conversation as resolved.
Show resolved Hide resolved
"""
Performs one iteration of CEM.
Updates the self.sigma and return the new mu.
"""
device = self.objective.device
n_batch = self.ordering[0].shape[0]

mu = torch.cat([var.tensor for var in self.ordering], dim=-1)

X = Normal(mu, self.sigma).rsample((self.n_samples,))

X_samples: List[Dict[str, torch.Tensor]] = []
for sample in X:
X_samples.append(self._mu_vec_to_dict(sample))

fX = torch.stack(
[self.objective.error_metric(X_samples[i]) for i in range(self.n_samples)],
dim=1,
)

assert fX.shape == (n_batch, self.n_samples)

if self.temp is not None and self.temp < np.infty:
if self.normalize:
fX_mu = fX.mean(dim=1).unsqueeze(1)
fX_sigma = fX.std(dim=1).unsqueeze(1)
_fX = (fX - fX_mu) / (fX_sigma + 1e-6)
else:
_fX = fX

if self.n_elite == 1:
# indexes = LML(N=n_elite, verbose=lml_verbose, eps=lml_eps)(-_fX*temp)
indexes = torch.softmax(-_fX * self.temp, dim=1)
else:
indexes = LML(
N=self.n_elite, verbose=self.lml_verbose, eps=self.lml_eps
)(-_fX * self.temp)
indexes = indexes.unsqueeze(2)
eps = 0

else:
indexes_vals = fX.argsort(dim=1)[:, : self.n_elite]
# Scatter 1.0 to the indexes using indexes_vals
indexes = torch.zeros(n_batch, self.n_samples, device=device).scatter_(
1, indexes_vals, 1.0
)
indexes = indexes.unsqueeze(2)
eps = 1e-10
# indexes.shape should be (n_batch, n_sample, 1)

X = X.transpose(0, 1)

assert indexes.shape[:2] == X.shape[:2]

X_I = indexes * X

mu = torch.sum(X_I, dim=1) / self.n_elite
self.sigma = (
(indexes * (X - mu.unsqueeze(1)) ** 2).sum(dim=1) / self.n_elite
).sqrt() + eps # adding eps to avoid sigma=0, which is happening when temp=None

assert self.sigma.shape == (n_batch, self._tot_dof)

return self._mu_vec_to_dict(mu)

def _optimize_loop(
self,
num_iter: int,
info: NonlinearOptimizerInfo,
verbose: bool,
end_iter_callback: Optional[EndIterCallbackType] = None,
**kwargs,
) -> int:
converged_indices = torch.zeros_like(info.last_err).bool()
iters_done = 0
for it_ in range(num_iter):
iters_done += 1
try:
mu = self._CEM_step()
except RuntimeError as error:
raise RuntimeError(f"There is an error in update {error}.")

self.objective.update(mu)

# check for convergence
with torch.no_grad():
err = self.objective.error_metric()
self._update_info(info, it_, err, converged_indices)
if verbose:
print(
f"Nonlinear optimizer. Iteration: {it_+1}. "
f"Error: {err.mean().item()} "
)
converged_indices = self._check_convergence(err, info.last_err)
info.status[
np.array(converged_indices.cpu().numpy())
] = NonlinearOptimizerStatus.CONVERGED

if converged_indices.all():
break # nothing else will happen at this point
info.last_err = err

if end_iter_callback is not None:
end_iter_callback(self, info, mu, it_)

info.status[
info.status == NonlinearOptimizerStatus.START
] = NonlinearOptimizerStatus.MAX_ITERATIONS

return iters_done

def _optimize_impl(
self,
track_best_solution: bool = False,
track_err_history: bool = False,
track_state_history: bool = False,
verbose: bool = False,
backward_mode: Union[str, BackwardMode] = BackwardMode.UNROLL,
end_iter_callback: Optional[EndIterCallbackType] = None,
**kwargs,
) -> OptimizerInfo:
backward_mode = BackwardMode.resolve(backward_mode)
init_sigma = kwargs.get("init_sigma", self.init_sigma)
self.reset_sigma(init_sigma)

with torch.no_grad():
info = self._init_info(
track_best_solution, track_err_history, track_state_history
)

if verbose:
print(
f"DCEM optimizer. Iteration: 0. "
f"Error: {info.last_err.mean().item()}"
)

if backward_mode in [BackwardMode.UNROLL, BackwardMode.DLM]:
self._optimize_loop(
num_iter=self.params.max_iterations,
info=info,
verbose=verbose,
end_iter_callback=end_iter_callback,
**kwargs,
)
# If didn't coverge, remove misleading converged_iter value
info.converged_iter[
info.status == NonlinearOptimizerStatus.MAX_ITERATIONS
] = -1
return info

else:
raise NotImplementedError(
"DCEM currently only supports 'unroll' backward mode."
)
8 changes: 2 additions & 6 deletions theseus/optimizer/nonlinear/nonlinear_least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import abc
import warnings
from typing import Any, Callable, Dict, NoReturn, Optional, Tuple, Type, Union
from typing import Any, Dict, Optional, Tuple, Type, Union

import torch

Expand All @@ -22,14 +22,10 @@
NonlinearOptimizer,
NonlinearOptimizerInfo,
NonlinearOptimizerStatus,
EndIterCallbackType,
)


EndIterCallbackType = Callable[
["NonlinearOptimizer", NonlinearOptimizerInfo, torch.Tensor, int], NoReturn
]


# Base class for all optimizers for NLLS problems,
# providing the skeleton of the
# optimization loop. Subclasses need to implement the following method:
Expand Down
Loading