Skip to content

Commit

Permalink
Added batching support to tactile pushing example. (facebookresearch#132
Browse files Browse the repository at this point in the history
)

* Added batching support to tactile pushing example.

* Changed outer loss so that it uses SE2.local().

* Added more options for the inner optimizer.

* Added options to control backward mode and step size.a

* Added results logging.

* Added cfg values for lr and early stopping.
  • Loading branch information
luisenp authored Mar 29, 2022
1 parent 1a2d2df commit 29646ad
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 81 deletions.
23 changes: 16 additions & 7 deletions examples/configs/tactile_pose_estimation.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
seed: 0
save_all: true

dataset_name: "rectangle-pushing-corners-keypoints"
sdf_name: "rect"

episode: 0
episode_length: 100
max_steps: 100
max_episodes: 1

inner_optim:
max_iters: 3
optimizer: GaussNewton
reg_w: 1e-4
backward_mode: IMPLICIT
step_size: 1.0
keep_step_size: true

# 0: disc, 1: rect-edges, 2: rect-corners, 3: ellip
class_label: 2
Expand All @@ -24,13 +34,12 @@ tactile_cost:

train:
# options: "weights_only" or "weights_and_measurement_nn"
mode: "weights_only"

batch_size: 1
num_batches: 1
mode: "weights_and_measurement_nn"

num_epochs: 100
eps_tracking_loss: 1e-5
batch_size: 4
num_epochs: 50
lr: 1e-3 # 5.0 for weights_only
eps_tracking_loss: 1e-10

options:
vis_traj: True
Expand Down
150 changes: 117 additions & 33 deletions examples/tactile_pose_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,26 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os
import pathlib
import random
import time
from typing import Dict

import hydra
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import theseus as th
import theseus.utils.examples as theg

# Logger
logger = logging.getLogger(__name__)

# To run this example, you will need a tactile pushing dataset available at
# https://dl.fbaipublicfiles.com/theseus/tactile_pushing_data.tar.gz
#
Expand Down Expand Up @@ -61,10 +67,41 @@
# 2021 (https://arxiv.org/abs/1705.10664)


def pack_batch_results(
theseus_outputs: Dict[str, torch.Tensor],
qsp_state_dict: torch.Tensor,
mfb_state_dict: torch.Tensor,
meas_state_dict: torch.Tensor,
info: th.optimizer.OptimizerInfo,
loss_value: float,
total_time: float,
) -> Dict:
def _clone(t_):
return t_.detach().cpu().clone()

return {
"theseus_outputs": dict((s, _clone(t)) for s, t in theseus_outputs.items()),
"qsp_state_dict": qsp_state_dict,
"mfb_state_dict": mfb_state_dict,
"meas_state_dict": meas_state_dict,
"err_history": info.err_history, # type: ignore
"loss": loss_value,
"total_time": total_time,
}


def run_learning_loop(cfg):
root_path = pathlib.Path(os.getcwd())
dataset_path = EXP_PATH / "datasets" / f"{cfg.dataset_name}.json"
sdf_path = EXP_PATH / "sdfs" / f"{cfg.sdf_name}.json"
dataset = theg.TactilePushingDataset(dataset_path, sdf_path, cfg.episode, device)
dataset = theg.TactilePushingDataset(
dataset_path,
sdf_path,
cfg.episode_length,
cfg.train.batch_size,
cfg.max_episodes,
device,
)

# -------------------------------------------------------------------- #
# Create pose estimator (which wraps a TheseusLayer)
Expand All @@ -77,6 +114,10 @@ def run_learning_loop(cfg):
step_window_moving_frame=cfg.tactile_cost.step_win_mf,
rectangle_shape=(cfg.shape.rect_len_x, cfg.shape.rect_len_y),
device=device,
optimizer_cls=getattr(th, cfg.inner_optim.optimizer),
max_iterations=cfg.inner_optim.max_iters,
step_size=cfg.inner_optim.step_size,
regularization_w=cfg.inner_optim.reg_w,
)
time_steps = pose_estimator.time_steps

Expand All @@ -94,29 +135,33 @@ def run_learning_loop(cfg):
qsp_model,
mf_between_model,
learnable_params,
hyperparameters,
) = theg.create_tactile_models(
cfg.train.mode, device, measurements_model_path=measurements_model_path
)
eps_tracking_loss = hyperparameters.get(
"eps_tracking_loss", cfg.train.eps_tracking_loss
)
outer_optim = optim.Adam(learnable_params, lr=hyperparameters["learning_rate"])
eps_tracking_loss = cfg.train.eps_tracking_loss
outer_optim = optim.Adam(learnable_params, lr=cfg.train.lr)

# -------------------------------------------------------------------- #
# Main learning loop
# -------------------------------------------------------------------- #
# Use theseus_layer in an outer learning loop to learn different cost
# function parameters:
measurements = dataset.get_measurements(
cfg.train.batch_size, cfg.train.num_batches, time_steps
)
obj_poses_gt = dataset.obj_poses[0:time_steps, :].clone().requires_grad_(True)
eff_poses_gt = dataset.eff_poses[0:time_steps, :].clone().requires_grad_(True)
theseus_inputs = {}
for _ in range(cfg.train.num_epochs):
measurements = dataset.get_measurements(time_steps)
results = {}
for epoch in range(cfg.train.num_epochs):
results[epoch] = {}
logger.info(" ********************* EPOCH {epoch} *********************")
losses = []
image_idx = 0
for batch_idx, batch in enumerate(measurements):
pose_and_motion_batch = dataset.get_start_pose_and_motion_for_batch(
batch_idx, time_steps
) # x_y_theta format
pose_estimator.update_start_pose_and_motion_from_batch(
pose_and_motion_batch
)
theseus_inputs = {}
# Updates the above with measurement factor data
theg.update_tactile_pushing_inputs(
dataset=dataset,
batch=batch,
Expand All @@ -129,28 +174,49 @@ def run_learning_loop(cfg):
theseus_inputs=theseus_inputs,
)

theseus_inputs, _ = pose_estimator.forward(
theseus_inputs, optimizer_kwargs={"verbose": True}
start_time = time.time_ns()
theseus_outputs, info = pose_estimator.forward(
theseus_inputs,
optimizer_kwargs={
"verbose": True,
"track_err_history": True,
"backward_mode": getattr(
th.BackwardMode, cfg.inner_optim.backward_mode
),
"__keep_final_step_size__": cfg.inner_optim.keep_step_size,
},
)
end_time = time.time_ns()

obj_poses_opt, eff_poses_opt = theg.get_tactile_poses_from_values(
values=theseus_inputs, time_steps=time_steps
values=theseus_outputs, time_steps=time_steps
)
obj_poses_gt, eff_poses_gt = dataset.get_gt_data_for_batch(
batch_idx, time_steps
)

loss = F.mse_loss(obj_poses_opt[batch_idx, :], obj_poses_gt)
se2_opt = th.SE2(x_y_theta=obj_poses_opt.view(-1, 3))
se2_gt = th.SE2(x_y_theta=obj_poses_gt.view(-1, 3))
loss = se2_opt.local(se2_gt).norm()
loss.backward()

nn.utils.clip_grad_norm_(qsp_model.parameters(), 100, norm_type=2)
nn.utils.clip_grad_norm_(mf_between_model.parameters(), 100, norm_type=2)
nn.utils.clip_grad_norm_(measurements_model.parameters(), 100, norm_type=2)

with torch.no_grad():
for name, param in qsp_model.named_parameters():
print(name, param.data)
logger.info(f"{name} {param.data}")
for name, param in mf_between_model.named_parameters():
print(name, param.data)
logger.info(f"{name} {param.data}")

print(" grad qsp", qsp_model.param.grad.norm().item())
print(" grad mfb", mf_between_model.param.grad.norm().item())
def _print_grad(msg_, param_):
logger.info(f"{msg_} {param_.grad.norm().item()}")

_print_grad(" grad qsp", qsp_model.param)
_print_grad(" grad mfb", mf_between_model.param)
_print_grad(" grad nn_weight", measurements_model.fc1.weight)
_print_grad(" grad nn_bias", measurements_model.fc1.bias)

outer_optim.step()

Expand All @@ -162,17 +228,35 @@ def run_learning_loop(cfg):

losses.append(loss.item())

if cfg.options.vis_traj:
theg.visualize_tactile_push2d(
obj_poses=obj_poses_opt[0, :],
eff_poses=eff_poses_opt[0, :],
obj_poses_gt=obj_poses_gt,
eff_poses_gt=eff_poses_gt,
rect_len_x=cfg.shape.rect_len_x,
rect_len_y=cfg.shape.rect_len_y,
)
if cfg.save_all:
results[epoch][batch_idx] = pack_batch_results(
theseus_outputs,
qsp_model.state_dict(),
mf_between_model.state_dict(),
measurements_model.state_dict(),
info,
loss.item(),
end_time - start_time,
)
torch.save(results, root_path / "results.pt")

if cfg.options.vis_traj:
for i in range(len(obj_poses_gt)):
save_dir = root_path / f"img_{image_idx}"
save_dir.mkdir(parents=True, exist_ok=True)
save_fname = save_dir / f"epoch{epoch}.png"
theg.visualize_tactile_push2d(
obj_poses=obj_poses_opt[i],
eff_poses=eff_poses_opt[i],
obj_poses_gt=obj_poses_gt[i],
eff_poses_gt=eff_poses_gt[i],
rect_len_x=cfg.shape.rect_len_x,
rect_len_y=cfg.shape.rect_len_y,
save_fname=save_fname,
)
image_idx += 1

print(f"AVG. LOSS: {np.mean(losses)}")
logger.info(f"AVG. LOSS: {np.mean(losses)}")

if np.mean(losses) < eps_tracking_loss:
break
Expand Down
Loading

0 comments on commit 29646ad

Please sign in to comment.