diff --git a/examples/homography_estimation.py b/examples/homography_estimation.py index 60a8fbf54..454565d06 100644 --- a/examples/homography_estimation.py +++ b/examples/homography_estimation.py @@ -21,7 +21,7 @@ from torch.utils.data import DataLoader, Dataset import theseus as th -from theseus.core.cost_function import AutogradMode, ErrFnType +from theseus.core.cost_function import ErrFnType from theseus.third_party.easyaug import GeoAugParam, RandomGeoAug, RandomPhotoAug from theseus.third_party.utils import grid_sample @@ -286,7 +286,7 @@ def run( outer_lr: float = 1e-4, max_iterations: int = 50, step_size: float = 0.1, - autograd_mode: AutogradMode = AutogradMode.VMAP, + autograd_mode: str = "vmap", ): logger.info( "===============================================================" @@ -479,18 +479,12 @@ def run( @hydra.main(config_path="./configs/", config_name="homography_estimation") def main(cfg): - autograd_modes = { - "dense": AutogradMode.DENSE, - "loop_batch": AutogradMode.LOOP_BATCH, - "vmap": AutogradMode.VMAP, - } - num_epochs: int = cfg.outer_optim.num_epochs batch_size: int = cfg.outer_optim.batch_size outer_lr: float = cfg.outer_optim.lr max_iterations: int = cfg.inner_optim.max_iters step_size: float = cfg.inner_optim.step_size - autograd_mode = autograd_modes[cfg.autograd_mode] + autograd_mode = cfg.autograd_mode run( batch_size=batch_size, diff --git a/theseus/core/cost_function.py b/theseus/core/cost_function.py index 2db9408e8..8ba1376dc 100644 --- a/theseus/core/cost_function.py +++ b/theseus/core/cost_function.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import abc -from typing import Callable, List, Optional, Sequence, Tuple, cast +from typing import Callable, List, Optional, Sequence, Tuple, Union, cast from enum import Enum import torch @@ -101,6 +101,22 @@ class AutogradMode(Enum): LOOP_BATCH = 1 VMAP = 2 + @staticmethod + def resolve(key: Union[str, "AutogradMode"]) -> "AutogradMode": + if isinstance(key, AutogradMode): + return key + if not isinstance(key, str): + raise ValueError("Autograd mode must be of type th.AutogradMode or string.") + + try: + mode = AutogradMode[key.upper()] + except KeyError: + raise ValueError( + f"Invalid autograd mode {key}. " + "Valid options are dense, loop_batch, and vmap." + ) + return mode + # The error function is assumed to receive variables in the format # err_fn( @@ -120,7 +136,7 @@ def __init__( name: Optional[str] = None, autograd_strict: bool = False, autograd_vectorize: bool = False, - autograd_mode: AutogradMode = AutogradMode.DENSE, + autograd_mode: Union[str, AutogradMode] = AutogradMode.DENSE, ): if cost_weight is None: cost_weight = ScaleCostWeight(1.0) @@ -147,7 +163,7 @@ def __init__( self._tmp_optim_vars_for_loop = None self._tmp_aux_vars_for_loop = None - self._autograd_mode = autograd_mode + self._autograd_mode = AutogradMode.resolve(autograd_mode) if self._autograd_mode == AutogradMode.LOOP_BATCH: self._tmp_optim_vars_for_loop = tuple(v.copy() for v in optim_vars) diff --git a/theseus/core/tests/test_cost_function.py b/theseus/core/tests/test_cost_function.py index 742d3cbd2..483aae0c7 100644 --- a/theseus/core/tests/test_cost_function.py +++ b/theseus/core/tests/test_cost_function.py @@ -57,9 +57,8 @@ def test_default_name_and_ids(): assert len(seen_ids) == reps -@pytest.mark.parametrize( - "autograd_mode", [AutogradMode.DENSE, AutogradMode.LOOP_BATCH, AutogradMode.VMAP] -) +# Adding three formatting options to include coverage for autograd mode resolution +@pytest.mark.parametrize("autograd_mode", ["DENSE", "loop_batch", AutogradMode.VMAP]) def test_autodiff_cost_function_error_and_jacobians_shape(autograd_mode): rng = torch.Generator() rng.manual_seed(0)