Skip to content

Commit

Permalink
use most efficient dot-product method based on variable shape (conjug…
Browse files Browse the repository at this point in the history
…ate gradient)
  • Loading branch information
rfeinman committed Feb 3, 2022
1 parent 424800d commit 94116c8
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions torchmin/newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@ def _cg_iters(grad, hess, max_iter, normp=1):
Derived from Algorithm 7.1 of "Numerical Optimization (2nd Ed.)"
(Nocedal & Wright, 2006; pp. 169)
"""
# generalized dot product that supports batch inputs
# TODO: let the user specify dot fn?
dot = lambda u,v: u.mul(v).sum(-1, keepdim=True)
# Get the most efficient dot product method for this problem
if grad.dim() == 1:
# standard dot product
dot = torch.dot
elif grad.dim() == 2:
# batched dot product
dot = lambda u,v: torch.bmm(u.unsqueeze(1), v.unsqueeze(2)).view(-1,1)
else:
# generalized dot product that supports batch inputs
dot = lambda u,v: u.mul(v).sum(-1, keepdim=True)

g_norm = grad.norm(p=normp)
tol = g_norm * g_norm.sqrt().clamp(0, 0.5)
Expand Down

0 comments on commit 94116c8

Please sign in to comment.