Skip to content

Commit

Permalink
StratifiedStandardize OutcomeTransform (#2671)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2671

see title. This allows applying stratified standardization at the model level, which will enable selecting whether to use a Single-task or multi-task model in Ax while using the appropriate transform. I.e. One could specify ModelConfigs that use 1) `SingleTaskGP` + `Standardize`, 2) `MultiTaskGP` + `StratifiedStandardize`.

Reviewed By: saitcakmak

Differential Revision: D67728920

fbshipit-source-id: ad6ee2bbed3e484288e94dcfb7b1555fbd4395e4
  • Loading branch information
sdaulton authored and facebook-github-bot committed Jan 16, 2025
1 parent be8ec7b commit 831ea5d
Show file tree
Hide file tree
Showing 5 changed files with 407 additions and 58 deletions.
35 changes: 1 addition & 34 deletions botorch/models/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from botorch.models.model import FantasizeMixin
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
from botorch.models.utils.assorted import get_task_value_remapping
from botorch.models.utils.gpytorch_modules import (
get_covar_module_with_dim_scaled_prior,
get_gaussian_likelihood_with_lognormal_prior,
Expand Down Expand Up @@ -82,40 +83,6 @@
from torch import Tensor


def get_task_value_remapping(task_values: Tensor, dtype: torch.dtype) -> Tensor | None:
"""Construct an mapping of discrete task values to contiguous int-valued floats.
Args:
task_values: A sorted long-valued tensor of task values.
dtype: The dtype of the model inputs (e.g. `X`), which the new
task values should have mapped to (e.g. float, double).
Returns:
A tensor of shape `task_values.max() + 1` that maps task values
to new task values. The indexing operation `mapper[task_value]`
will produce a tensor of new task values, of the same shape as
the original. The elements of the `mapper` tensor that do not
appear in the original `task_values` are mapped to `nan`. The
return value will be `None`, when the task values are contiguous
integers starting from zero.
"""
task_range = torch.arange(
len(task_values), dtype=task_values.dtype, device=task_values.device
)
mapper = None
if not torch.equal(task_values, task_range):
# Create a tensor that maps task values to new task values.
# The number of tasks should be small, so this should be quite efficient.
mapper = torch.full(
(int(task_values.max().item()) + 1,),
float("nan"),
dtype=dtype,
device=task_values.device,
)
mapper[task_values] = task_range.to(dtype=dtype)
return mapper


class MultiTaskGP(ExactGP, MultiTaskGPyTorchModel, FantasizeMixin):
r"""Multi-Task exact GP model using an ICM (intrinsic co-regionalization model)
kernel. See [Bonilla2007MTGP]_ and [Swersky2013MTBO]_ for a reference on the
Expand Down
274 changes: 250 additions & 24 deletions botorch/models/transforms/outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@

import torch
from botorch.models.transforms.utils import (
nanstd,
norm_to_lognorm_mean,
norm_to_lognorm_variance,
)
from botorch.models.utils.assorted import get_task_value_remapping
from botorch.posteriors import GPyTorchPosterior, Posterior, TransformedPosterior
from botorch.utils.transforms import normalize_indices
from linear_operator.operators import CholLinearOperator, DiagLinearOperator
Expand Down Expand Up @@ -259,6 +261,46 @@ def __init__(
self._batch_shape = batch_shape
self._min_stdv = min_stdv

def _get_per_input_means_stdvs(
self, X: Tensor, include_stdvs_sq: bool
) -> tuple[Tensor, Tensor, Tensor | None]:
r"""Get per-input means and stdvs.
Args:
X: A `batch_shape x n x d`-dim tensor of input parameters.
include_stdvs_sq: Whether to include the stdvs squared.
This parameter is not used by this method
Returns:
A three-tuple with the means and stdvs:
- The per-input means.
- The per-input stdvs.
- The per-input stdvs squared.
"""
return self.means, self.stdvs, self._stdvs_sq

def _validate_training_inputs(self, Y: Tensor, Yvar: Tensor | None = None) -> None:
"""Validate training inputs.
Args:
Y: A `batch_shape x n x m`-dim tensor of training targets.
Yvar: A `batch_shape x n x m`-dim tensor of observation noises.
"""
if Y.shape[:-2] != self._batch_shape:
raise RuntimeError(
f"Expected Y.shape[:-2] to be {self._batch_shape}, matching "
f"the `batch_shape` argument to `{self.__class__.__name__}`, but got "
f"Y.shape[:-2]={Y.shape[:-2]}."
)
elif Y.shape[-2] < 1:
raise ValueError(f"Can't standardize with no observations. {Y.shape=}.")
elif Y.size(-1) != self._m:
raise RuntimeError(
f"Wrong output dimension. Y.size(-1) is {Y.size(-1)}; expected "
f"{self._m}."
)

def forward(
self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None
) -> tuple[Tensor, Tensor | None]:
Expand All @@ -283,21 +325,8 @@ def forward(
- The transformed observation noise (if applicable).
"""
if self.training:
if Y.shape[:-2] != self._batch_shape:
raise RuntimeError(
f"Expected Y.shape[:-2] to be {self._batch_shape}, matching "
"the `batch_shape` argument to `Standardize`, but got "
f"Y.shape[:-2]={Y.shape[:-2]}."
)
if Y.size(-1) != self._m:
raise RuntimeError(
f"Wrong output dimension. Y.size(-1) is {Y.size(-1)}; expected "
f"{self._m}."
)
if Y.shape[-2] < 1:
raise ValueError(f"Can't standardize with no observations. {Y.shape=}.")

elif Y.shape[-2] == 1:
self._validate_training_inputs(Y=Y, Yvar=Yvar)
if Y.shape[-2] == 1:
stdvs = torch.ones(
(*Y.shape[:-2], 1, Y.shape[-1]), dtype=Y.dtype, device=Y.device
)
Expand All @@ -313,9 +342,12 @@ def forward(
self.stdvs = stdvs
self._stdvs_sq = stdvs.pow(2)
self._is_trained = torch.tensor(True)

Y_tf = (Y - self.means) / self.stdvs
Yvar_tf = Yvar / self._stdvs_sq if Yvar is not None else None
include_stdvs_sq = Yvar is not None
means, stdvs, stdvs_sq = self._get_per_input_means_stdvs(
X=X, include_stdvs_sq=include_stdvs_sq
)
Y_tf = (Y - means) / stdvs
Yvar_tf = Yvar / stdvs_sq if include_stdvs_sq else None
return Y_tf, Yvar_tf

def subset_output(self, idcs: list[int]) -> OutcomeTransform:
Expand Down Expand Up @@ -376,9 +408,12 @@ def untransform(
"(e.g. `transform(Y)`) before calling `untransform`, since "
"means and standard deviations need to be computed."
)

Y_utf = self.means + self.stdvs * Y
Yvar_utf = self._stdvs_sq * Yvar if Yvar is not None else None
include_stdvs_sq = Yvar is not None
means, stdvs, stdvs_sq = self._get_per_input_means_stdvs(
X=X, include_stdvs_sq=include_stdvs_sq
)
Y_utf = means + stdvs * Y
Yvar_utf = stdvs_sq * Yvar if include_stdvs_sq else None
return Y_utf, Yvar_utf

@property
Expand Down Expand Up @@ -433,8 +468,9 @@ def untransform_posterior(
)
# GPyTorchPosterior (TODO: Should we Lazy-evaluate the mean here as well?)
mvn = posterior.distribution
offset = self.means
scale_fac = self.stdvs
offset, scale_fac, _ = self._get_per_input_means_stdvs(
X=X, include_stdvs_sq=False
)
if not posterior._is_mt:
mean_tf = offset.squeeze(-1) + scale_fac.squeeze(-1) * mvn.mean
scale_fac = scale_fac.squeeze(-1).expand_as(mean_tf)
Expand All @@ -449,7 +485,6 @@ def untransform_posterior(

if (
not mvn.islazy
# TODO: Figure out attribute namming weirdness here
or mvn._MultivariateNormal__unbroadcasted_scale_tril is not None
):
# if already computed, we can save a lot of time using scale_tril
Expand All @@ -465,6 +500,197 @@ def untransform_posterior(
return GPyTorchPosterior(mvn_tf)


class StratifiedStandardize(Standardize):
r"""Standardize outcomes (zero mean, unit variance) along stratification dimension.
This module is stateful: If in train mode, calling forward updates the
module state (i.e. the mean/std normalizing constants). If in eval mode,
calling forward simply applies the standardization using the current module
state.
"""

def __init__(
self,
task_values: Tensor,
stratification_idx: int,
batch_shape: torch.Size = torch.Size(), # noqa: B008
min_stdv: float = 1e-8,
# dtype: torch.dtype = torch.double,
) -> None:
r"""Standardize outcomes (zero mean, unit variance) along stratification dim.
Note: This currenlty only supports single output models
(including multi-task models that have a single output).
Args:
task_values: `t`-dim tensor of task values.
stratification_idx: The index of the stratification dimension.
batch_shape: The batch_shape of the training targets.
min_stddv: The minimum standard deviation for which to perform
standardization (if lower, only de-mean the data).
"""
OutcomeTransform.__init__(self)
self._stratification_idx = stratification_idx
task_values = task_values.unique(sorted=True)
self.strata_mapping = get_task_value_remapping(task_values, dtype=torch.long)
if self.strata_mapping is None:
self.strata_mapping = task_values
n_strata = self.strata_mapping.shape[0]
self._min_stdv = min_stdv
self.register_buffer("means", torch.zeros(*batch_shape, n_strata, 1))
self.register_buffer("stdvs", torch.ones(*batch_shape, n_strata, 1))
self.register_buffer("_stdvs_sq", torch.ones(*batch_shape, n_strata, 1))
self.register_buffer("_is_trained", torch.tensor(False))
self._batch_shape = batch_shape
self._m = 1 # TODO: support multiple outputs
self._outputs = None

def forward(
self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None
) -> tuple[Tensor, Tensor | None]:
r"""Standardize outcomes.
If the module is in train mode, this updates the module state (i.e. the
mean/std normalizing constants). If the module is in eval mode, simply
applies the normalization using the module state.
Args:
Y: A `batch_shape x n x m`-dim tensor of training targets.
Yvar: A `batch_shape x n x m`-dim tensor of observation noises
associated with the training targets (if applicable).
X: A `batch_shape x n x d`-dim tensor of input parameters.
Returns:
A two-tuple with the transformed outcomes:
- The transformed outcome observations.
- The transformed observation noise (if applicable).
"""
if X is None:
raise ValueError("X is required for StratifiedStandardize.")
if self.training:
self._validate_training_inputs(Y=Y, Yvar=Yvar)
self.means = self.means.to(dtype=X.dtype, device=X.device)
self.stdvs = self.stdvs.to(dtype=X.dtype, device=X.device)
self._stdvs_sq = self._stdvs_sq.to(dtype=X.dtype, device=X.device)
strata = X[..., self._stratification_idx].long()
unique_strata = strata.unique()
for s in unique_strata:
mapped_strata = self.strata_mapping[s]
mask = strata != s
Y_strata = Y.clone()
Y_strata[..., mask, :] = float("nan")
stdvs = (
torch.ones_like(Y_strata)
if Y.shape[-2] == 1
else nanstd(X=Y_strata, dim=-2)
)
stdvs = stdvs.where(
stdvs >= self._min_stdv, torch.full_like(stdvs, 1.0)
)
means = Y_strata.nanmean(dim=-2)
self.means[..., mapped_strata, :] = means
self.stdvs[..., mapped_strata, :] = stdvs
self._stdvs_sq[..., mapped_strata, :] = stdvs.pow(2)
self._is_trained = torch.tensor(True)
training = self.training
self.training = False
tf_Y, tf_Yvar = super().forward(Y=Y, Yvar=Yvar, X=X)
self.training = training
return tf_Y, tf_Yvar

def _get_per_input_means_stdvs(
self, X: Tensor, include_stdvs_sq: bool
) -> tuple[Tensor, Tensor, Tensor | None]:
r"""Get per-input means and stdvs.
Args:
X: A `batch_shape x n x d`-dim tensor of input parameters.
include_stdvs_sq: Whether to include the stdvs squared.
Returns:
A three-tuple with the per-input means and stdvs:
- The per-input means.
- The per-input stdvs.
- The per-input stdvs squared.
"""
strata = X[..., self._stratification_idx].long()
mapped_strata = self.strata_mapping[strata].unsqueeze(-1)
# get means and stdvs for each strata
n_extra_batch_dims = mapped_strata.ndim - 2 - len(self._batch_shape)
expand_shape = mapped_strata.shape[:n_extra_batch_dims] + self.means.shape
means = torch.gather(
input=self.means.expand(expand_shape),
dim=-2,
index=mapped_strata,
)
stdvs = torch.gather(
input=self.stdvs.expand(expand_shape),
dim=-2,
index=mapped_strata,
)
if include_stdvs_sq:
stdvs_sq = torch.gather(
input=self._stdvs_sq.expand(expand_shape),
dim=-2,
index=mapped_strata,
)
else:
stdvs_sq = None
return means, stdvs, stdvs_sq

def subset_output(self, idcs: list[int]) -> OutcomeTransform:
r"""Subset the transform along the output dimension.
Args:
idcs: The output indices to subset the transform to.
Returns:
The current outcome transform, subset to the specified output indices.
"""
raise NotImplementedError

def untransform(
self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None
) -> tuple[Tensor, Tensor | None]:
r"""Un-standardize outcomes.
Args:
Y: A `batch_shape x n x m`-dim tensor of standardized targets.
Yvar: A `batch_shape x n x m`-dim tensor of standardized observation
noises associated with the targets (if applicable).
X: A `batch_shape x n x d`-dim tensor of input parameters.
Returns:
A two-tuple with the un-standardized outcomes:
- The un-standardized outcome observations.
- The un-standardized observation noise (if applicable).
"""
if X is None:
raise ValueError("X is required for StratifiedStandardize.")
return super().untransform(Y=Y, Yvar=Yvar, X=X)

def untransform_posterior(
self, posterior: Posterior, X: Tensor | None = None
) -> GPyTorchPosterior | TransformedPosterior:
r"""Un-standardize the posterior.
Args:
posterior: A posterior in the standardized space.
X: A `batch_shape x n x d`-dim tensor of training inputs (if applicable).
Returns:
The un-standardized posterior. If the input posterior is a
`GPyTorchPosterior`, return a `GPyTorchPosterior`. Otherwise, return a
`TransformedPosterior`.
"""
if X is None:
raise ValueError("X is required for StratifiedStandardize.")
return super().untransform_posterior(posterior=posterior, X=X)


class Log(OutcomeTransform):
r"""Log-transform outcomes.
Expand Down
Loading

0 comments on commit 831ea5d

Please sign in to comment.