diff --git a/embodichain/lab/sim/atomic_actions/__init__.py b/embodichain/lab/sim/atomic_actions/__init__.py index f3f92da7..a154cf89 100644 --- a/embodichain/lab/sim/atomic_actions/__init__.py +++ b/embodichain/lab/sim/atomic_actions/__init__.py @@ -23,17 +23,18 @@ from .core import ( Affordance, - GraspPose, InteractionPoints, ObjectSemantics, ActionCfg, AtomicAction, ) from .actions import ( - ReachAction, - GraspAction, - ReleaseAction, MoveAction, + PickUpAction, + PlaceAction, + MoveActionCfg, + PickUpActionCfg, + PlaceActionCfg, ) from .engine import ( AtomicActionEngine, @@ -51,10 +52,12 @@ "ActionCfg", "AtomicAction", # Action implementations - "ReachAction", - "GraspAction", - "ReleaseAction", "MoveAction", + "PickUpAction", + "PlaceAction", + "MoveActionCfg", + "PickUpActionCfg", + "PlaceActionCfg", # Engine "AtomicActionEngine", "register_action", diff --git a/embodichain/lab/sim/atomic_actions/actions.py b/embodichain/lab/sim/atomic_actions/actions.py index d15a7eff..b21c31e9 100644 --- a/embodichain/lab/sim/atomic_actions/actions.py +++ b/embodichain/lab/sim/atomic_actions/actions.py @@ -22,523 +22,622 @@ from embodichain.lab.sim.planners import PlanResult, PlanState, MoveType from embodichain.lab.sim.planners.motion_generator import MotionGenOptions from embodichain.lab.sim.planners.toppra_planner import ToppraPlanOptions -from .core import AtomicAction, ObjectSemantics +from .core import AtomicAction, ObjectSemantics, AntipodalAffordance, ActionCfg +from embodichain.utils import logger +from embodichain.utils import configclass +from embodichain.lab.sim.utility.action_utils import interpolate_with_distance +import numpy as np if TYPE_CHECKING: from embodichain.lab.sim.planners import MotionGenerator from embodichain.lab.sim.objects import Robot -# ============================================================================= -# Reach Action -# ============================================================================= +@configclass +class MoveActionCfg(ActionCfg): + sample_interval: int = 50 + """Number of waypoints to sample for the motion trajectory. Should be large enough to ensure smooth motion, but not too large to cause unnecessary computation overhead.""" -class ReachAction(AtomicAction): - """Atomic action for reaching a target pose or object.""" - +class MoveAction(AtomicAction): def __init__( self, - motion_generator: "MotionGenerator", - robot: "Robot", - control_part: str, - device: torch.device = torch.device("cuda"), - interpolation_type: str = "linear", # "linear", "cubic", "toppra" + motion_generator: MotionGenerator, + cfg: MoveActionCfg | None = None, ): - super().__init__(motion_generator, robot, control_part, device) - self.interpolation_type = interpolation_type - - def execute( - self, - target: Union[torch.Tensor, ObjectSemantics], - start_qpos: Optional[torch.Tensor] = None, - approach_offset: Optional[torch.Tensor] = None, - use_affordance: bool = True, - **kwargs, - ) -> PlanResult: - """Execute reach action. - + """ + Initialize the atomic action. Args: - target: Target pose [4, 4] or ObjectSemantics - start_qpos: Starting joint configuration - approach_offset: Offset for pre-grasp approach [x, y, z] - use_affordance: Whether to use object's affordance data - - Returns: - PlanResult with trajectory and execution status + motion_generator: The motion generator instance to use for planning. + cfg: Configuration for the action. """ - # Resolve target pose from ObjectSemantics if needed - if isinstance(target, ObjectSemantics): - target_pose = self._resolve_target_pose(target, use_affordance) - else: - target_pose = target - - # Apply approach offset if specified - if approach_offset is not None: - approach_pose = self._apply_offset(target_pose, approach_offset) - else: - approach_pose = target_pose - - # Get current state if not provided - if start_qpos is None: - start_qpos = self._get_current_qpos() - - # Create plan states - target_states = [ - PlanState(qpos=start_qpos, move_type=MoveType.JOINT_MOVE), - PlanState(xpos=approach_pose, move_type=MoveType.EEF_MOVE), - ] - - # Plan trajectory - options = MotionGenOptions( - control_part=self.control_part, - is_interpolate=True, - is_linear=self.interpolation_type == "linear", - interpolate_position_step=0.002, - plan_opts=ToppraPlanOptions( - sample_interval=kwargs.get("sample_interval", 30), - ), + super().__init__( + motion_generator, cfg=cfg if cfg is not None else MoveActionCfg() ) - result = self.plan_trajectory(target_states, options) - - # Return PlanResult directly - return result + self.n_envs = self.robot.get_qpos().shape[0] + self.arm_joint_ids = self.robot.get_joint_ids(name=self.cfg.control_part) + self.dof = len(self.arm_joint_ids) - def validate( + def _resolve_pose_target( self, - target: Union[torch.Tensor, ObjectSemantics], - start_qpos: Optional[torch.Tensor] = None, - **kwargs, - ) -> bool: - """Check if the reach action is feasible.""" - try: - # Quick IK feasibility check - if isinstance(target, ObjectSemantics): - target_pose = self._resolve_target_pose(target, use_affordance=True) - else: - target_pose = target - - # Attempt IK - qpos_seed = ( - start_qpos if start_qpos is not None else self._get_current_qpos() + target: Union[ObjectSemantics, torch.Tensor], + *, + action_name: str, + ) -> tuple[bool, torch.Tensor]: + """Resolve a pose target into a batched homogeneous transform tensor.""" + if isinstance(target, ObjectSemantics): + logger.log_error( + f"{action_name} currently does not support ObjectSemantics target. " + f"Please provide target pose as torch.Tensor of shape (4, 4) or " + f"(n_envs, 4, 4)", + NotImplementedError, ) - success, _ = self.robot.compute_ik( - pose=target_pose.unsqueeze(0), - qpos_seed=qpos_seed.unsqueeze(0), - name=self.control_part, + if not isinstance(target, torch.Tensor): + logger.log_error( + "Target must be either ObjectSemantics or torch.Tensor of shape " + f"(4, 4) or ({self.n_envs}, 4, 4)", + TypeError, ) - return success.all().item() - except Exception: - return False - - def _resolve_target_pose( - self, semantics: ObjectSemantics, use_affordance: bool - ) -> torch.Tensor: - """Resolve target pose from object semantics.""" - from .core import GraspPose - - if use_affordance and isinstance(semantics.affordance, GraspPose): - # Use precomputed grasp pose from affordance data - grasp_pose = semantics.affordance.get_best_grasp() - object_pose = self._get_object_pose(semantics.label) - target_pose = object_pose @ grasp_pose - else: - # Default to object center with approach direction - object_pose = self._get_object_pose(semantics.label) - approach_offset = torch.tensor([0, 0, 0.05], device=self.device) - target_pose = object_pose.clone() - target_pose[:3, 3] += approach_offset - - return target_pose - def _get_object_pose(self, label: str) -> torch.Tensor: - """Get current pose of object by label.""" - # Implementation depends on environment's object management - # This is a placeholder - should be implemented based on environment - raise NotImplementedError( - "_get_object_pose must be implemented by subclass or " - "provided with environment-specific object management" - ) - - -# ============================================================================= -# Grasp Action -# ============================================================================= - - -class GraspAction(AtomicAction): - """Atomic action for grasping objects.""" + if target.shape == (4, 4): + target = target.unsqueeze(0).repeat(self.n_envs, 1, 1) + if target.shape != (self.n_envs, 4, 4): + logger.log_error( + f"Target tensor must have shape (4, 4) or ({self.n_envs}, 4, 4), but got {target.shape}", + ValueError, + ) + return True, target - def __init__( + def _resolve_start_qpos( self, - motion_generator: "MotionGenerator", - robot: "Robot", - control_part: str, - device: torch.device = torch.device("cuda"), - pre_grasp_distance: float = 0.05, - approach_direction: str = "z", # "x", "y", "z", or "custom" - ): - super().__init__(motion_generator, robot, control_part, device) - self.pre_grasp_distance = pre_grasp_distance - self.approach_direction = approach_direction - - def execute( - self, - target: ObjectSemantics, - start_qpos: Optional[torch.Tensor] = None, - use_affordance: bool = True, - grasp_type: str = "default", # "default", "pinch", "power" - **kwargs, - ) -> PlanResult: - """Execute grasp action. - - Args: - target: ObjectSemantics with grasp affordances - start_qpos: Starting joint configuration - use_affordance: Whether to use precomputed grasp poses - grasp_type: Type of grasp to execute - """ - # Resolve grasp pose - grasp_pose = self._resolve_grasp_pose(target, use_affordance, grasp_type) - - # Compute pre-grasp pose (approach position) - pre_grasp_pose = self._compute_pre_grasp_pose(grasp_pose) - - # Get current state + start_qpos: Optional[torch.Tensor], + arm_dof: Optional[int] = None, + ) -> torch.Tensor: + """Resolve planning start joint positions into batched arm joint positions.""" + arm_dof = self.dof if arm_dof is None else arm_dof if start_qpos is None: - start_qpos = self._get_current_qpos() + start_qpos = self.robot.get_qpos(name=self.cfg.control_part) + if start_qpos.shape == (arm_dof,): + start_qpos = start_qpos.unsqueeze(0).repeat(self.n_envs, 1) + if start_qpos.shape != (self.n_envs, arm_dof): + logger.log_error( + f"start_qpos must have shape ({self.n_envs}, {arm_dof}), but got {start_qpos.shape}", + ValueError, + ) + return start_qpos - # Build trajectory plan states - target_states = [ - PlanState(qpos=start_qpos, move_type=MoveType.JOINT_MOVE), - PlanState(xpos=pre_grasp_pose, move_type=MoveType.EEF_MOVE), - PlanState(xpos=grasp_pose, move_type=MoveType.EEF_MOVE), - ] + def _compute_three_phase_waypoints( + self, + hand_interp_steps: int, + *, + first_phase_name: str, + third_phase_name: str, + first_phase_ratio: float = 0.6, + ) -> tuple[int, int, int]: + """Split total sample interval into motion, hand interpolation, and motion phases.""" + first_phase_waypoint = int( + np.round(self.cfg.sample_interval - hand_interp_steps) * first_phase_ratio + ) + if first_phase_waypoint < 2: + logger.log_error( + f"Not enough waypoints for {first_phase_name} trajectory. " + "Please increase sample_interval or decrease hand_interp_steps.", + ValueError, + ) + second_phase_waypoint = hand_interp_steps + third_phase_waypoint = ( + self.cfg.sample_interval - first_phase_waypoint - second_phase_waypoint + ) + if third_phase_waypoint < 2: + logger.log_error( + f"Not enough waypoints for {third_phase_name} trajectory. " + "Please increase sample_interval or decrease hand_interp_steps.", + ValueError, + ) + return first_phase_waypoint, second_phase_waypoint, third_phase_waypoint - # Plan trajectory - options = MotionGenOptions( - control_part=self.control_part, + def _build_motion_gen_options( + self, + start_qpos: torch.Tensor, + sample_interval: int, + ) -> MotionGenOptions: + """Build default motion generation options for an atomic action.""" + return MotionGenOptions( + start_qpos=start_qpos[0], + control_part=self.cfg.control_part, is_interpolate=True, is_linear=False, interpolate_position_step=0.001, plan_opts=ToppraPlanOptions( - sample_interval=kwargs.get("sample_interval", 30), + sample_interval=sample_interval, ), ) - result = self.plan_trajectory(target_states, options) + def _plan_arm_trajectory( + self, + target_states_list: list[list[PlanState]], + start_qpos: torch.Tensor, + n_waypoints: int, + arm_dof: Optional[int] = None, + ) -> tuple[bool, torch.Tensor]: + """Plan batched arm trajectories for all environments.""" + arm_dof = self.dof if arm_dof is None else arm_dof + + # TODO: + + n_state = len(target_states_list[0]) + xpos_traj = torch.zeros( + size=(self.n_envs, n_state, 4, 4), + dtype=torch.float32, device=self.device + ) + for i, target_states in enumerate(target_states_list): + for j, target_state in enumerate(target_states): + # [env_i, state_j, 4, 4] + xpos_traj[i, j] = target_state.xpos + + trajectory = torch.zeros( + size=(self.n_envs, n_state, arm_dof), + dtype=torch.float32, + device=self.device, + ) + qpos_seed = start_qpos + for j in range(n_state): + is_success, qpos = self.robot.compute_ik( + pose=xpos_traj[:, j], + name=self.cfg.control_part, + joint_seed=qpos_seed + ) + if not is_success: + logger.log_warning( + f"Failed to compute IK for target state {j} in some environments. " + "The resulting trajectory may be invalid." + ) + return False, trajectory + else: + trajectory[:, j] = qpos + qpos_seed = qpos + trajectory = torch.concatenate([start_qpos.unsqueeze(1), trajectory], dim=1) + interp_traj = interpolate_with_distance( + trajectory=trajectory, + interp_num=n_waypoints, + device=self.device + ) + return True, interp_traj - # Return PlanResult directly - it contains all trajectory data - return result + def _interpolate_hand_qpos( + self, + start_hand_qpos: torch.Tensor, + end_hand_qpos: torch.Tensor, + n_waypoints: int, + ) -> torch.Tensor: + """Interpolate hand joint positions between two gripper states.""" + weights = torch.linspace(0, 1, steps=n_waypoints, device=self.device) + hand_qpos_list = [ + torch.lerp(start_hand_qpos, end_hand_qpos, weight) for weight in weights + ] + return torch.stack(hand_qpos_list, dim=0) - def validate( + def execute( self, - target: ObjectSemantics, + target: Union[ObjectSemantics, torch.Tensor], start_qpos: Optional[torch.Tensor] = None, **kwargs, - ) -> bool: - """Validate if grasp is feasible.""" - try: - grasp_pose = self._resolve_grasp_pose( - target, use_affordance=True, grasp_type="default" - ) - qpos_seed = ( - start_qpos if start_qpos is not None else self._get_current_qpos() - ) - success, _ = self.robot.compute_ik( - pose=grasp_pose.unsqueeze(0), - qpos_seed=qpos_seed.unsqueeze(0), - name=self.control_part, - ) - return success.all().item() - except Exception: - return False + ) -> tuple[bool, torch.Tensor, list[float]]: + """execute pick up action - def _resolve_grasp_pose( - self, semantics: ObjectSemantics, use_affordance: bool, grasp_type: str - ) -> torch.Tensor: - """Resolve grasp pose from object semantics.""" - from .core import GraspPose + Args: + target (ObjectSemantics): object semantics containing grasp affordance and entity information + start_qpos (Optional[torch.Tensor], optional): Planning start qpos. Defaults to None. - if use_affordance and isinstance(semantics.affordance, GraspPose): - grasp_pose_affordance = semantics.affordance - grasp_pose = grasp_pose_affordance.get_grasp_by_type(grasp_type) - if grasp_pose is None: - grasp_pose = grasp_pose_affordance.get_best_grasp() + Returns: + tuple[bool, torch.Tensor, list[float]]: + is_success, + trajectory of shape (n_envs, n_waypoints, dof), + joint_ids corresponding to trajectory + """ + is_success, move_xpos = self._resolve_pose_target( + target, action_name=self.__class__.__name__ + ) + start_qpos = self._resolve_start_qpos(start_qpos) - # Transform to world frame - object_pose = self._get_object_pose(semantics.label) - return object_pose @ grasp_pose + # TODO: warning and fallback if no valid grasp pose found + if not is_success: + logger.log_warning( + "Failed to resolve grasp pose, using default approach pose" + ) + return False, torch.empty(0), self.arm_joint_ids - # Fallback: compute grasp pose from geometry - return self._compute_grasp_from_geometry(semantics) + target_states_list = [ + [ + PlanState(xpos=move_xpos[i], move_type=MoveType.EEF_MOVE), + ] + for i in range(self.n_envs) + ] + is_plan_success, trajectory = self._plan_arm_trajectory( + target_states_list, start_qpos, self.cfg.sample_interval + ) + return is_plan_success, trajectory, self.arm_joint_ids - def _compute_pre_grasp_pose(self, grasp_pose: torch.Tensor) -> torch.Tensor: - """Compute pre-grasp pose with offset.""" - offset = torch.zeros(3, device=self.device) - if self.approach_direction == "z": - offset[2] = -self.pre_grasp_distance - elif self.approach_direction == "x": - offset[0] = -self.pre_grasp_distance - elif self.approach_direction == "y": - offset[1] = -self.pre_grasp_distance + def validate(self, target, start_qpos=None, **kwargs): + # TODO: implement proper validation logic for pick up action + return True - pre_grasp = grasp_pose.clone() - pre_grasp[:3, 3] += grasp_pose[:3, :3] @ offset - return pre_grasp - def _get_object_pose(self, label: str) -> torch.Tensor: - """Get current pose of object by label.""" - raise NotImplementedError( - "_get_object_pose must be implemented by subclass or " - "provided with environment-specific object management" - ) +@configclass +class PickUpActionCfg(MoveActionCfg): + hand_open_qpos: torch.Tensor | None = None + """[hand_dof,] of float. Joint positions for open hand state. Must be specified for PickUpAction.""" - def _compute_grasp_from_geometry(self, semantics: ObjectSemantics) -> torch.Tensor: - """Compute grasp pose from object geometry.""" - # Get object pose - object_pose = self._get_object_pose(semantics.label) + hand_close_qpos: torch.Tensor | None = None + """[hand_dof,] of float. Joint positions for closed hand state. Must be specified for PickUpAction.""" - # Get bounding box from geometry - bbox = semantics.geometry.get("bounding_box", [0.1, 0.1, 0.1]) + hand_control_part: str = "hand" + """Name of the robot part that controls the hand joints. Must correspond to a valid control part in the robot definition.""" - # Default top-down grasp - grasp_offset = torch.eye(4, device=self.device) - grasp_offset[2, 3] = bbox[2] / 2 + 0.02 # Slightly above object + pre_grasp_distance: float = 0.15 + """Distance to offset back from the grasp pose along the approach direction to get the pre-grasp pose. Should be large enough to avoid collision during approach, but not too large to cause unnecessary detour.""" - return object_pose @ grasp_offset + approach_direction: torch.Tensor = torch.tensor([0, 0, -1], dtype=torch.float32) + """Direction from which the gripper approaches the object for grasping, expressed in the object local frame. Should be a unit vector. Default is [0, 0, -1], which means approaching from above along the negative z-axis.""" + lift_height: float = 0.1 + """Height to lift the object after grasping, expressed in meters. Should be large enough to avoid collision with the environment, but not too large to cause unnecessary motion.""" -# ============================================================================= -# Move Action -# ============================================================================= + sample_interval: int = 80 + """Number of waypoints to sample for the entire pick up motion trajectory, including approach, hand closing, and lifting. Should be large enough to ensure smooth motion, but not too large to cause unnecessary computation overhead.""" + hand_interp_steps: int = 5 + """Number of waypoints to interpolate for the hand closing motion. Should be at least 2 to ensure smooth interpolation between open and closed hand states, but not too large to cause unnecessary computation overhead.""" -class MoveAction(AtomicAction): - """Atomic action for moving to a target position.""" +class PickUpAction(MoveAction): def __init__( self, - motion_generator: "MotionGenerator", - robot: "Robot", - control_part: str, - device: torch.device = torch.device("cuda"), - move_type: str = "cartesian", # "cartesian", "joint" - interpolation: str = "linear", # "linear", "cubic", "toppra" + motion_generator: MotionGenerator, + cfg: PickUpActionCfg | None = None, ): - super().__init__(motion_generator, robot, control_part, device) - self.move_type = move_type - self.interpolation = interpolation + """ + Initialize the atomic action. + Args: + motion_generator: The motion generator instance to use for planning. + cfg: Configuration for the action. + """ + super().__init__( + motion_generator, cfg=cfg if cfg is not None else PickUpActionCfg() + ) + self.cfg = cfg + self.approach_direction = self.cfg.approach_direction.to(self.device) + if self.cfg.hand_open_qpos is None: + logger.log_error("hand_open_qpos must be specified in PickUpActionCfg") + if self.cfg.hand_close_qpos is None: + logger.log_error("hand_close_qpos must be specified in PickUpActionCfg") + self.hand_open_qpos = self.cfg.hand_open_qpos.to(self.device) + self.hand_close_qpos = self.cfg.hand_close_qpos.to(self.device) + + self.hand_joint_ids = self.robot.get_joint_ids(name=self.cfg.hand_control_part) + self.joint_ids = self.arm_joint_ids + self.hand_joint_ids + self.arm_dof = len(self.arm_joint_ids) + self.dof = len(self.joint_ids) def execute( self, - target: Union[torch.Tensor, ObjectSemantics], + target: Union[ObjectSemantics, torch.Tensor], start_qpos: Optional[torch.Tensor] = None, - offset: Optional[torch.Tensor] = None, - velocity_limit: Optional[float] = None, - acceleration_limit: Optional[float] = None, - ) -> PlanResult: - """Execute move action. + **kwargs, + ) -> tuple[bool, torch.Tensor, list[float]]: + """execute pick up action Args: - target: Target pose [4, 4] or ObjectSemantics - start_qpos: Starting joint configuration - offset: Optional offset from target - velocity_limit: Max velocity for trajectory - acceleration_limit: Max acceleration for trajectory + target (Union[ObjectSemantics, torch.Tensor]): target object semantics or target pose for grasping + start_qpos (Optional[torch.Tensor], optional): Planning start qpos. Defaults to None. + + Returns: + tuple[bool, torch.Tensor, list[float]]: + is_success, + trajectory of shape (n_envs, n_waypoints, dof), + joint_ids corresponding to trajectory """ - # Resolve target + + # Resolve grasp pose if isinstance(target, ObjectSemantics): - target_pose = self._get_object_pose(target.label) + is_success, grasp_xpos, open_length = self._resolve_grasp_pose(target) else: - target_pose = target - - # Apply offset if specified - if offset is not None: - target_pose = self._apply_offset(target_pose, offset) + is_success, grasp_xpos = self._resolve_pose_target( + target, action_name=self.__class__.__name__ + ) - # Get start state - if start_qpos is None: - start_qpos = self._get_current_qpos() + # TODO: warning and fallback if no valid grasp pose found + if not is_success: + logger.log_warning( + "Failed to resolve grasp pose, using default approach pose" + ) + return False, torch.empty(0), self.joint_ids + + # Compute pre-grasp pose + # TODO: only for parallel gripper, approach in negative grasp z direction + grasp_z = grasp_xpos[:, :3, 2] + pre_grasp_xpos = self._apply_offset( + pose=grasp_xpos, + offset=-grasp_z * self.cfg.pre_grasp_distance, + ) + # Compute lift pose + start_qpos = self._resolve_start_qpos(start_qpos, self.arm_dof) + + # compute waypoint number for each phase + n_approach_waypoint, n_close_waypoint, n_lift_waypoint = ( + self._compute_three_phase_waypoints( + self.cfg.hand_interp_steps, + first_phase_name="approach", + third_phase_name="lift", + ) + ) - # Create plan states based on move type - if self.move_type == "cartesian": - target_states = [ - PlanState(qpos=start_qpos, move_type=MoveType.JOINT_MOVE), - PlanState(xpos=target_pose, move_type=MoveType.EEF_MOVE), + # get pick trajectory + target_states_list = [ + [ + PlanState(xpos=pre_grasp_xpos[i], move_type=MoveType.EEF_MOVE), + PlanState(xpos=grasp_xpos[i], move_type=MoveType.EEF_MOVE), ] - is_linear = self.interpolation == "linear" - else: # joint space - target_qpos = self._ik_solve(target_pose, start_qpos) - target_states = [ - PlanState(qpos=start_qpos, move_type=MoveType.JOINT_MOVE), - PlanState(qpos=target_qpos, move_type=MoveType.JOINT_MOVE), + for i in range(self.n_envs) + ] + pick_trajectory = torch.zeros( + size=(self.n_envs, n_approach_waypoint, self.dof), + dtype=torch.float32, + device=self.device, + ) + is_success, plan_traj = self._plan_arm_trajectory( + target_states_list, + start_qpos, + n_approach_waypoint, + self.arm_dof, + ) + if not is_success: + logger.log_warning("Failed to plan approach trajectory.") + return False, pick_trajectory, self.joint_ids + pick_trajectory[:, :, : self.arm_dof] = plan_traj + # Padding hand open qpos to pick trajectory + pick_trajectory[:, :, self.arm_dof :] = self.hand_open_qpos + + # get hand closing trajectory + grasp_qpos = pick_trajectory[ + :, -1, : self.arm_dof + ] # Assuming the last point of pick trajectory is the grasp pose + hand_close_path = self._interpolate_hand_qpos( + self.hand_open_qpos, + self.hand_close_qpos, + n_close_waypoint, + ) + hand_close_trajectory = torch.zeros( + size=(self.n_envs, n_close_waypoint, self.dof), + device=self.device, + ) + hand_close_trajectory[:, :, : self.arm_dof] = grasp_qpos + hand_close_trajectory[:, :, self.arm_dof :] = hand_close_path + + # get lift trajectory + lift_trajectory = torch.zeros( + size=(self.n_envs, n_lift_waypoint, self.dof), + dtype=torch.float32, + device=self.device, + ) + # lift_xpos = self._compute_lift_xpos(grasp_xpos) + lift_xpos = self._apply_offset( + pose=grasp_xpos, + offset=torch.tensor([0, 0, 1], device=self.device) * self.cfg.lift_height, + ) + target_states_list = [ + [ + PlanState(xpos=lift_xpos[i], move_type=MoveType.EEF_MOVE), ] - is_linear = False - - # Configure motion generation - options = MotionGenOptions( - control_part=self.control_part, - is_interpolate=True, - is_linear=is_linear, - interpolate_position_step=0.002, - plan_opts=ToppraPlanOptions( - sample_interval=kwargs.get("sample_interval", 30), - velocity_limit=velocity_limit, - acceleration_limit=acceleration_limit, - ), + for i in range(self.n_envs) + ] + is_success, plan_traj = self._plan_arm_trajectory( + target_states_list, + grasp_qpos, + n_lift_waypoint, + self.arm_dof, + ) + if not is_success: + logger.log_warning("Failed to plan lift trajectory.") + return False, lift_trajectory, self.joint_ids + lift_trajectory[:, :, : self.arm_dof] = plan_traj + # padding hand close qpos to lift trajectory + lift_trajectory[:, :, self.arm_dof :] = self.hand_close_qpos + + # concatenate trajectories + trajectory = torch.cat( + [pick_trajectory, hand_close_trajectory, lift_trajectory], dim=1 ) + return True, trajectory, self.joint_ids - result = self.plan_trajectory(target_states, options) + def _resolve_grasp_pose( + self, semantics: ObjectSemantics + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if not isinstance(semantics.affordance, AntipodalAffordance): + logger.log_error( + "Grasp pose affordance must be of type AntipodalAffordance" + ) + if semantics.entity is None: + logger.log_error( + "ObjectSemantics must be associated with an entity to get object pose" + ) + obj_poses = semantics.entity.get_local_pose(to_matrix=True) - # Return PlanResult directly - it contains all trajectory data - return result + is_success, grasp_xpos, open_length = semantics.affordance.get_best_grasp_poses( + obj_poses=obj_poses, approach_direction=self.approach_direction + ) + return is_success, grasp_xpos, open_length - def validate( - self, - target: Union[torch.Tensor, ObjectSemantics], - start_qpos: Optional[torch.Tensor] = None, - **kwargs, - ) -> bool: - """Validate if move action is feasible.""" - try: - if isinstance(target, ObjectSemantics): - target_pose = self._get_object_pose(target.label) - else: - target_pose = target + def validate(self, target, start_qpos=None, **kwargs): + # TODO: implement proper validation logic for pick up action + return True - qpos_seed = ( - start_qpos if start_qpos is not None else self._get_current_qpos() - ) - if self.move_type == "joint": - # For joint space moves, we need IK solvability - self._ik_solve(target_pose, qpos_seed) - else: - # For cartesian moves, just check IK - success, _ = self.robot.compute_ik( - pose=target_pose.unsqueeze(0), - qpos_seed=qpos_seed.unsqueeze(0), - name=self.control_part, - ) - if not success.all(): - return False +@configclass +class PlaceActionCfg(MoveActionCfg): + hand_open_qpos: torch.Tensor | None = None + """[hand_dof,] of float. Joint positions for open hand state. Must be specified for PickUpAction.""" - return True - except Exception: - return False + hand_close_qpos: torch.Tensor | None = None + """[hand_dof,] of float. Joint positions for closed hand state. Must be specified for PickUpAction.""" - def _get_object_pose(self, label: str) -> torch.Tensor: - """Get current pose of object by label.""" - raise NotImplementedError( - "_get_object_pose must be implemented by subclass or " - "provided with environment-specific object management" - ) + hand_control_part: str = "hand" + """Name of the robot part that controls the hand joints. Must correspond to a valid control part in the robot definition.""" + lift_height: float = 0.1 + """Height to lift the object after grasping, expressed in meters. Should be large enough to avoid collision with the environment, but not too large to cause unnecessary motion.""" -# ============================================================================= -# Release Action -# ============================================================================= + sample_interval: int = 80 + """Number of waypoints to sample for the entire pick up motion trajectory, including approach, hand closing, and lifting. Should be large enough to ensure smooth motion, but not too large to cause unnecessary computation overhead.""" + hand_interp_steps: int = 5 + """Number of waypoints to interpolate for the hand closing motion. Should be at least 2 to ensure smooth interpolation between open and closed hand states, but not too large to cause unnecessary computation overhead.""" -class ReleaseAction(AtomicAction): - """Atomic action for releasing an object.""" + +class PlaceAction(MoveAction): + def __init__( + self, + motion_generator: MotionGenerator, + cfg: PlaceActionCfg | None = None, + ): + """ + Initialize the atomic action. + Args: + motion_generator: The motion generator instance to use for planning. + cfg: Configuration for the action. + """ + super().__init__( + motion_generator, cfg=cfg if cfg is not None else PlaceActionCfg() + ) + self.cfg = cfg + if self.cfg.hand_open_qpos is None: + logger.log_error("hand_open_qpos must be specified in PlaceActionCfg") + if self.cfg.hand_close_qpos is None: + logger.log_error("hand_close_qpos must be specified in PlaceActionCfg") + self.hand_open_qpos = self.cfg.hand_open_qpos.to(self.device) + self.hand_close_qpos = self.cfg.hand_close_qpos.to(self.device) + + self.hand_joint_ids = self.robot.get_joint_ids(name=self.cfg.hand_control_part) + self.joint_ids = self.arm_joint_ids + self.hand_joint_ids + self.arm_dof = len(self.arm_joint_ids) + self.dof = len(self.joint_ids) def execute( self, - target: Optional[Union[torch.Tensor, ObjectSemantics]] = None, + target: Union[ObjectSemantics, torch.Tensor], start_qpos: Optional[torch.Tensor] = None, **kwargs, - ) -> PlanResult: - """Execute release action. + ) -> tuple[bool, torch.Tensor, list[float]]: + """execute pick up action Args: - target: Optional target pose after release (for place operations) - start_qpos: Starting joint configuration + target (ObjectSemantics): object semantics containing grasp affordance and entity information + start_qpos (Optional[torch.Tensor], optional): Planning start qpos. Defaults to None. Returns: - PlanResult with trajectory (may be empty for simple release) + tuple[bool, torch.Tensor, list[float]]: + is_success, + trajectory of shape (n_envs, n_waypoints, dof), + joint_ids corresponding to trajectory """ - # Get current state - if start_qpos is None: - start_qpos = self._get_current_qpos() - - # If target is specified, move to that pose first - if target is not None: - if isinstance(target, ObjectSemantics): - # Move above the object - target_pose = self._get_object_pose(target.label) - approach_offset = torch.tensor([0, 0, 0.1], device=self.device) - target_pose = self._apply_offset(target_pose, approach_offset) - else: - target_pose = target - - target_states = [ - PlanState(qpos=start_qpos, move_type=MoveType.JOINT_MOVE), - PlanState(xpos=target_pose, move_type=MoveType.EEF_MOVE), - ] + is_success, place_xpos = self._resolve_pose_target( + target, action_name=self.__class__.__name__ + ) + start_qpos = self._resolve_start_qpos(start_qpos, self.arm_dof) - options = MotionGenOptions( - control_part=self.control_part, - is_interpolate=True, - is_linear=False, - interpolate_position_step=0.002, - plan_opts=ToppraPlanOptions( - sample_interval=kwargs.get("sample_interval", 30), - ), + # TODO: warning and fallback if no valid grasp pose found + if not is_success: + logger.log_warning( + "Failed to resolve grasp pose, using default approach pose" ) - - result = self.plan_trajectory(target_states, options) - else: - # Simple release - return success with current state - result = PlanResult( - success=True, - positions=start_qpos.unsqueeze(0), + return False, torch.empty(0), self.joint_ids + + # compute waypoint number for each phase + n_down_waypoint, n_open_waypoint, n_lift_waypoint = ( + self._compute_three_phase_waypoints( + self.cfg.hand_interp_steps, + first_phase_name="approach", + third_phase_name="lift", ) + ) - # Open gripper (if applicable) - # This would be robot-specific and should be implemented by subclasses - self._open_gripper() - - return result + down_trajectory = torch.zeros( + size=(self.n_envs, n_down_waypoint, self.dof), + dtype=torch.float32, + device=self.device, + ) + lift_xpos = self._apply_offset( + pose=place_xpos, + offset=torch.tensor([0, 0, 1], device=self.device) * self.cfg.lift_height, + ) + target_states_list = [ + [ + PlanState(xpos=lift_xpos[i], move_type=MoveType.EEF_MOVE), + PlanState(xpos=place_xpos[i], move_type=MoveType.EEF_MOVE), + ] + for i in range(self.n_envs) + ] + is_success, plan_traj = self._plan_arm_trajectory( + target_states_list, + start_qpos, + n_down_waypoint, + self.arm_dof, + ) + if not is_success: + logger.log_warning("Failed to plan down trajectory.") + return False, down_trajectory, self.joint_ids + down_trajectory[:, :, : self.arm_dof] = plan_traj + # Padding hand open qpos to pick trajectory + down_trajectory[:, :, self.arm_dof :] = self.hand_close_qpos + + # get hand closing trajectory + reach_qpos = down_trajectory[ + :, -1, : self.arm_dof + ] # Assuming the last point of pick trajectory is the grasp pose + hand_open_path = self._interpolate_hand_qpos( + self.hand_close_qpos, + self.hand_open_qpos, + n_open_waypoint, + ) + hand_open_trajectory = torch.zeros( + size=(self.n_envs, n_open_waypoint, self.dof), + device=self.device, + ) + hand_open_trajectory[:, :, : self.arm_dof] = reach_qpos + hand_open_trajectory[:, :, self.arm_dof :] = hand_open_path + + # get lift trajectory + back_trajectory = torch.zeros( + size=(self.n_envs, n_lift_waypoint, self.dof), + dtype=torch.float32, + device=self.device, + ) + target_states_list = [ + [ + PlanState(xpos=lift_xpos[i], move_type=MoveType.EEF_MOVE), + ] + for i in range(self.n_envs) + ] + is_success, plan_traj = self._plan_arm_trajectory( + target_states_list, + reach_qpos, + n_lift_waypoint, + self.arm_dof, + ) + if not is_success: + logger.log_warning("Failed to plan back trajectory.") + return False, back_trajectory, self.joint_ids + back_trajectory[:, :, : self.arm_dof] = plan_traj + # padding hand open qpos to back trajectory + back_trajectory[:, :, self.arm_dof :] = self.hand_open_qpos + + # concatenate trajectories + trajectory = torch.cat( + [down_trajectory, hand_open_trajectory, back_trajectory], dim=1 + ) + return True, trajectory, self.joint_ids - def validate( - self, - target: Optional[Union[torch.Tensor, ObjectSemantics]] = None, - start_qpos: Optional[torch.Tensor] = None, - **kwargs, - ) -> bool: - """Validate if release action is feasible.""" - # Release is generally always feasible - # If target is specified, validate that we can reach it - if target is not None and isinstance(target, torch.Tensor): - try: - qpos_seed = ( - start_qpos if start_qpos is not None else self._get_current_qpos() - ) - success, _ = self.robot.compute_ik( - pose=target.unsqueeze(0), - qpos_seed=qpos_seed.unsqueeze(0), - name=self.control_part, - ) - return success.all().item() - except Exception: - return False + def validate(self, target, start_qpos=None, **kwargs): + # TODO: implement proper validation logic for pick up action return True - - def _open_gripper(self) -> None: - """Open the gripper to release the object. - - This is a placeholder method that should be implemented by subclasses - based on the specific robot hardware or simulation environment. - """ - # Placeholder - should be implemented by subclass - pass - - def _get_object_pose(self, label: str) -> torch.Tensor: - """Get current pose of object by label.""" - raise NotImplementedError( - "_get_object_pose must be implemented by subclass or " - "provided with environment-specific object management" - ) diff --git a/embodichain/lab/sim/atomic_actions/core.py b/embodichain/lab/sim/atomic_actions/core.py index 42337fa3..4ac51466 100644 --- a/embodichain/lab/sim/atomic_actions/core.py +++ b/embodichain/lab/sim/atomic_actions/core.py @@ -24,8 +24,18 @@ from embodichain.lab.sim.planners import PlanResult, PlanState, MoveType from embodichain.utils import configclass +from embodichain.toolkits.graspkit.pg_grasp import ( + GraspGenerator, + GraspGeneratorCfg, +) +from embodichain.toolkits.graspkit.pg_grasp.gripper_collision_checker import ( + GripperCollisionCfg, +) +from embodichain.lab.sim.common import BatchEntity +from embodichain.utils import logger + if TYPE_CHECKING: - from embodichain.lab.sim.planners import MotionGenerator + from embodichain.lab.sim.planners import MotionGenerator, MotionGenOptions from embodichain.lab.sim.objects import Robot @@ -45,69 +55,128 @@ class Affordance: object_label: str = "" """Label of the object this affordance belongs to.""" - def get_batch_size(self) -> int: - """Return the batch size of this affordance data.""" - return 1 - + geometry: Dict[str, Any] = field(default_factory=dict) + """Geometry dictionary shared with ObjectSemantics. -@dataclass -class GraspPose(Affordance): - """Grasp pose affordance containing a batch of 4x4 transformation matrices. - - Each grasp pose represents a valid end-effector pose for grasping the object. - Multiple poses may be available for different grasp types (pinch, power, etc.) - or approach directions. + The mesh payload is expected to be stored in: + - ``mesh_vertices``: torch.Tensor with shape [N, 3] + - ``mesh_triangles``: torch.Tensor with shape [M, 3] """ - poses: torch.Tensor = field(default_factory=lambda: torch.eye(4).unsqueeze(0)) - """Batch of grasp poses with shape [B, 4, 4]. + custom_config: Dict[str, Any] = field(default_factory=dict) + """User-defined configuration payload for affordance creation and usage.""" - Each pose is a 4x4 homogeneous transformation matrix representing - the end-effector pose in the object's local coordinate frame. - """ + @property + def mesh_vertices(self) -> torch.Tensor | None: + """Get mesh vertices from geometry. - grasp_types: List[str] = field(default_factory=lambda: ["default"]) - """List of grasp type labels for each pose in the batch. + Returns: + Mesh vertices tensor [N, 3], or None if unavailable. - Examples: "pinch", "power", "hook", "spherical", etc. - Length must match the batch dimension of `poses`. - """ + Raises: + TypeError: If ``mesh_vertices`` exists but is not a torch tensor. + """ + vertices = self.geometry.get("mesh_vertices") + if vertices is None: + return None + if not isinstance(vertices, torch.Tensor): + raise TypeError("geometry['mesh_vertices'] must be a torch.Tensor") + return vertices - confidence_scores: torch.Tensor | None = None - """Optional confidence scores for each grasp pose with shape [B]. + @property + def mesh_triangles(self) -> torch.Tensor | None: + """Get mesh triangles from geometry. - Higher values indicate more reliable/ stable grasps. - Used for grasp selection when multiple options exist. - """ + Returns: + Mesh triangle index tensor [M, 3], or None if unavailable. - def get_batch_size(self) -> int: - """Return the number of grasp poses in this affordance.""" - return self.poses.shape[0] + Raises: + TypeError: If ``mesh_triangles`` exists but is not a torch tensor. + """ + triangles = self.geometry.get("mesh_triangles") + if triangles is None: + return None + if not isinstance(triangles, torch.Tensor): + raise TypeError("geometry['mesh_triangles'] must be a torch.Tensor") + return triangles - def get_grasp_by_type(self, grasp_type: str) -> Optional[torch.Tensor]: - """Get grasp pose by type label. + def set_custom_config(self, key: str, value: Any) -> None: + """Set a custom affordance configuration value.""" + self.custom_config[key] = value - Args: - grasp_type: Type of grasp (e.g., "pinch", "power") + def get_custom_config(self, key: str, default: Any = None) -> Any: + """Get a custom affordance configuration value.""" + return self.custom_config.get(key, default) - Returns: - 4x4 pose tensor if found, None otherwise - """ - if grasp_type in self.grasp_types: - idx = self.grasp_types.index(grasp_type) - return self.poses[idx] - return None + def get_batch_size(self) -> int: + """Return the batch size of this affordance data.""" + return 1 - def get_best_grasp(self) -> torch.Tensor: - """Get the best grasp pose based on confidence scores. - Returns: - 4x4 pose tensor with highest confidence - """ - if self.confidence_scores is not None: - best_idx = torch.argmax(self.confidence_scores) - return self.poses[best_idx] - return self.poses[0] # Default to first if no scores available +@dataclass +class AntipodalAffordance(Affordance): + generator: GraspGenerator | None = None + """Grasp generator instance, initialized lazily when needed.""" + + force_reannotate: bool = False + """Whether to force re-annotation of grasp generator on each access.""" + + def _init_generator(self): + if ( + self.geometry.get("mesh_vertices", None) is None + or self.geometry.get("mesh_triangles", None) is None + ): + logger.log_error( + "Mesh vertices and triangles must be provided in geometry to initialize AntipodalAffordance." + ) + self.generator = GraspGenerator( + vertices=self.geometry.get("mesh_vertices"), + triangles=self.geometry.get("mesh_triangles"), + cfg=self.custom_config.get("generator_cfg", None), + gripper_collision_cfg=self.custom_config.get("gripper_collision_cfg", None), + ) + if self.force_reannotate: + self.generator.annotate() + else: + if self.generator._hit_point_pairs is None: + self.generator.annotate() + + def get_best_grasp_poses( + self, + obj_poses: torch.Tensor, + approach_direction: torch.Tensor = torch.tensor( + [0, 0, -1], dtype=torch.float32 + ), + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if self.generator is None: + self._init_generator() + + grasp_xpos_list = [] + is_success_list = [] + open_length_list = [] + for i, obj_pose in enumerate(obj_poses): + is_success, grasp_xpos, open_length = self.generator.get_grasp_poses( + obj_pose, approach_direction + ) + if is_success: + grasp_xpos_list.append(grasp_xpos.unsqueeze(0)) + else: + logger.log_warning(f"No valid grasp pose found for {i}-th object.") + grasp_xpos_list.append( + torch.eye( + 4, dtype=torch.float32, device=self.generator.device + ).unsqueeze(0) + ) # Default to identity pose if no grasp found + is_success_list.append(is_success) + open_length_list.append(open_length) + is_success = torch.tensor( + is_success_list, dtype=torch.bool, device=self.generator.device + ) + grasp_xpos = torch.concatenate(grasp_xpos_list, dim=0) # [B, 4, 4] + open_length = torch.tensor( + open_length_list, dtype=torch.float32, device=self.generator.device + ) + return is_success, grasp_xpos, open_length @dataclass @@ -199,8 +268,17 @@ class ObjectSemantics: label: str = "none" """Object category label (e.g., 'apple', 'bottle').""" - uid: Optional[str] = None - """Optional unique identifier for the object instance.""" + entity: BatchEntity | None = None + """Optional reference to the underlying simulation entity representing this object.""" + + def __post_init__(self) -> None: + """Bind affordance metadata to this semantic object. + + The affordance shares the same geometry dict instance as + ``ObjectSemantics.geometry`` so mesh tensors are authored in one place. + """ + self.affordance.object_label = self.label + self.affordance.geometry = self.geometry # ============================================================================= @@ -212,7 +290,7 @@ class ObjectSemantics: class ActionCfg: """Configuration for atomic actions.""" - control_part: str = "left_arm" + control_part: str = "arm" """Control part name for the action.""" interpolation_type: str = "linear" @@ -236,9 +314,18 @@ class AtomicAction(ABC): def __init__( self, motion_generator: MotionGenerator, + cfg: ActionCfg = ActionCfg(), ): + """ + Initialize the atomic action. + Args: + motion_generator: The motion generator instance to use for planning. + cfg: Configuration for the action. + """ self.motion_generator = motion_generator + self.cfg = cfg self.robot = motion_generator.robot + self.control_part = cfg.control_part self.device = self.robot.device @abstractmethod @@ -247,21 +334,19 @@ def execute( target: Union[torch.Tensor, ObjectSemantics], start_qpos: Optional[torch.Tensor] = None, **kwargs, - ) -> PlanResult: - """Execute the atomic action. + ) -> tuple[bool, torch.Tensor, list[float]]: + """execute pick up action Args: - target: Target pose [4, 4] or ObjectSemantics - start_qpos: Starting joint configuration [DOF] - **kwargs: Additional action-specific parameters + target (ObjectSemantics): object semantics containing grasp affordance and entity information + start_qpos (Optional[torch.Tensor], optional): Planning start qpos. Defaults to None. Returns: - PlanResult with trajectory (positions, velocities, accelerations), - end-effector poses (xpos_list), and success status. - Use result.positions for joint trajectory [T, DOF]. - Use result.xpos_list for EE poses [T, 4, 4]. + tuple[bool, torch.Tensor, list[float]]: + is_success, + trajectory of shape (n_envs, n_waypoints, dof), + joint_ids corresponding to trajectory """ - pass @abstractmethod def validate( @@ -333,14 +418,20 @@ def _apply_offset(self, pose: torch.Tensor, offset: torch.Tensor) -> torch.Tenso """Apply offset to pose in local frame. Args: - pose: Base pose [4, 4] - offset: Offset in local frame [3] + pose: Base pose [N, 4, 4] + offset: Offset in local frame [N, 3] or [3] Returns: - Pose with offset applied [4, 4] + Pose with offset applied [N, 4, 4] """ + if not len(pose.shape) == 3 or pose.shape[1:] != (4, 4): + logger.log_error("pose must have shape [N, 4, 4]") + if len(offset.shape) == 1: + offset = offset.unsqueeze(0) + if not len(offset.shape) == 2 or offset.shape[1] != 3: + logger.log_error("offset must have shape [N, 3] or [3]") result = pose.clone() - result[:3, 3] += pose[:3, :3] @ offset + result[:, :3, 3] += offset return result def plan_trajectory( diff --git a/embodichain/lab/sim/atomic_actions/engine.py b/embodichain/lab/sim/atomic_actions/engine.py index b152adfa..e7f6d2b2 100644 --- a/embodichain/lab/sim/atomic_actions/engine.py +++ b/embodichain/lab/sim/atomic_actions/engine.py @@ -17,9 +17,10 @@ from __future__ import annotations import torch -from typing import Dict, List, Optional, Type, Union, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Type, Union, TYPE_CHECKING from embodichain.lab.sim.planners import PlanResult +from embodichain.utils import logger from .core import AtomicAction, ObjectSemantics, ActionCfg if TYPE_CHECKING: @@ -94,7 +95,13 @@ class SemanticAnalyzer: def __init__(self): self._object_cache: Dict[str, ObjectSemantics] = {} - def analyze(self, label: str) -> ObjectSemantics: + def analyze( + self, + label: str, + geometry: Optional[Dict[str, Any]] = None, + custom_config: Optional[Dict[str, Any]] = None, + use_cache: bool = True, + ) -> ObjectSemantics: """Analyze object by label and return ObjectSemantics. This is a placeholder implementation that should be extended @@ -102,43 +109,51 @@ def analyze(self, label: str) -> ObjectSemantics: Args: label: Object category label (e.g., "apple", "bottle") + geometry: Optional geometry payload. Can include mesh tensors: + ``mesh_vertices`` [N, 3] and ``mesh_triangles`` [M, 3]. + custom_config: Optional user-defined affordance configuration. + use_cache: Whether to use cached semantics when available. Returns: ObjectSemantics containing affordance data """ - # Check cache first - if label in self._object_cache: + # Only use cache for default analyze path + if ( + use_cache + and geometry is None + and custom_config is None + and label in self._object_cache + ): return self._object_cache[label] # Create default semantics (placeholder implementation) - from .core import GraspPose, InteractionPoints + from .core import AntipodalAffordance # Generate default grasp poses based on object type default_poses = torch.eye(4).unsqueeze(0) default_poses[0, 2, 3] = 0.1 # Default offset - grasp_affordance = GraspPose( + default_geometry: Dict[str, Any] = {"bounding_box": [0.1, 0.1, 0.1]} + if geometry is not None: + default_geometry.update(geometry) + + grasp_affordance = AntipodalAffordance( object_label=label, poses=default_poses, grasp_types=["default"], - ) - - # Default interaction points - interaction_affordance = InteractionPoints( - object_label=label, - points=torch.zeros(1, 3), - point_types=["contact"], + custom_config=custom_config or {}, ) semantics = ObjectSemantics( label=label, affordance=grasp_affordance, - geometry={"bounding_box": [0.1, 0.1, 0.1]}, + geometry=default_geometry, properties={"mass": 1.0, "friction": 0.5}, ) - # Cache and return - self._object_cache[label] = semantics + # Cache only default path + if use_cache and geometry is None and custom_config is None: + self._object_cache[label] = semantics return semantics def clear_cache(self) -> None: @@ -159,6 +174,7 @@ def __init__( robot: "Robot", motion_generator: "MotionGenerator", device: torch.device = torch.device("cuda"), + actions_cfg_dict: Optional[Dict[str, ActionCfg]] = dict(), ): self.robot = robot self.motion_generator = motion_generator @@ -171,71 +187,51 @@ def __init__( self._semantic_analyzer = SemanticAnalyzer() # Initialize default actions - self._init_default_actions() + self._init_default_actions(actions_cfg_dict) - def _init_default_actions(self): + def _init_default_actions( + self, + actions_cfg_dict: Optional[Dict[str, ActionCfg]] = dict(), + ): """Initialize default atomic action instances.""" - from .actions import ReachAction, GraspAction, MoveAction, ReleaseAction + from .actions import MoveAction, PickUpAction, PlaceAction control_parts = getattr(self.robot, "control_parts", None) or ["default"] + default_action_dict = { + "move": MoveAction, + "pick_up": PickUpAction, + "place": PlaceAction, + } + # set default actions for each control part for part in control_parts: - self.register_action( - f"reach_{part}", - ReachAction( - motion_generator=self.motion_generator, - robot=self.robot, - control_part=part, - device=self.device, - ), - ) - self.register_action( - f"grasp_{part}", - GraspAction( - motion_generator=self.motion_generator, - robot=self.robot, - control_part=part, - device=self.device, - ), - ) - self.register_action( - f"move_{part}", - MoveAction( - motion_generator=self.motion_generator, - robot=self.robot, - control_part=part, - device=self.device, - ), - ) - self.register_action( - f"release_{part}", - ReleaseAction( - motion_generator=self.motion_generator, - robot=self.robot, - control_part=part, - device=self.device, - ), - ) - - # Register action classes for dynamic instantiation + for action_name, action_class in default_action_dict.items(): + action_key = f"{action_name}_{part}" + if action_key not in self._actions: + if action_name in actions_cfg_dict: + cfg = actions_cfg_dict[action_name] + else: + cfg = None + instance = action_class( + motion_generator=self.motion_generator, cfg=cfg + ) + self._actions[action_key] = instance + + # Register user defined action classes for dynamic instantiation for action_name, action_class in _global_action_registry.items(): # Don't override default actions - if action_name not in ["reach", "grasp", "move", "release"]: + if action_name not in list(default_action_dict.keys()): for part in control_parts: action_key = f"{action_name}_{part}" if action_key not in self._actions: - # Create instance with default config - try: - instance = action_class( - motion_generator=self.motion_generator, - robot=self.robot, - control_part=part, - device=self.device, - ) - self._actions[action_key] = instance - except Exception: - # Skip if instantiation fails - pass + if action_name in actions_cfg_dict: + cfg = actions_cfg_dict[action_name] + else: + cfg = None + instance = action_class( + motion_generator=self.motion_generator, cfg=cfg + ) + self._actions[action_key] = instance def register_action(self, name: str, action: AtomicAction): """Register a custom atomic action.""" @@ -262,36 +258,47 @@ def get_action_names(self) -> List[str]: def execute( self, action_name: str, - target: Union[torch.Tensor, str, ObjectSemantics], + target: Union[torch.Tensor, str, ObjectSemantics, Dict[str, Any]], control_part: Optional[str] = None, **kwargs, - ) -> PlanResult: + ) -> tuple[bool, torch.Tensor, list[float]]: """Execute an atomic action. Args: action_name: Name of registered action - target: Target pose, object label, or ObjectSemantics + target: One of: + - Target pose tensor [4, 4] + - Object label string + - ObjectSemantics instance + - Dict convenience spec. Supported keys include: + - ``label`` (required unless using ``pose``/``semantics``) + - ``geometry`` (e.g., mesh tensors) + - ``custom_config`` (affordance custom config) + - ``use_cache`` (bool, default ``True``) + - ``properties`` (merged into semantics properties) + - ``uid`` (assigned to semantics uid) + - ``pose`` (direct tensor passthrough) + - ``semantics`` (direct ObjectSemantics passthrough) control_part: Robot control part to use **kwargs: Additional action parameters Returns: - PlanResult with trajectory (positions, velocities, accelerations), - end-effector poses (xpos_list), and success status. + tuple[bool, torch.Tensor, list[float]]: + is_success, + trajectory of shape (n_envs, n_waypoints, dof), + joint_ids corresponding to trajectory """ # Resolve action if control_part: action_key = f"{action_name}_{control_part}" else: action_key = action_name - if action_key not in self._actions: - raise ValueError(f"Unknown action: {action_key}") + logger.log_error(f"Unknown action: {action_key}") action = self._actions[action_key] - # Resolve target to ObjectSemantics if string label provided - if isinstance(target, str): - target = self._semantic_analyzer.analyze(target) + target = self._resolve_target(target) # Execute action - returns PlanResult directly return action.execute(target, **kwargs) @@ -299,7 +306,7 @@ def execute( def validate( self, action_name: str, - target: Union[torch.Tensor, str, ObjectSemantics], + target: Union[torch.Tensor, str, ObjectSemantics, Dict[str, Any]], control_part: Optional[str] = None, **kwargs, ) -> bool: @@ -314,11 +321,75 @@ def validate( action = self._actions[action_key] - if isinstance(target, str): - target = self._semantic_analyzer.analyze(target) + target = self._resolve_target(target) return action.validate(target, **kwargs) + def _resolve_target( + self, + target: Union[torch.Tensor, str, ObjectSemantics, Dict[str, Any]], + ) -> Union[torch.Tensor, ObjectSemantics]: + """Resolve user target input into tensor pose or ObjectSemantics. + + Supports the convenience dict format in ``execute`` and ``validate``. + """ + if isinstance(target, torch.Tensor): + return target + + if isinstance(target, ObjectSemantics): + return target + + if isinstance(target, str): + return self._semantic_analyzer.analyze(target) + + if isinstance(target, dict): + if "pose" in target: + pose = target["pose"] + if not isinstance(pose, torch.Tensor): + raise TypeError("target['pose'] must be a torch.Tensor") + return pose + + if "semantics" in target: + semantics = target["semantics"] + if not isinstance(semantics, ObjectSemantics): + raise TypeError( + "target['semantics'] must be an ObjectSemantics instance" + ) + return semantics + + label = target.get("label") + if label is None: + raise ValueError( + "Dict target must provide 'label', or use 'pose'/'semantics'." + ) + if not isinstance(label, str): + raise TypeError("target['label'] must be a string") + + geometry = target.get("geometry") + custom_config = target.get("custom_config") + use_cache = target.get("use_cache", True) + + semantics = self._semantic_analyzer.analyze( + label=label, + geometry=geometry, + custom_config=custom_config, + use_cache=use_cache, + ) + + properties = target.get("properties") + if properties is not None: + semantics.properties.update(properties) + + uid = target.get("uid") + if uid is not None: + semantics.uid = uid + + return semantics + + raise TypeError( + "target must be torch.Tensor, str, ObjectSemantics, or Dict[str, Any]" + ) + def get_semantic_analyzer(self) -> SemanticAnalyzer: """Get the semantic analyzer for object understanding.""" return self._semantic_analyzer diff --git a/embodichain/lab/sim/planners/motion_generator.py b/embodichain/lab/sim/planners/motion_generator.py index 0682c492..f8dedac7 100644 --- a/embodichain/lab/sim/planners/motion_generator.py +++ b/embodichain/lab/sim/planners/motion_generator.py @@ -508,7 +508,11 @@ def interpolate_trajectory( qpos_seed = options.start_qpos if qpos_seed is None and qpos_list is not None: + # first waypoint as seed qpos_seed = qpos_list[0] + if qpos_seed is None: + # fallback to current robot state as seed + qpos_seed = self.robot.get_qpos(name=control_part)[0] # Generate trajectory interpolate_qpos_list = [] @@ -551,9 +555,14 @@ def interpolate_trajectory( # compute_batch_ik expects (n_envs, n_batch, 7) or (n_envs, n_batch, 4, 4) # Here we assume n_envs = 1 or we want to apply this to all envs if available. # Since MotionGenerator usually works with self.robot.device, we use its batching capabilities. + qpos_seed_repeat = ( + qpos_seed.unsqueeze(0) + .repeat(total_interpolated_poses.shape[0], 1) + .unsqueeze(0) + ) success_batch, qpos_batch = self.robot.compute_batch_ik( pose=total_interpolated_poses.unsqueeze(0), - joint_seed=None, # Or use qpos_seed if properly shaped + joint_seed=qpos_seed_repeat, # Or use qpos_seed if properly shaped name=control_part, ) diff --git a/embodichain/lab/sim/planners/toppra_planner.py b/embodichain/lab/sim/planners/toppra_planner.py index 0c20ccf9..218d17ed 100644 --- a/embodichain/lab/sim/planners/toppra_planner.py +++ b/embodichain/lab/sim/planners/toppra_planner.py @@ -191,11 +191,9 @@ def plan( ) # Build waypoints - waypoints = [] - for target in target_states: - waypoints.append(np.array(target.qpos)) - - waypoints = np.array(waypoints) + waypoints = np.array( + [target.qpos.to("cpu").numpy() for target in target_states] + ) # Create spline interpolation # NOTE: Suitable for dense waypoints ss = np.linspace(0, 1, len(waypoints)) diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py index 658f4f88..9ec009bc 100644 --- a/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_generator.py @@ -73,7 +73,7 @@ class GraspGeneratorCfg: number of sampled surface points, ray perturbation angle, and gripper jaw distance limits. See :class:`AntipodalSamplerCfg` for details.""" - max_deviation_angle: float = np.pi / 12 + max_deviation_angle: float = np.pi / 6 """Maximum allowed angle (in radians) between the specified approach direction and the axis connecting an antipodal point pair. Pairs that deviate more than this threshold from perpendicular to the approach are diff --git a/scripts/tutorials/sim/atom_action.py b/scripts/tutorials/sim/atom_action.py new file mode 100644 index 00000000..28fe7ae8 --- /dev/null +++ b/scripts/tutorials/sim/atom_action.py @@ -0,0 +1,327 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +""" +This script demonstrates the creation and simulation of a robot with a soft object, +and performs a pressing task in a simulated environment. +""" + +import argparse +import numpy as np +import time +import open3d as o3d +import torch + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.objects import Robot, RigidObject +from embodichain.lab.sim.shapes import MeshCfg +from embodichain.lab.sim.solvers import PytorchSolverCfg +from embodichain.data import get_data_path +from embodichain.utils import logger +from embodichain.lab.sim.cfg import ( + JointDrivePropertiesCfg, + RobotCfg, + RigidObjectCfg, + RigidBodyAttributesCfg, + LightCfg, + URDFCfg, +) +from embodichain.lab.sim.planners import MotionGenerator, MotionGenCfg, ToppraPlannerCfg +from embodichain.lab.sim.shapes import MeshCfg + +from embodichain.toolkits.graspkit.pg_grasp.gripper_collision_checker import ( + GripperCollisionCfg, +) +from embodichain.toolkits.graspkit.pg_grasp.antipodal_generator import ( + GraspGenerator, + GraspGeneratorCfg, + AntipodalSamplerCfg, +) +from embodichain.lab.sim.atomic_actions.engine import ( + AtomicActionEngine, + register_action, +) +from embodichain.lab.sim.atomic_actions.core import ObjectSemantics, AntipodalAffordance +from embodichain.lab.sim.atomic_actions.actions import ( + PickUpActionCfg, + PlaceActionCfg, + MoveActionCfg, +) + + +def parse_arguments(): + """ + Parse command-line arguments to configure the simulation. + + Returns: + argparse.Namespace: Parsed arguments including number of environments, device, and rendering options. + """ + parser = argparse.ArgumentParser( + description="Create and simulate a robot in SimulationManager" + ) + parser.add_argument( + "--enable_rt", action="store_true", help="Enable ray tracing rendering" + ) + parser.add_argument( + "--num_envs", type=int, default=1, help="Number of parallel environments" + ) + return parser.parse_args() + + +def initialize_simulation(args): + """ + Initialize the simulation environment based on the provided arguments. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + + Returns: + SimulationManager: Configured simulation manager instance. + """ + config = SimulationManagerCfg( + headless=True, + sim_device="cuda", + enable_rt=args.enable_rt, + physics_dt=1.0 / 100.0, + num_envs=args.num_envs, + ) + sim = SimulationManager(config) + + light = sim.add_light( + cfg=LightCfg(uid="main_light", intensity=50.0, init_pos=(0, 0, 2.0)) + ) + + return sim + + +def create_robot(sim: SimulationManager, position=[0.0, 0.0, 0.0]): + """ + Create and configure a robot with an arm and a dexterous hand in the simulation. + + Args: + sim (SimulationManager): The simulation manager instance. + + Returns: + Robot: The configured robot instance added to the simulation. + """ + # Retrieve URDF paths for the robot arm and hand + ur10_urdf_path = get_data_path("UniversalRobots/UR10/UR10.urdf") + gripper_urdf_path = get_data_path("DH_PGC_140_50_M/DH_PGC_140_50_M.urdf") + # Configure the robot with its components and control properties + cfg = RobotCfg( + uid="UR10", + urdf_cfg=URDFCfg( + components=[ + {"component_type": "arm", "urdf_path": ur10_urdf_path}, + {"component_type": "hand", "urdf_path": gripper_urdf_path}, + ] + ), + drive_pros=JointDrivePropertiesCfg( + stiffness={"JOINT[0-9]": 1e4, "FINGER[1-2]": 1e2}, + damping={"JOINT[0-9]": 1e3, "FINGER[1-2]": 1e1}, + max_effort={"JOINT[0-9]": 1e5, "FINGER[1-2]": 1e3}, + drive_type="force", + ), + control_parts={ + "arm": ["JOINT[0-9]"], + "hand": ["FINGER[1-2]"], + }, + solver_cfg={ + "arm": PytorchSolverCfg( + end_link_name="ee_link", + root_link_name="base_link", + tcp=[ + [0.0, 1.0, 0.0, 0.0], + [-1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.12], + [0.0, 0.0, 0.0, 1.0], + ], + ) + }, + init_qpos=[0.0, -np.pi / 2, -np.pi / 2, np.pi / 2, -np.pi / 2, 0.0, 0.0, 0.0], + init_pos=position, + ) + return sim.add_robot(cfg=cfg) + + +def create_mug(sim: SimulationManager) -> RigidObject: + mug_cfg = RigidObjectCfg( + uid="table", + shape=MeshCfg( + fpath=get_data_path("CoffeeCup/cup.ply"), + ), + attrs=RigidBodyAttributesCfg( + mass=0.01, + dynamic_friction=0.97, + static_friction=0.99, + ), + max_convex_hull_num=16, + init_pos=[0.55, 0.0, 0.01], + init_rot=[0.0, 0.0, -90], + body_scale=(4, 4, 4), + ) + mug = sim.add_rigid_object(cfg=mug_cfg) + return mug + + +def run_trajactory( + robot: Robot, + trajectory: torch.Tensor, + joint_ids: list[float], + sim: SimulationManager, +): + n_waypoint = trajectory.shape[1] + for i in range(n_waypoint): + robot.set_qpos(trajectory[:, i, :], joint_ids=joint_ids) + sim.update(step=4) + time.sleep(1e-2) + + +def main(): + """ + Main function to demonstrate robot simulation. + + This function initializes the simulation, creates the robot and other objects, + and performs the press softbody task. + """ + args = parse_arguments() + sim: SimulationManager = initialize_simulation(args) + robot = create_robot(sim) + mug = create_mug(sim) + + motion_gen = MotionGenerator( + cfg=MotionGenCfg(planner_cfg=ToppraPlannerCfg(robot_uid=robot.uid)) + ) + + pickup_cfg = PickUpActionCfg( + hand_open_qpos=torch.tensor( + [0.00, 0.00], dtype=torch.float32, device=sim.device + ), + hand_close_qpos=torch.tensor( + [0.025, 0.025], dtype=torch.float32, device=sim.device + ), + control_part="arm", + hand_control_part="hand", + approach_direction=torch.tensor( + [0.0, 0.0, -1.0], dtype=torch.float32, device=sim.device + ), + pre_grasp_distance=0.15, + lift_height=0.15, + ) + + place_cfg = PlaceActionCfg( + hand_open_qpos=torch.tensor( + [0.00, 0.00], dtype=torch.float32, device=sim.device + ), + hand_close_qpos=torch.tensor( + [0.025, 0.025], dtype=torch.float32, device=sim.device + ), + control_part="arm", + hand_control_part="hand", + lift_height=0.15, + ) + + move_cfg = MoveActionCfg( + control_part="arm", + ) + + atom_engine = AtomicActionEngine( + robot=robot, + motion_generator=motion_gen, + device=sim.device, + actions_cfg_dict={"pick_up": pickup_cfg, "place": place_cfg, "move": move_cfg}, + ) + + sim.init_gpu_physics() + sim.open_window() + + # Define object semantics and affordances for the mug + gripper_collision_cfg = GripperCollisionCfg( + max_open_length=0.088, finger_length=0.078, point_sample_dense=0.012 + ) + generator_cfg = GraspGeneratorCfg( + viser_port=11801, + antipodal_sampler_cfg=AntipodalSamplerCfg( + n_sample=20000, max_length=0.088, min_length=0.003 + ), + ) + mug_grasp_affordance = AntipodalAffordance( + object_label="mug", + force_reannotate=False, # set to True if you want to re-annotate affordance even if the object has been seen before, which is useful when you have changed the grasp generator configuration and want to see the effect of new configuration, but it will take more time to annotate. So usually set it to False and only set it to True when you have changed the grasp generator configuration or you want to debug the annotation process. + custom_config={ + "gripper_collision_cfg": gripper_collision_cfg, + "generator_cfg": generator_cfg, + }, + ) + mug_semantics = ObjectSemantics( + label="mug", + geometry={ + "mesh_vertices": mug.get_vertices(env_ids=[0], scale=True)[0], + "mesh_triangles": mug.get_triangles(env_ids=[0])[0], + }, + affordance=mug_grasp_affordance, + entity=mug, # in order to fetch object pose + ) + start_qpos = robot.get_qpos(name="arm") + + target_grasp_xpos = torch.tensor( + [ + [-0.0539, -0.9985, -0.0022, 0.4489], + [-0.9977, 0.0540, -0.0401, -0.0030], + [0.0401, 0.0000, -0.9992, 0.1400], + [0.0000, 0.0000, 0.0000, 1.0000], + ], + dtype=torch.float32, + device=sim.device, + ) + + is_success, pick_trajectory, joint_ids = atom_engine.execute( + start_qpos=start_qpos, + action_name="pick_up", + target=mug_semantics, + # target=target_grasp_xpos, # can directly specify target grasp pose without semantics, but then no affordance will be used and no grasp generator will be called, which is not recommended + control_part="arm", + ) + arm_joint_ids = robot.get_joint_ids("arm") + place_start_qpos = pick_trajectory[:, -1, arm_joint_ids] + place_xpos = target_grasp_xpos.clone() + place_xpos[:3, 3] += torch.tensor([-0.2, 0.4, 0.1], device=sim.device) + is_success, place_trajectory, joint_ids = atom_engine.execute( + start_qpos=place_start_qpos, + action_name="place", + target=place_xpos, + control_part="arm", + ) + rest_xpos = target_grasp_xpos.clone() + rest_xpos[:3, 3] = torch.tensor([0.5, 0.0, 0.5], device=sim.device) + move_start_qpos = place_trajectory[:, -1, arm_joint_ids] + is_success, move_trajectory, arm_joint_ids = atom_engine.execute( + start_qpos=move_start_qpos, + action_name="move", + target=rest_xpos, + control_part="arm", + ) + + logger.logger.info(f"Starting simulation with pick success: {is_success}") + run_trajactory(robot, pick_trajectory, joint_ids, sim) + run_trajactory(robot, place_trajectory, joint_ids, sim) + run_trajactory(robot, move_trajectory, arm_joint_ids, sim) + + input("Press Enter to exit...") + + +if __name__ == "__main__": + main() diff --git a/tests/sim/atomic_actions/test_actions.py b/tests/sim/atomic_actions/test_actions.py new file mode 100644 index 00000000..d2750234 --- /dev/null +++ b/tests/sim/atomic_actions/test_actions.py @@ -0,0 +1,508 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import numpy as np +import pytest +import torch + +from embodichain.data import get_data_path +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.atomic_actions.actions import ( + MoveAction, + MoveActionCfg, + PickUpAction, + PickUpActionCfg, + PlaceAction, + PlaceActionCfg, +) +from embodichain.lab.sim.atomic_actions.core import AntipodalAffordance, ObjectSemantics +from embodichain.lab.sim.cfg import ( + JointDrivePropertiesCfg, + RigidBodyAttributesCfg, + RigidObjectCfg, + RobotCfg, + URDFCfg, +) +from embodichain.lab.sim.objects import RigidObject, Robot +from embodichain.lab.sim.planners import MoveType +from embodichain.lab.sim.shapes import MeshCfg +from embodichain.lab.sim.solvers import PytorchSolverCfg + + +class DummyMotionGenerator: + def __init__(self, robot: Robot) -> None: + self.robot = robot + + +def create_robot(sim: SimulationManager, position: list[float] | None = None) -> Robot: + if position is None: + position = [0.0, 0.0, 0.0] + + ur10_urdf_path = get_data_path("UniversalRobots/UR10/UR10.urdf") + gripper_urdf_path = get_data_path("DH_PGC_140_50_M/DH_PGC_140_50_M.urdf") + + cfg = RobotCfg( + uid="UR10", + urdf_cfg=URDFCfg( + components=[ + {"component_type": "arm", "urdf_path": ur10_urdf_path}, + {"component_type": "hand", "urdf_path": gripper_urdf_path}, + ] + ), + drive_pros=JointDrivePropertiesCfg( + stiffness={"JOINT[0-9]": 1e4, "FINGER[1-2]": 1e2}, + damping={"JOINT[0-9]": 1e3, "FINGER[1-2]": 1e1}, + max_effort={"JOINT[0-9]": 1e5, "FINGER[1-2]": 1e3}, + drive_type="force", + ), + control_parts={ + "arm": ["JOINT[0-9]"], + "hand": ["FINGER[1-2]"], + }, + solver_cfg={ + "arm": PytorchSolverCfg( + end_link_name="ee_link", + root_link_name="base_link", + tcp=[ + [0.0, 1.0, 0.0, 0.0], + [-1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.12], + [0.0, 0.0, 0.0, 1.0], + ], + ) + }, + init_qpos=[0.0, -np.pi / 2, -np.pi / 2, np.pi / 2, -np.pi / 2, 0.0, 0.0, 0.0], + init_pos=position, + ) + return sim.add_robot(cfg=cfg) + + +def create_mug(sim: SimulationManager) -> RigidObject: + mug_cfg = RigidObjectCfg( + uid="table", + shape=MeshCfg( + fpath=get_data_path("CoffeeCup/cup.ply"), + ), + attrs=RigidBodyAttributesCfg( + mass=0.01, + dynamic_friction=0.97, + static_friction=0.99, + ), + max_convex_hull_num=16, + init_pos=[0.55, 0.0, 0.01], + init_rot=[0.0, 0.0, -90], + body_scale=(4, 4, 4), + ) + return sim.add_rigid_object(cfg=mug_cfg) + + +class BaseActionTest: + @classmethod + def setup_class(cls) -> None: + cls.sim = SimulationManager( + SimulationManagerCfg(headless=True, sim_device="cpu", num_envs=2) + ) + cls.robot = create_robot(cls.sim) + cls.mug = create_mug(cls.sim) + cls.motion_generator = DummyMotionGenerator(cls.robot) + cls.arm_joint_ids = cls.robot.get_joint_ids("arm") + cls.hand_joint_ids = cls.robot.get_joint_ids("hand") + cls.arm_dof = len(cls.arm_joint_ids) + cls.hand_dof = len(cls.hand_joint_ids) + cls.device = cls.robot.device + cls.sim.update(step=1) + + @classmethod + def teardown_class(cls) -> None: + cls.sim.destroy() + + def _make_pose( + self, + translation: tuple[float, float, float], + ) -> torch.Tensor: + pose = torch.eye(4, dtype=torch.float32, device=self.device) + pose[:3, 3] = torch.tensor(translation, dtype=torch.float32, device=self.device) + return pose + + def _make_semantics(self) -> ObjectSemantics: + return ObjectSemantics( + label="mug", + geometry={ + "mesh_vertices": torch.zeros( + (3, 3), dtype=torch.float32, device=self.device + ), + "mesh_triangles": torch.tensor( + [[0, 1, 2]], dtype=torch.int64, device=self.device + ), + }, + affordance=AntipodalAffordance(object_label="mug"), + entity=self.mug, + ) + + +class TestActions(BaseActionTest): + def test_move_action_execute_with_pose_target( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + action = MoveAction( + motion_generator=self.motion_generator, + cfg=MoveActionCfg(control_part="arm", sample_interval=6), + ) + target_pose = self._make_pose((0.45, -0.05, 0.18)) + plan_call: dict[str, object] = {} + + def fake_plan_arm_trajectory( + target_states_list, + start_qpos, + n_waypoints, + arm_dof=None, + ): + actual_arm_dof = action.dof if arm_dof is None else arm_dof + plan_call["target_states_list"] = target_states_list + plan_call["start_qpos"] = start_qpos.clone() + plan_call["n_waypoints"] = n_waypoints + return torch.full( + (action.n_envs, n_waypoints, actual_arm_dof), + fill_value=0.5, + dtype=torch.float32, + device=action.device, + ) + + monkeypatch.setattr(action, "_plan_arm_trajectory", fake_plan_arm_trajectory) + + success, trajectory, joint_ids = action.execute(target=target_pose) + + assert success is True + assert trajectory.shape == ( + action.n_envs, + action.cfg.sample_interval, + action.dof, + ) + assert joint_ids == self.arm_joint_ids + assert plan_call["n_waypoints"] == action.cfg.sample_interval + assert torch.allclose( + plan_call["start_qpos"], + self.robot.get_qpos(name="arm"), + ) + + target_states_list = plan_call["target_states_list"] + assert len(target_states_list) == action.n_envs + for target_states in target_states_list: + assert len(target_states) == 1 + assert target_states[0].move_type == MoveType.EEF_MOVE + assert torch.allclose(target_states[0].xpos, target_pose) + + def test_move_action_resolve_start_qpos_repeats_single_configuration(self) -> None: + action = MoveAction( + motion_generator=self.motion_generator, + cfg=MoveActionCfg(control_part="arm", sample_interval=6), + ) + single_qpos = torch.zeros(self.arm_dof, dtype=torch.float32, device=self.device) + + resolved_qpos = action._resolve_start_qpos(single_qpos) + + assert resolved_qpos.shape == (action.n_envs, self.arm_dof) + assert torch.allclose( + resolved_qpos, + torch.zeros( + (action.n_envs, self.arm_dof), dtype=torch.float32, device=self.device + ), + ) + + def test_pick_up_action_execute_builds_three_phase_trajectory( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + action = PickUpAction( + motion_generator=self.motion_generator, + cfg=PickUpActionCfg( + control_part="arm", + hand_control_part="hand", + hand_open_qpos=torch.tensor( + [0.0, 0.0], dtype=torch.float32, device=self.device + ), + hand_close_qpos=torch.tensor( + [0.025, 0.025], dtype=torch.float32, device=self.device + ), + approach_direction=torch.tensor( + [0.0, 0.0, -1.0], dtype=torch.float32, device=self.device + ), + pre_grasp_distance=0.15, + lift_height=0.2, + sample_interval=11, + hand_interp_steps=3, + ), + ) + grasp_pose = ( + self._make_pose((0.5, 0.02, 0.12)).unsqueeze(0).repeat(action.n_envs, 1, 1) + ) + semantics = self._make_semantics() + plan_calls: list[dict[str, object]] = [] + + def fake_resolve_grasp_pose(_semantics): + return ( + True, + grasp_pose, + torch.full( + (action.n_envs,), 0.025, dtype=torch.float32, device=self.device + ), + ) + + def fake_plan_arm_trajectory( + target_states_list, + start_qpos, + n_waypoints, + arm_dof=None, + ): + actual_arm_dof = action.arm_dof if arm_dof is None else arm_dof + fill_value = float(len(plan_calls) + 1) + plan_calls.append( + { + "target_states_list": target_states_list, + "start_qpos": start_qpos.clone(), + "n_waypoints": n_waypoints, + "arm_dof": actual_arm_dof, + } + ) + return torch.full( + (action.n_envs, n_waypoints, actual_arm_dof), + fill_value=fill_value, + dtype=torch.float32, + device=action.device, + ) + + monkeypatch.setattr(action, "_resolve_grasp_pose", fake_resolve_grasp_pose) + monkeypatch.setattr(action, "_plan_arm_trajectory", fake_plan_arm_trajectory) + + start_qpos = self.robot.get_qpos(name="arm") + success, trajectory, joint_ids = action.execute( + target=semantics, + start_qpos=start_qpos, + ) + + n_approach, n_close, n_lift = action._compute_three_phase_waypoints( + action.cfg.hand_interp_steps, + first_phase_name="approach", + third_phase_name="lift", + ) + + assert success is True + assert len(plan_calls) == 2 + assert trajectory.shape == ( + action.n_envs, + n_approach + n_close + n_lift, + action.dof, + ) + assert joint_ids == action.joint_ids + assert torch.allclose(plan_calls[0]["start_qpos"], start_qpos) + assert torch.allclose( + plan_calls[1]["start_qpos"], + torch.ones( + (action.n_envs, action.arm_dof), + dtype=torch.float32, + device=self.device, + ), + ) + + approach_states = plan_calls[0]["target_states_list"] + for env_id in range(action.n_envs): + assert len(approach_states[env_id]) == 2 + pre_grasp_pose = approach_states[env_id][0].xpos + final_grasp_pose = approach_states[env_id][1].xpos + assert approach_states[env_id][0].move_type == MoveType.EEF_MOVE + assert approach_states[env_id][1].move_type == MoveType.EEF_MOVE + assert torch.allclose(final_grasp_pose, grasp_pose[env_id]) + assert pre_grasp_pose[2, 3] == pytest.approx( + final_grasp_pose[2, 3].item() + action.cfg.pre_grasp_distance + ) + + lift_states = plan_calls[1]["target_states_list"] + for env_id in range(action.n_envs): + assert len(lift_states[env_id]) == 1 + assert lift_states[env_id][0].xpos[2, 3] > grasp_pose[env_id][2, 3] + + assert torch.allclose( + trajectory[:, :n_approach, : action.arm_dof], + torch.ones( + (action.n_envs, n_approach, action.arm_dof), + dtype=torch.float32, + device=self.device, + ), + ) + assert torch.allclose( + trajectory[:, :n_approach, action.arm_dof :], + action.hand_open_qpos.view(1, 1, -1).expand(action.n_envs, n_approach, -1), + ) + + expected_close_path = action._interpolate_hand_qpos( + action.hand_open_qpos, + action.hand_close_qpos, + n_close, + ) + assert torch.allclose( + trajectory[:, n_approach : n_approach + n_close, : action.arm_dof], + torch.ones( + (action.n_envs, n_close, action.arm_dof), + dtype=torch.float32, + device=self.device, + ), + ) + assert torch.allclose( + trajectory[:, n_approach : n_approach + n_close, action.arm_dof :], + expected_close_path.unsqueeze(0).expand(action.n_envs, -1, -1), + ) + assert torch.allclose( + trajectory[:, n_approach + n_close :, : action.arm_dof], + torch.full( + (action.n_envs, n_lift, action.arm_dof), + fill_value=2.0, + dtype=torch.float32, + device=self.device, + ), + ) + assert torch.allclose( + trajectory[:, n_approach + n_close :, action.arm_dof :], + action.hand_close_qpos.view(1, 1, -1).expand(action.n_envs, n_lift, -1), + ) + + def test_place_action_execute_builds_release_trajectory( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + action = PlaceAction( + motion_generator=self.motion_generator, + cfg=PlaceActionCfg( + control_part="arm", + hand_control_part="hand", + hand_open_qpos=torch.tensor( + [0.0, 0.0], dtype=torch.float32, device=self.device + ), + hand_close_qpos=torch.tensor( + [0.025, 0.025], dtype=torch.float32, device=self.device + ), + lift_height=0.2, + sample_interval=11, + hand_interp_steps=3, + ), + ) + place_pose = self._make_pose((0.42, 0.18, 0.1)) + start_qpos = self.robot.get_qpos(name="arm") + plan_calls: list[dict[str, object]] = [] + + def fake_plan_arm_trajectory( + target_states_list, + start_qpos_arg, + n_waypoints, + arm_dof=None, + ): + actual_arm_dof = action.arm_dof if arm_dof is None else arm_dof + fill_value = float(len(plan_calls) + 1) + plan_calls.append( + { + "target_states_list": target_states_list, + "start_qpos": start_qpos_arg.clone(), + "n_waypoints": n_waypoints, + "arm_dof": actual_arm_dof, + } + ) + return torch.full( + (action.n_envs, n_waypoints, actual_arm_dof), + fill_value=fill_value, + dtype=torch.float32, + device=action.device, + ) + + monkeypatch.setattr(action, "_plan_arm_trajectory", fake_plan_arm_trajectory) + + success, trajectory, joint_ids = action.execute( + target=place_pose, + start_qpos=start_qpos, + ) + + n_down, n_open, n_lift = action._compute_three_phase_waypoints( + action.cfg.hand_interp_steps, + first_phase_name="approach", + third_phase_name="lift", + ) + + assert success is True + assert len(plan_calls) == 2 + assert trajectory.shape == ( + action.n_envs, + n_down + n_open + n_lift, + action.dof, + ) + assert joint_ids == action.joint_ids + assert torch.allclose(plan_calls[0]["start_qpos"], start_qpos) + assert torch.allclose( + plan_calls[1]["start_qpos"], + torch.ones( + (action.n_envs, action.arm_dof), + dtype=torch.float32, + device=self.device, + ), + ) + + down_states = plan_calls[0]["target_states_list"] + for env_id in range(action.n_envs): + assert len(down_states[env_id]) == 2 + lifted_pose = down_states[env_id][0].xpos + final_place_pose = down_states[env_id][1].xpos + assert torch.allclose(final_place_pose, place_pose) + assert lifted_pose[2, 3] > final_place_pose[2, 3] + + expected_open_path = action._interpolate_hand_qpos( + action.hand_close_qpos, + action.hand_open_qpos, + n_open, + ) + assert torch.allclose( + trajectory[:, :n_down, : action.arm_dof], + torch.ones( + (action.n_envs, n_down, action.arm_dof), + dtype=torch.float32, + device=self.device, + ), + ) + assert torch.allclose( + trajectory[:, :n_down, action.arm_dof :], + action.hand_close_qpos.view(1, 1, -1).expand(action.n_envs, n_down, -1), + ) + assert torch.allclose( + trajectory[:, n_down : n_down + n_open, : action.arm_dof], + torch.ones( + (action.n_envs, n_open, action.arm_dof), + dtype=torch.float32, + device=self.device, + ), + ) + assert torch.allclose( + trajectory[:, n_down : n_down + n_open, action.arm_dof :], + expected_open_path.unsqueeze(0).expand(action.n_envs, -1, -1), + ) + assert torch.allclose( + trajectory[:, n_down + n_open :, : action.arm_dof], + torch.full( + (action.n_envs, n_lift, action.arm_dof), + fill_value=2.0, + dtype=torch.float32, + device=self.device, + ), + ) + assert torch.allclose( + trajectory[:, n_down + n_open :, action.arm_dof :], + action.hand_open_qpos.view(1, 1, -1).expand(action.n_envs, n_lift, -1), + ) diff --git a/tests/sim/atomic_actions/test_core.py b/tests/sim/atomic_actions/test_core.py index bbb02a36..220455d1 100644 --- a/tests/sim/atomic_actions/test_core.py +++ b/tests/sim/atomic_actions/test_core.py @@ -14,224 +14,321 @@ # limitations under the License. # ---------------------------------------------------------------------------- -import torch +from __future__ import annotations + import pytest +import torch + +import embodichain.lab.sim.atomic_actions.core as core_module -from embodichain.lab.sim.atomic_actions import ( +from embodichain.lab.sim.atomic_actions.core import ( + ActionCfg, Affordance, - GraspPose, + AntipodalAffordance, + AtomicAction, InteractionPoints, ObjectSemantics, - ActionCfg, - register_action, - unregister_action, - get_registered_actions, +) +from embodichain.lab.sim.planners import ( + MotionGenOptions, + MoveType, + PlanResult, + PlanState, ) -class TestAffordance: - """Test affordance base class and subclasses.""" - - def test_affordance_base(self): - """Test base affordance class.""" - aff = Affordance(object_label="test_object") - assert aff.object_label == "test_object" - assert aff.get_batch_size() == 1 - - def test_grasp_pose_default(self): - """Test GraspPose with default values.""" - grasp = GraspPose(object_label="bottle") - assert grasp.object_label == "bottle" - assert grasp.poses.shape == (1, 4, 4) - assert grasp.grasp_types == ["default"] - assert grasp.get_batch_size() == 1 - - def test_grasp_pose_multiple(self): - """Test GraspPose with multiple poses.""" - poses = torch.stack( - [ - torch.eye(4), - torch.eye(4), - torch.eye(4), - ] - ) - grasp = GraspPose( - object_label="bottle", - poses=poses, - grasp_types=["pinch", "power", "hook"], - ) - assert grasp.get_batch_size() == 3 - - # Test get_grasp_by_type - pinch_pose = grasp.get_grasp_by_type("pinch") - assert pinch_pose is not None - assert torch.allclose(pinch_pose, torch.eye(4)) - - nonexistent = grasp.get_grasp_by_type("nonexistent") - assert nonexistent is None - - def test_grasp_pose_best_grasp(self): - """Test get_best_grasp method.""" - poses = torch.stack( - [ - torch.eye(4), - torch.eye(4) * 2, - ] - ) - confidence = torch.tensor([0.7, 0.9]) - grasp = GraspPose( - poses=poses, - grasp_types=["low_conf", "high_conf"], - confidence_scores=confidence, - ) - - best = grasp.get_best_grasp() - # Should return the second pose (higher confidence) - assert torch.allclose(best, poses[1]) - - def test_interaction_points(self): - """Test InteractionPoints class.""" - points = torch.tensor( - [ - [0.1, 0.0, 0.0], - [0.0, 0.1, 0.0], - [0.0, 0.0, 0.1], - ] - ) - normals = torch.tensor( - [ - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - [0.0, 0.0, 1.0], - ] - ) - interaction = InteractionPoints( - object_label="cube", - points=points, - normals=normals, - point_types=["push", "poke", "touch"], - ) - - assert interaction.get_batch_size() == 3 - - # Test get_points_by_type - push_points = interaction.get_points_by_type("push") - assert push_points is not None - assert torch.allclose(push_points, points[0:1]) +class DummyRobot: + def __init__(self) -> None: + self.device = torch.device("cpu") + self.qpos = torch.tensor([[0.1, 0.2, 0.3]], dtype=torch.float32) + self.fail_ik = False + self.last_ik_call: dict | None = None + self.last_fk_call: dict | None = None + + def get_qpos(self) -> torch.Tensor: + return self.qpos.clone() + + def compute_ik( + self, + pose: torch.Tensor, + qpos_seed: torch.Tensor, + name: str, + ) -> tuple[torch.Tensor, torch.Tensor]: + self.last_ik_call = { + "pose": pose.clone(), + "qpos_seed": qpos_seed.clone(), + "name": name, + } + success = torch.tensor([not self.fail_ik], dtype=torch.bool, device=self.device) + if self.fail_ik: + return success, torch.zeros_like(qpos_seed) + return success, qpos_seed + 1.0 + + def compute_fk( + self, + qpos: torch.Tensor, + name: str, + to_matrix: bool, + ) -> torch.Tensor: + self.last_fk_call = { + "qpos": qpos.clone(), + "name": name, + "to_matrix": to_matrix, + } + batch_size = qpos.shape[0] + poses = torch.eye(4, dtype=torch.float32).unsqueeze(0).repeat(batch_size, 1, 1) + poses[:, 0, 3] = qpos.sum(dim=-1) + return poses + + +class DummyMotionGenerator: + def __init__(self) -> None: + self.robot = DummyRobot() + self.last_target_states: list[PlanState] | None = None + self.last_options: MotionGenOptions | None = None + + def generate( + self, + target_states: list[PlanState], + options: MotionGenOptions, + ) -> PlanResult: + self.last_target_states = target_states + self.last_options = options + positions = torch.tensor( + [[0.0, 0.1, 0.2], [0.3, 0.4, 0.5]], dtype=torch.float32 + ) + return PlanResult(success=True, positions=positions) + + +class DummyAtomicAction(AtomicAction): + def execute( + self, + target: torch.Tensor | ObjectSemantics, + start_qpos: torch.Tensor | None = None, + **kwargs, + ) -> tuple[bool, torch.Tensor, list[float]]: + return True, torch.empty(0), [] + + def validate( + self, + target: torch.Tensor | ObjectSemantics, + start_qpos: torch.Tensor | None = None, + **kwargs, + ) -> bool: + return True + + +class DummyGraspGenerator: + instances: list["DummyGraspGenerator"] = [] + + def __init__( + self, + vertices: torch.Tensor, + triangles: torch.Tensor, + cfg=None, + gripper_collision_cfg=None, + ) -> None: + self.vertices = vertices + self.triangles = triangles + self.cfg = cfg + self.gripper_collision_cfg = gripper_collision_cfg + self.device = vertices.device + self._hit_point_pairs: torch.Tensor | None = None + self.annotate_calls = 0 + self.get_grasp_pose_calls: list[tuple[torch.Tensor, torch.Tensor]] = [] + DummyGraspGenerator.instances.append(self) + + def annotate(self) -> None: + self.annotate_calls += 1 + self._hit_point_pairs = torch.ones( + (1, 2, 3), dtype=torch.float32, device=self.device + ) - nonexistent = interaction.get_points_by_type("nonexistent") - assert nonexistent is None + def get_grasp_poses( + self, + obj_pose: torch.Tensor, + approach_direction: torch.Tensor, + ) -> tuple[bool, torch.Tensor, float]: + self.get_grasp_pose_calls.append((obj_pose.clone(), approach_direction.clone())) + if float(obj_pose[0, 3]) > 0.5: + return False, torch.eye(4, dtype=torch.float32, device=self.device), 0.0 + + grasp_pose = obj_pose.clone() + grasp_pose[2, 3] += 0.02 + return True, grasp_pose, 0.04 + + +class TestAffordanceAndSemantics: + def test_affordance_mesh_properties_and_custom_config(self) -> None: + vertices = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], dtype=torch.float32) + triangles = torch.tensor([[0, 1, 1]], dtype=torch.int64) + affordance = Affordance( + object_label="mug", + geometry={"mesh_vertices": vertices, "mesh_triangles": triangles}, + ) - # Test get_approach_direction - approach = interaction.get_approach_direction(0) - assert torch.allclose(approach, torch.tensor([-1.0, 0.0, 0.0])) + affordance.set_custom_config("score_threshold", 0.8) - def test_interaction_points_no_normals(self): - """Test InteractionPoints without normals.""" - points = torch.tensor([[0.1, 0.2, 0.3]]) - interaction = InteractionPoints(points=points) + assert torch.equal(affordance.mesh_vertices, vertices) + assert torch.equal(affordance.mesh_triangles, triangles) + assert affordance.get_custom_config("score_threshold") == pytest.approx(0.8) + assert affordance.get_batch_size() == 1 - # Default approach direction should be +z - approach = interaction.get_approach_direction(0) - assert torch.allclose(approach, torch.tensor([0.0, 0.0, 1.0])) + def test_affordance_mesh_properties_raise_on_invalid_types(self) -> None: + affordance = Affordance( + geometry={"mesh_vertices": [[0.0, 0.0, 0.0]], "mesh_triangles": [[0, 1, 2]]} + ) + with pytest.raises(TypeError): + _ = affordance.mesh_vertices + + with pytest.raises(TypeError): + _ = affordance.mesh_triangles + + def test_interaction_points_helpers(self) -> None: + interaction_points = InteractionPoints( + points=torch.tensor( + [[0.0, 0.0, 0.0], [0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + dtype=torch.float32, + ), + normals=torch.tensor( + [[0.0, 0.0, 1.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], + dtype=torch.float32, + ), + point_types=["push", "touch", "push"], + ) -class TestObjectSemantics: - """Test ObjectSemantics dataclass.""" + push_points = interaction_points.get_points_by_type("push") - def test_basic_creation(self): - """Test basic ObjectSemantics creation.""" - affordance = GraspPose(object_label="bottle") - semantics = ObjectSemantics( - label="bottle", - affordance=affordance, - geometry={"bounding_box": [0.1, 0.2, 0.3]}, - properties={"mass": 0.5, "friction": 0.8}, - uid="bottle_001", + assert push_points is not None + assert push_points.shape == (2, 3) + assert interaction_points.get_batch_size() == 3 + assert torch.allclose( + interaction_points.get_approach_direction(1), + torch.tensor([-1.0, 0.0, 0.0], dtype=torch.float32), ) - assert semantics.label == "bottle" - assert semantics.uid == "bottle_001" - assert semantics.affordance.object_label == "bottle" - assert semantics.properties["mass"] == 0.5 + no_normal_points = InteractionPoints( + points=torch.zeros((1, 3), dtype=torch.float32) + ) + assert torch.allclose( + no_normal_points.get_approach_direction(0), + torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32), + ) + + def test_object_semantics_binds_label_and_geometry_to_affordance(self) -> None: + geometry = {"mesh_vertices": torch.zeros((4, 3), dtype=torch.float32)} + affordance = Affordance() - def test_no_uid(self): - """Test ObjectSemantics without UID.""" - affordance = GraspPose() semantics = ObjectSemantics( - label="apple", + label="cup", affordance=affordance, - geometry={}, - properties={}, + geometry=geometry, + properties={"mass": 0.1}, ) - assert semantics.uid is None + assert semantics.affordance.object_label == "cup" + assert semantics.affordance.geometry is geometry + assert semantics.properties["mass"] == pytest.approx(0.1) + def test_antipodal_affordance_requires_mesh_geometry(self) -> None: + affordance = AntipodalAffordance(object_label="mug") -class TestActionCfg: - """Test ActionCfg dataclass.""" + with pytest.raises(RuntimeError): + affordance.get_best_grasp_poses( + torch.eye(4, dtype=torch.float32).unsqueeze(0) + ) - def test_defaults(self): - """Test ActionCfg default values.""" - cfg = ActionCfg() - assert cfg.control_part == "left_arm" - assert cfg.interpolation_type == "linear" - assert cfg.velocity_limit is None - assert cfg.acceleration_limit is None + def test_antipodal_affordance_get_best_grasp_poses( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + DummyGraspGenerator.instances.clear() + monkeypatch.setattr(core_module, "GraspGenerator", DummyGraspGenerator) - def test_custom_values(self): - """Test ActionCfg with custom values.""" - cfg = ActionCfg( - control_part="right_arm", - interpolation_type="toppra", - velocity_limit=0.5, - acceleration_limit=1.0, + vertices = torch.tensor( + [[0.0, 0.0, 0.0], [0.1, 0.0, 0.0], [0.0, 0.1, 0.0]], + dtype=torch.float32, + ) + triangles = torch.tensor([[0, 1, 2]], dtype=torch.int64) + affordance = AntipodalAffordance( + object_label="mug", + geometry={"mesh_vertices": vertices, "mesh_triangles": triangles}, + custom_config={ + "generator_cfg": object(), + "gripper_collision_cfg": object(), + }, ) - assert cfg.control_part == "right_arm" - assert cfg.interpolation_type == "toppra" - assert cfg.velocity_limit == 0.5 - assert cfg.acceleration_limit == 1.0 + object_poses = torch.eye(4, dtype=torch.float32).unsqueeze(0).repeat(2, 1, 1) + object_poses[1, 0, 3] = 1.0 + is_success, grasp_xpos, open_length = affordance.get_best_grasp_poses( + object_poses + ) + generator = DummyGraspGenerator.instances[-1] + + assert generator.annotate_calls == 1 + assert torch.equal(is_success, torch.tensor([True, False], dtype=torch.bool)) + assert torch.allclose( + grasp_xpos[0, :3, 3], + torch.tensor([0.0, 0.0, 0.02], dtype=torch.float32), + ) + assert torch.allclose(grasp_xpos[1], torch.eye(4, dtype=torch.float32)) + assert open_length.tolist() == pytest.approx([0.04, 0.0]) -class TestActionRegistry: - """Test action registry functions.""" - def test_register_and_unregister(self): - """Test registering and unregistering actions.""" - from embodichain.lab.sim.atomic_actions import AtomicAction +class TestAtomicActionHelpers: + def setup_method(self) -> None: + self.motion_generator = DummyMotionGenerator() + self.action = DummyAtomicAction( + motion_generator=self.motion_generator, + cfg=ActionCfg(control_part="arm"), + ) - class TestAction(AtomicAction): - def execute(self, target, **kwargs): - return PlanResult(success=True) + def test_ik_solve_uses_control_part_and_seed(self) -> None: + target_pose = torch.eye(4, dtype=torch.float32) + qpos_seed = torch.tensor([0.5, 0.6, 0.7], dtype=torch.float32) - def validate(self, target, **kwargs): - return True + result = self.action._ik_solve(target_pose, qpos_seed) - # Register - register_action("test", TestAction) - assert "test" in get_registered_actions() + assert torch.allclose(result, qpos_seed + 1.0) + assert self.motion_generator.robot.last_ik_call is not None + assert self.motion_generator.robot.last_ik_call["name"] == "arm" + assert torch.allclose( + self.motion_generator.robot.last_ik_call["qpos_seed"], + qpos_seed.unsqueeze(0), + ) - # Unregister - unregister_action("test") - assert "test" not in get_registered_actions() + def test_ik_solve_raises_when_solver_fails(self) -> None: + self.motion_generator.robot.fail_ik = True - def test_get_registered_actions_copy(self): - """Test that get_registered_actions returns a copy.""" - from embodichain.lab.sim.atomic_actions import AtomicAction + with pytest.raises(RuntimeError): + self.action._ik_solve(torch.eye(4, dtype=torch.float32)) - initial = get_registered_actions() + def test_fk_compute_handles_single_and_batched_qpos(self) -> None: + single_qpos = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + batched_qpos = torch.tensor( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype=torch.float32 + ) - class DummyAction(AtomicAction): - def execute(self, target, **kwargs): - return PlanResult(success=True) + single_pose = self.action._fk_compute(single_qpos) + batched_pose = self.action._fk_compute(batched_qpos) - def validate(self, target, **kwargs): - return True + assert single_pose.shape == (4, 4) + assert batched_pose.shape == (2, 4, 4) + assert single_pose[0, 3] == pytest.approx(6.0) + assert torch.allclose( + batched_pose[:, 0, 3], + torch.tensor([1.0, 1.0], dtype=torch.float32), + ) + assert self.motion_generator.robot.last_fk_call is not None + assert self.motion_generator.robot.last_fk_call["to_matrix"] is True + assert self.motion_generator.robot.last_fk_call["name"] == "arm" - register_action("dummy", DummyAction) + def test_apply_offset_updates_translation_in_place_copy(self) -> None: + pose = torch.eye(4, dtype=torch.float32).unsqueeze(0).repeat(2, 1, 1) + offset = torch.tensor([[0.1, 0.2, 0.3], [-0.1, 0.0, 0.2]], dtype=torch.float32) - # Original should not contain the new action - assert "dummy" not in initial + result = self.action._apply_offset(pose, offset) - # Cleanup - unregister_action("dummy") + assert torch.allclose(result[:, :3, 3], offset) + assert torch.allclose(pose[:, :3, 3], torch.zeros((2, 3), dtype=torch.float32))