Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,32 @@ async def start_conversation(
)
return conversation_info, False

# Dynamically register tools from client's registry
if request.tool_module_qualnames:
import importlib

for tool_name, module_qualname in request.tool_module_qualnames.items():
try:
# Import the module to trigger tool auto-registration
importlib.import_module(module_qualname)
logger.debug(
f"Tool '{tool_name}' registered via module '{module_qualname}'"
)
except ImportError as e:
logger.warning(
f"Failed to import module '{module_qualname}' for tool "
f"'{tool_name}': {e}. Tool will not be available."
)
# Continue even if some tools fail to register
# The agent will fail gracefully if it tries to use unregistered
# tools
if request.tool_module_qualnames:
logger.info(
f"Dynamically registered {len(request.tool_module_qualnames)} "
f"tools for conversation {conversation_id}: "
f"{list(request.tool_module_qualnames.keys())}"
)

stored = StoredConversation(id=conversation_id, **request.model_dump())
event_service = await self._start_event_service(stored)
initial_message = request.initial_message
Expand Down
8 changes: 8 additions & 0 deletions openhands-agent-server/openhands/agent_server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ class StartConversationRequest(BaseModel):
default_factory=dict,
description="Secrets available in the conversation",
)
tool_module_qualnames: dict[str, str] = Field(
default_factory=dict,
description=(
"Mapping of tool names to their module qualnames from the client's "
"registry. These modules will be dynamically imported on the server "
"to register the tools for this conversation."
),
)


class StoredConversation(StartConversationRequest):
Expand Down
5 changes: 3 additions & 2 deletions openhands-agent-server/openhands/agent_server/tool_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

from openhands.sdk.tool.registry import list_registered_tools
from openhands.tools.preset.default import register_default_tools
from openhands.tools.preset.planning import register_planning_tools


tool_router = APIRouter(prefix="/tools", tags=["Tools"])
# Register default tools for backward compatibility
# Planning tools and other custom tools are now dynamically registered
# when creating a RemoteConversation
register_default_tools(enable_browser=True)
register_planning_tools()


# Tool listing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,9 @@ def __init__(
self._client = workspace.client

if conversation_id is None:
# Import here to avoid circular imports
from openhands.sdk.tool.registry import get_tool_module_qualnames

payload = {
"agent": agent.model_dump(
mode="json", context={"expose_secrets": True}
Expand All @@ -462,6 +465,8 @@ def __init__(
"workspace": LocalWorkspace(
working_dir=self.workspace.working_dir
).model_dump(),
# Include tool module qualnames for dynamic registration on server
"tool_module_qualnames": get_tool_module_qualnames(),
}
resp = _send_request(
self._client, "POST", "/api/conversations", json=payload
Expand Down
23 changes: 23 additions & 0 deletions openhands-sdk/openhands/sdk/tool/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

_LOCK = RLock()
_REG: dict[str, Resolver] = {}
_MODULE_QUALNAMES: dict[str, str] = {} # Maps tool name to module qualname


def _resolver_from_instance(name: str, tool: ToolDefinition) -> Resolver:
Expand Down Expand Up @@ -137,11 +138,22 @@ def register_tool(
"(3) a callable factory returning a Sequence[ToolDefinition]"
)

# Track the module qualname for this tool
module_qualname = None
if isinstance(factory, type):
module_qualname = factory.__module__
elif callable(factory):
module_qualname = getattr(factory, "__module__", None)
elif isinstance(factory, ToolDefinition):
module_qualname = factory.__class__.__module__

with _LOCK:
# TODO: throw exception when registering duplicate name tools
if name in _REG:
logger.warning(f"Duplicate tool name registerd {name}")
_REG[name] = resolver
if module_qualname:
_MODULE_QUALNAMES[name] = module_qualname


def resolve_tool(
Expand All @@ -159,3 +171,14 @@ def resolve_tool(
def list_registered_tools() -> list[str]:
with _LOCK:
return list(_REG.keys())


def get_tool_module_qualnames() -> dict[str, str]:
"""Get a mapping of tool names to their module qualnames.

Returns:
A dictionary mapping tool names to module qualnames (e.g.,
{"glob": "openhands.tools.glob.definition"}).
"""
with _LOCK:
return dict(_MODULE_QUALNAMES)
106 changes: 106 additions & 0 deletions tests/agent_server/test_conversation_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,3 +1169,109 @@ def test_generate_conversation_title_invalid_params(
assert response.status_code == 422 # Validation error
finally:
client.app.dependency_overrides.clear()


def test_start_conversation_with_tool_module_qualnames(
client, mock_conversation_service, sample_conversation_info
):
"""Test start_conversation endpoint with tool_module_qualnames field."""

# Mock the service response
mock_conversation_service.start_conversation.return_value = (
sample_conversation_info,
True,
)

# Override the dependency
client.app.dependency_overrides[get_conversation_service] = (
lambda: mock_conversation_service
)

try:
request_data = {
"agent": {
"llm": {
"model": "gpt-4o",
"api_key": "test-key",
"usage_id": "test-llm",
},
"tools": [
{"name": "glob"},
{"name": "grep"},
{"name": "planning_file_editor"},
],
},
"workspace": {"working_dir": "/tmp/test"},
"tool_module_qualnames": {
"glob": "openhands.tools.glob.definition",
"grep": "openhands.tools.grep.definition",
"planning_file_editor": (
"openhands.tools.planning_file_editor.definition"
),
},
}

response = client.post("/api/conversations", json=request_data)

assert response.status_code == 201
data = response.json()
assert data["id"] == str(sample_conversation_info.id)

# Verify service was called
mock_conversation_service.start_conversation.assert_called_once()
call_args = mock_conversation_service.start_conversation.call_args
request_arg = call_args[0][0]
assert hasattr(request_arg, "tool_module_qualnames")
assert request_arg.tool_module_qualnames == {
"glob": "openhands.tools.glob.definition",
"grep": "openhands.tools.grep.definition",
"planning_file_editor": ("openhands.tools.planning_file_editor.definition"),
}
finally:
client.app.dependency_overrides.clear()


def test_start_conversation_without_tool_module_qualnames(
client, mock_conversation_service, sample_conversation_info
):
"""Test start_conversation endpoint without tool_module_qualnames field."""

# Mock the service response
mock_conversation_service.start_conversation.return_value = (
sample_conversation_info,
True,
)

# Override the dependency
client.app.dependency_overrides[get_conversation_service] = (
lambda: mock_conversation_service
)

try:
request_data = {
"agent": {
"llm": {
"model": "gpt-4o",
"api_key": "test-key",
"usage_id": "test-llm",
},
"tools": [{"name": "TerminalTool"}],
},
"workspace": {"working_dir": "/tmp/test"},
}

response = client.post("/api/conversations", json=request_data)

assert response.status_code == 201
data = response.json()
assert data["id"] == str(sample_conversation_info.id)

# Verify service was called
mock_conversation_service.start_conversation.assert_called_once()
call_args = mock_conversation_service.start_conversation.call_args
request_arg = call_args[0][0]
assert hasattr(request_arg, "tool_module_qualnames")
# Should default to empty dict
assert request_arg.tool_module_qualnames == {}
finally:
client.app.dependency_overrides.clear()
68 changes: 68 additions & 0 deletions tests/sdk/test_registry_qualnames.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Tests for tool registry module qualname tracking."""

from openhands.sdk.tool.registry import (
get_tool_module_qualnames,
list_registered_tools,
register_tool,
)


def test_get_tool_module_qualnames_with_class():
"""Test that module qualnames are tracked when registering a class."""
from openhands.tools.glob import GlobTool

# Register the GlobTool class
register_tool("test_glob_class", GlobTool)

# Get the module qualnames
qualnames = get_tool_module_qualnames()

# Verify the tool is tracked with its module
assert "test_glob_class" in qualnames
assert qualnames["test_glob_class"] == "openhands.tools.glob.definition"


def test_get_tool_module_qualnames_with_callable():
"""Test that module qualnames are tracked when registering a callable."""

def test_factory(conv_state):
return []

# Register the callable
register_tool("test_callable", test_factory)

# Get the module qualnames
qualnames = get_tool_module_qualnames()

# Verify the tool is tracked with its module
assert "test_callable" in qualnames
assert "test_registry_qualnames" in qualnames["test_callable"]


def test_get_tool_module_qualnames_after_import():
"""Test that importing a tool module registers it with qualname."""
# Import glob tool module to trigger auto-registration
import openhands.tools.glob.definition # noqa: F401

# Get registered tools
registered_tools = list_registered_tools()

# Should have glob registered
assert "glob" in registered_tools

# Get module qualnames
qualnames = get_tool_module_qualnames()

# Verify glob has its module qualname tracked
if "glob" in qualnames:
assert qualnames["glob"] == "openhands.tools.glob.definition"


def test_get_tool_module_qualnames_returns_copy():
"""Test that get_tool_module_qualnames returns a copy, not the original dict."""
qualnames1 = get_tool_module_qualnames()
qualnames2 = get_tool_module_qualnames()

# Should be equal but not the same object
assert qualnames1 == qualnames2
assert qualnames1 is not qualnames2