Skip to content

Commit

Permalink
minor refactor and improved code comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rfeinman committed Mar 10, 2022
1 parent 35c311b commit 83602fa
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions torchmin/newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ def _minimize_newton_exact(
elif handle_npd == 'grad':
d = g.neg()
elif handle_npd == 'cauchy':
# cauchy point for a trust radius of delta=1.
# equivalent to 'grad' method with scaled lr
gnorm = g.norm(p=2)
scale = 1 / gnorm
gHg = g.dot(hess.mv(g))
Expand All @@ -327,13 +329,12 @@ def _minimize_newton_exact(
d = scale * g.neg()
elif handle_npd == 'eig':
# this setting is experimental! use with caution
# TODO: why chose the factor 1.5 here? Seems to work best
eig0 = eigsh(hess.cpu().numpy(), k=1, which="SA", tol=1e-4,
return_eigenvectors=False).item()
# TODO: why use the factor 1.5 here? Seems to work best
eig0 = eigsh(hess.cpu().numpy(), k=1, which="SA", tol=1e-4)[0].item()
tau = max(1e-3 - 1.5 * eig0, 0)
hess.diagonal().add_(tau)
d = torch.cholesky_solve(g.neg().unsqueeze(1),
torch.linalg.cholesky(hess)).squeeze(1)
L = torch.linalg.cholesky(hess)
d = torch.cholesky_solve(g.neg().unsqueeze(1), L).squeeze(1)
else:
raise RuntimeError('invalid handle_npd encountered.')

Expand Down

0 comments on commit 83602fa

Please sign in to comment.