# Copyright 2020 DeepMind Technologies Limited. # # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Loss functions to be used by LayerCollection.""" import abc from typing import Tuple, Optional, Union, Sequence import jax import jax.numpy as jnp from kfac_ferminet_alpha import distributions from kfac_ferminet_alpha import layers_and_loss_tags as tags from kfac_ferminet_alpha import utils ArrayPair = Tuple[jnp.ndarray, jnp.ndarray] FloatArray = Union[float, jnp.ndarray] Index = Tuple[int] class LossFunction(abc.ABC): """Abstract base class for loss functions. Note that unlike typical loss functions used in neural networks these are neither summed nor averaged over the batch and hence the output of evaluate() will not be a scalar. It is up to the user to then to correctly manipulate them as needed. """ def __init__(self, weight: FloatArray): self._weight = weight @property def weight(self) -> FloatArray: return self._weight @property @abc.abstractmethod def targets(self) -> Optional[jnp.ndarray]: """The targets being predicted by the model. Returns: None or Tensor of appropriate shape for calling self._evaluate() on. """ pass @property @abc.abstractmethod def inputs(self) -> Sequence[jnp.ndarray]: """The inputs to the loss function (excluding the targets).""" pass @abc.abstractmethod def copy_with_different_inputs(self, inputs: Sequence[jnp.ndarray]): pass def evaluate( self, targets: Optional[jnp.ndarray] = None, coefficient_mode: str = "regular", ) -> jnp.ndarray: """Evaluate the loss function on the targets.""" if targets is None and self.targets is None: raise ValueError("Cannot evaluate losses with unspecified targets.") elif targets is None: targets = self.targets if coefficient_mode == "regular": multiplier = self.weight elif coefficient_mode == "sqrt": multiplier = jnp.sqrt(self.weight) elif coefficient_mode == "off": multiplier = 1.0 else: raise ValueError(f"Unrecognized coefficient_mode={coefficient_mode}.") return self._evaluate(targets) * multiplier @abc.abstractmethod def _evaluate(self, targets: jnp.ndarray) -> jnp.ndarray: """Evaluates the negative log probability of the targets. Args: targets: Tensor that distribution can calculate log_prob() of. Returns: negative log probability of each target, summed across all targets. """ pass def grad_of_evaluate( self, targets: Optional[jnp.ndarray], coefficient_mode: str, ) -> Sequence[jnp.ndarray]: """Evaluates the gradient of the loss function. Note that the targets of the loss must not be `None`. Args: targets: The potential targets on which to evaluate the gradient. coefficient_mode: The coefficient mode to use for evaluation. Returns: The gradient of the loss evaluation function with respect to the inputs. """ def evaluate_sum(inputs: Sequence[jnp.ndarray]) -> jnp.ndarray: instance = self.copy_with_different_inputs(inputs) return jnp.sum(instance.evaluate(targets, coefficient_mode)) return jax.grad(evaluate_sum)(self.inputs) def multiply_ggn(self, vector: jnp.ndarray) -> jnp.ndarray: """Right-multiply a vector by the GGN. Here the 'GGN' is the GGN matrix (whose definition is slightly flexible) of the loss function with respect to its inputs. Args: vector: The vector to multiply. Must be the same shape(s) as the 'inputs' property. Returns: The vector right-multiplied by the GGN. Will be of the same shape(s) as the 'inputs' property. """ return utils.scalar_mul(self.multiply_ggn_unweighted(vector), self.weight) @abc.abstractmethod def multiply_ggn_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: """Same as `multiply_ggn`, but without taking into account the weight.""" pass def multiply_ggn_factor(self, vector: jnp.ndarray) -> jnp.ndarray: """Right-multiply a vector by a factor B of the GGN. Here the 'GGN' is the GGN matrix (whose definition is slightly flexible) of the loss function with respect to its inputs. Typically this will be block-diagonal across different cases in the batch, since the loss function is typically summed across cases. Note that B can be any matrix satisfying B * B^T = G where G is the GGN, but will agree with the one used in the other methods of this class. Args: vector: The vector to multiply. Must be of the shape given by the 'ggn_factor_inner_shape' property. Returns: The vector right-multiplied by B. Will be of the same shape(s) as the 'inputs' property. """ return utils.scalar_mul( self.multiply_ggn_factor_unweighted(vector), jnp.sqrt(self.weight)) @abc.abstractmethod def multiply_ggn_factor_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: """Same as `multiply_ggn_factor`, but without taking into account the weight.""" pass def multiply_ggn_factor_transpose(self, vector: jnp.ndarray) -> jnp.ndarray: """Right-multiply a vector by the transpose of a factor B of the GGN. Here the 'GGN' is the GGN matrix (whose definition is slightly flexible) of the loss function with respect to its inputs. Typically this will be block-diagonal across different cases in the batch, since the loss function is typically summed across cases. Note that B can be any matrix satisfying B * B^T = G where G is the GGN, but will agree with the one used in the other methods of this class. Args: vector: The vector to multiply. Must be the same shape(s) as the 'inputs' property. Returns: The vector right-multiplied by B^T. Will be of the shape given by the 'ggn_factor_inner_shape' property. """ return utils.scalar_mul( self.multiply_ggn_factor_transpose_unweighted(vector), jnp.sqrt(self.weight)) @abc.abstractmethod def multiply_ggn_factor_transpose_unweighted( self, vector: jnp.ndarray ) -> jnp.ndarray: """Same as `multiply_ggn_factor_transpose`, but without taking into account the weight.""" pass def multiply_ggn_factor_replicated_one_hot(self, index: Index) -> jnp.ndarray: """Right-multiply a replicated-one-hot vector by a factor B of the GGN. Here the 'GGN' is the GGN matrix (whose definition is slightly flexible) of the loss function with respect to its inputs. Typically this will be block-diagonal across different cases in the batch, since the loss function is typically summed across cases. A 'replicated-one-hot' vector means a tensor which, for each slice along the batch dimension (assumed to be dimension 0), is 1.0 in the entry corresponding to the given index and 0 elsewhere. Note that B can be any matrix satisfying B * B^T = G where G is the GGN, but will agree with the one used in the other methods of this class. Args: index: A tuple representing in the index of the entry in each slice that is 1.0. Note that len(index) must be equal to the number of elements of the 'ggn_factor_inner_shape' tensor minus one. Returns: The vector right-multiplied by B^T. Will be of the same shape(s) as the 'inputs' property. """ return utils.scalar_mul( self.multiply_ggn_factor_replicated_one_hot_unweighted(index), jnp.sqrt(self.weight)) @abc.abstractmethod def multiply_ggn_factor_replicated_one_hot_unweighted( self, index: Index ) -> jnp.ndarray: pass @property @abc.abstractmethod def ggn_factor_inner_shape(self) -> Sequence[int]: """The shape of the tensor returned by multiply_ggn_factor.""" pass class NegativeLogProbLoss(LossFunction): """Abstract base class for loss functions that are negative log probs.""" @property def inputs(self): return self.params @property @abc.abstractmethod def params(self): """Parameters to the underlying distribution.""" pass def multiply_fisher(self, vector: jnp.ndarray) -> jnp.ndarray: """Right-multiply a vector by the Fisher. Args: vector: The vector to multiply. Must be the same shape(s) as the 'inputs' property. Returns: The vector right-multiplied by the Fisher. Will be of the same shape(s) as the 'inputs' property. """ return utils.scalar_mul( self.multiply_fisher_unweighted(vector), self.weight) @abc.abstractmethod def multiply_fisher_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: pass def multiply_fisher_factor(self, vector: jnp.ndarray) -> jnp.ndarray: """Right-multiply a vector by a factor B of the Fisher. Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- product of gradients) with respect to the parameters of the underlying probability distribution (whose log-prob defines the loss). Typically this will be block-diagonal across different cases in the batch, since the distribution is usually (but not always) conditionally iid across different cases. Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, but will agree with the one used in the other methods of this class. Args: vector: The vector to multiply. Must be of the shape given by the 'fisher_factor_inner_shape' property. Returns: The vector right-multiplied by B. Will be of the same shape(s) as the 'inputs' property. """ return utils.scalar_mul( self.multiply_fisher_factor_unweighted(vector), jnp.sqrt(self.weight)) @abc.abstractmethod def multiply_fisher_factor_unweighted( self, vector: jnp.ndarray ) -> jnp.ndarray: pass def multiply_fisher_factor_transpose( self, vector: jnp.ndarray ) -> jnp.ndarray: """Right-multiply a vector by the transpose of a factor B of the Fisher. Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- product of gradients) with respect to the parameters of the underlying probability distribution (whose log-prob defines the loss). Typically this will be block-diagonal across different cases in the batch, since the distribution is usually (but not always) conditionally iid across different cases. Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, but will agree with the one used in the other methods of this class. Args: vector: The vector to multiply. Must be the same shape(s) as the 'inputs' property. Returns: The vector right-multiplied by B^T. Will be of the shape given by the 'fisher_factor_inner_shape' property. """ return utils.scalar_mul( self.multiply_fisher_factor_transpose_unweighted(vector), jnp.sqrt(self.weight)) @abc.abstractmethod def multiply_fisher_factor_transpose_unweighted( self, vector: jnp.ndarray ) -> jnp.ndarray: pass def multiply_fisher_factor_replicated_one_hot( self, index: Index ) -> jnp.ndarray: """Right-multiply a replicated-one-hot vector by a factor B of the Fisher. Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- product of gradients) with respect to the parameters of the underlying probability distribution (whose log-prob defines the loss). Typically this will be block-diagonal across different cases in the batch, since the distribution is usually (but not always) conditionally iid across different cases. A 'replicated-one-hot' vector means a tensor which, for each slice along the batch dimension (assumed to be dimension 0), is 1.0 in the entry corresponding to the given index and 0 elsewhere. Note that B can be any matrix satisfying B * B^T = H where H is the Fisher, but will agree with the one used in the other methods of this class. Args: index: A tuple representing in the index of the entry in each slice that is 1.0. Note that len(index) must be equal to the number of elements of the 'fisher_factor_inner_shape' tensor minus one. Returns: The vector right-multiplied by B. Will be of the same shape(s) as the 'inputs' property. """ return utils.scalar_mul( self.multiply_fisher_factor_replicated_one_hot_unweighted(index), jnp.sqrt(self.weight)) @abc.abstractmethod def multiply_fisher_factor_replicated_one_hot_unweighted( self, index: Index ) -> jnp.ndarray: pass @property @abc.abstractmethod def fisher_factor_inner_shape(self) -> Sequence[int]: """The shape of the tensor returned by multiply_fisher_factor.""" pass @abc.abstractmethod def sample(self, rng_key: jnp.ndarray) -> jnp.ndarray: """Sample 'targets' from the underlying distribution.""" pass def grad_of_evaluate_on_sample( self, rng_key: jnp.ndarray, coefficient_mode: str, ) -> Sequence[jnp.ndarray]: """Evaluates the gradient of the log probability on a random sample. Args: rng_key: Jax PRNG key for sampling. coefficient_mode: The coefficient mode to use for evaluation. Returns: The gradient of the log probability of targets sampled from the distribution. """ return self.grad_of_evaluate(self.sample(rng_key), coefficient_mode) class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss, abc.ABC): """Base class for neg log prob losses whose inputs are 'natural' parameters. We will take the GGN of the loss to be the Fisher associated with the distribution, which also happens to be equal to the Hessian for this class of loss functions. See here: https://arxiv.org/abs/1412.1193 'Natural parameters' are defined for exponential-family models. See for example: https://en.wikipedia.org/wiki/Exponential_family """ def multiply_ggn_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: return self.multiply_fisher_unweighted(vector) def multiply_ggn_factor_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: return self.multiply_fisher_factor_unweighted(vector) def multiply_ggn_factor_transpose_unweighted( self, vector: jnp.ndarray ) -> jnp.ndarray: return self.multiply_fisher_factor_transpose_unweighted(vector) def multiply_ggn_factor_replicated_one_hot_unweighted( self, index: Index ) -> jnp.ndarray: return self.multiply_fisher_factor_replicated_one_hot_unweighted(index) @property def ggn_factor_inner_shape(self) -> Sequence[int]: return self.fisher_factor_inner_shape class DistributionNegativeLogProbLoss(NegativeLogProbLoss): """Base class for neg log prob losses that use the distribution classes.""" @property @abc.abstractmethod def dist(self): """The underlying distribution instance.""" pass def _evaluate(self, targets: jnp.ndarray): return -self.dist.log_prob(targets) def sample(self, rng_key: jnp.ndarray): return self.dist.sample(seed=rng_key) @property def fisher_factor_inner_shape(self) -> Sequence[int]: return self.dist.mean().shape class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss, NaturalParamsNegativeLogProbLoss): """Neg log prob loss for a normal distribution parameterized by a mean vector. Note that the covariance is treated as the identity divided by 2. Also note that the Fisher for such a normal distribution with respect the mean parameter is given by: F = (1 / variance) * I See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf. """ def __init__( self, mean: jnp.ndarray, targets: Optional[jnp.ndarray] = None, variance: float = 0.5, weight: float = 1.0, ): super().__init__(weight=weight) self._mean = mean self._targets = targets self._variance = variance if not isinstance(variance, float): raise ValueError("The `variance` argument should be python float.") @property def targets(self) -> Optional[jnp.ndarray]: return self._targets @property def dist(self): scale_diag = jnp.full_like(self._mean, jnp.sqrt(self._variance)) return distributions.MultivariateNormalDiag(self._mean, scale_diag) @property def params(self): return self._mean, def copy_with_different_inputs(self, inputs: Sequence[jnp.ndarray]): [mean] = inputs return NormalMeanNegativeLogProbLoss( mean=mean, targets=self.targets, variance=self._variance, weight=self.weight, ) def multiply_fisher_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: return vector / self._variance def multiply_fisher_factor_unweighted( self, vector: jnp.ndarray, ) -> jnp.ndarray: return vector / jnp.sqrt(self._variance) def multiply_fisher_factor_transpose_unweighted( self, vector: jnp.ndarray, ) -> jnp.ndarray: return self.multiply_fisher_factor_unweighted(vector) # it's symmetric def multiply_fisher_factor_replicated_one_hot_unweighted( self, index: Index, ) -> jnp.ndarray: assert len(index) == 1, f"Length of index was {len(index)}." index = index[0] ones_slice = jnp.ones([self._mean.shape[0]])[..., None] output_slice = ones_slice / jnp.sqrt(self._variance) return insert_slice_in_zeros(output_slice, 1, self._mean.shape[1], index) def insert_slice_in_zeros( slice_to_insert: jnp.ndarray, dim: int, dim_size: int, position: int, ) -> jnp.ndarray: """Inserts slice into a larger tensor of zeros. Forms a new tensor which is the same shape as slice_to_insert, except that the dimension given by 'dim' is expanded to the size given by 'dim_size'. 'position' determines the position (index) at which to insert the slice within that dimension. Assumes slice_to_insert.shape[dim] = 1. Args: slice_to_insert: The slice to insert. dim: The dimension which to expand with zeros. dim_size: The new size of the 'dim' dimension. position: The position of 'slice_to_insert' in the new tensor. Returns: The new tensor. Raises: ValueError: If the slice's shape at the given dim is not 1. """ slice_shape = slice_to_insert.shape if slice_shape[dim] != 1: raise ValueError(f"Expected slice_to_insert.shape to have {dim} dim of 1," f" but was {slice_to_insert.shape[dim]}.") before = [0] * int(len(slice_shape)) after = before[:] before[dim] = position after[dim] = dim_size - position - 1 return jnp.pad(slice_to_insert, list(zip(before, after))) # _______ _____ _ _ _ _ # |__ __| | __ \ (_) | | | | (_) # | | __ _ __ _ | |__) |___ __ _ _ ___| |_ _ __ __ _| |_ _ ___ _ __ # | |/ _` |/ _` | | _ // _ \/ _` | / __| __| '__/ _` | __| |/ _ \| '_ \ # | | (_| | (_| | | | \ \ __/ (_| | \__ \ |_| | | (_| | |_| | (_) | | | | # |_|\__,_|\__, | |_| \_\___|\__, |_|___/\__|_| \__,_|\__|_|\___/|_| |_| # __/ | __/ | # |___/ |___/ NormalMeanNegativeLogProbLoss_tag = tags.LossTag( NormalMeanNegativeLogProbLoss, num_inputs=1) def register_normal_predictive_distribution( mean: jnp.ndarray, targets: Optional[jnp.ndarray] = None, variance: float = 0.5, weight: float = 1.0, ): """Registers a normal predictive distribution. This corresponds to a squared error loss of the form weight/(2*var) * ||target - mean||^2 Args: mean: A tensor defining the mean vector of the distribution. The first dimension must be the batch size. targets: (OPTIONAL) The targets for the loss function. Only required if one wants to use the "empirical Fisher" instead of the true Fisher (which is controlled by the 'estimation_mode' to the optimizer). (Default: None) variance: float. The variance of the distribution. Note that the default value of 0.5 corresponds to a standard squared error loss weight * ||target - prediction||^2. If you want your squared error loss to be of the form 0.5*coeff*||target - prediction||^2 you should use variance=1.0. (Default: 0.5) weight: A scalar coefficient to multiply the log prob loss associated with this distribution. The Fisher will be multiplied by the corresponding factor. In general this is NOT equivalent to changing the temperature of the distribution, but in the ase of normal distributions it may be. (Default: 1.0) Returns: The mean and targets as dependable on the tag. """ if targets is None: targets = jnp.zeros_like(mean) return NormalMeanNegativeLogProbLoss_tag.bind( mean, targets, variance=variance, weight=weight, return_loss=False) def register_squared_error_loss( prediction: jnp.ndarray, targets: Optional[jnp.ndarray] = None, weight: float = 1.0, ): """Registers a squared error loss function. This assumes the squared error loss of the form ||target - prediction||^2, averaged across the mini-batch. If your loss uses a coefficient of 0.5 you need to set the "weight" argument to reflect this. Args: prediction: The prediction made by the network (i.e. its output). The first dimension must be the batch size. targets: (OPTIONAL) The targets for the loss function. Only required if one wants to use the "empirical Fisher" instead of the true Fisher (which is controlled by the 'estimation_mode' to the optimizer). (Default: None) weight: A float coefficient to multiply the loss function by. (Default: 1.0) Returns: The mean and targets as dependable on the tag. """ return register_normal_predictive_distribution( prediction, targets=targets, variance=0.5, weight=weight)