Skip to content

Allow MCP toolcallback to use ToolContext #3925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ public class McpToolCallbackAutoConfiguration {
* <p>
* These callbacks enable integration with Spring AI's tool execution framework,
* allowing MCP tools to be used as part of AI interactions.
* @param syncClientsToolFilter list of {@link McpSyncClientBiPredicate}s for the sync
* client to filter the discovered tools
* @param syncMcpClients provider of MCP sync clients
* @return list of tool callbacks for MCP integration
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.lang.Nullable;

/**
* Implementation of {@link ToolCallback} that adapts MCP tools to Spring AI's tool
Expand Down Expand Up @@ -55,6 +56,7 @@
* }</pre>
*
* @author Christian Tzolov
* @author Ilayaperumal Gopinathan
* @see ToolCallback
* @see McpAsyncClient
* @see Tool
Expand Down Expand Up @@ -109,25 +111,30 @@ public ToolDefinition getToolDefinition() {
*/
@Override
public String call(String functionInput) {
Map<String, Object> arguments = ModelOptionsUtils.jsonToMap(functionInput);
// Note that we use the original tool name here, not the adapted one from
// getToolDefinition
return this.asyncMcpClient.callTool(new CallToolRequest(this.tool.name(), arguments)).onErrorMap(exception -> {
// If the tool throws an error during execution
throw new ToolExecutionException(this.getToolDefinition(), exception);
}).map(response -> {
if (response.isError() != null && response.isError()) {
throw new ToolExecutionException(this.getToolDefinition(),
new IllegalStateException("Error calling tool: " + response.content()));
}
return ModelOptionsUtils.toJsonString(response.content());
}).contextWrite(ctx -> ctx.putAll(ToolCallReactiveContextHolder.getContext())).block();
return this.call(functionInput, null);
}

@Override
public String call(String toolArguments, ToolContext toolContext) {
// ToolContext is not supported by the MCP tools
return this.call(toolArguments);
public String call(String toolArguments, @Nullable ToolContext toolContext) {
Map<String, Object> arguments = ModelOptionsUtils.jsonToMap(toolArguments);
// Note that we use the original tool name here, not the adapted one from
// getToolDefinition
return this.asyncMcpClient
.callTool(new CallToolRequest(this.tool.name(), arguments,
toolContext != null ? toolContext.getContext() : Map.of()))
.onErrorMap(exception -> {
// If the tool throws an error during execution
throw new ToolExecutionException(this.getToolDefinition(), exception);
})
.map(response -> {
if (response.isError() != null && response.isError()) {
throw new ToolExecutionException(this.getToolDefinition(),
new IllegalStateException("Error calling tool: " + response.content()));
}
return ModelOptionsUtils.toJsonString(response.content());
})
.contextWrite(ctx -> ctx.putAll(ToolCallReactiveContextHolder.getContext()))
.block();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider {
/**
* Creates a new {@code AsyncMcpToolCallbackProvider} instance with a list of MCP
* clients.
* @param mcpClients the list of MCP clients to use for discovering tools
* @param toolFilter a filter to apply to each discovered tool
* @param mcpClients the list of MCP clients to use for discovering tools
*/
public AsyncMcpToolCallbackProvider(BiPredicate<McpAsyncClient, Tool> toolFilter, List<McpAsyncClient> mcpClients) {
Assert.notNull(mcpClients, "MCP clients must not be null");
Expand All @@ -106,8 +106,8 @@ public AsyncMcpToolCallbackProvider(List<McpAsyncClient> mcpClients) {
/**
* Creates a new {@code AsyncMcpToolCallbackProvider} instance with one or more MCP
* clients.
* @param mcpClients the MCP clients to use for discovering tools
* @param toolFilter a filter to apply to each discovered tool
* @param mcpClients the MCP clients to use for discovering tools
*/
public AsyncMcpToolCallbackProvider(BiPredicate<McpAsyncClient, Tool> toolFilter, McpAsyncClient... mcpClients) {
this(toolFilter, List.of(mcpClients));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.lang.Nullable;

/**
* Implementation of {@link ToolCallback} that adapts MCP tools to Spring AI's tool
Expand Down Expand Up @@ -57,6 +58,7 @@
* }</pre>
*
* @author Christian Tzolov
* @author Ilayaperumal Gopinathan
* @see ToolCallback
* @see McpSyncClient
* @see Tool
Expand Down Expand Up @@ -114,13 +116,19 @@ public ToolDefinition getToolDefinition() {
*/
@Override
public String call(String functionInput) {
Map<String, Object> arguments = ModelOptionsUtils.jsonToMap(functionInput);
return this.call(functionInput, null);
}

@Override
public String call(String toolArguments, @Nullable ToolContext toolContext) {
Map<String, Object> arguments = ModelOptionsUtils.jsonToMap(toolArguments);

CallToolResult response;
try {
// Note that we use the original tool name here, not the adapted one from
// getToolDefinition
response = this.mcpClient.callTool(new CallToolRequest(this.tool.name(), arguments));
response = this.mcpClient.callTool(new CallToolRequest(this.tool.name(), arguments,
(toolContext != null ? toolContext.getContext() : Map.of())));
}
catch (Exception ex) {
logger.error("Exception while tool calling: ", ex);
Expand All @@ -135,10 +143,4 @@ public String call(String functionInput) {
return ModelOptionsUtils.toJsonString(response.content());
}

@Override
public String call(String toolArguments, ToolContext toolContext) {
// ToolContext is not supported by the MCP tools
return this.call(toolArguments);
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.springframework.ai.mcp;

import java.util.Map;

import io.modelcontextprotocol.client.McpAsyncClient;
import io.modelcontextprotocol.spec.McpSchema;
import org.junit.jupiter.api.Test;
Expand All @@ -8,9 +10,15 @@
import org.mockito.junit.jupiter.MockitoExtension;
import reactor.core.publisher.Mono;

import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.execution.ToolExecutionException;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
Expand Down Expand Up @@ -51,4 +59,24 @@ void callShouldWrapReactiveErrors() {
.hasMessage("Testing tool error");
}

@Test
void shouldApplyToolContext() {

when(this.tool.name()).thenReturn("testTool");
McpSchema.CallToolResult callResult = mock(McpSchema.CallToolResult.class);
when(this.mcpClient.callTool(any(McpSchema.CallToolRequest.class))).thenReturn(Mono.just(callResult));

AsyncMcpToolCallback callback = new AsyncMcpToolCallback(this.mcpClient, this.tool);

String response = callback.call("{\"param\":\"value\"}", new ToolContext(Map.of("foo", "bar")));

assertThat(response).isNotNull();

verify(this.mcpClient).callTool(argThat(callToolRequest -> callToolRequest.name().equals("testTool")));
verify(this.mcpClient)
.callTool(argThat(callToolRequest -> callToolRequest.arguments().equals(Map.of("param", "value"))));
verify(this.mcpClient)
.callTool(argThat(callToolRequest -> callToolRequest.meta().equals(Map.of("foo", "bar"))));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
Expand Down Expand Up @@ -84,9 +86,7 @@ void callShouldHandleJsonInputAndOutput() {
}

@Test
void callShouldIgnoreToolContext() {
// when(mcpClient.getClientInfo()).thenReturn(new Implementation("testClient",
// "1.0.0"));
void shouldApplyToolContext() {

when(this.tool.name()).thenReturn("testTool");
CallToolResult callResult = mock(CallToolResult.class);
Expand All @@ -97,6 +97,12 @@ void callShouldIgnoreToolContext() {
String response = callback.call("{\"param\":\"value\"}", new ToolContext(Map.of("foo", "bar")));

assertThat(response).isNotNull();

verify(this.mcpClient).callTool(argThat(callToolRequest -> callToolRequest.name().equals("testTool")));
verify(this.mcpClient)
.callTool(argThat(callToolRequest -> callToolRequest.arguments().equals(Map.of("param", "value"))));
verify(this.mcpClient)
.callTool(argThat(callToolRequest -> callToolRequest.meta().equals(Map.of("foo", "bar"))));
}

@Test
Expand Down