Skip to content

Commit

Permalink
Merge pull request UM-ARM-Lab#2 from UM-ARM-Lab/add_external_mesh
Browse files Browse the repository at this point in the history
Allow construction with existing mesh
  • Loading branch information
LemonPi authored Apr 17, 2024
2 parents 32899da + 43c3f27 commit 387d907
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/pytorch_volumetric/sdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class SDFQuery(NamedTuple):


class ObjectFactory(abc.ABC):
def __init__(self, name, scale=1.0, vis_frame_pos=(0, 0, 0), vis_frame_rot=(0, 0, 0, 1),
plausible_suboptimality=0.001, **kwargs):
def __init__(self, name='', scale=1.0, vis_frame_pos=(0, 0, 0), vis_frame_rot=(0, 0, 0, 1),
plausible_suboptimality=0.001, mesh=None, **kwargs):
self.name = name
self.scale = scale if scale is not None else 1.0
# frame from model's base frame to the simulation's use of the model
Expand All @@ -39,10 +39,11 @@ def __init__(self, name, scale=1.0, vis_frame_pos=(0, 0, 0), vis_frame_rot=(0, 0
self.plausible_suboptimality = plausible_suboptimality

# use external mesh library to compute closest point for non-convex meshes
self._mesh = None
self._mesh = mesh
self._mesht = None
self._raycasting_scene = None
self._face_normals = None
self.precompute_sdf()

def __reduce__(self):
return partial(self.__class__, scale=self.scale, vis_frame_pos=self.vis_frame_pos,
Expand All @@ -67,10 +68,7 @@ def draw_mesh(self, dd, name, pose, rgba, object_id=None):
return dd.draw_mesh(name, self.get_mesh_resource_filename(), pose, scale=self.scale, rgba=rgba,
object_id=object_id, vis_frame_pos=frame_pos, vis_frame_rot=self.vis_frame_rot)

def bounding_box(self, padding=0., padding_ratio=0.):
if self._mesh is None:
self.precompute_sdf()

def bounding_box(self, padding=0.):
aabb = self._mesh.get_axis_aligned_bounding_box()
world_min = aabb.get_min_bound()
world_max = aabb.get_max_bound()
Expand All @@ -88,7 +86,10 @@ def center(self):
return self._mesh.get_center()

def precompute_sdf(self):
if self._mesh is not None:
return
# scale mesh the approrpiate amount

full_path = self.get_mesh_high_poly_resource_filename()
full_path = os.path.expanduser(full_path)
if not os.path.exists(full_path):
Expand All @@ -98,6 +99,7 @@ def precompute_sdf(self):
scale_transform = np.eye(4)
np.fill_diagonal(scale_transform[:3, :3], self.scale)
self._mesh.transform(scale_transform)

# convert from mesh object frame to simulator object frame
x, y, z, w = self.vis_frame_rot
self._mesh = self._mesh.rotate(o3d.geometry.get_rotation_matrix_from_quaternion((w, x, y, z)),
Expand All @@ -112,8 +114,6 @@ def precompute_sdf(self):

@tensor_utils.handle_batch_input(n=2)
def _do_object_frame_closest_point(self, points_in_object_frame, compute_normal=False):
if self._mesh is None:
self.precompute_sdf()

if torch.is_tensor(points_in_object_frame):
dtype = points_in_object_frame.dtype
Expand Down Expand Up @@ -183,7 +183,7 @@ def object_frame_closest_point(self, points_in_object_frame, compute_normal=Fals


class MeshObjectFactory(ObjectFactory):
def __init__(self, mesh_name, path_prefix='', **kwargs):
def __init__(self, mesh_name='', path_prefix='', **kwargs):
self.path_prefix = path_prefix
# whether to strip the package:// prefix from the mesh name, for example if we are loading a mesh manually
# with a path prefix
Expand Down
1 change: 1 addition & 0 deletions src/pytorch_volumetric/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def draw_sdf_slice(s: sdf.ObjectFrameSDF, query_range, resolution=0.01, interior
sdf_grad_uv[::subsample_n, ::subsample_n, shown_dims[0]],
sdf_grad_uv[::subsample_n, ::subsample_n, shown_dims[1]], color='g')
ax.clabel(cset2, cset2.levels, inline=True, fontsize=13, fmt=fmt)
plt.colorbar(cset1)
# fig = plt.gcf()
# fig.canvas.draw()
plt.draw()
Expand Down

0 comments on commit 387d907

Please sign in to comment.