diff --git a/src/rai_core/rai/tools/ros2/__init__.py b/src/rai_core/rai/tools/ros2/__init__.py index 70e411fbd..f996711dc 100644 --- a/src/rai_core/rai/tools/ros2/__init__.py +++ b/src/rai_core/rai/tools/ros2/__init__.py @@ -18,7 +18,6 @@ raise ImportError( "This is a ROS2 feature. Make sure ROS2 is installed and sourced." ) - from .cli import ( ROS2CLIToolkit, ros2_action, @@ -48,11 +47,6 @@ ROS2TopicsToolkit, StartROS2ActionTool, ) -from .manipulation.custom import ( - GetObjectPositionsTool, - MoveToPointTool, - MoveToPointToolInput, -) from .navigation.nav2 import ( CancelNavigateToPoseTool, GetNavigateToPoseFeedbackTool, @@ -71,7 +65,6 @@ "CancelROS2ActionTool", "GetNavigateToPoseFeedbackTool", "GetNavigateToPoseResultTool", - "GetObjectPositionsTool", "GetROS2ActionFeedbackTool", "GetROS2ActionIDsTool", "GetROS2ActionResultTool", @@ -83,8 +76,6 @@ "GetROS2TopicsNamesAndTypesTool", "GetROS2TransformConfiguredTool", "GetROS2TransformTool", - "MoveToPointTool", - "MoveToPointToolInput", "Nav2Toolkit", "NavigateToPoseTool", "PublishROS2MessageTool", diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py index ab8c868aa..96f9d3d73 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py @@ -16,11 +16,10 @@ import numpy as np import sensor_msgs.msg -from langchain_core.tools import BaseTool from pydantic import BaseModel, Field -from rai.communication.ros2 import ROS2Connector from rai.communication.ros2.api import convert_ros_img_to_ndarray from rai.communication.ros2.ros_async import get_future_result +from rai.tools.ros2.base import BaseROS2Tool from rclpy.exceptions import ( ParameterNotDeclaredException, ParameterUninitializedException, @@ -78,9 +77,7 @@ class DistanceMeasurement(NamedTuple): # --------------------- Tools --------------------- -class GroundingDinoBaseTool(BaseTool): - connector: ROS2Connector = Field(..., exclude=True) - +class GroundingDinoBaseTool(BaseROS2Tool): box_threshold: float = Field(default=0.35, description="Box threshold for GDINO") text_threshold: float = Field(default=0.45, description="Text threshold for GDINO") @@ -89,7 +86,7 @@ def _call_gdino_node( ) -> Future: cli = self.connector.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME) while not cli.wait_for_service(timeout_sec=1.0): - self.node.get_logger().info( + self.connector.node.get_logger().info( f"service {GDINO_SERVICE_NAME} not available, waiting again..." ) req = RAIGroundingDino.Request() diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py index 16c6fc2df..f4171a58e 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py @@ -18,14 +18,13 @@ import numpy as np import rclpy import sensor_msgs.msg -from langchain_core.tools import BaseTool from pydantic import BaseModel, Field from rai.communication.ros2.api import ( convert_ros_img_to_base64, convert_ros_img_to_ndarray, ) -from rai.communication.ros2.connectors import ROS2Connector from rai.communication.ros2.ros_async import get_future_result +from rai.tools.ros2.base import BaseROS2Tool from rclpy import Future from rclpy.exceptions import ( ParameterNotDeclaredException, @@ -67,12 +66,7 @@ class GetGrabbingPointInput(BaseModel): # --------------------- Tools --------------------- -class GetSegmentationTool: - connector: ROS2Connector = Field(..., exclude=True) - - name: str = "" - description: str = "" - +class GetSegmentationTool(BaseROS2Tool): box_threshold: float = Field(default=0.35, description="Box threshold for GDINO") text_threshold: float = Field(default=0.45, description="Text threshold for GDINO") @@ -194,9 +188,7 @@ def depth_to_point_cloud( return points -class GetGrabbingPointTool(BaseTool): - connector: ROS2Connector = Field(..., exclude=True) - +class GetGrabbingPointTool(BaseROS2Tool): name: str = "GetGrabbingPointTool" description: str = "Get the grabbing point of an object" pcd: List[Any] = []