Skip to content

Commit

Permalink
preprocess in init
Browse files Browse the repository at this point in the history
  • Loading branch information
fan-ziqi committed Dec 12, 2024
1 parent 7f98385 commit a1ed867
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class BinaryJointAction(ActionTerm):
"""The configuration of the action term."""
_asset: Articulation
"""The articulation asset on which the action term is applied."""
_clip: dict[str, tuple] | None = None
_clip: torch.Tensor
"""The clip applied to the input action."""

def __init__(self, cfg: actions_cfg.BinaryJointActionCfg, env: ManagerBasedEnv) -> None:
Expand Down Expand Up @@ -85,9 +85,11 @@ def __init__(self, cfg: actions_cfg.BinaryJointActionCfg, env: ManagerBasedEnv)
self._close_command[index_list] = torch.tensor(value_list, device=self.device)

# parse clip
if cfg.clip is not None:
if self.cfg.clip is not None:
if isinstance(cfg.clip, dict):
self._clip = cfg.clip
self._clip = torch.tensor([[-float('inf'), float('inf')]], device=self.device).expand(self.num_envs, self.action_dim, 2).clone()
index_list, _, value_list = string_utils.resolve_matching_names_values(self.cfg.clip, self._joint_names)
self._clip[:, index_list] = torch.tensor(value_list, device=self.device)
else:
raise ValueError(f"Unsupported clip type: {type(cfg.clip)}. Supported types are dict.")

Expand Down Expand Up @@ -123,13 +125,8 @@ def process_actions(self, actions: torch.Tensor):
binary_mask = actions < 0
# compute the command
self._processed_actions = torch.where(binary_mask, self._close_command, self._open_command)
# clip actions
if self._clip is not None:
# resolve the dictionary config
index_list, _, value_list = string_utils.resolve_matching_names_values(self._clip, self._joint_names)
for index in range(len(index_list)):
min_value, max_value = value_list[index]
self._processed_actions[:, index_list[index]].clip_(min_value, max_value)
if self.cfg.clip is not None:
self._processed_actions = torch.clamp(self._processed_actions, min=self._clip[:, :, 0], max=self._clip[:, :, 1])

def reset(self, env_ids: Sequence[int] | None = None) -> None:
self._raw_actions[env_ids] = 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class JointAction(ActionTerm):
"""The scaling factor applied to the input action."""
_offset: torch.Tensor | float
"""The offset applied to the input action."""
_clip: dict[str, tuple] | None = None
_clip: torch.Tensor
"""The clip applied to the input action."""

def __init__(self, cfg: actions_cfg.JointActionCfg, env: ManagerBasedEnv) -> None:
Expand Down Expand Up @@ -97,9 +97,11 @@ def __init__(self, cfg: actions_cfg.JointActionCfg, env: ManagerBasedEnv) -> Non
else:
raise ValueError(f"Unsupported offset type: {type(cfg.offset)}. Supported types are float and dict.")
# parse clip
if cfg.clip is not None:
if self.cfg.clip is not None:
if isinstance(cfg.clip, dict):
self._clip = cfg.clip
self._clip = torch.tensor([[-float('inf'), float('inf')]], device=self.device).expand(self.num_envs, self.action_dim, 2).clone()
index_list, _, value_list = string_utils.resolve_matching_names_values(self.cfg.clip, self._joint_names)
self._clip[:, index_list] = torch.tensor(value_list, device=self.device)
else:
raise ValueError(f"Unsupported clip type: {type(cfg.clip)}. Supported types are dict.")

Expand Down Expand Up @@ -129,12 +131,8 @@ def process_actions(self, actions: torch.Tensor):
# apply the affine transformations
self._processed_actions = self._raw_actions * self._scale + self._offset
# clip actions
if self._clip is not None:
# resolve the dictionary config
index_list, _, value_list = string_utils.resolve_matching_names_values(self._clip, self._joint_names)
for index in range(len(index_list)):
min_value, max_value = value_list[index]
self._processed_actions[:, index_list[index]].clip_(min_value, max_value)
if self.cfg.clip is not None:
self._processed_actions = torch.clamp(self._processed_actions, min=self._clip[:, :, 0], max=self._clip[:, :, 1])

def reset(self, env_ids: Sequence[int] | None = None) -> None:
self._raw_actions[env_ids] = 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class JointPositionToLimitsAction(ActionTerm):
"""The articulation asset on which the action term is applied."""
_scale: torch.Tensor | float
"""The scaling factor applied to the input action."""
_clip: dict[str, tuple] | None = None
_clip: torch.Tensor
"""The clip applied to the input action."""

def __init__(self, cfg: actions_cfg.JointPositionToLimitsActionCfg, env: ManagerBasedEnv):
Expand Down Expand Up @@ -79,9 +79,11 @@ def __init__(self, cfg: actions_cfg.JointPositionToLimitsActionCfg, env: Manager
else:
raise ValueError(f"Unsupported scale type: {type(cfg.scale)}. Supported types are float and dict.")
# parse clip
if cfg.clip is not None:
if self.cfg.clip is not None:
if isinstance(cfg.clip, dict):
self._clip = cfg.clip
self._clip = torch.tensor([[-float('inf'), float('inf')]], device=self.device).expand(self.num_envs, self.action_dim, 2).clone()
index_list, _, value_list = string_utils.resolve_matching_names_values(self.cfg.clip, self._joint_names)
self._clip[:, index_list] = torch.tensor(value_list, device=self.device)
else:
raise ValueError(f"Unsupported clip type: {type(cfg.clip)}. Supported types are dict.")

Expand Down Expand Up @@ -110,13 +112,8 @@ def process_actions(self, actions: torch.Tensor):
self._raw_actions[:] = actions
# apply affine transformations
self._processed_actions = self._raw_actions * self._scale
# clip actions
if self._clip is not None:
# resolve the dictionary config
index_list, _, value_list = string_utils.resolve_matching_names_values(self._clip, self._joint_names)
for index in range(len(index_list)):
min_value, max_value = value_list[index]
self._processed_actions[:, index_list[index]].clip_(min_value, max_value)
if self.cfg.clip is not None:
self._processed_actions = torch.clamp(self._processed_actions, min=self._clip[:, :, 0], max=self._clip[:, :, 1])
# rescale the position targets if configured
# this is useful when the input actions are in the range [-1, 1]
if self.cfg.rescale_to_limits:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class NonHolonomicAction(ActionTerm):
"""The scaling factor applied to the input action. Shape is (1, 2)."""
_offset: torch.Tensor
"""The offset applied to the input action. Shape is (1, 2)."""
_clip: dict[str, tuple] | None = None
_clip: torch.Tensor
"""The clip applied to the input action."""

def __init__(self, cfg: actions_cfg.NonHolonomicActionCfg, env: ManagerBasedEnv):
Expand Down Expand Up @@ -108,9 +108,11 @@ def __init__(self, cfg: actions_cfg.NonHolonomicActionCfg, env: ManagerBasedEnv)
self._scale = torch.tensor(self.cfg.scale, device=self.device).unsqueeze(0)
self._offset = torch.tensor(self.cfg.offset, device=self.device).unsqueeze(0)
# parse clip
if cfg.clip is not None:
if self.cfg.clip is not None:
if isinstance(cfg.clip, dict):
self._clip = cfg.clip
self._clip = torch.tensor([[-float('inf'), float('inf')]], device=self.device).expand(self.num_envs, self.action_dim, 2).clone()
index_list, _, value_list = string_utils.resolve_matching_names_values(self.cfg.clip, self._joint_names)
self._clip[:, index_list] = torch.tensor(value_list, device=self.device)
else:
raise ValueError(f"Unsupported clip type: {type(cfg.clip)}. Supported types are dict.")

Expand Down Expand Up @@ -139,12 +141,8 @@ def process_actions(self, actions):
self._raw_actions[:] = actions
self._processed_actions = self.raw_actions * self._scale + self._offset
# clip actions
if self._clip is not None:
# resolve the dictionary config
index_list, _, value_list = string_utils.resolve_matching_names_values(self._clip, self._joint_names)
for index in range(len(index_list)):
min_value, max_value = value_list[index]
self._processed_actions[:, index_list[index]].clip_(min_value, max_value)
if self.cfg.clip is not None:
self._processed_actions = torch.clamp(self._processed_actions, min=self._clip[:, :, 0], max=self._clip[:, :, 1])

def apply_actions(self):
# obtain current heading
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class DifferentialInverseKinematicsAction(ActionTerm):
"""The articulation asset on which the action term is applied."""
_scale: torch.Tensor
"""The scaling factor applied to the input action. Shape is (1, action_dim)."""
_clip: dict[str, tuple] | None = None
_clip: torch.Tensor
"""The clip applied to the input action."""

def __init__(self, cfg: actions_cfg.DifferentialInverseKinematicsActionCfg, env: ManagerBasedEnv):
Expand Down Expand Up @@ -105,9 +105,11 @@ def __init__(self, cfg: actions_cfg.DifferentialInverseKinematicsActionCfg, env:
self._offset_pos, self._offset_rot = None, None

# parse clip
if cfg.clip is not None:
if self.cfg.clip is not None:
if isinstance(cfg.clip, dict):
self._clip = cfg.clip
self._clip = torch.tensor([[-float('inf'), float('inf')]], device=self.device).expand(self.num_envs, self.action_dim, 2).clone()
index_list, _, value_list = string_utils.resolve_matching_names_values(self.cfg.clip, self._joint_names)
self._clip[:, index_list] = torch.tensor(value_list, device=self.device)
else:
raise ValueError(f"Unsupported clip type: {type(cfg.clip)}. Supported types are dict.")

Expand Down Expand Up @@ -135,13 +137,8 @@ def process_actions(self, actions: torch.Tensor):
# store the raw actions
self._raw_actions[:] = actions
self._processed_actions[:] = self.raw_actions * self._scale
# clip actions
if self._clip is not None:
# resolve the dictionary config
index_list, _, value_list = string_utils.resolve_matching_names_values(self._clip, self._joint_names)
for index in range(len(index_list)):
min_value, max_value = value_list[index]
self._processed_actions[:, index_list[index]].clip_(min_value, max_value)
if self.cfg.clip is not None:
self._processed_actions = torch.clamp(self._processed_actions, min=self._clip[:, :, 0], max=self._clip[:, :, 1])
# obtain quantities from simulation
ee_pos_curr, ee_quat_curr = self._compute_frame_pose()
# set command into controller
Expand Down

0 comments on commit a1ed867

Please sign in to comment.