Skip to content

Commit

Permalink
Remove softmax from end to end test and do some clean up. (#389)
Browse files Browse the repository at this point in the history
  • Loading branch information
luisenp authored Nov 30, 2022
1 parent adb7c70 commit 36b38c9
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions tests/test_theseus_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def _run_optimizer_test(
verbose=False,
learning_method="default",
force_vectorization=False,
max_iterations=10,
):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"_run_test_for: {device}")
Expand Down Expand Up @@ -229,10 +230,12 @@ def _run_optimizer_test(
linear_solver_cls=linear_solver_cls,
use_learnable_error=use_learnable_error,
force_vectorization=force_vectorization,
max_iterations=max_iterations,
)
layer_ref.to(device)
initial_coefficients = torch.ones(batch_size, 2, device=device) * 0.75
with torch.no_grad():
input_values = {"coefficients": torch.ones(batch_size, 2, device=device) * 0.75}
input_values = {"coefficients": initial_coefficients}
target_vars, _ = layer_ref.forward(
input_values, optimizer_kwargs={**optimizer_kwargs, **{"verbose": verbose}}
)
Expand All @@ -252,14 +255,13 @@ def _run_optimizer_test(
)

# Here we create the outer loop models and optimizers for the cost weight
if cost_weight_model == "softmax":

if cost_weight_model == "direct":
cost_weight_params = nn.Parameter(
torch.randn(num_points, generator=rng, device=device)
)

def cost_weight_fn():
return F.softmax(cost_weight_params, dim=0).view(1, -1)
return cost_weight_params.clone().view(1, -1)

optimizer = torch.optim.Adam([cost_weight_params], lr=0.075)

Expand All @@ -280,6 +282,7 @@ def cost_weight_fn():
linear_solver_cls=linear_solver_cls,
use_learnable_error=use_learnable_error,
force_vectorization=force_vectorization,
max_iterations=max_iterations,
)
layer_to_learn.to(device)

Expand All @@ -291,7 +294,7 @@ def cost_weight_fn():
"learnable_err_param" if use_learnable_error else "cost_weight_values"
)
input_values = {
"coefficients": torch.ones(batch_size, 2, device=device) * 0.75,
"coefficients": initial_coefficients,
cost_weight_param_name: cost_weight_fn(),
}

Expand All @@ -313,7 +316,7 @@ def cost_weight_fn():
for i in range(200):
optimizer.zero_grad()
input_values = {
"coefficients": torch.ones(batch_size, 2, device=device) * 0.75,
"coefficients": initial_coefficients,
cost_weight_param_name: cost_weight_fn(),
}
pred_vars, info = layer_to_learn.forward(
Expand Down Expand Up @@ -374,7 +377,7 @@ def cost_weight_fn():
[th.CholeskyDenseSolver, th.LUDenseSolver, th.CholmodSparseSolver],
)
@pytest.mark.parametrize("use_learnable_error", [True, False])
@pytest.mark.parametrize("cost_weight_model", ["softmax", "mlp"])
@pytest.mark.parametrize("cost_weight_model", ["direct", "mlp"])
@pytest.mark.parametrize("learning_method", ["default", "leo"])
def test_backward(
nonlinear_optim_cls,
Expand All @@ -383,14 +386,14 @@ def test_backward(
cost_weight_model,
learning_method,
):
optim_kwargs = {} if nonlinear_optim_cls == th.GaussNewton else {"damping": 0.01}
optim_kwargs = {
th.GaussNewton: {},
th.LevenbergMarquardt: {"damping": 0.01},
}[nonlinear_optim_cls]
if learning_method == "leo":
# CholmodSparseSolver doesn't support sampling from system's covariance
if lin_solver_cls == th.CholmodSparseSolver:
return
# LEO fails to work in this case, not sure why
if cost_weight_model == "softmax":
return
# test both vectorization on/off
force_vectorization = torch.rand(1).item() > 0.5
_run_optimizer_test(
Expand All @@ -401,6 +404,7 @@ def test_backward(
use_learnable_error=use_learnable_error,
force_vectorization=force_vectorization,
learning_method=learning_method,
max_iterations=10,
)


Expand Down

0 comments on commit 36b38c9

Please sign in to comment.