Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

error and errorsquaredNorm optional data #105

Merged
merged 10 commits into from
Mar 14, 2022
Prev Previous commit
unit test changes after review
  • Loading branch information
jeffin07 committed Mar 14, 2022
commit 1da20bc6743649d7b7ee6f867c311e4a9a17f576
45 changes: 29 additions & 16 deletions theseus/core/tests/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,6 @@ def _check_error_for_data(v1_data_, v2_data_, error_, error_type):
else:
assert error_.allclose(expected_error.norm(dim=1) ** 2)

# def _check_error_for_data(v1_data_, v2_data_, error_, error_norm_2_):
# expected_error = torch.cat([v1_data_, v2_data_], dim=1) * w
# assert error_.allclose(expected_error)
# assert error_norm_2_.allclose(expected_error.norm(dim=1) ** 2)

def _check_variables(objective, input_data, v1_data, v2_data, also_update):

if also_update:
Expand All @@ -273,7 +268,7 @@ def _check_variables(objective, input_data, v1_data, v2_data, also_update):
assert objective.optim_vars["v1"].data is v1_data
assert objective.optim_vars["v2"].data is v2_data

def _check_error(
def _check_error_and_variables(
v1_data_, v2_data_, error_, error_type, objective, input_data, also_update
):

Expand Down Expand Up @@ -313,9 +308,10 @@ def _check_error(
v2_data_new = torch.ones(batch_size, dof) * f2 * 0.1

input_data = {"v1": v1_data_new, "v2": v2_data_new}

error = objective.error(input_data=input_data, also_update=False)

_check_error(
_check_error_and_variables(
v1_data_new,
v2_data_new,
error,
Expand All @@ -325,11 +321,16 @@ def _check_error(
also_update=False,
)

v1_data_new = torch.ones(batch_size, dof) * f1 * 0.3
v2_data_new = torch.ones(batch_size, dof) * f2 * 0.3

input_data = {"v1": v1_data_new, "v2": v2_data_new}

error_norm_2 = objective.error_squared_norm(
input_data=input_data, also_update=False
jeffin07 marked this conversation as resolved.
Show resolved Hide resolved
)

_check_error(
_check_error_and_variables(
v1_data_new,
v2_data_new,
error_norm_2,
Expand All @@ -339,23 +340,35 @@ def _check_error(
also_update=False,
)

v1_data = torch.ones(batch_size, dof) * f1 * 0.4
v2_data = torch.ones(batch_size, dof) * f2 * 0.4
v1_data_new = torch.ones(batch_size, dof) * f1 * 0.4
v2_data_new = torch.ones(batch_size, dof) * f2 * 0.4

input_data = {"v1": v1_data_new, "v2": v2_data_new}

input_data = {"v1": v1_data, "v2": v2_data}
error = objective.error(input_data=input_data, also_update=True)

_check_error(
v1_data, v2_data, error, "error", objective, input_data, also_update=True
_check_error_and_variables(
v1_data_new,
v2_data_new,
error,
"error",
objective,
input_data,
also_update=True,
)

v1_data_new = torch.ones(batch_size, dof) * f1 * 0.4
v2_data_new = torch.ones(batch_size, dof) * f2 * 0.4

input_data = {"v1": v1_data_new, "v2": v2_data_new}

error_norm_2 = objective.error_squared_norm(
mhmukadam marked this conversation as resolved.
Show resolved Hide resolved
input_data=input_data, also_update=True
)

_check_error(
v1_data,
v2_data,
_check_error_and_variables(
v1_data_new,
v2_data_new,
error_norm_2,
"error_norm_2",
objective,
Expand Down