Skip to content

Commit

Permalink
Refactor sensor test
Browse files Browse the repository at this point in the history
  • Loading branch information
hang-yin committed Oct 4, 2024
1 parent 89a45af commit 22b482c
Showing 1 changed file with 31 additions and 49 deletions.
80 changes: 31 additions & 49 deletions tests/test_sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@ def test_segmentation_modalities(env):
robot = env.scene.robots[0]
place_obj_on_floor_plane(breakfast_table)
dishtowel.set_position_orientation(position=[-0.4, 0.0, 0.55], orientation=[0, 0, 0, 1])
robot.set_position_orientation(
position=[0.0, 0.8, 0.0], orientation=T.euler2quat(th.tensor([0, 0, -math.pi / 2], dtype=th.float32))
)
robot.reset()

og.sim.viewer_camera.set_position_orientation(position=[-0.0017, -0.1072, 1.4969], orientation=[0.0, 0.0, 0.0, 1.0])

modalities_required = ["seg_semantic", "seg_instance", "seg_instance_id"]
for modality in modalities_required:
robot.add_obs_modality(modality)
og.sim.viewer_camera.add_modality(modality)

systems = [env.scene.get_system(system_name) for system_name in SYSTEM_EXAMPLES.keys()]
for i, system in enumerate(systems):
Expand All @@ -37,17 +35,14 @@ def test_segmentation_modalities(env):
system.generate_group_particles(
group=system.get_group_name(breakfast_table),
positions=[pos, pos + th.tensor([0.1, 0.0, 0.0])],
link_prim_paths=[breakfast_table.root_link.prim_path],
link_prim_paths=[breakfast_table.root_link.prim_path] * 2,
)

og.sim.step()
for _ in range(3):
og.sim.render()

sensors = [s for s in robot.sensors.values() if isinstance(s, VisionSensor)]
assert len(sensors) > 0
vision_sensor = sensors[0]
all_observation, all_info = vision_sensor.get_obs()
all_observation, all_info = og.sim.viewer_camera.get_obs()

seg_semantic = all_observation["seg_semantic"]
seg_semantic_info = all_info["seg_semantic"]
Expand All @@ -57,7 +52,6 @@ def test_segmentation_modalities(env):
825831922: "floors",
884110082: "stain",
1949122937: "breakfast_table",
2814990211: "agent",
3051938632: "white_rice",
3330677804: "water",
4207839377: "dishtowel",
Expand All @@ -68,38 +62,31 @@ def test_segmentation_modalities(env):
seg_instance_info = all_info["seg_instance"]
assert set(int(x.item()) for x in th.unique(seg_instance)) == set(seg_instance_info.keys())
expected_dict = {
1: "unlabelled",
2: env.robots[0].name,
3: "groundPlane",
4: "dishtowel",
5: "breakfast_table",
6: "stain",
# 7: "water",
# 8: "white_rice",
9: "diced__apple",
2: "groundPlane",
3: "water",
4: "diced__apple",
5: "stain",
6: "white_rice",
7: "breakfast_table",
8: "dishtowel",
}
assert set(seg_instance_info.values()) == set(expected_dict.values())

seg_instance_id = all_observation["seg_instance_id"]
seg_instance_id_info = all_info["seg_instance_id"]
assert set(int(x.item()) for x in th.unique(seg_instance_id)) == set(seg_instance_id_info.keys())
expected_dict = {
3: f"/World/{env.robots[0].name}/gripper_link/visuals",
4: f"/World/{env.robots[0].name}/wrist_roll_link/visuals",
5: f"/World/{env.robots[0].name}/forearm_roll_link/visuals",
6: f"/World/{env.robots[0].name}/wrist_flex_link/visuals",
8: "/World/groundPlane/geom",
9: "/World/dishtowel/base_link_cloth",
10: f"/World/{env.robots[0].name}/r_gripper_finger_link/visuals",
11: f"/World/{env.robots[0].name}/l_gripper_finger_link/visuals",
12: "/World/breakfast_table/base_link/visuals",
13: "stain",
14: "white_rice",
15: "diced__apple",
16: "water",
1: "/World/ground_plane/geom",
2: "/World/scene_0/breakfast_table/base_link/visuals",
3: "/World/scene_0/dishtowel/base_link_cloth",
4: "/World/scene_0/water/waterInstancer0/prototype0",
5: "/World/scene_0/white_rice/white_riceInstancer0/prototype0",
6: "/World/scene_0/diced__apple/particles/diced__appleParticle1",
7: "/World/scene_0/breakfast_table/base_link/stainParticle1",
8: "/World/scene_0/breakfast_table/base_link/stainParticle0",
9: "/World/scene_0/diced__apple/particles/diced__appleParticle0",
}
# Temporarily disable this test because og_assets are outdated on CI machines
# assert set(seg_instance_id_info.values()) == set(expected_dict.values())
assert set(seg_instance_id_info.values()) == set(expected_dict.values())

for system in systems:
env.scene.clear_system(system.name)
Expand All @@ -112,34 +99,29 @@ def test_bbox_modalities(env):
robot = env.scene.robots[0]
place_obj_on_floor_plane(breakfast_table)
dishtowel.set_position_orientation(position=[-0.4, 0.0, 0.55], orientation=[0, 0, 0, 1])
robot.set_position_orientation(
position=[0, 0.8, 0.0], orientation=T.euler2quat(th.tensor([0, 0, -math.pi / 2], dtype=th.float32))
)
robot.reset()

og.sim.viewer_camera.set_position_orientation(position=[-0.0017, -0.1072, 1.4969], orientation=[0.0, 0.0, 0.0, 1.0])

modalities_required = ["bbox_2d_tight", "bbox_2d_loose", "bbox_3d"]
for modality in modalities_required:
robot.add_obs_modality(modality)
og.sim.viewer_camera.add_modality(modality)

og.sim.step()
for _ in range(3):
og.sim.render()

sensors = [s for s in robot.sensors.values() if isinstance(s, VisionSensor)]
assert len(sensors) > 0
vision_sensor = sensors[0]
all_observation, all_info = vision_sensor.get_obs()
all_observation, all_info = og.sim.viewer_camera.get_obs()

bbox_2d_tight = all_observation["bbox_2d_tight"]
bbox_2d_loose = all_observation["bbox_2d_loose"]
bbox_3d = all_observation["bbox_3d"]

assert len(bbox_2d_tight) == 4
assert len(bbox_2d_loose) == 4
assert len(bbox_3d) == 3
assert len(bbox_2d_tight) == 3
assert len(bbox_2d_loose) == 3
assert len(bbox_3d) == 2

bbox_2d_expected_objs = set(["floors", "agent", "breakfast_table", "dishtowel"])
bbox_3d_expected_objs = set(["agent", "breakfast_table", "dishtowel"])
bbox_2d_expected_objs = set(["floors", "breakfast_table", "dishtowel"])
bbox_3d_expected_objs = set(["breakfast_table", "dishtowel"])

bbox_2d_objs = set([semantic_class_id_to_name()[bbox[0]] for bbox in bbox_2d_tight])
bbox_3d_objs = set([semantic_class_id_to_name()[bbox[0]] for bbox in bbox_3d])
Expand Down

0 comments on commit 22b482c

Please sign in to comment.