Skip to content

Commit

Permalink
Merge pull request #23 from calvinmccarter/master
Browse files Browse the repository at this point in the history
Supports constrained minimization with constraint function with constant gradient
  • Loading branch information
rfeinman authored Mar 14, 2023
2 parents 1017e97 + 9238dac commit e1af11f
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
57 changes: 57 additions & 0 deletions tests/torchmin/test_minimize_constr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest
import torch

from torchmin import minimize, minimize_constr
from torchmin.benchmarks import rosen


def test_rosen():
"""Test Rosenbrock problem with constraints."""

x0 = torch.tensor([1., 8.])
res = minimize(
rosen, x0,
method='l-bfgs',
options=dict(line_search='strong-wolfe'),
max_iter=50,
disp=0
)


# Test inactive constraints

res_constrained_sum = minimize_constr(
rosen, x0,
constr=dict(fun=lambda x: x.sum(), ub=10.),
max_iter=50,
disp=0
)
torch.testing.assert_close(
res.x, res_constrained_sum.x, rtol=1e-2, atol=1e-2)

res_constrained_norm = minimize_constr(
rosen, x0,
constr=dict(fun=lambda x: x.square().sum(), ub=10.),
max_iter=50,
disp=0
)
torch.testing.assert_close(
res.x, res_constrained_norm.x, rtol=1e-2, atol=1e-2)


# Test active constraints

res_constrained_sum = minimize_constr(
rosen, x0,
constr=dict(fun=lambda x: x.sum(), ub=1.),
max_iter=50,
disp=0
)
assert res_constrained_sum.x.sum() <= 1.
res_constrained_norm = minimize_constr(
rosen, x0,
constr=dict(fun=lambda x: x.square().sum(), ub=1.),
max_iter=50,
disp=0
)
assert res_constrained_norm.x.square().sum() <= 1.
6 changes: 5 additions & 1 deletion torchmin/minimize_constr.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ def matvec(p):
grad, = torch.autograd.grad(f_(x), x, create_graph=True)
def matvec(p):
p = to_tensor(p)
hvp, = torch.autograd.grad(grad, x, p, retain_graph=True)
if grad.grad_fn is None:
# If grad_fn is None, then grad is constant wrt x, and hess is 0.
hvp = torch.zeros_like(grad)
else:
hvp, = torch.autograd.grad(grad, x, p, retain_graph=True)
return v[0] * hvp.view(-1).cpu().numpy()
return LinearOperator((numel, numel), matvec=matvec)

Expand Down

0 comments on commit e1af11f

Please sign in to comment.