diff --git a/docs/simulation_and_benchmarking/rai_bench.md b/docs/simulation_and_benchmarking/rai_bench.md index 1969be88e..a86831273 100644 --- a/docs/simulation_and_benchmarking/rai_bench.md +++ b/docs/simulation_and_benchmarking/rai_bench.md @@ -73,7 +73,7 @@ score = (correctly_placed_now - correctly_placed_initially) / initially_incorrec You can find predefined scene configs in `rai_bench/manipulation_o3de/predefined/configs/`. -Predefined scenarios can be imported like: +Predefined scenarios can be imported like, chosing tasks by difficulty: ```python from rai_bench.manipulation_o3de import get_scenarios @@ -81,8 +81,6 @@ from rai_bench.manipulation_o3de import get_scenarios get_scenarios(levels=["easy", "medium"]) ``` -Choose which task you want by selecting the difficulty, from trivial to very hard scenarios. - ## Tool Calling Agent Benchmark Evaluates agent performance independently from any simulation, based only on tool calls that the agent makes. To make it independent from simulations, this benchmark introduces tool mocks which can be adjusted for different tasks. This makes the benchmark more universal and a lot faster. @@ -106,6 +104,7 @@ The `Validator` class can combine single or multiple subtasks to create a single - OrderedCallsValidator - requires a strict order of subtasks. The next subtask will be validated only when the previous one was completed. Validator passes when all subtasks pass. - NotOrderedCallsValidator - doesn't enforce order of subtasks. Every subtask will be validated against every tool call. Validator passes when all subtasks pass. +- OneFromManyValidator - passes when any one of the given subtasks passes. ### Task diff --git a/docs/tutorials/benchmarking.md b/docs/tutorials/benchmarking.md index fa65de41f..e68f3d387 100644 --- a/docs/tutorials/benchmarking.md +++ b/docs/tutorials/benchmarking.md @@ -15,60 +15,61 @@ If your goal is creating custom tasks and scenarios, visit [Creating Custom Task ## Manipulation O3DE -- Follow setup from [Manipulation demo Setup](../demos/manipulation.md#setup) -- Run the benchmark with: +- Follow the main setup [Basic Setup](../setup/install.md) and setup from [Manipulation demo Setup](../demos/manipulation.md#setup) +- To see available options run: + ```bash + python src/rai_bench/rai_bench/examples/manipulation_o3de.py --help + ``` +- Example usage: ```bash - python src/rai_bench/rai_bench/examples/manipulation_o3de.py --model-name --vendor --levels + python src/rai_bench/rai_bench/examples/manipulation_o3de.py --model-name qwen2.5:7b --vendor ollama --levels trivial ``` + !!! note + + When using Ollama, be sure to pull the model first. + !!! warning - Running all scenarios will take a while. If you want to just try it out, we recommend choosing just one level of difficulty. + Running all scenarios will take a while. If you want to just try it out, we recommend choosing just one level of difficulty. ## Tool Calling Agent -This benchmark does not require any additional setup besides the main one [Basic Setup](../setup/install.md), just run: +- This benchmark does not require any additional setup besides the main one [Basic Setup](../setup/install.md) +- To see available options run: + ```bash + python src/rai_bench/rai_bench/examples/tool_calling_agent.py --help + ``` +- Example usage: ```bash -python src/rai_bench/rai_bench/examples/tool_calling_agent.py --model-name --vendor --extra-tool-calls <0 5> --task-types basic --n-shots <0 2> --prompt-detail --complexities --out-dir +python src/rai_bench/rai_bench/examples/tool_calling_agent.py --model-name qwen2.5:7b --vendor ollama --extra-tool-calls 5 --task-types basic --n-shots 5 --prompt-detail descriptive --complexities easy ``` -!!! note - - This Benchmark is significantly faster, but still, if just trying out, we recommend choosing just one parameter per flag as every combination on params will create more tasks. - ## Testing Models -The best way of benchmarking your models is using the `rai_bench.test_models` function with benchmark configs. - -??? info "test_models function definition" - - ::: rai_bench.test_models.test_models +The best way of benchmarking your models is using the `src/rai_bench/rai_bench/examples/benchmarking_models.py` -Example usage: +Feel free to modify the benchmark configs to suit your needs, you can chose every possible set of params +and the benchmark will be run tasks with every combination: ```python -from rai_bench import ( - ManipulationO3DEBenchmarkConfig, - ToolCallingAgentBenchmarkConfig, - test_models, -) - if __name__ == "__main__": # Define models you want to benchmark - model_names = ["qwen2.5:7b", "llama3.2:3b"] + model_names = ["qwen3:4b", "llama3.2:3b"] vendors = ["ollama", "ollama"] # Define benchmarks that will be used - man_conf = ManipulationO3DEBenchmarkConfig( - o3de_config_path="path/to/your/o3de_config.yaml", # path to your O3DE config + mani_conf = ManipulationO3DEBenchmarkConfig( + o3de_config_path="src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/o3de_config.yaml", levels=[ # define what difficulty of tasks to include in benchmark "trivial", + "easy", ], repeats=1, # how many times to repeat ) - tool_conf = ToolCallingAgentBenchmarkConfig( + tool_conf = ToolCallingAgentBenchmarkConfig( extra_tool_calls=[0, 5], # how many extra tool calls allowed to still pass task_types=[ # what types of tasks to include "basic", @@ -76,10 +77,7 @@ if __name__ == "__main__": "custom_interfaces", ], N_shots=[0, 2], # examples in system prompt - prompt_detail=[ # how descriptive should task prompt be - "brief", - "descriptive" - ], + prompt_detail=["brief", "descriptive"], # how descriptive should task prompt be repeats=1, ) @@ -87,11 +85,22 @@ if __name__ == "__main__": test_models( model_names=model_names, vendors=vendors, - benchmark_configs=[man_conf, tool_conf], + benchmark_configs=[mani_conf, tool_conf], out_dir=out_dir, + # if you want to pass any additinal args to model + additional_model_args=[ + {"reasoning": False}, + {}, + ], ) ``` +Based on the example above the `Tool Calling` benchmark will run basic, spatial_reasoning and custom_interfaces tasks with every configuration of [extra_tool_calls x N_shots x prompt_detail] provided which will result in almost 500 tasks. Manipulation benchmark will run all specified task level once as there is no additional params. Reapeat is set to 1 in both configs so there will be no additional runs. + +!!! note + + When using ollama vendor make sure to pull used models first + ## Viewing Results From every benchmark run, there will be results saved in the provided output directory: @@ -100,7 +109,7 @@ From every benchmark run, there will be results saved in the provided output dir - results_summary.csv - for overall metrics - results.csv - for detailed results of every task/scenario -When using `test_models`, the output directories will be saved as `////...` and this format can be visualized with our Streamlit script: +When using `test_models`, the output directories will be saved as `////...` and this format can be visualized with our Streamlit script: ```bash streamlit run src/rai_bench/rai_bench/examples/visualise_streamlit.py @@ -110,16 +119,29 @@ streamlit run src/rai_bench/rai_bench/examples/visualise_streamlit.py ### Manipulation O3DE Scenarios -To create your own Scenarios, you will need a Scene Config and Task. You can combine already existing Scene and existing Task to create a new Scenario like: +To create your own Scenarios, you will need a Scene Config and Task - check out example `src/rai_bench/rai_bench/examples/custom_scenario.py`. +You can combine already existing Scene and existing Task to create a new Scenario like: ```python +import logging from pathlib import Path -from rai_bench.manipulation_o3de.tasks import PlaceObjectAtCoordTask -from rai_sim.simulation_bridge import SceneConfig +from typing import List, Sequence, Tuple, Union + +from rclpy.impl.rcutils_logger import RcutilsLogger + from rai_bench.manipulation_o3de.benchmark import Scenario +from rai_bench.manipulation_o3de.interfaces import ( + ManipulationTask, +) +from rai_bench.manipulation_o3de.tasks import PlaceObjectAtCoordTask +from rai_sim.simulation_bridge import Entity, SceneConfig +loggers_type = Union[RcutilsLogger, logging.Logger] -path_to_your_config = "src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/1a.yaml" +### Define your scene setup ####################3 +path_to_your_config = ( + "src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/1a.yaml" +) scene_config = SceneConfig.load_base_config(Path(path_to_your_config)) # configure existing Task with different params @@ -156,17 +178,6 @@ entities: Creating your own Task will require slightly more effort. Let's start with something simple - a Task that will require throwing given objects off the table: ```python -import logging -from typing import List, Tuple, Union -from rclpy.impl.rcutils_logger import RcutilsLogger -from rai_bench.manipulation_o3de.interfaces import ( - ManipulationTask, -) -from rai_sim.simulation_bridge import Entity, SimulationConfig - -loggers_type = Union[RcutilsLogger, logging.Logger] - - class ThrowObjectsOffTableTask(ManipulationTask): def __init__(self, obj_types: List[str], logger: loggers_type | None = None): super().__init__(logger=logger) @@ -180,11 +191,9 @@ class ThrowObjectsOffTableTask(ManipulationTask): # define prompt obj_names = ", ".join(obj + "s" for obj in self.obj_types).replace("_", " ") # 0.0 z is the level of table, so any coord below that means it is off the table - return f"Manipulate objects, so that all of the {obj_names} are thrown off the table (negative z)" + return f"Manipulate objects, so that all of the {obj_names} are dropped outside of the table (for example y<-0.75)." - def check_if_required_objects_present( - self, simulation_config: SimulationConfig - ) -> bool: + def check_if_required_objects_present(self, simulation_config: SceneConfig) -> bool: # Validate if any required objects are present in sim config # if there is not a single object of provided type, there is no point in running # this task of given scene config @@ -193,7 +202,7 @@ class ThrowObjectsOffTableTask(ManipulationTask): ) return count > 1 - def calculate_correct(self, entities: List[Entity]) -> Tuple[int, int]: + def calculate_correct(self, entities: Sequence[Entity]) -> Tuple[int, int]: selected_type_objects = self.filter_entities_by_object_type( entities=entities, object_types=self.obj_types ) @@ -206,88 +215,178 @@ class ThrowObjectsOffTableTask(ManipulationTask): incorrect: int = len(selected_type_objects) - correct return correct, incorrect + # configure existing Task with different params target_coords = (0.1, 0.1) disp = 0.1 -task = PlaceObjectAtCoordTask( - obj_type="apple", - target_position=target_coords, - allowable_displacement=disp, +task = ThrowObjectsOffTableTask( + obj_types=["apple"], ) -Scenario( - task=task, - scene_config=scene_config, - scene_config_path=path_to_your_config +super_scenario = Scenario( + task=task, scene_config=scene_config, scene_config_path=path_to_your_config ) ``` As `obj_types` is parameterizable, it enables various variants of this Task. In combination with a lot of simulation configs available, it means that a single Task can provide dozens of scenarios. -Congratulations, you just created your first Scenario from scratch! +Then yo test it simply run: + +```python +##### Now you can run it in benchmark ################## +if __name__ == "__main__": + from pathlib import Path + + from rai_bench import ( + define_benchmark_logger, + ) + from rai_bench.manipulation_o3de import run_benchmark + from rai_bench.utils import get_llm_for_benchmark + + experiment_dir = Path(out_dir="src/rai_bench/experiments/custom_task/") + + experiment_dir.mkdir(parents=True, exist_ok=True) + bench_logger = define_benchmark_logger(out_dir=experiment_dir) + + llm = get_llm_for_benchmark( + model_name="gpt-4o", + vendor="openai", + ) + + run_benchmark( + llm=llm, + out_dir=experiment_dir, + # use your scenario + scenarios=[super_scenario], + bench_logger=bench_logger, + ) + +``` + +Congratulations, you just created and launched your first Scenario from scratch! ### Tool Calling Tasks To create a Tool Calling Task, you will need to define Subtasks, Validators, and Task itself. +Check the example `src/rai_bench/rai_bench/examples/custom_task.py`. Let's create a basic task that requires using a tool to receive a message from a specific topic. ```python +from typing import List + +from langchain_core.tools import BaseTool + +from rai_bench.tool_calling_agent.interfaces import Task, TaskArgs +from rai_bench.tool_calling_agent.mocked_tools import ( + MockGetROS2TopicsNamesAndTypesTool, + MockReceiveROS2MessageTool, +) from rai_bench.tool_calling_agent.subtasks import ( CheckArgsToolCallSubTask, ) from rai_bench.tool_calling_agent.validators import ( OrderedCallsValidator, ) -from rai_bench.tool_calling_agent.mocked_tools import ( - MockGetROS2TopicsNamesAndTypesTool, -) -from rai_bench.tool_calling_agent.interfaces import Task, TaskArgs -from langchain_core.tools import BaseTool -from typing import List - - - -# define subtask that requires -receive_robot_pos_subtask = CheckArgsToolCallSubTask( - expected_tool_name="receive_ros2_message", - expected_args={"topic": "/robot_position"}, - expected_optional_args={ - "timeout_sec": int - }, # if there is not exact value expected, you can pass type -) -# use OrderedCallValidator as there is only 1 subtask -topics_ord_val = OrderedCallsValidator(subtasks=[receive_robot_pos_subtask]) +# This Task will check if robot can receive msessage from specified topic class GetROS2RobotPositionTask(Task): complexity = "easy" + type = "custom" @property def available_tools(self) -> List[BaseTool]: + # define topics that will be seen by agent + TOPICS = [ + "/robot_position", + "/attached_collision_object", + "/clock", + "/collision_object", + ] + + TOPICS_STRING = [ + "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", + "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", + "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", + "topic: /robot_position\n type: sensor_msgs/msg/RobotPosition", + ] + # define which tools will be available for agent return [ - # define which topics will be seen by agent MockGetROS2TopicsNamesAndTypesTool( - mock_topics_names_and_types=[ - "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", - "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", - "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", - "topic: /robot_position\n type: sensor_msgs/msg/RobotPosition", - ] + mock_topics_names_and_types=TOPICS_STRING ), + MockReceiveROS2MessageTool(available_topics=TOPICS), ] def get_system_prompt(self) -> str: return "You are a ROS 2 expert that want to solve tasks. You have access to various tools that allow you to query the ROS 2 system." - def get_prompt(self) -> str: + def get_base_prompt(self) -> str: return "Get the position of the robot." + def get_prompt(self) -> str: + # Create versions for different levels + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can discover what topics are currently active." + ) + @property def optional_tool_calls_number(self) -> int: - # Listing topics before getting any message + # Listing topics before getting any message is fine return 1 + +# define subtask +receive_robot_pos_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/robot_position"}, + expected_optional_args={ + "timeout_sec": int # if there is not exact value expected, you can pass type + }, +) +# use OrderedCallValidator as there is only 1 subtask to check +topics_ord_val = OrderedCallsValidator(subtasks=[receive_robot_pos_subtask]) + + # optionally pass number of extra tool calls args = TaskArgs(extra_tool_calls=0) -task = GetROS2RobotPositionTask(validators=[topics_ord_val], task_args=args) +super_task = GetROS2RobotPositionTask(validators=[topics_ord_val], task_args=args) +``` + +Then run it with: + +```python +##### Now you can run it in benchmark ################## +if __name__ == "__main__": + from pathlib import Path + + from rai_bench import ( + define_benchmark_logger, + ) + from rai_bench.tool_calling_agent import ( + run_benchmark, + ) + from rai_bench.utils import get_llm_for_benchmark + + experiment_dir = Path("src/rai_bench/rai_bench/experiments/custom_task") + experiment_dir.mkdir(parents=True, exist_ok=True) + bench_logger = define_benchmark_logger(out_dir=experiment_dir) + + super_task.set_logger(bench_logger) + + llm = get_llm_for_benchmark( + model_name="gpt-4o", + vendor="openai", + ) + + run_benchmark( + llm=llm, + out_dir=experiment_dir, + tasks=[super_task], + bench_logger=bench_logger, + ) ``` diff --git a/src/rai_bench/README.md b/src/rai_bench/README.md index d475ad42c..bdaebe9c4 100644 --- a/src/rai_bench/README.md +++ b/src/rai_bench/README.md @@ -1,209 +1,4 @@ -# RAI Benchmarks +# RAI Bench -The RAI Bench is a package including benchmarks and providing frame for creating new benchmarks - -## Manipulation O3DE Benchmark - -The Manipulation O3DE Benchmark [manipulation_o3de_benchmark_module](./rai_bench//manipulation_o3de/) provides tasks and scene configurations for robotic arm manipulation simulation in O3DE. The tasks use a common `ManipulationTask` logic and can be parameterized, which allows for many task variants. The current tasks include: - -- **MoveObjectToLeftTask** -- **GroupObjectsTask** -- **BuildCubeTowerTask** -- **PlaceObjectAtCoordTask** -- **RotateObjectTask** (currently not applicable due to limitations in the ManipulatorMoveTo tool) - -The result of a task is a value between 0 and 1, calculated like initially_misplaced_now_correct / initially_misplaced. This score is calculated at the end of each scenario. - -### Frame Components - -- `Task` -- `Scenario` -- `Benchmark` - -For more information about these classes go to -> [benchmark](./rai_bench//manipulation_o3de/benchmark.py) and [Task](./rai_bench//manipulation_o3de//interfaces.py) and - -### Example usage - -Example of how to load scenes, define scenarios and run benchmark can be found in [manipulation_o3de_benchmark_example](rai_bench/examples/manipulation_o3de/main.py) - -Scenarios can be loaded manually like: - -```python -one_carrot_simulation_config = O3DExROS2SimulationConfig.load_config( - base_config_path=Path("path_to_scene.yaml"), - connector_config_path=Path("path_to_o3de_config.yaml"), - ) - -Scenario(task=GrabCarrotTask(logger=some_logger), simulation_config=one_carrot_simulation_config) -``` - -or automatically like: - -```python -scenarios = Benchmark.create_scenarios( - tasks=tasks, simulation_configs=simulations_configs - ) -``` - -which will result in list of scenarios with combination of every possible task and scene(task decides if scene config is suitable for it). - -or can be imported from exisitng packets [scenarios_packets](rai_bench/examples/manipulation_o3de/scenarios.py): - -```python -t_scenarios = trivial_scenarios( - configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger - ) -e_scenarios = easy_scenarios( - configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger -) -m_scenarios = medium_scenarios( - configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger -) -h_scenarios = hard_scenarios( - configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger -) -vh_scenarios = very_hard_scenarios( - configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger -) -``` - -which are grouped by their subjective difficulty. For now there are 10 trivial, 42 easy, 23 medium, 38 hard and 47 very hard scenarios. -Check docstrings and code in [scenarios_packets](rai_bench/examples/manipulation_o3de/scenarios.py) if you want to know how scenarios are assigned to difficulty level. - -### Running - -1. Download O3DE simulation binary and unzip it. - - - [ros2-humble](https://robotec-ml-rai-public.s3.eu-north-1.amazonaws.com/RAIManipulationDemo_jammyhumble.zip) - - [ros2-jazzy](https://robotec-ml-rai-public.s3.eu-north-1.amazonaws.com/RAIManipulationDemo_noblejazzy.zip) - -2. Follow step 2 from [Manipulation demo Setup section](../../docs/demos/manipulation.md#setup) - -3. Adjust the path to the binary in: [o3de_config.yaml](./rai_bench/examples/manipulation_o3de/configs/o3de_config.yaml) -4. Choose the model you want to run and a vendor. - > [!NOTE] - > The configs of vendors are defined in [config.toml](../../config.toml) Change ithem if needed. -5. Run benchmark with: - -```bash -cd rai -source setup_shell.sh -python src/rai_bench/rai_bench/examples/manipulation_o3de/main.py --model-name llama3.2 --vendor ollama -``` - -> [!NOTE] -> For now benchmark runs all available scenarios (~160). See [Examples](#example-usege) -> section for details. - -### Development - -When creating new task or changing existing ones, make sure to add unit tests for score calculation in [rai_bench_tests](../../tests/rai_bench/manipulation_o3de/tasks/). -This applies also when you are adding or changing the helper methods in `Task` or `ManipulationTask`. - -The number of scenarios can be easily extened without writing new tasks, by increasing number of variants of the same task and adding more simulation configs but it won't improve variety of scenarios as much as creating new tasks. - -## Tool Calling Agent Benchmark - -The Tool Calling Agent Benchmark is the benchmark for LangChain tool calling agents. It includes a set of tasks and a benchmark that evaluates the performance of the agent on those tasks by verifying the correctness of the tool calls requested by the agent. The benchmark is integrated with LangSmith and Langfuse tracing backends to easily track the performance of the agents. - -### Frame Components - -- [Tool Calling Agent Benchmark](rai_bench//tool_calling_agent/benchmark.py) - Benchmark for LangChain tool calling agents -- [Scores tracing](rai_bench/tool_calling_agent_bench/scores_tracing.py) - Component handling sending scores to tracing backends -- [Interfaces](rai_bench//tool_calling_agent/interfaces.py) - Interfaces for validation classes - Task, Validator, SubTask - For detailed description of validation visit -> [Validation](.//rai_bench/docs/tool_calling_agent_benchmark.md) - -[tool_calling_agent_test_bench.py](rai_bench/examples/tool_calling_agent/main.py) - Script providing benchmark on tasks based on the ROS2 tools usage. - -### Example Usage - -Validators can be constructed from any SubTasks, Tasks can be validated by any numer of Validators, which makes whole validation process incredibly versital. - -```python -# subtasks -get_topics_subtask = CheckArgsToolCallSubTask( - expected_tool_name="get_ros2_topics_names_and_types" -) -color_image_subtask = CheckArgsToolCallSubTask( - expected_tool_name="get_ros2_image", expected_args={"topic": "/camera_image_color"} -) -# validators - consist of subtasks -topics_ord_val = OrderedCallsValidator(subtasks=[get_topics_subtask]) -color_image_ord_val = OrderedCallsValidator(subtasks=[color_image_subtask]) -topics_and_color_image_ord_val = OrderedCallsValidator( - subtasks=[ - get_topics_subtask, - color_image_subtask, - ] -) -# tasks - validated by list of validators -GetROS2TopicsTask(validators=[topics_ord_val]) -GetROS2RGBCameraTask(validators=[topics_and_color_image_ord_val]), -GetROS2RGBCameraTask(validators=[topics_ord_val, color_image_ord_val]), -``` - -### Running - -To set up tracing backends, please follow the instructions in the [tracing.md](../../docs/tracing.md) document. - -To run the benchmark: - -```bash -cd rai -source setup_shell.sh -python src/rai_bench/rai_bench/examples/tool_calling_agent/main.py -``` - -There is also flags to declare model type and vendor: - -```bash -python src/rai_bench/rai_bench/examples/tool_calling_agent/main.py --model-name llama3.2 --vendor ollama -``` - -> [!NOTE] -> The configs of vendors are defined in [config.toml](../../config.toml) Change ithem if needed. - -## Testing Models - -To test multiple models, different benchamrks or couple repeats in one go - use script [test_models](./rai_bench/examples/test_models.py) - -Modify these params: - -```python -models_name = ["llama3.2", "qwen2.5:7b"] -vendors = ["ollama", "ollama"] -benchmarks = ["tool_calling_agent"] -repeats = 1 -``` - -to your liking and run the script! - -```bash -python src/rai_bench/rai_bench/examples/test_models.py -``` - -### Results and Visualization - -All results from running benchmarks will be saved to folder [experiments](./rai_bench/experiments/) - -If you run single benchmark test like: - -```bash -python src/rai_bench/rai_bench/examples//main.py -``` - -Results will be saved to dedicated directory named `` - -When you run a test via: - -```bash -python src/rai_bench/rai_bench/examples/test_models.py -``` - -results will be saved to separate folder in [results](./rai_bench/experiments/), with prefix `run_` - -To visualise the results run: - -```bash -streamlit run src/rai_bench/rai_bench/results_processing/visualise.py -``` +For tutorial see [RAI Bench Tutorial](../../docs/tutorials/benchmarking.md) +For understanding the structure of the package visit [RAI Bench Overview](../../docs/simulation_and_benchmarking) diff --git a/src/rai_bench/rai_bench/examples/benchmarking_models.py b/src/rai_bench/rai_bench/examples/benchmarking_models.py index c97d11cbd..2a0cc188c 100644 --- a/src/rai_bench/rai_bench/examples/benchmarking_models.py +++ b/src/rai_bench/rai_bench/examples/benchmarking_models.py @@ -20,26 +20,26 @@ if __name__ == "__main__": # Define models you want to benchmark - model_names = ["qwen2.5:7b"] - vendors = ["ollama"] + model_names = ["qwen3:4b", "llama3.2:3b"] + vendors = ["ollama", "ollama"] # Define benchmarks that will be used mani_conf = ManipulationO3DEBenchmarkConfig( - o3de_config_path="src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/o3de_config.yaml", # path to your o3de config + o3de_config_path="src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/o3de_config.yaml", levels=[ # define what difficulty of tasks to include in benchmark "trivial", + "easy", ], repeats=1, # how many times to repeat ) tool_conf = ToolCallingAgentBenchmarkConfig( - extra_tool_calls=[0], # how many extra tool calls allowed to still pass + extra_tool_calls=[0, 5], # how many extra tool calls allowed to still pass task_types=[ # what types of tasks to include "basic", "spatial_reasoning", "custom_interfaces", - "manipulation", ], - N_shots=[2], # examples in system prompt + N_shots=[0, 2], # examples in system prompt prompt_detail=["brief", "descriptive"], # how descriptive should task prompt be repeats=1, ) @@ -48,6 +48,11 @@ test_models( model_names=model_names, vendors=vendors, - benchmark_configs=[tool_conf], + benchmark_configs=[mani_conf, tool_conf], out_dir=out_dir, + # if you want to pass any additinal args to model + additional_model_args=[ + {"reasoning": False}, + {}, + ], ) diff --git a/src/rai_bench/rai_bench/examples/custom_scenario.py b/src/rai_bench/rai_bench/examples/custom_scenario.py new file mode 100644 index 000000000..731c678e2 --- /dev/null +++ b/src/rai_bench/rai_bench/examples/custom_scenario.py @@ -0,0 +1,125 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path +from typing import List, Sequence, Tuple, Union + +from rclpy.impl.rcutils_logger import RcutilsLogger + +from rai_bench.manipulation_o3de.benchmark import Scenario +from rai_bench.manipulation_o3de.interfaces import ( + ManipulationTask, +) +from rai_bench.manipulation_o3de.tasks import PlaceObjectAtCoordTask +from rai_sim.simulation_bridge import Entity, SceneConfig + +loggers_type = Union[RcutilsLogger, logging.Logger] + +### Define your scene setup ####################3 +path_to_your_config = ( + "src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/1a.yaml" +) +scene_config = SceneConfig.load_base_config(Path(path_to_your_config)) + +# configure existing Task with different params +target_coords = (0.1, 0.1) +disp = 0.1 +task = PlaceObjectAtCoordTask( + obj_type="apple", + target_position=target_coords, + allowable_displacement=disp, +) + +Scenario(task=task, scene_config=scene_config, scene_config_path=path_to_your_config) + + +######### Define your task ################### +class ThrowObjectsOffTableTask(ManipulationTask): + def __init__(self, obj_types: List[str], logger: loggers_type | None = None): + super().__init__(logger=logger) + # obj_types is a list of objects that are subject of the task + # In this case, it will mean which objects should be thrown off the table + # can be any objects + self.obj_types = obj_types + + @property + def task_prompt(self) -> str: + # define prompt + obj_names = ", ".join(obj + "s" for obj in self.obj_types).replace("_", " ") + # 0.0 z is the level of table, so any coord below that means it is off the table + return f"Manipulate objects, so that all of the {obj_names} are dropped outside of the table (for example y<-0.75)." + + def check_if_required_objects_present(self, simulation_config: SceneConfig) -> bool: + # Validate if any required objects are present in sim config + # if there is not a single object of provided type, there is no point in running + # this task of given scene config + count = sum( + 1 for ent in simulation_config.entities if ent.prefab_name in self.obj_types + ) + return count > 1 + + def calculate_correct(self, entities: Sequence[Entity]) -> Tuple[int, int]: + selected_type_objects = self.filter_entities_by_object_type( + entities=entities, object_types=self.obj_types + ) + + # check how many objects are below table, that will be our metric + correct = sum( + 1 for ent in selected_type_objects if ent.pose.pose.position.z < 0.0 + ) + + incorrect: int = len(selected_type_objects) - correct + return correct, incorrect + + +# configure existing Task with different params +target_coords = (0.1, 0.1) +disp = 0.1 +task = ThrowObjectsOffTableTask( + obj_types=["apple"], +) + +super_scenario = Scenario( + task=task, scene_config=scene_config, scene_config_path=path_to_your_config +) + + +##### Now you can run it in benchmark ################## +if __name__ == "__main__": + from pathlib import Path + + from rai_bench import ( + define_benchmark_logger, + ) + from rai_bench.manipulation_o3de import run_benchmark + from rai_bench.utils import get_llm_for_benchmark + + experiment_dir = Path("src/rai_bench/rai_bench/experiments/custom_task/") + + experiment_dir.mkdir(parents=True, exist_ok=True) + bench_logger = define_benchmark_logger(out_dir=experiment_dir) + + llm = get_llm_for_benchmark( + model_name="gpt-4o", + vendor="openai", + ) + + run_benchmark( + llm=llm, + out_dir=experiment_dir, + # use your scenario + scenarios=[super_scenario], + bench_logger=bench_logger, + ) diff --git a/src/rai_bench/rai_bench/examples/custom_task.py b/src/rai_bench/rai_bench/examples/custom_task.py new file mode 100644 index 000000000..fdfb9b763 --- /dev/null +++ b/src/rai_bench/rai_bench/examples/custom_task.py @@ -0,0 +1,128 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List + +from langchain_core.tools import BaseTool + +from rai_bench.tool_calling_agent.interfaces import Task, TaskArgs +from rai_bench.tool_calling_agent.mocked_tools import ( + MockGetROS2TopicsNamesAndTypesTool, + MockReceiveROS2MessageTool, +) +from rai_bench.tool_calling_agent.subtasks import ( + CheckArgsToolCallSubTask, +) +from rai_bench.tool_calling_agent.validators import ( + OrderedCallsValidator, +) + + +# This Task will check if robot can receive msessage from specified topic +class GetROS2RobotPositionTask(Task): + complexity = "easy" + type = "custom" + + @property + def available_tools(self) -> List[BaseTool]: + # define topics that will be seen by agent + TOPICS = [ + "/robot_position", + "/attached_collision_object", + "/clock", + "/collision_object", + ] + + TOPICS_STRING = [ + "topic: /attached_collision_object\ntype: moveit_msgs/msg/AttachedCollisionObject\n", + "topic: /clock\ntype: rosgraph_msgs/msg/Clock\n", + "topic: /collision_object\ntype: moveit_msgs/msg/CollisionObject\n", + "topic: /robot_position\n type: sensor_msgs/msg/RobotPosition", + ] + # define which tools will be available for agent + return [ + MockGetROS2TopicsNamesAndTypesTool( + mock_topics_names_and_types=TOPICS_STRING + ), + MockReceiveROS2MessageTool(available_topics=TOPICS), + ] + + def get_system_prompt(self) -> str: + return "You are a ROS 2 expert that want to solve tasks. You have access to various tools that allow you to query the ROS 2 system." + + def get_base_prompt(self) -> str: + return "Get the position of the robot." + + def get_prompt(self) -> str: + # Create versions for different levels + if self.prompt_detail == "brief": + return self.get_base_prompt() + else: + return ( + f"{self.get_base_prompt()} " + "You can discover what topics are currently active." + ) + + @property + def optional_tool_calls_number(self) -> int: + # Listing topics before getting any message is fine + return 1 + + +# define subtask +receive_robot_pos_subtask = CheckArgsToolCallSubTask( + expected_tool_name="receive_ros2_message", + expected_args={"topic": "/robot_position"}, + expected_optional_args={ + "timeout_sec": int # if there is not exact value expected, you can pass type + }, +) +# use OrderedCallValidator as there is only 1 subtask to check +topics_ord_val = OrderedCallsValidator(subtasks=[receive_robot_pos_subtask]) + + +# optionally pass number of extra tool calls +args = TaskArgs(extra_tool_calls=0) +super_task = GetROS2RobotPositionTask(validators=[topics_ord_val], task_args=args) + +##### Now you can run it in benchmark ################## +if __name__ == "__main__": + from pathlib import Path + + from rai_bench import ( + define_benchmark_logger, + ) + from rai_bench.tool_calling_agent import ( + run_benchmark, + ) + from rai_bench.utils import get_llm_for_benchmark + + experiment_dir = Path("src/rai_bench/rai_bench/experiments/custom_task") + experiment_dir.mkdir(parents=True, exist_ok=True) + bench_logger = define_benchmark_logger(out_dir=experiment_dir) + + super_task.set_logger(bench_logger) + + llm = get_llm_for_benchmark( + model_name="gpt-4o", + vendor="openai", + ) + + run_benchmark( + llm=llm, + out_dir=experiment_dir, + tasks=[super_task], + bench_logger=bench_logger, + ) diff --git a/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py b/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py index ed53cad85..fae462edb 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py @@ -423,9 +423,9 @@ def _setup_benchmark_environment( def run_benchmark( llm: BaseChatModel, out_dir: Path, - o3de_config_path: str, scenarios: List[Scenario], bench_logger: logging.Logger, + o3de_config_path: str = "src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/o3de_config.yaml", experiment_id: uuid.UUID = uuid.uuid4(), ): connector, o3de, benchmark, tools = _setup_benchmark_environment( diff --git a/src/rai_bench/rai_bench/manipulation_o3de/interfaces.py b/src/rai_bench/rai_bench/manipulation_o3de/interfaces.py index 91e595320..8ee3bcc27 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/interfaces.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/interfaces.py @@ -15,7 +15,7 @@ import math from abc import ABC, abstractmethod from collections import defaultdict -from typing import Dict, List, Set, Tuple, TypeVar, Union +from typing import Dict, List, Sequence, Set, Tuple, Union from rai.types import Pose from rclpy.impl.rcutils_logger import RcutilsLogger @@ -24,12 +24,10 @@ Entity, SceneConfig, SimulationBridge, - SimulationConfigT, - SpawnedEntity, ) loggers_type = Union[RcutilsLogger, logging.Logger] -EntityT = TypeVar("EntityT", bound=Entity) +# EntityT = TypeVar("EntityT", bound=Entity) class EntitiesMismatchException(Exception): @@ -79,9 +77,7 @@ def validate_config(self, simulation_config: SceneConfig) -> bool: pass @abstractmethod - def calculate_score( - self, simulation_bridge: SimulationBridge[SimulationConfigT] - ) -> float: + def calculate_score(self, simulation_bridge: SimulationBridge) -> float: """ Calculate the task score based on the simulation information. @@ -98,8 +94,8 @@ def calculate_score( pass def filter_entities_by_object_type( - self, entities: List[EntityT], object_types: List[str] - ) -> List[EntityT]: + self, entities: Sequence[Entity], object_types: List[str] + ) -> List[Entity]: """ Filter and return only the entities that match the provided prefab types. @@ -198,14 +194,14 @@ def count_adjacent(self, positions: List[Pose], threshold_distance: float) -> in return adjacent_count def build_neighbourhood_list( - self, entities: List[EntityT], threshold_distance: float = 0.15 - ) -> Dict[EntityT, List[EntityT]]: + self, entities: Sequence[Entity], threshold_distance: float = 0.15 + ) -> Dict[Entity, List[Entity]]: """ Build a neighbourhood list assigning a list of neighbours to every entity based on a threshold distance. Parameters ---------- - entities : List[EntityT] + entities : Sequence[EntityT] # Changed from List[EntityT] The list of entities. threshold_distance : float, optional The maximum distance between entities to consider them neighbours. Default is 0.15. @@ -215,7 +211,7 @@ def build_neighbourhood_list( Dict[EntityT, List[EntityT]] A dictionary mapping each entity to a list of neighbouring entities. """ - neighbourhood_graph: Dict[EntityT, List[EntityT]] = { + neighbourhood_graph: Dict[Entity, List[Entity]] = { entity: [] for entity in entities } for entity in entities: @@ -230,8 +226,8 @@ def build_neighbourhood_list( return neighbourhood_graph def group_entities_by_type( - self, entities: List[EntityT] - ) -> Dict[str, List[EntityT]]: + self, entities: Sequence[Entity] + ) -> Dict[str, List[Entity]]: """ Group entities by their prefab type. @@ -245,14 +241,14 @@ def group_entities_by_type( Dict[str, List[EntityT]] A dictionary with keys as prefab names and values as lists of entities of that type. """ - entities_by_type: Dict[str, List[EntityT]] = defaultdict(list) + entities_by_type: Dict[str, List[Entity]] = defaultdict(list) for entity in entities: entities_by_type[entity.prefab_name].append(entity) return entities_by_type def check_neighbourhood_types( self, - neighbourhood: List[EntityT], + neighbourhood: Sequence[Entity], allowed_types: List[str], ) -> bool: """ @@ -275,8 +271,8 @@ def check_neighbourhood_types( ) def find_clusters( - self, neighbourhood_list: Dict[EntityT, List[EntityT]] - ) -> List[List[EntityT]]: + self, neighbourhood_list: Dict[Entity, List[Entity]] + ) -> List[List[Entity]]: """ Identify clusters of entities using a DFS algorithm. @@ -293,10 +289,10 @@ def find_clusters( List[List[EntityT]] A list of clusters, where each cluster is a list of connected entities. """ - visited: Set[EntityT] = set() - clusters: List[List[EntityT]] = [] + visited: Set[Entity] = set() + clusters: List[List[Entity]] = [] - def dfs(node: EntityT, cluster: List[EntityT]): + def dfs(node: Entity, cluster: List[Entity]): visited.add(node) cluster.append(node) for neighbor in neighbourhood_list.get(node, []): @@ -305,7 +301,7 @@ def dfs(node: EntityT, cluster: List[EntityT]): for node in neighbourhood_list.keys(): if node not in visited: - component: List[EntityT] = [] + component: List[Entity] = [] dfs(node, component) clusters.append(component) @@ -314,9 +310,9 @@ def dfs(node: EntityT, cluster: List[EntityT]): def group_entities_along_z_axis( # NOTE (jmatejcz) figure out how to group by other coords and orientation, without reapeting code self, - entities: List[EntityT], + entities: List[Entity], margin: float, - ) -> List[List[EntityT]]: + ) -> List[List[Entity]]: """ Group entities that are aligned along the z axis based on their x and y coordinates. @@ -347,7 +343,7 @@ def group_entities_along_z_axis( key=lambda ent: (ent.pose.pose.position.x, ent.pose.pose.position.y), ) - groups: List[List[EntityT]] = [] + groups: List[List[Entity]] = [] for entity in entities: placed = False for group in groups: @@ -440,9 +436,7 @@ def validate_config(self, simulation_config: SceneConfig) -> bool: return False @abstractmethod - def calculate_correct( - self, entities: List[Entity] | List[SpawnedEntity] - ) -> Tuple[int, int]: + def calculate_correct(self, entities: Sequence[Entity]) -> Tuple[int, int]: """Method to calculate how many objects are placed correctly Parameters @@ -458,7 +452,7 @@ def calculate_correct( pass def calculate_current_placements( - self, simulation_bridge: SimulationBridge[SimulationConfigT] + self, simulation_bridge: SimulationBridge ) -> tuple[int, int]: """ Get the current placements of objects in the simulation @@ -485,9 +479,7 @@ def calculate_current_placements( ) return current_correct, current_incorrect - def calculate_score( - self, simulation_bridge: SimulationBridge[SceneConfig] - ) -> float: + def calculate_score(self, simulation_bridge: SimulationBridge) -> float: """ Calculate the task score based on the difference between initial and current placements. diff --git a/src/rai_bench/rai_bench/manipulation_o3de/tasks/build_tower_task.py b/src/rai_bench/rai_bench/manipulation_o3de/tasks/build_tower_task.py index 771b99387..58f4cf4e9 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/tasks/build_tower_task.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/tasks/build_tower_task.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import List, Tuple, Union +from typing import List, Sequence, Tuple, Union from rclpy.impl.rcutils_logger import RcutilsLogger @@ -94,7 +94,7 @@ def check_if_required_objects_present(self, simulation_config: SceneConfig) -> b ) return cube_count > 1 - def calculate_correct(self, entities: List[Entity]) -> Tuple[int, int]: + def calculate_correct(self, entities: Sequence[Entity]) -> Tuple[int, int]: """ Calculate the number of correctly and incorrectly placed cubes. diff --git a/src/rai_bench/rai_bench/manipulation_o3de/tasks/group_objects_task.py b/src/rai_bench/rai_bench/manipulation_o3de/tasks/group_objects_task.py index 41ce103ec..18b2dcea1 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/tasks/group_objects_task.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/tasks/group_objects_task.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List, Tuple, Union +from typing import List, Sequence, Tuple, Union from rclpy.impl.rcutils_logger import RcutilsLogger @@ -76,7 +76,7 @@ def check_if_required_objects_present(self, simulation_config: SceneConfig) -> b return set(self.obj_types) <= object_types_present - def calculate_correct(self, entities: List[Entity]) -> Tuple[int, int]: + def calculate_correct(self, entities: Sequence[Entity]) -> Tuple[int, int]: """ Count correctly and incorrectly clustered objects based on clustering rules. diff --git a/src/rai_bench/rai_bench/manipulation_o3de/tasks/move_object_to_left_task.py b/src/rai_bench/rai_bench/manipulation_o3de/tasks/move_object_to_left_task.py index b6032ec1b..239306acd 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/tasks/move_object_to_left_task.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/tasks/move_object_to_left_task.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List, Tuple, Union +from typing import List, Sequence, Tuple, Union from rclpy.impl.rcutils_logger import RcutilsLogger @@ -51,7 +51,7 @@ def check_if_required_objects_present(self, simulation_config: SceneConfig) -> b ) return set(self.obj_types) <= object_types_present.keys() - def calculate_correct(self, entities: List[Entity]) -> Tuple[int, int]: + def calculate_correct(self, entities: Sequence[Entity]) -> Tuple[int, int]: """ Calculate the number of objects correctly moved to the left side of the table. diff --git a/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_at_coord_task.py b/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_at_coord_task.py index cd0f9fc28..69811fd3a 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_at_coord_task.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_at_coord_task.py @@ -14,7 +14,7 @@ import logging import math -from typing import List, Tuple, Union +from typing import Sequence, Tuple, Union from rclpy.impl.rcutils_logger import RcutilsLogger @@ -67,7 +67,7 @@ def check_if_required_objects_present(self, simulation_config: SceneConfig) -> b ) return count >= 1 - def calculate_correct(self, entities: List[Entity]) -> Tuple[int, int]: + def calculate_correct(self, entities: Sequence[Entity]) -> Tuple[int, int]: """ Calculate the number of correctly and incorrectly placed objects. diff --git a/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_cubes_task.py b/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_cubes_task.py index 5ec7c5d1b..4337f7b19 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_cubes_task.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_cubes_task.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List, Tuple, Union +from typing import Sequence, Tuple, Union from rclpy.impl.rcutils_logger import RcutilsLogger @@ -65,7 +65,7 @@ def check_if_required_objects_present(self, simulation_config: SceneConfig) -> b return False - def calculate_correct(self, entities: List[Entity]) -> Tuple[int, int]: + def calculate_correct(self, entities: Sequence[Entity]) -> Tuple[int, int]: """ Calculate the number of correctly and incorrectly placed cubes based on adjacency. diff --git a/src/rai_bench/rai_bench/manipulation_o3de/tasks/rotate_object_task.py b/src/rai_bench/rai_bench/manipulation_o3de/tasks/rotate_object_task.py index e91af00c8..893619f1e 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/tasks/rotate_object_task.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/tasks/rotate_object_task.py @@ -14,7 +14,7 @@ import logging import math -from typing import List, Tuple, Union +from typing import List, Sequence, Tuple, Union from rai.types import Quaternion from rclpy.impl.rcutils_logger import RcutilsLogger @@ -72,7 +72,7 @@ def check_if_required_objects_present(self, simulation_config: SceneConfig) -> b ) def calculate_correct( - self, entities: List[Entity], allowable_rotation_error: float = 5.0 + self, entities: Sequence[Entity], allowable_rotation_error: float = 5.0 ) -> Tuple[int, int]: """ Calculate the number of correctly rotated objects and incorrectly rotated objects, diff --git a/src/rai_bench/rai_bench/test_models.py b/src/rai_bench/rai_bench/test_models.py index 0501d811c..fbf84bfd0 100644 --- a/src/rai_bench/rai_bench/test_models.py +++ b/src/rai_bench/rai_bench/test_models.py @@ -15,9 +15,8 @@ from abc import abstractmethod from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Literal +from typing import Any, Dict, List, Literal, Optional -from git import Optional from langchain.chat_models.base import BaseChatModel from pydantic import BaseModel diff --git a/src/rai_bench/rai_bench/utils.py b/src/rai_bench/rai_bench/utils.py index 60fbac038..e1150082c 100644 --- a/src/rai_bench/rai_bench/utils.py +++ b/src/rai_bench/rai_bench/utils.py @@ -34,6 +34,7 @@ def parse_tool_calling_benchmark_args(): parser.add_argument( "--extra-tool-calls", type=int, + nargs="+", help="Number of extra tools calls agent can make and still pass the task", default=0, ) diff --git a/src/rai_sim/rai_sim/o3de/o3de_bridge.py b/src/rai_sim/rai_sim/o3de/o3de_bridge.py index dac0e1027..4a3e4a2e7 100644 --- a/src/rai_sim/rai_sim/o3de/o3de_bridge.py +++ b/src/rai_sim/rai_sim/o3de/o3de_bridge.py @@ -61,7 +61,7 @@ def load_config(cls, config_path: Path) -> "O3DExROS2SimulationConfig": return cls(**connector_content) -class O3DExROS2Bridge(SimulationBridge[O3DExROS2SimulationConfig]): +class O3DExROS2Bridge(SimulationBridge): def __init__( self, connector: ROS2Connector, logger: Optional[logging.Logger] = None ): diff --git a/src/rai_sim/rai_sim/simulation_bridge.py b/src/rai_sim/rai_sim/simulation_bridge.py index bb514e270..6a4f6657b 100644 --- a/src/rai_sim/rai_sim/simulation_bridge.py +++ b/src/rai_sim/rai_sim/simulation_bridge.py @@ -15,7 +15,7 @@ import logging from abc import ABC, abstractmethod from pathlib import Path -from typing import Generic, List, Optional, TypeVar +from typing import List, Optional import yaml from pydantic import BaseModel, Field, field_validator @@ -167,10 +167,10 @@ class SceneState(BaseModel): class SimulationConfig(BaseModel): ... -SimulationConfigT = TypeVar("SimulationConfigT", bound=SimulationConfig) +# SimulationConfigT = TypeVar("SimulationConfigT", bound=SimulationConfig) -class SimulationBridge(ABC, Generic[SimulationConfigT]): +class SimulationBridge(ABC): """ Responsible for communication with simulation. """ @@ -185,7 +185,7 @@ def __init__(self, logger: Optional[logging.Logger] = None): self.logger = logger @abstractmethod - def init_simulation(self, simulation_config: SimulationConfigT): + def init_simulation(self, simulation_config: SimulationConfig): """ Initialize simulation binary and all other required processes, for example ros2 nodes diff --git a/tests/rai_sim/test_simulation_bridge.py b/tests/rai_sim/test_simulation_bridge.py index 5e383fb85..f68264d01 100644 --- a/tests/rai_sim/test_simulation_bridge.py +++ b/tests/rai_sim/test_simulation_bridge.py @@ -148,7 +148,7 @@ def test_load_base_config(sample_base_yaml_config: Path): assert len(config.entities) == 2 -class MockSimulationBridge(SimulationBridge[SimulationConfig]): +class MockSimulationBridge(SimulationBridge): """Mock implementation of SimulationBridge for testing.""" def init_simulation(self, simulation_config: SimulationConfig):