Skip to content

Commit

Permalink
Merge pull request #11 from calvinmccarter/master
Browse files Browse the repository at this point in the history
Add very basic tests
  • Loading branch information
rfeinman authored Apr 11, 2022
2 parents 72ae847 + 033c818 commit 78be8f3
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 0 deletions.
Empty file added tests/__init__.py
Empty file.
5 changes: 5 additions & 0 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def test_import_packages():
"""Test that importing works."""
import torchmin
from torchmin import minimize
from torchmin import Minimizer
Empty file added tests/torchmin/__init__.py
Empty file.
27 changes: 27 additions & 0 deletions tests/torchmin/test_leastsquares.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
import torch

from torchmin import minimize

torch.manual_seed(42)
N = 100
D = 7
M = 5
X = torch.randn(N, D)
Y = torch.randn(N, M)
trueB = torch.linalg.inv(X.T @ X) @ X.T @ Y
all_methods = [
'bfgs', 'l-bfgs', 'cg', 'newton-cg', 'newton-exact',
'trust-ncg', 'trust-krylov', 'trust-exact', 'dogleg']


@pytest.mark.parametrize('method', all_methods)
def test_minimize(method):
"""Test least-squares problem on unconstrained minimizers."""
B0 = torch.zeros(D, M)

def leastsquares_obj(B):
return torch.sum((Y - X @ B) ** 2)

result = minimize(leastsquares_obj, B0, method=method)
torch.testing.assert_close(trueB, result.x, rtol=1e-4, atol=1e-4)

0 comments on commit 78be8f3

Please sign in to comment.