diff --git a/src/pytorch_volumetric/sdf.py b/src/pytorch_volumetric/sdf.py index c1221eb..230086f 100644 --- a/src/pytorch_volumetric/sdf.py +++ b/src/pytorch_volumetric/sdf.py @@ -28,8 +28,8 @@ class SDFQuery(NamedTuple): class ObjectFactory(abc.ABC): - def __init__(self, name, scale=1.0, vis_frame_pos=(0, 0, 0), vis_frame_rot=(0, 0, 0, 1), - plausible_suboptimality=0.001, **kwargs): + def __init__(self, name='', scale=1.0, vis_frame_pos=(0, 0, 0), vis_frame_rot=(0, 0, 0, 1), + plausible_suboptimality=0.001, mesh=None, **kwargs): self.name = name self.scale = scale if scale is not None else 1.0 # frame from model's base frame to the simulation's use of the model @@ -39,10 +39,11 @@ def __init__(self, name, scale=1.0, vis_frame_pos=(0, 0, 0), vis_frame_rot=(0, 0 self.plausible_suboptimality = plausible_suboptimality # use external mesh library to compute closest point for non-convex meshes - self._mesh = None + self._mesh = mesh self._mesht = None self._raycasting_scene = None self._face_normals = None + self.precompute_sdf() def __reduce__(self): return partial(self.__class__, scale=self.scale, vis_frame_pos=self.vis_frame_pos, @@ -67,10 +68,7 @@ def draw_mesh(self, dd, name, pose, rgba, object_id=None): return dd.draw_mesh(name, self.get_mesh_resource_filename(), pose, scale=self.scale, rgba=rgba, object_id=object_id, vis_frame_pos=frame_pos, vis_frame_rot=self.vis_frame_rot) - def bounding_box(self, padding=0., padding_ratio=0.): - if self._mesh is None: - self.precompute_sdf() - + def bounding_box(self, padding=0.): aabb = self._mesh.get_axis_aligned_bounding_box() world_min = aabb.get_min_bound() world_max = aabb.get_max_bound() @@ -88,7 +86,10 @@ def center(self): return self._mesh.get_center() def precompute_sdf(self): + if self._mesh is not None: + return # scale mesh the approrpiate amount + full_path = self.get_mesh_high_poly_resource_filename() full_path = os.path.expanduser(full_path) if not os.path.exists(full_path): @@ -98,6 +99,7 @@ def precompute_sdf(self): scale_transform = np.eye(4) np.fill_diagonal(scale_transform[:3, :3], self.scale) self._mesh.transform(scale_transform) + # convert from mesh object frame to simulator object frame x, y, z, w = self.vis_frame_rot self._mesh = self._mesh.rotate(o3d.geometry.get_rotation_matrix_from_quaternion((w, x, y, z)), @@ -112,8 +114,6 @@ def precompute_sdf(self): @tensor_utils.handle_batch_input(n=2) def _do_object_frame_closest_point(self, points_in_object_frame, compute_normal=False): - if self._mesh is None: - self.precompute_sdf() if torch.is_tensor(points_in_object_frame): dtype = points_in_object_frame.dtype @@ -183,7 +183,7 @@ def object_frame_closest_point(self, points_in_object_frame, compute_normal=Fals class MeshObjectFactory(ObjectFactory): - def __init__(self, mesh_name, path_prefix='', **kwargs): + def __init__(self, mesh_name='', path_prefix='', **kwargs): self.path_prefix = path_prefix # whether to strip the package:// prefix from the mesh name, for example if we are loading a mesh manually # with a path prefix diff --git a/src/pytorch_volumetric/visualization.py b/src/pytorch_volumetric/visualization.py index b3abcb8..3f1f3ad 100644 --- a/src/pytorch_volumetric/visualization.py +++ b/src/pytorch_volumetric/visualization.py @@ -68,6 +68,7 @@ def draw_sdf_slice(s: sdf.ObjectFrameSDF, query_range, resolution=0.01, interior sdf_grad_uv[::subsample_n, ::subsample_n, shown_dims[0]], sdf_grad_uv[::subsample_n, ::subsample_n, shown_dims[1]], color='g') ax.clabel(cset2, cset2.levels, inline=True, fontsize=13, fmt=fmt) + plt.colorbar(cset1) # fig = plt.gcf() # fig.canvas.draw() plt.draw()