Skip to content

Commit

Permalink
fix symmetric option in JacobianLinearOperator
Browse files Browse the repository at this point in the history
  • Loading branch information
rfeinman committed Nov 10, 2022
1 parent 8f9cfdc commit 1017e97
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions torchmin/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,20 @@ def rmv(self, v: Tensor) -> Tensor:
return vjp


def jacobian_linear_operator(x, f, symmetric=False):
if symmetric:
# Use vector-jacobian product (more efficient)
gf = gx = None
else:
# Apply the "double backwards" trick to get true
# jacobian-vector product
with torch.enable_grad():
gf = torch.zeros_like(f, requires_grad=True)
gx = autograd.grad(f, x, gf, create_graph=True)[0]
return JacobianLinearOperator(x, f, gf, gx, symmetric)



class ScalarFunction(object):
"""Scalar-valued objective function with autograd backend.
Expand Down Expand Up @@ -114,7 +128,7 @@ def closure(self, x):
hessp = None
hess = None
if self._hessp:
hessp = JacobianLinearOperator(x, grad, symmetric=self._twice_diffable)
hessp = jacobian_linear_operator(x, grad, symmetric=self._twice_diffable)
if self._hess:
if self._I is None:
self._I = torch.eye(x.numel(), dtype=x.dtype, device=x.device)
Expand Down Expand Up @@ -169,7 +183,7 @@ def closure(self, x):
jacp = None
jac = None
if self._jacp:
jacp = JacobianLinearOperator(x, f)
jacp = jacobian_linear_operator(x, f)
if self._jac:
if self._I is None:
self._I = torch.eye(f.numel(), dtype=x.dtype, device=x.device)
Expand Down

0 comments on commit 1017e97

Please sign in to comment.