From 033c818c2e0eea1a21b20ea57adfa97d8a57cec1 Mon Sep 17 00:00:00 2001 From: Calvin McCarter Date: Thu, 7 Apr 2022 23:14:47 -0400 Subject: [PATCH] very basic tests --- tests/__init__.py | 0 tests/test_imports.py | 5 +++++ tests/torchmin/__init__.py | 0 tests/torchmin/test_leastsquares.py | 27 +++++++++++++++++++++++++++ 4 files changed, 32 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/test_imports.py create mode 100644 tests/torchmin/__init__.py create mode 100644 tests/torchmin/test_leastsquares.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_imports.py b/tests/test_imports.py new file mode 100644 index 0000000..e1bde14 --- /dev/null +++ b/tests/test_imports.py @@ -0,0 +1,5 @@ +def test_import_packages(): + """Test that importing works.""" + import torchmin + from torchmin import minimize + from torchmin import Minimizer diff --git a/tests/torchmin/__init__.py b/tests/torchmin/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/torchmin/test_leastsquares.py b/tests/torchmin/test_leastsquares.py new file mode 100644 index 0000000..46a1899 --- /dev/null +++ b/tests/torchmin/test_leastsquares.py @@ -0,0 +1,27 @@ +import pytest +import torch + +from torchmin import minimize + +torch.manual_seed(42) +N = 100 +D = 7 +M = 5 +X = torch.randn(N, D) +Y = torch.randn(N, M) +trueB = torch.linalg.inv(X.T @ X) @ X.T @ Y +all_methods = [ + 'bfgs', 'l-bfgs', 'cg', 'newton-cg', 'newton-exact', + 'trust-ncg', 'trust-krylov', 'trust-exact', 'dogleg'] + + +@pytest.mark.parametrize('method', all_methods) +def test_minimize(method): + """Test least-squares problem on unconstrained minimizers.""" + B0 = torch.zeros(D, M) + + def leastsquares_obj(B): + return torch.sum((Y - X @ B) ** 2) + + result = minimize(leastsquares_obj, B0, method=method) + torch.testing.assert_close(trueB, result.x, rtol=1e-4, atol=1e-4)