Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Differentiable CEM solver #329

Merged
merged 35 commits into from
Mar 23, 2023
Merged
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
31c277d
first implementation of dcem solver
dishank-b Oct 13, 2022
245d80c
dcem optimizer working
dishank-b Oct 14, 2022
11962c6
minor changes in dcem, added tests
dishank-b Oct 17, 2022
d069ee6
online calculatino of n_batch
dishank-b Oct 18, 2022
24d2fa9
initializing LML layer
dishank-b Oct 20, 2022
5632588
dcem working backwards tutorial 2
dishank-b Oct 28, 2022
655ba19
better vectorization in solver
dishank-b Oct 31, 2022
63a04cf
vectoriztion in solve method for itr loop in optimizer class
dishank-b Oct 31, 2022
fd93941
forward pass working perfectly with current set of hyperparams with b…
dishank-b Nov 3, 2022
976fb08
dcem backward unit test passed for one setting
dishank-b Nov 3, 2022
7547504
DCEM backward unit test working, not tested with leo, insanely slow w…
dishank-b Nov 4, 2022
3a22e09
refactoring, removed DcemSolver in favour of solve method in DCEM opt…
dishank-b Nov 4, 2022
89c8b39
correcting circle ci errors
dishank-b Nov 7, 2022
9bf03c4
corrected lml url for requirements.txt
dishank-b Nov 7, 2022
a1064a2
corrected reuirements.txt for lml
dishank-b Nov 7, 2022
c21dc69
removing -e from requirements
dishank-b Nov 8, 2022
c809e94
changing setup.py to install lml
dishank-b Nov 8, 2022
2bbf2db
changing setup.py to add lml
dishank-b Nov 8, 2022
4f3788c
commented dcem_test
dishank-b Nov 8, 2022
0f69df6
unit test working with both gpu, cpu with even less 10-2 error thres …
dishank-b Nov 9, 2022
0592f6f
testing with lml_eps=10-4
dishank-b Nov 10, 2022
7d68639
Revert "testing with lml_eps=10-4"
dishank-b Nov 10, 2022
044e881
reverting the common.py file
dishank-b Nov 10, 2022
769d483
dcem working, name changed from DCem to DCEM
dishank-b Mar 6, 2023
126ee47
removed _all_solve function and chnaged _solve name to _CEM_step
dishank-b Mar 6, 2023
f2ccf4b
changed dcem objective to use error_metric and edit __init files
dishank-b Mar 6, 2023
74a2a5d
dcem working, added dcem tutorial
dishank-b Mar 6, 2023
679dc3b
add lml as third party
dishank-b Mar 6, 2023
fab68be
or black pre-commit hook
dishank-b Mar 6, 2023
f4b345a
removeing abs in loss function since model chnaged test_theseus layer
dishank-b Mar 7, 2023
97c94b6
changes in test_theseus to make it compatible with DCEM
dishank-b Mar 9, 2023
bc32139
minor changes:styling, nits, typehinting, etc.
dishank-b Mar 17, 2023
50ea767
reverted minor changes, corrected test_theseus_layer argument logic f…
dishank-b Mar 20, 2023
ee75e95
using scatter for indexes with temp=None in dcem
dishank-b Mar 21, 2023
947e026
final changes, removing half-complete changes before merge
dishank-b Mar 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
dcem working, name changed from DCem to DCEM
  • Loading branch information
dishank-b committed Mar 9, 2023
commit 769d48395ff61e47ea924f9f26fe1df54cd234d6
185 changes: 50 additions & 135 deletions theseus/optimizer/nonlinear/dcem.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,44 @@
import math
from typing import Callable, Dict, NoReturn, Optional, Union, List
from typing import Callable, NoReturn, Optional, Union, List

import numpy as np
import torch
from lml import LML
from torch.distributions import Normal

import theseus.constants
from theseus.core.objective import Objective
from theseus.optimizer import Optimizer, OptimizerInfo
from theseus.optimizer import OptimizerInfo
from theseus.optimizer.variable_ordering import VariableOrdering

from .nonlinear_optimizer import (
NonlinearOptimizer,
BackwardMode,
NonlinearOptimizerInfo,
NonlinearOptimizerParams,
NonlinearOptimizerStatus,
)

EndIterCallbackType = Callable[
dishank-b marked this conversation as resolved.
Show resolved Hide resolved
["DCem", NonlinearOptimizerInfo, torch.Tensor, int], NoReturn
["DCEM", NonlinearOptimizerInfo, torch.Tensor, int], NoReturn
]


class DCem(Optimizer):
class DCEM(NonlinearOptimizer):
"""
DCEM optimizer for nonlinear optimization using sampling based techniques.
The optimizer can be really sensitive to hypermeter tuning. Here are few tuning
hints:
1. If have to lower the max_iterations, then increase the n_sample.
2. The higher the n_sample, the slowly with variance of samples will decrease.
3. The higher the n_sample, more the chances of optimum being in the elite set.
4. The higher the n_elite, the slower is convergence, but more accurate it might
be, but would need more iterations. n_elite= 5 is good enough for most cases.
"""

def __init__(
self,
objective: Objective,
vectorize: bool = False,
max_iterations: int = 100,
n_sample: int = 50, # 20
max_iterations: int = 50, # 50 for test_theseus
dishank-b marked this conversation as resolved.
Show resolved Hide resolved
n_sample: int = 100, # 100 for test_theseus
n_elite: int = 5, # 5
temp: float = 1.0,
init_sigma: Union[float, torch.Tensor, List[float]] = 1.0,
Expand All @@ -38,12 +47,18 @@ def __init__(
lml_verbose: bool = False,
lml_eps: float = 1e-3,
normalize: bool = True,
iter_eps: float = 1e-7,
abs_err_tolerance: float = 1e-6,
rel_err_tolerance: float = 1e-4,
**kwargs,
) -> None:
super().__init__(objective, vectorize=vectorize, **kwargs)

self.params = NonlinearOptimizerParams(iter_eps, iter_eps, max_iterations, 1e-2)
super().__init__(
objective,
vectorize=vectorize,
abs_err_tolerance=abs_err_tolerance,
rel_err_tolerance=rel_err_tolerance,
max_iterations=max_iterations,
**kwargs,
)

self.objective = objective
self.ordering = VariableOrdering(objective)
Expand All @@ -62,117 +77,8 @@ def __init__(
)
dishank-b marked this conversation as resolved.
Show resolved Hide resolved
self.lml_eps = lml_eps
self.lml_verbose = lml_verbose
self.iter_eps = iter_eps
self.init_sigma = init_sigma

dishank-b marked this conversation as resolved.
Show resolved Hide resolved
def set_params(self, **kwargs):
self.params.update(kwargs)

def _maybe_init_best_solution(
self, do_init: bool = False
) -> Optional[Dict[str, torch.Tensor]]:
if not do_init:
return None
solution_dict = {}
for var in self.ordering:
solution_dict[var.name] = var.tensor.detach().clone().cpu()
return solution_dict

def _init_info(
self,
track_best_solution: bool,
track_err_history: bool,
track_state_history: bool,
) -> NonlinearOptimizerInfo:
with torch.no_grad():
last_err = self.objective.error_squared_norm() / 2
best_err = last_err.clone() if track_best_solution else None
if track_err_history:
err_history = (
torch.ones(self.objective.batch_size, self.params.max_iterations + 1)
* math.inf
)
assert last_err.grad_fn is None
err_history[:, 0] = last_err.clone().cpu()
else:
err_history = None

if track_state_history:
state_history = {}
for var in self.objective.optim_vars.values():
state_history[var.name] = (
torch.ones(
self.objective.batch_size,
*var.shape[1:],
self.params.max_iterations + 1,
)
* math.inf
)
state_history[var.name][..., 0] = var.tensor.detach().clone().cpu()
else:
state_history = None

return NonlinearOptimizerInfo(
best_solution=self._maybe_init_best_solution(do_init=track_best_solution),
last_err=last_err,
best_err=best_err,
status=np.array(
[NonlinearOptimizerStatus.START] * self.objective.batch_size
),
converged_iter=torch.zeros_like(last_err, dtype=torch.long),
best_iter=torch.zeros_like(last_err, dtype=torch.long),
err_history=err_history,
state_history=state_history,
)

def _check_convergence(self, err: torch.Tensor, last_err: torch.Tensor):
assert not torch.is_grad_enabled()
if err.abs().mean() < theseus.constants.EPS:
return torch.ones_like(err).bool()

abs_error = (last_err - err).abs()
rel_error = abs_error / last_err
return (abs_error < self.params.abs_err_tolerance).logical_or(
rel_error < self.params.rel_err_tolerance
)

def _update_state_history(self, iter_idx: int, info: NonlinearOptimizerInfo):
for var in self.objective.optim_vars.values():
info.state_history[var.name][..., iter_idx + 1] = (
var.tensor.detach().clone().cpu()
)

def _update_info(
self,
info: NonlinearOptimizerInfo,
current_iter: int,
err: torch.Tensor,
converged_indices: torch.Tensor,
):
info.converged_iter += 1 - converged_indices.long()
if info.err_history is not None:
assert err.grad_fn is None
info.err_history[:, current_iter + 1] = err.clone().cpu()
if info.state_history is not None:
self._update_state_history(current_iter, info)

if info.best_solution is not None:
# Only copy best solution if needed (None means track_best_solution=False)
assert info.best_err is not None
good_indices = err < info.best_err
info.best_iter[good_indices] = current_iter
for var in self.ordering:
info.best_solution[var.name][good_indices] = (
var.tensor.detach().clone()[good_indices].cpu()
)

info.best_err = torch.minimum(info.best_err, err)

converged_indices = self._check_convergence(err, info.last_err)
info.status[
np.array(converged_indices.detach().cpu())
] = NonlinearOptimizerStatus.CONVERGED

def _mu_vec_to_dict(self, mu):
dishank-b marked this conversation as resolved.
Show resolved Hide resolved
dishank-b marked this conversation as resolved.
Show resolved Hide resolved
idx = 0
mu_dic = {}
Expand Down Expand Up @@ -313,7 +219,6 @@ def _solve(self):
mu = torch.cat([var.tensor for var in self.ordering], dim=-1)

X = Normal(mu, self.sigma).rsample((self.n_samples,))
# X = Normal(mu, self.sigma + 1e-5).rsample((self.n_samples,))

X_samples = []
dishank-b marked this conversation as resolved.
Show resolved Hide resolved
for sample in X:
Expand Down Expand Up @@ -354,13 +259,30 @@ def _solve(self):
for v in indexes_vals[j]:
indexes[j, v] = 1.0
indexes = indexes.unsqueeze(2)
# I.shape should be (n_batch, n_sample, 1)
# indexes.shape should be (n_batch, n_sample, 1)

X = X.transpose(0, 1)

assert indexes.shape[:2] == X.shape[:2]
# print("Samples:", X)
X_I = indexes * X
# top_k_idx_11 = np.argsort(indexes[11].squeeze(1).cpu().numpy())[::-1][
# : self.n_elite
# ]
# top_k_idx_12 = np.argsort(indexes[12].squeeze(1).cpu().numpy())[::-1][
# : self.n_elite
# ]
# print(indexes[11][:50].squeeze(1), indexes[12][:50].squeeze(1))
# print("top K indices:", top_k_idx_11, top_k_idx_12)
# print(
# indexes[11].squeeze(1).cpu().numpy()[top_k_idx_11],
# indexes[12].squeeze(1).cpu().numpy()[top_k_idx_12],
# )
# print(X[11][top_k_idx_11.copy()])
# print(
# mu[11].cpu().numpy(),
# self.sigma[11].cpu().numpy(),
# )

mu = torch.sum(X_I, dim=1) / self.n_elite
self.sigma = (
Expand All @@ -380,32 +302,25 @@ def _optimize_loop(
**kwargs,
) -> int:

# itr_done = self._all_solve(num_iter, info=info)
# # self.objective.update(mu)
# # with torch.no_grad():
# # info.best_solution = mu
# return itr_done

converged_indices = torch.zeros_like(info.last_err).bool()
iters_done = 0
for it_ in range(num_iter):
iters_done += 1
try:
mu = self._solve()
# mu = self.linear_solver.solve()
except RuntimeError as error:
raise RuntimeError(f"There is an error in update {error}")
dishank-b marked this conversation as resolved.
Show resolved Hide resolved

self.objective.update(mu)

# check for convergence
with torch.no_grad():
err = self.objective.error_squared_norm() / 2
err = self.objective.error_metric()
self._update_info(info, it_, err, converged_indices)
if verbose:
print(
f"Nonlinear optimizer. Iteration: {it_+1}. "
f"Error: {err.mean().item()}"
f"Error: {err.mean().item()} "
)
converged_indices = self._check_convergence(err, info.last_err)
info.status[
Expand All @@ -416,8 +331,8 @@ def _optimize_loop(
# Doesn't work with lml_eps = 1e-5.
# and with lml_eps= 1e-4, gives suboptimal solution

# if converged_indices.all():
# break # nothing else will happen at this point
if converged_indices.all():
break # nothing else will happen at this point
info.last_err = err

if end_iter_callback is not None:
Expand Down