Skip to content

Commit

Permalink
Add caching for diag(AtA).
Browse files Browse the repository at this point in the history
  • Loading branch information
luisenp committed Dec 5, 2022
1 parent 83b19b0 commit 8c5c348
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions theseus/optimizer/sparse_linearization.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,18 @@ def __init__(
# batched data
self.A_val: torch.Tensor = None
self.b: torch.Tensor = None

# computed lazily by self._atb_impl() and reset to None by
# self._linearize_jacobian_impl()
self._Atb: torch.Tensor = None

# computed lazily by self.diagonal_scaling() and reset to None by
# self._linearize_jacobian_impl()
self._AtA_diag: torch.Tensor = None

def _linearize_jacobian_impl(self):
self._Atb = None
self._AtA_diag = None

# those will be fully overwritten, no need to zero:
self.A_val = torch.empty(
Expand Down Expand Up @@ -173,11 +179,12 @@ def Av(self, v: torch.Tensor) -> torch.Tensor:
def diagonal_scaling(self, v: torch.Tensor) -> torch.Tensor:
assert v.ndim == 2
assert v.shape[1] == self.num_cols
A_val = self.A_val
diag = torch.zeros(A_val.shape[0], self.num_cols)
for row in range(self.num_rows):
start = self.A_row_ptr[row]
end = self.A_row_ptr[row + 1]
columns = self.A_col_ind[start:end]
diag[:, columns] += A_val[:, start:end] ** 2
return diag * v
if self._AtA_diag is None:
A_val = self.A_val
self._AtA_diag = torch.zeros(A_val.shape[0], self.num_cols)
for row in range(self.num_rows):
start = self.A_row_ptr[row]
end = self.A_row_ptr[row + 1]
columns = self.A_col_ind[start:end]
self._AtA_diag[:, columns] += A_val[:, start:end] ** 2
return self._AtA_diag * v

0 comments on commit 8c5c348

Please sign in to comment.