diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java index a477af8a47a..5da1043f4be 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java @@ -45,6 +45,8 @@ public class McpToolCallbackAutoConfiguration { *

* These callbacks enable integration with Spring AI's tool execution framework, * allowing MCP tools to be used as part of AI interactions. + * @param syncClientsToolFilter list of {@link McpSyncClientBiPredicate}s for the sync + * client to filter the discovered tools * @param syncMcpClients provider of MCP sync clients * @return list of tool callbacks for MCP integration */ diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java index 5f8da416109..3936d201849 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java @@ -28,6 +28,7 @@ import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.ToolExecutionException; +import org.springframework.lang.Nullable; /** * Implementation of {@link ToolCallback} that adapts MCP tools to Spring AI's tool @@ -55,6 +56,7 @@ * } * * @author Christian Tzolov + * @author Ilayaperumal Gopinathan * @see ToolCallback * @see McpAsyncClient * @see Tool @@ -109,25 +111,30 @@ public ToolDefinition getToolDefinition() { */ @Override public String call(String functionInput) { - Map arguments = ModelOptionsUtils.jsonToMap(functionInput); - // Note that we use the original tool name here, not the adapted one from - // getToolDefinition - return this.asyncMcpClient.callTool(new CallToolRequest(this.tool.name(), arguments)).onErrorMap(exception -> { - // If the tool throws an error during execution - throw new ToolExecutionException(this.getToolDefinition(), exception); - }).map(response -> { - if (response.isError() != null && response.isError()) { - throw new ToolExecutionException(this.getToolDefinition(), - new IllegalStateException("Error calling tool: " + response.content())); - } - return ModelOptionsUtils.toJsonString(response.content()); - }).contextWrite(ctx -> ctx.putAll(ToolCallReactiveContextHolder.getContext())).block(); + return this.call(functionInput, null); } @Override - public String call(String toolArguments, ToolContext toolContext) { - // ToolContext is not supported by the MCP tools - return this.call(toolArguments); + public String call(String toolArguments, @Nullable ToolContext toolContext) { + Map arguments = ModelOptionsUtils.jsonToMap(toolArguments); + // Note that we use the original tool name here, not the adapted one from + // getToolDefinition + return this.asyncMcpClient + .callTool(new CallToolRequest(this.tool.name(), arguments, + toolContext != null ? toolContext.getContext() : Map.of())) + .onErrorMap(exception -> { + // If the tool throws an error during execution + throw new ToolExecutionException(this.getToolDefinition(), exception); + }) + .map(response -> { + if (response.isError() != null && response.isError()) { + throw new ToolExecutionException(this.getToolDefinition(), + new IllegalStateException("Error calling tool: " + response.content())); + } + return ModelOptionsUtils.toJsonString(response.content()); + }) + .contextWrite(ctx -> ctx.putAll(ToolCallReactiveContextHolder.getContext())) + .block(); } } diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java index 3525b9593e3..25c17924ce5 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java @@ -81,8 +81,8 @@ public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider { /** * Creates a new {@code AsyncMcpToolCallbackProvider} instance with a list of MCP * clients. - * @param mcpClients the list of MCP clients to use for discovering tools * @param toolFilter a filter to apply to each discovered tool + * @param mcpClients the list of MCP clients to use for discovering tools */ public AsyncMcpToolCallbackProvider(BiPredicate toolFilter, List mcpClients) { Assert.notNull(mcpClients, "MCP clients must not be null"); @@ -106,8 +106,8 @@ public AsyncMcpToolCallbackProvider(List mcpClients) { /** * Creates a new {@code AsyncMcpToolCallbackProvider} instance with one or more MCP * clients. - * @param mcpClients the MCP clients to use for discovering tools * @param toolFilter a filter to apply to each discovered tool + * @param mcpClients the MCP clients to use for discovering tools */ public AsyncMcpToolCallbackProvider(BiPredicate toolFilter, McpAsyncClient... mcpClients) { this(toolFilter, List.of(mcpClients)); diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java index fc61d801df1..eed40a911cb 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java @@ -30,6 +30,7 @@ import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.ToolExecutionException; +import org.springframework.lang.Nullable; /** * Implementation of {@link ToolCallback} that adapts MCP tools to Spring AI's tool @@ -57,6 +58,7 @@ * } * * @author Christian Tzolov + * @author Ilayaperumal Gopinathan * @see ToolCallback * @see McpSyncClient * @see Tool @@ -114,13 +116,19 @@ public ToolDefinition getToolDefinition() { */ @Override public String call(String functionInput) { - Map arguments = ModelOptionsUtils.jsonToMap(functionInput); + return this.call(functionInput, null); + } + + @Override + public String call(String toolArguments, @Nullable ToolContext toolContext) { + Map arguments = ModelOptionsUtils.jsonToMap(toolArguments); CallToolResult response; try { // Note that we use the original tool name here, not the adapted one from // getToolDefinition - response = this.mcpClient.callTool(new CallToolRequest(this.tool.name(), arguments)); + response = this.mcpClient.callTool(new CallToolRequest(this.tool.name(), arguments, + (toolContext != null ? toolContext.getContext() : Map.of()))); } catch (Exception ex) { logger.error("Exception while tool calling: ", ex); @@ -135,10 +143,4 @@ public String call(String functionInput) { return ModelOptionsUtils.toJsonString(response.content()); } - @Override - public String call(String toolArguments, ToolContext toolContext) { - // ToolContext is not supported by the MCP tools - return this.call(toolArguments); - } - } diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java b/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java index 549535eb0ac..9144eaabbf5 100644 --- a/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java @@ -1,5 +1,7 @@ package org.springframework.ai.mcp; +import java.util.Map; + import io.modelcontextprotocol.client.McpAsyncClient; import io.modelcontextprotocol.spec.McpSchema; import org.junit.jupiter.api.Test; @@ -8,9 +10,15 @@ import org.mockito.junit.jupiter.MockitoExtension; import reactor.core.publisher.Mono; +import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.tool.execution.ToolExecutionException; + +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) @@ -51,4 +59,24 @@ void callShouldWrapReactiveErrors() { .hasMessage("Testing tool error"); } + @Test + void shouldApplyToolContext() { + + when(this.tool.name()).thenReturn("testTool"); + McpSchema.CallToolResult callResult = mock(McpSchema.CallToolResult.class); + when(this.mcpClient.callTool(any(McpSchema.CallToolRequest.class))).thenReturn(Mono.just(callResult)); + + AsyncMcpToolCallback callback = new AsyncMcpToolCallback(this.mcpClient, this.tool); + + String response = callback.call("{\"param\":\"value\"}", new ToolContext(Map.of("foo", "bar"))); + + assertThat(response).isNotNull(); + + verify(this.mcpClient).callTool(argThat(callToolRequest -> callToolRequest.name().equals("testTool"))); + verify(this.mcpClient) + .callTool(argThat(callToolRequest -> callToolRequest.arguments().equals(Map.of("param", "value")))); + verify(this.mcpClient) + .callTool(argThat(callToolRequest -> callToolRequest.meta().equals(Map.of("foo", "bar")))); + } + } diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java index a93d28b7332..69c9be95f91 100644 --- a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java @@ -37,7 +37,9 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) @@ -84,9 +86,7 @@ void callShouldHandleJsonInputAndOutput() { } @Test - void callShouldIgnoreToolContext() { - // when(mcpClient.getClientInfo()).thenReturn(new Implementation("testClient", - // "1.0.0")); + void shouldApplyToolContext() { when(this.tool.name()).thenReturn("testTool"); CallToolResult callResult = mock(CallToolResult.class); @@ -97,6 +97,12 @@ void callShouldIgnoreToolContext() { String response = callback.call("{\"param\":\"value\"}", new ToolContext(Map.of("foo", "bar"))); assertThat(response).isNotNull(); + + verify(this.mcpClient).callTool(argThat(callToolRequest -> callToolRequest.name().equals("testTool"))); + verify(this.mcpClient) + .callTool(argThat(callToolRequest -> callToolRequest.arguments().equals(Map.of("param", "value")))); + verify(this.mcpClient) + .callTool(argThat(callToolRequest -> callToolRequest.meta().equals(Map.of("foo", "bar")))); } @Test