diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java
index a477af8a47a..5da1043f4be 100644
--- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java
+++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java
@@ -45,6 +45,8 @@ public class McpToolCallbackAutoConfiguration {
*
* These callbacks enable integration with Spring AI's tool execution framework,
* allowing MCP tools to be used as part of AI interactions.
+ * @param syncClientsToolFilter list of {@link McpSyncClientBiPredicate}s for the sync
+ * client to filter the discovered tools
* @param syncMcpClients provider of MCP sync clients
* @return list of tool callbacks for MCP integration
*/
diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java
index 5f8da416109..3936d201849 100644
--- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java
+++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java
@@ -28,6 +28,7 @@
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.ToolExecutionException;
+import org.springframework.lang.Nullable;
/**
* Implementation of {@link ToolCallback} that adapts MCP tools to Spring AI's tool
@@ -55,6 +56,7 @@
* }
*
* @author Christian Tzolov
+ * @author Ilayaperumal Gopinathan
* @see ToolCallback
* @see McpAsyncClient
* @see Tool
@@ -109,25 +111,30 @@ public ToolDefinition getToolDefinition() {
*/
@Override
public String call(String functionInput) {
- Map arguments = ModelOptionsUtils.jsonToMap(functionInput);
- // Note that we use the original tool name here, not the adapted one from
- // getToolDefinition
- return this.asyncMcpClient.callTool(new CallToolRequest(this.tool.name(), arguments)).onErrorMap(exception -> {
- // If the tool throws an error during execution
- throw new ToolExecutionException(this.getToolDefinition(), exception);
- }).map(response -> {
- if (response.isError() != null && response.isError()) {
- throw new ToolExecutionException(this.getToolDefinition(),
- new IllegalStateException("Error calling tool: " + response.content()));
- }
- return ModelOptionsUtils.toJsonString(response.content());
- }).contextWrite(ctx -> ctx.putAll(ToolCallReactiveContextHolder.getContext())).block();
+ return this.call(functionInput, null);
}
@Override
- public String call(String toolArguments, ToolContext toolContext) {
- // ToolContext is not supported by the MCP tools
- return this.call(toolArguments);
+ public String call(String toolArguments, @Nullable ToolContext toolContext) {
+ Map arguments = ModelOptionsUtils.jsonToMap(toolArguments);
+ // Note that we use the original tool name here, not the adapted one from
+ // getToolDefinition
+ return this.asyncMcpClient
+ .callTool(new CallToolRequest(this.tool.name(), arguments,
+ toolContext != null ? toolContext.getContext() : Map.of()))
+ .onErrorMap(exception -> {
+ // If the tool throws an error during execution
+ throw new ToolExecutionException(this.getToolDefinition(), exception);
+ })
+ .map(response -> {
+ if (response.isError() != null && response.isError()) {
+ throw new ToolExecutionException(this.getToolDefinition(),
+ new IllegalStateException("Error calling tool: " + response.content()));
+ }
+ return ModelOptionsUtils.toJsonString(response.content());
+ })
+ .contextWrite(ctx -> ctx.putAll(ToolCallReactiveContextHolder.getContext()))
+ .block();
}
}
diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java
index 3525b9593e3..25c17924ce5 100644
--- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java
+++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java
@@ -81,8 +81,8 @@ public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider {
/**
* Creates a new {@code AsyncMcpToolCallbackProvider} instance with a list of MCP
* clients.
- * @param mcpClients the list of MCP clients to use for discovering tools
* @param toolFilter a filter to apply to each discovered tool
+ * @param mcpClients the list of MCP clients to use for discovering tools
*/
public AsyncMcpToolCallbackProvider(BiPredicate toolFilter, List mcpClients) {
Assert.notNull(mcpClients, "MCP clients must not be null");
@@ -106,8 +106,8 @@ public AsyncMcpToolCallbackProvider(List mcpClients) {
/**
* Creates a new {@code AsyncMcpToolCallbackProvider} instance with one or more MCP
* clients.
- * @param mcpClients the MCP clients to use for discovering tools
* @param toolFilter a filter to apply to each discovered tool
+ * @param mcpClients the MCP clients to use for discovering tools
*/
public AsyncMcpToolCallbackProvider(BiPredicate toolFilter, McpAsyncClient... mcpClients) {
this(toolFilter, List.of(mcpClients));
diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java
index fc61d801df1..eed40a911cb 100644
--- a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java
+++ b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java
@@ -30,6 +30,7 @@
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.ToolExecutionException;
+import org.springframework.lang.Nullable;
/**
* Implementation of {@link ToolCallback} that adapts MCP tools to Spring AI's tool
@@ -57,6 +58,7 @@
* }
*
* @author Christian Tzolov
+ * @author Ilayaperumal Gopinathan
* @see ToolCallback
* @see McpSyncClient
* @see Tool
@@ -114,13 +116,19 @@ public ToolDefinition getToolDefinition() {
*/
@Override
public String call(String functionInput) {
- Map arguments = ModelOptionsUtils.jsonToMap(functionInput);
+ return this.call(functionInput, null);
+ }
+
+ @Override
+ public String call(String toolArguments, @Nullable ToolContext toolContext) {
+ Map arguments = ModelOptionsUtils.jsonToMap(toolArguments);
CallToolResult response;
try {
// Note that we use the original tool name here, not the adapted one from
// getToolDefinition
- response = this.mcpClient.callTool(new CallToolRequest(this.tool.name(), arguments));
+ response = this.mcpClient.callTool(new CallToolRequest(this.tool.name(), arguments,
+ (toolContext != null ? toolContext.getContext() : Map.of())));
}
catch (Exception ex) {
logger.error("Exception while tool calling: ", ex);
@@ -135,10 +143,4 @@ public String call(String functionInput) {
return ModelOptionsUtils.toJsonString(response.content());
}
- @Override
- public String call(String toolArguments, ToolContext toolContext) {
- // ToolContext is not supported by the MCP tools
- return this.call(toolArguments);
- }
-
}
diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java b/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java
index 549535eb0ac..9144eaabbf5 100644
--- a/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java
+++ b/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java
@@ -1,5 +1,7 @@
package org.springframework.ai.mcp;
+import java.util.Map;
+
import io.modelcontextprotocol.client.McpAsyncClient;
import io.modelcontextprotocol.spec.McpSchema;
import org.junit.jupiter.api.Test;
@@ -8,9 +10,15 @@
import org.mockito.junit.jupiter.MockitoExtension;
import reactor.core.publisher.Mono;
+import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.execution.ToolExecutionException;
+
+import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.argThat;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
@@ -51,4 +59,24 @@ void callShouldWrapReactiveErrors() {
.hasMessage("Testing tool error");
}
+ @Test
+ void shouldApplyToolContext() {
+
+ when(this.tool.name()).thenReturn("testTool");
+ McpSchema.CallToolResult callResult = mock(McpSchema.CallToolResult.class);
+ when(this.mcpClient.callTool(any(McpSchema.CallToolRequest.class))).thenReturn(Mono.just(callResult));
+
+ AsyncMcpToolCallback callback = new AsyncMcpToolCallback(this.mcpClient, this.tool);
+
+ String response = callback.call("{\"param\":\"value\"}", new ToolContext(Map.of("foo", "bar")));
+
+ assertThat(response).isNotNull();
+
+ verify(this.mcpClient).callTool(argThat(callToolRequest -> callToolRequest.name().equals("testTool")));
+ verify(this.mcpClient)
+ .callTool(argThat(callToolRequest -> callToolRequest.arguments().equals(Map.of("param", "value"))));
+ verify(this.mcpClient)
+ .callTool(argThat(callToolRequest -> callToolRequest.meta().equals(Map.of("foo", "bar"))));
+ }
+
}
diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java
index a93d28b7332..69c9be95f91 100644
--- a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java
+++ b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java
@@ -37,7 +37,9 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
@@ -84,9 +86,7 @@ void callShouldHandleJsonInputAndOutput() {
}
@Test
- void callShouldIgnoreToolContext() {
- // when(mcpClient.getClientInfo()).thenReturn(new Implementation("testClient",
- // "1.0.0"));
+ void shouldApplyToolContext() {
when(this.tool.name()).thenReturn("testTool");
CallToolResult callResult = mock(CallToolResult.class);
@@ -97,6 +97,12 @@ void callShouldIgnoreToolContext() {
String response = callback.call("{\"param\":\"value\"}", new ToolContext(Map.of("foo", "bar")));
assertThat(response).isNotNull();
+
+ verify(this.mcpClient).callTool(argThat(callToolRequest -> callToolRequest.name().equals("testTool")));
+ verify(this.mcpClient)
+ .callTool(argThat(callToolRequest -> callToolRequest.arguments().equals(Map.of("param", "value"))));
+ verify(this.mcpClient)
+ .callTool(argThat(callToolRequest -> callToolRequest.meta().equals(Map.of("foo", "bar"))));
}
@Test