From 8c5c348da1cd71ec0727f0d7867806cb57dfe46a Mon Sep 17 00:00:00 2001 From: Luis Pineda <4759586+luisenp@users.noreply.github.com> Date: Fri, 2 Dec 2022 21:36:03 -0800 Subject: [PATCH] Add caching for diag(AtA). --- theseus/optimizer/sparse_linearization.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/theseus/optimizer/sparse_linearization.py b/theseus/optimizer/sparse_linearization.py index da7d015a8..f0a87b922 100644 --- a/theseus/optimizer/sparse_linearization.py +++ b/theseus/optimizer/sparse_linearization.py @@ -86,12 +86,18 @@ def __init__( # batched data self.A_val: torch.Tensor = None self.b: torch.Tensor = None + # computed lazily by self._atb_impl() and reset to None by # self._linearize_jacobian_impl() self._Atb: torch.Tensor = None + # computed lazily by self.diagonal_scaling() and reset to None by + # self._linearize_jacobian_impl() + self._AtA_diag: torch.Tensor = None + def _linearize_jacobian_impl(self): self._Atb = None + self._AtA_diag = None # those will be fully overwritten, no need to zero: self.A_val = torch.empty( @@ -173,11 +179,12 @@ def Av(self, v: torch.Tensor) -> torch.Tensor: def diagonal_scaling(self, v: torch.Tensor) -> torch.Tensor: assert v.ndim == 2 assert v.shape[1] == self.num_cols - A_val = self.A_val - diag = torch.zeros(A_val.shape[0], self.num_cols) - for row in range(self.num_rows): - start = self.A_row_ptr[row] - end = self.A_row_ptr[row + 1] - columns = self.A_col_ind[start:end] - diag[:, columns] += A_val[:, start:end] ** 2 - return diag * v + if self._AtA_diag is None: + A_val = self.A_val + self._AtA_diag = torch.zeros(A_val.shape[0], self.num_cols) + for row in range(self.num_rows): + start = self.A_row_ptr[row] + end = self.A_row_ptr[row + 1] + columns = self.A_col_ind[start:end] + self._AtA_diag[:, columns] += A_val[:, start:end] ** 2 + return self._AtA_diag * v