Skip to content

Commit

Permalink
mod: inference step
Browse files Browse the repository at this point in the history
  • Loading branch information
kelvin34501 committed Oct 11, 2024
1 parent 0b0f07a commit 6b1becc
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 12 deletions.
45 changes: 34 additions & 11 deletions lib/models/POEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,26 +298,36 @@ def _forward_impl(self, batch, **kwargs):
ref_joints.append(ref_joints_sub)
ref_joints = torch.cat(ref_joints, dim=0)

gt_J3d = batch["master_joints_3d"].reshape(-1, 21, 3)
gt_V3d = batch["master_verts_3d"].reshape(-1, 778, 3)
gt_mesh = torch.cat([gt_J3d, gt_V3d], dim=1) # (B, 799, 3)
if mode != "inference":
gt_J3d = batch["master_joints_3d"].reshape(-1, 21, 3)
gt_V3d = batch["master_verts_3d"].reshape(-1, 778, 3)
gt_mesh = torch.cat([gt_J3d, gt_V3d], dim=1) # (B, 799, 3)

# prepare image_metas
img_metas = {
"inp_img_shape": inp_img_shape, # h, w
"cam_intr": batch["target_cam_intr"].reshape(-1, 3, 3), # tensor (BN, 3, 3)
"cam_extr": batch["target_cam_extr"].reshape(-1, 4, 4), # tensor (BN, 4, 4)
"master_id": batch["master_id"], # lst (B, )
"ref_mesh_gt": gt_mesh,
# "ref_mesh_gt": gt_mesh,
"cam_view_num": batch["cam_view_num"]
}

debug_metas = {"img": batch["image"], "2d_joints_gt": batch["target_joints_2d"]}

preds = self.ptEmb_head(mlvl_feat=mlvl_feat,
img_metas=img_metas,
reference_joints=ref_joints,
debug_metas=debug_metas)
if mode != "inference":
img_metas["ref_mesh_gt"] = gt_mesh
img_metas["master_joints_3d"] = gt_J3d
img_metas["master_verts_3d"] = gt_V3d

debug_metas = {"img": batch["image"], "2d_joints_gt": batch["target_joints_2d"]}

extra_kwargs = {}
if mode != "inference":
extra_kwargs["debug_metas"] = debug_metas
preds = self.ptEmb_head(
mlvl_feat=mlvl_feat,
img_metas=img_metas,
reference_joints=ref_joints,
**extra_kwargs,
)

# last decoder's output
pred_joints_3d = preds["all_coords_preds"][-1, :, :self.num_joints, :] # (B, 21, 3)
Expand Down Expand Up @@ -670,6 +680,19 @@ def format_metric(self, mode="train"):

return " | ".join([str(me) for me in metric_toshow])

def inference_step(self, batch, step_idx, **kwargs):
# img = batch["image"] # (BN, 3, H, W) 4 channels
# batch_size = len(batch["cam_view_num"])

preds = self._forward_impl(batch, mode="inference", **kwargs)

if "callback" in kwargs:
callback = kwargs.pop("callback")
if callable(callback):
callback(preds, batch, step_idx, **kwargs)

return preds

def forward(self, inputs, step_idx, mode="train", **kwargs):
if mode == "train":
return self.training_step(inputs, step_idx, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion lib/models/heads/ptEmb_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def forward(self, mlvl_feat, img_metas, reference_points, template_mesh, **kwarg
batch_size, num_cams = x.size(0), x.size(1)

inp_img_w, inp_img_h = img_metas["inp_img_shape"] # (256, 256)
ref_mesh_gt = img_metas["ref_mesh_gt"]
# ref_mesh_gt = img_metas["ref_mesh_gt"]
inp_res = torch.Tensor([inp_img_w, inp_img_h]).to(x.device).float()
masks = x.new_zeros((batch_size, num_cams, inp_img_h, inp_img_w))
x = self.input_proj(x.flatten(0, 1))
Expand Down

0 comments on commit 6b1becc

Please sign in to comment.