Skip to content

Commit

Permalink
[Feature] Handle more flexible observation modes, support users to ch…
Browse files Browse the repository at this point in the history
…eck if ground truth state data is requested (#835)

* work

* tests

* update tests, update example push cube code

* w

* refactors

* update docs

* Update intro.md
  • Loading branch information
StoneT2000 authored Feb 6, 2025
1 parent 5d61a2c commit 50105bf
Show file tree
Hide file tree
Showing 25 changed files with 332 additions and 151 deletions.
4 changes: 2 additions & 2 deletions docs/source/user_guide/concepts/observation.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
## Observation mode

**The observation mode defines the observation space.**
All ManiSkill tasks take the observation mode (`obs_mode`) as one of the input arguments of `__init__`.
All ManiSkill tasks take the observation mode (`obs_mode`) as one of the input arguments of `gym.make(env_id, obs_mode=...)`.
In general, the observation is organized as a dictionary (with an observation space of `gym.spaces.Dict`).

There are three raw observations modes: `state_dict` (privileged states), `sensor_data` (raw sensor data like visual data without postprocessing) and `state+sensor_data` for both. `state` is a flat version of `state_dict`. `rgb+depth`, `rgb+depth+segmentation` (or any combination of `rgb`, `depth`, `segmentation`), and `pointcloud` apply post-processing on `sensor_data` to give convenient representations of visual data. `state+rgb` would return privileged states and visual data, you can mix and match the different modalities however you like.
There are three raw observations modes: `state_dict` (privileged states), `sensor_data` (raw sensor data like visual data without postprocessing) and `state+sensor_data` for both. `state` is a flat version of `state_dict`. `rgb+depth`, `rgb+depth+segmentation` (or any combination of `rgb`, `depth`, `segmentation`), and `pointcloud` apply post-processing on `sensor_data` to give convenient representations of visual data. `state_dict+rgb` would return privileged unflattened states and visual data, you can mix and match the different modalities however you like.

The details here show the unbatched shapes. In general returned data always has a batch dimension unless you are using CPU simulation and returned as torch tensors. Moreover, we annotate what dtype some values are.

Expand Down
8 changes: 4 additions & 4 deletions docs/source/user_guide/tutorials/custom_tasks/intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ By default, the observations returned to users through calls to `env.reset` and

To better support state based observations, you need to augment the observations given to users by implementing the `_get_obs_extra` function. This function takes the `info` object generated via the earlier defined `evaluate` function as input and returns the augmented observation data as a dictionary.

Generally you want to ensure you do not provide any ground-truth information that should not be available unless the observation mode is "state" or "state_dict", such as the pose of the cube you are pushing. There are some data like `self.agent.tcp.pose` which are always available for single-arm robots and given all the time, and also critical information like the goal position to direct the agent where to push the cube.
Generally you want to ensure you do not provide any ground-truth information that should not be available unless the observation mode the user requests is asking for it, such as the pose of the cube you are pushing. There are some data like `self.agent.tcp.pose` which are always available for single-arm robots and given all the time, and also critical information like goal information in tasks like PickCube which direct where the robot should pick the cube to. To check if a user is requesting for state data, you can check if `self.obs_mode_struct.use_state` is true. It is true if the provided `obs_mode` when creating an environment includes "state" or "state_dict" in it. More details on observation modes are available in the [separate observation page](../../concepts/observation.md).

```python
class PushCubeEnv(BaseEnv):
Expand All @@ -270,12 +270,12 @@ class PushCubeEnv(BaseEnv):
# grippers of the robot
obs = dict(
tcp_pose=self.agent.tcp.pose.raw_pose,
goal_pos=self.goal_region.pose.p,
)
if self._obs_mode in ["state", "state_dict"]:
# if the observation mode is state/state_dict, we provide ground truth information about where the cube is.
if self.obs_mode_struct.use_state:
# if the observation mode requests to use state, we provide ground truth information about where the cube is.
# for visual observation modes one should rely on the sensed visual data to determine where the cube is
obs.update(
goal_pos=self.goal_region.pose.p,
obj_pose=self.obj.pose.raw_pose,
)
return obs
Expand Down
24 changes: 15 additions & 9 deletions mani_skill/envs/sapien_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from mani_skill.agents.multi_agent import MultiAgent
from mani_skill.envs.scene import ManiSkillScene
from mani_skill.envs.utils.observations import (
parse_visual_obs_mode_to_struct,
parse_obs_mode_to_struct,
sensor_data_to_pointcloud,
)
from mani_skill.envs.utils.randomization.batched_rng import BatchedRNG
Expand Down Expand Up @@ -277,7 +277,8 @@ def __init__(
else:
raise NotImplementedError(f"Unsupported obs mode: {obs_mode}. Must be one of {self.SUPPORTED_OBS_MODES}")
self._obs_mode = obs_mode
self._visual_obs_mode_struct = parse_visual_obs_mode_to_struct(self._obs_mode)
self.obs_mode_struct = parse_obs_mode_to_struct(self._obs_mode)
"""dataclass describing what observation data is being requested by the user, detailing if state data is requested and what visual data is requested"""

# Reward mode
if reward_mode is None:
Expand Down Expand Up @@ -490,14 +491,19 @@ def get_obs(self, info: Optional[Dict] = None):
elif self._obs_mode == "state_dict":
obs = self._get_obs_state_dict(info)
elif self._obs_mode == "pointcloud":
# TODO support more flexible pcd obs mode with new render system
obs = self._get_obs_with_sensor_data(info)
obs = sensor_data_to_pointcloud(obs, self._sensors)
elif self._obs_mode == "sensor_data":
# return raw texture data dependent on choice of shader
obs = self._get_obs_with_sensor_data(info, apply_texture_transforms=False)
else:
obs = self._get_obs_with_sensor_data(info)

# flatten parts of the state observation if requested
if self.obs_mode_struct.state:
if isinstance(obs, dict):
data = dict(agent=obs.pop("agent"), extra=obs.pop("extra"))
obs["state"] = common.flatten_state_dict(data, use_torch=True, device=self.device)
return obs

def _get_obs_state_dict(self, info: Dict):
Expand Down Expand Up @@ -546,12 +552,12 @@ def _get_obs_sensor_data(self, apply_texture_transforms: bool = True) -> dict:
sensor_obs[name] = sensor.get_obs(position=False, segmentation=False, apply_texture_transforms=apply_texture_transforms)
else:
sensor_obs[name] = sensor.get_obs(
rgb=self._visual_obs_mode_struct.rgb,
depth=self._visual_obs_mode_struct.depth,
position=self._visual_obs_mode_struct.position,
segmentation=self._visual_obs_mode_struct.segmentation,
normal=self._visual_obs_mode_struct.normal,
albedo=self._visual_obs_mode_struct.albedo,
rgb=self.obs_mode_struct.visual.rgb,
depth=self.obs_mode_struct.visual.depth,
position=self.obs_mode_struct.visual.position,
segmentation=self.obs_mode_struct.visual.segmentation,
normal=self.obs_mode_struct.visual.normal,
albedo=self.obs_mode_struct.visual.albedo,
apply_texture_transforms=apply_texture_transforms
)
# explicitly synchronize and wait for cuda kernels to finish
Expand Down
2 changes: 1 addition & 1 deletion mani_skill/envs/tasks/control/ant.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def evaluate(self) -> Dict:

def _get_obs_extra(self, info: Dict):
obs = super()._get_obs_extra(info)
if self._obs_mode in ["state", "state_dict"]:
if self.obs_mode_struct.use_state:
obs.update(
cmass=info["cmass_linvel"],
link_angvels=info["link_angvels"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _initialize_agent(self, env_idx: torch.Tensor):
def _get_obs_extra(self, info: Dict):
with torch.device(self.device):
obs = dict(rotate_dir=self.rot_dir)
if self._obs_mode in ["state", "state_dict"]:
if self.obs_mode_struct.use_state:
obs.update(
obj_pose=vectorize_pose(self.obj.pose),
obj_tip_vec=info["obj_tip_vec"].view(self.num_envs, 12),
Expand Down
2 changes: 1 addition & 1 deletion mani_skill/envs/tasks/dexterity/rotate_valve.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _get_obs_extra(self, info: Dict):
valve_x=torch.cos(valve_qpos[:, 0]),
valve_y=torch.sin(valve_qpos[:, 0]),
)
if self._obs_mode in ["state", "state_dict"]:
if self.obs_mode_struct.use_state:
obs.update(
valve_pose=vectorize_pose(self.valve.pose),
)
Expand Down
6 changes: 5 additions & 1 deletion mani_skill/envs/tasks/digital_twins/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,11 @@ def get_obs(self, info: dict = None):
obs = super().get_obs(info)

# "greenscreen" process
if self._obs_mode == "rgb+segmentation" and self.rgb_overlay_paths is not None:
if (
self.obs_mode_struct.visual.rgb
and self.obs_mode_struct.visual.segmentation
and self.rgb_overlay_paths is not None
):
# get the actor ids of objects to manipulate; note that objects here are not articulated
for camera_name in self._rgb_overlay_images.keys():
# obtain overlay mask based on segmentation info
Expand Down
2 changes: 1 addition & 1 deletion mani_skill/envs/tasks/fmb/fmb.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def evaluate(self):

def _get_obs_extra(self, info: Dict):
obs = dict(tcp_pose=self.agent.tcp.pose.raw_pose)
if self.obs_mode in ["state", "state_dict"]:
if self.obs_mode_struct.use_state:
obs.update(
board_pos=self.board.pose.p,
bridge_pose=self.bridge.pose.raw_pose,
Expand Down
2 changes: 1 addition & 1 deletion mani_skill/envs/tasks/quadruped/quadruped_reach.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _get_obs_extra(self, info: Dict):
root_angular_velocity=self.agent.robot.root_angular_velocity,
reached_goal=info["success"],
)
if self.obs_mode in ["state", "state_dict"]:
if self.obs_mode_struct.use_state:
obs.update(
goal_pos=self.goal.pose.p[:, :2],
robot_to_goal=self.goal.pose.p[:, :2] - self.agent.robot.pose.p[:, :2],
Expand Down
2 changes: 1 addition & 1 deletion mani_skill/envs/tasks/rotate_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _get_obs_extra(self, info: Dict):
goal_pos=self.obj_goal.pose.p,
goal_q=self.obj_goal.pose.q,
)
if self._obs_mode in ["state", "state_dict"]:
if self.obs_mode_struct.use_state:
obs.update(
obj_p=self.obj.pose.p,
obj_q=self.obj.pose.q,
Expand Down
2 changes: 1 addition & 1 deletion mani_skill/envs/tasks/tabletop/assembling_kits.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def _get_obs_extra(self, info: Dict):
obs = dict(
tcp_pose=self.agent.tcp.pose.raw_pose,
)
if self._obs_mode in ["state", "state_dict"]:
if self.obs_mode_struct.use_state:
obs.update(
obj_pose=self.obj.pose.raw_pose,
tcp_to_obj_pos=self.obj.pose.p - self.agent.tcp.pose.p,
Expand Down
2 changes: 1 addition & 1 deletion mani_skill/envs/tasks/tabletop/lift_peg_upright.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _get_obs_extra(self, info: Dict):
obs = dict(
tcp_pose=self.agent.tcp.pose.raw_pose,
)
if self._obs_mode in ["state", "state_dict"]:
if self.obs_mode_struct.use_state:
obs.update(
obj_pose=self.peg.pose.raw_pose,
)
Expand Down
2 changes: 1 addition & 1 deletion mani_skill/envs/tasks/tabletop/peg_insertion_side.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def evaluate(self):

def _get_obs_extra(self, info: Dict):
obs = dict(tcp_pose=self.agent.tcp.pose.raw_pose)
if self._obs_mode in ["state", "state_dict"]:
if self.obs_mode_struct.use_state:
obs.update(
peg_pose=self.peg.pose.raw_pose,
peg_half_size=self.peg_half_sizes,
Expand Down
2 changes: 1 addition & 1 deletion mani_skill/envs/tasks/tabletop/plug_charger.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def evaluate(self):

def _get_obs_extra(self, info: Dict):
obs = dict(tcp_pose=self.agent.tcp.pose.raw_pose)
if self._obs_mode in ["state", "state_dict"]:
if self.obs_mode_struct.use_state:
obs.update(
charger_pose=self.charger.pose.raw_pose,
receptacle_pose=self.receptacle.pose.raw_pose,
Expand Down
2 changes: 1 addition & 1 deletion mani_skill/envs/tasks/tabletop/poke_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _get_obs_extra(self, info: Dict):
tcp_pose=self.agent.tcp.pose.raw_pose,
)

if self._obs_mode in ["state", "state_dict"]:
if self.obs_mode_struct.use_state:
obs.update(
cube_pose=self.cube.pose.raw_pose,
peg_pose=self.peg.pose.raw_pose,
Expand Down
4 changes: 2 additions & 2 deletions mani_skill/envs/tasks/tabletop/pull_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _get_obs_extra(self, info: Dict):
tcp_pose=self.agent.tcp.pose.raw_pose,
goal_pos=self.goal_region.pose.p,
)
if self._obs_mode in ["state", "state_dict"]:
if self.obs_mode_struct.use_state:
obs.update(
obj_pose=self.obj.pose.raw_pose,
)
Expand Down Expand Up @@ -148,4 +148,4 @@ def compute_dense_reward(self, obs: Any, action: Array, info: Dict):

def compute_normalized_dense_reward(self, obs: Any, action: Array, info: Dict):
max_reward = 3.0
return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward
return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward
2 changes: 1 addition & 1 deletion mani_skill/envs/tasks/tabletop/pull_cube_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _get_obs_extra(self, info: Dict):
tcp_pose=self.agent.tcp.pose.raw_pose,
)

if self._obs_mode in ["state", "state_dict"]:
if self.obs_mode_struct.use_state:
obs.update(
cube_pose=self.cube.pose.raw_pose,
tool_pose=self.l_shape_tool.pose.raw_pose,
Expand Down
4 changes: 2 additions & 2 deletions mani_skill/envs/tasks/tabletop/push_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ def _get_obs_extra(self, info: Dict):
obs = dict(
tcp_pose=self.agent.tcp.pose.raw_pose,
)
if self._obs_mode in ["state", "state_dict"]:
# if the observation mode is state/state_dict, we provide ground truth information about where the cube is.
if self.obs_mode_struct.use_state:
# if the observation mode requests to use state, we provide ground truth information about where the cube is.
# for visual observation modes one should rely on the sensed visual data to determine where the cube is
obs.update(
goal_pos=self.goal_region.pose.p,
Expand Down
2 changes: 1 addition & 1 deletion mani_skill/envs/tasks/tabletop/push_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def _get_obs_extra(self, info: Dict):
obs = dict(
tcp_pose=self.agent.tcp.pose.raw_pose,
)
if self._obs_mode in ["state", "state_dict"]:
if self.obs_mode_struct.use_state:
# state based gets info on goal position and t full pose - necessary to learn task
obs.update(
goal_pos=self.goal_tee.pose.p,
Expand Down
2 changes: 1 addition & 1 deletion mani_skill/envs/tasks/tabletop/roll_ball.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _get_obs_extra(self, info: Dict):
obs = dict(
tcp_pose=self.agent.tcp.pose.raw_pose,
)
if self._obs_mode in ["state", "state_dict"]:
if self.obs_mode_struct.use_state:
obs.update(
goal_pos=self.goal_region.pose.p,
ball_pose=self.ball.pose.raw_pose,
Expand Down
106 changes: 71 additions & 35 deletions mani_skill/envs/utils/observations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,57 +13,93 @@ class CameraObsTextures:
albedo: bool


ALL_TEXTURES = ["rgb", "depth", "segmentation", "position", "normal", "albedo"]
@dataclass
class ObservationModeStruct:
"""A dataclass describing what observation data is being requested by the user"""

state_dict: bool
"""whether to include state data which generally means including privileged information such as object poses"""
state: bool
"""whether to include flattened state data which generally means including privileged information such as object poses"""
visual: CameraObsTextures
"""textures to capture from cameras"""

@property
def use_state(self):
"""whether or not the environment should return ground truth/privileged information such as object poses"""
return self.state or self.state_dict


ALL_VISUAL_TEXTURES = ["rgb", "depth", "segmentation", "position", "normal", "albedo"]
"""set of all standard textures that can come from cameras"""


def parse_visual_obs_mode_to_struct(obs_mode: str) -> CameraObsTextures:
def parse_obs_mode_to_struct(obs_mode: str) -> ObservationModeStruct:
"""Given user supplied observation mode, return a struct with the relevant textures that are to be captured"""
# parse obs mode into a string of possible textures
if obs_mode == "rgbd":
return CameraObsTextures(
rgb=True,
depth=True,
segmentation=False,
position=False,
normal=False,
albedo=False,
return ObservationModeStruct(
state_dict=False,
state=False,
visual=CameraObsTextures(
rgb=True,
depth=True,
segmentation=False,
position=False,
normal=False,
albedo=False,
),
)
elif obs_mode == "pointcloud":
return CameraObsTextures(
rgb=True,
depth=False,
segmentation=True,
position=True,
normal=False,
albedo=False,
return ObservationModeStruct(
state_dict=False,
state=False,
visual=CameraObsTextures(
rgb=True,
depth=False,
segmentation=True,
position=True,
normal=False,
albedo=False,
),
)
elif obs_mode == "sensor_data":
return CameraObsTextures(
rgb=True,
depth=True,
segmentation=True,
position=True,
normal=False,
albedo=False,
return ObservationModeStruct(
state_dict=False,
state=False,
visual=CameraObsTextures(
rgb=True,
depth=True,
segmentation=True,
position=True,
normal=False,
albedo=False,
),
)
elif obs_mode in ["state", "state_dict", "none"]:
return None
else:
# Parse obs mode into individual texture types
textures = obs_mode.split("+")
if "pointcloud" in textures:
textures.remove("pointcloud")
textures.append("position")
textures.append("rgb")
textures.append("segmentation")
for texture in textures:
if texture == "state" or texture == "state_dict":
if texture == "state" or texture == "state_dict" or texture == "none":
# allows fetching privileged state data in addition to visual data.
continue
assert (
texture in ALL_TEXTURES
), f"Invalid texture type '{texture}' requested in the obs mode '{obs_mode}'. Each individual texture must be one of {ALL_TEXTURES}"
return CameraObsTextures(
rgb="rgb" in textures,
depth="depth" in textures,
segmentation="segmentation" in textures,
position="position" in textures,
normal="normal" in textures,
albedo="albedo" in textures,
texture in ALL_VISUAL_TEXTURES
), f"Invalid texture type '{texture}' requested in the obs mode '{obs_mode}'. Each individual texture must be one of {ALL_VISUAL_TEXTURES}"
return ObservationModeStruct(
state_dict="state_dict" in textures,
state="state" in textures,
visual=CameraObsTextures(
rgb="rgb" in textures,
depth="depth" in textures,
segmentation="segmentation" in textures,
position="position" in textures,
normal="normal" in textures,
albedo="albedo" in textures,
),
)
Loading

0 comments on commit 50105bf

Please sign in to comment.