diff --git a/lib/mcp/prompt.rb b/lib/mcp/prompt.rb index 7624ef5..2b77c6e 100644 --- a/lib/mcp/prompt.rb +++ b/lib/mcp/prompt.rb @@ -9,7 +9,7 @@ class << self attr_reader :description_value attr_reader :arguments_value - def template(args, server_context:) + def template(args, server_context: nil) raise NotImplementedError, "Subclasses must implement template" end @@ -57,7 +57,7 @@ def define(name: nil, description: nil, arguments: [], &block) prompt_name name description description arguments arguments - define_singleton_method(:template) do |args, server_context:| + define_singleton_method(:template) do |args, server_context: nil| instance_exec(args, server_context:, &block) end end diff --git a/lib/mcp/server.rb b/lib/mcp/server.rb index 98c5ac1..7fc2480 100644 --- a/lib/mcp/server.rb +++ b/lib/mcp/server.rb @@ -243,13 +243,7 @@ def call_tool(request) end begin - call_params = tool_call_parameters(tool) - - if call_params.include?(:server_context) - tool.call(**arguments.transform_keys(&:to_sym), server_context:).to_h - else - tool.call(**arguments.transform_keys(&:to_sym)).to_h - end + call_tool_with_args(tool, arguments) rescue => e raise RequestHandlerError.new("Internal error calling tool #{tool_name}", request, original_error: e) end @@ -272,7 +266,7 @@ def get_prompt(request) prompt_args = request[:arguments] prompt.validate_arguments!(prompt_args) - prompt.template(prompt_args, server_context:).to_h + call_prompt_template_with_args(prompt, prompt_args) end def list_resources(request) @@ -299,22 +293,29 @@ def index_resources_by_uri(resources) end end - def tool_call_parameters(tool) - method_def = tool_call_method_def(tool) - method_def.parameters.flatten + def accepts_server_context?(method_object) + parameters = method_object.parameters + accepts_server_context = parameters.any? { |_type, name| name == :server_context } + has_kwargs = parameters.any? { |type, _| type == :keyrest } + + accepts_server_context || has_kwargs end - def tool_call_method_def(tool) - method = tool.method(:call) + def call_tool_with_args(tool, arguments) + args = arguments.transform_keys(&:to_sym) - if defined?(T::Utils) && T::Utils.respond_to?(:signature_for_method) - sorbet_typed_method_definition = T::Utils.signature_for_method(method)&.method + if accepts_server_context?(tool.method(:call)) + tool.call(**args, server_context: server_context).to_h + else + tool.call(**args).to_h + end + end - # Return the Sorbet typed method definition if it exists, otherwise fallback to original method - # definition if Sorbet is defined but not used by this tool. - sorbet_typed_method_definition || method + def call_prompt_template_with_args(prompt, args) + if accepts_server_context?(prompt.method(:template)) + prompt.template(args, server_context: server_context).to_h else - method + prompt.template(args).to_h end end end diff --git a/lib/mcp/tool.rb b/lib/mcp/tool.rb index 48378e4..441c048 100644 --- a/lib/mcp/tool.rb +++ b/lib/mcp/tool.rb @@ -9,7 +9,7 @@ class << self attr_reader :input_schema_value attr_reader :annotations_value - def call(*args, server_context:) + def call(*args, server_context: nil) raise NotImplementedError, "Subclasses must implement call" end diff --git a/test/mcp/server_context_test.rb b/test/mcp/server_context_test.rb new file mode 100644 index 0000000..03ad3e6 --- /dev/null +++ b/test/mcp/server_context_test.rb @@ -0,0 +1,417 @@ +# frozen_string_literal: true + +require "test_helper" + +module MCP + class ServerContextTest < ActiveSupport::TestCase + # Tool without server_context parameter + class SimpleToolWithoutContext < Tool + tool_name "simple_without_context" + description "A tool that doesn't use server_context" + input_schema({ properties: { message: { type: "string" } }, required: ["message"] }) + + class << self + def call(message:) + Tool::Response.new([ + { type: "text", content: "SimpleToolWithoutContext: #{message}" }, + ]) + end + end + end + + # Tool with optional server_context parameter + class ToolWithOptionalContext < Tool + tool_name "tool_with_optional_context" + description "A tool with optional server_context" + input_schema({ properties: { message: { type: "string" } }, required: ["message"] }) + + class << self + def call(message:, server_context: nil) + context_info = server_context ? "with context: #{server_context[:user]}" : "no context" + Tool::Response.new([ + { type: "text", content: "ToolWithOptionalContext: #{message} (#{context_info})" }, + ]) + end + end + end + + # Tool with required server_context parameter + class ToolWithRequiredContext < Tool + tool_name "tool_with_required_context" + description "A tool that requires server_context" + input_schema({ properties: { message: { type: "string" } }, required: ["message"] }) + + class << self + def call(message:, server_context:) + Tool::Response.new([ + { type: "text", content: "ToolWithRequiredContext: #{message} for user #{server_context[:user]}" }, + ]) + end + end + end + + setup do + @server_with_context = Server.new( + name: "test_server", + tools: [SimpleToolWithoutContext, ToolWithOptionalContext, ToolWithRequiredContext], + server_context: { user: "test_user" }, + ) + + @server_without_context = Server.new( + name: "test_server_no_context", + tools: [SimpleToolWithoutContext, ToolWithOptionalContext], + ) + end + + test "tool without server_context parameter works when server has context" do + request = { + jsonrpc: "2.0", + id: 1, + method: "tools/call", + params: { + name: "simple_without_context", + arguments: { message: "Hello" }, + }, + } + + response = @server_with_context.handle(request) + + assert response[:result] + assert_equal "SimpleToolWithoutContext: Hello", response[:result][:content][0][:content] + end + + test "tool with optional server_context receives context when server has it" do + request = { + jsonrpc: "2.0", + id: 1, + method: "tools/call", + params: { + name: "tool_with_optional_context", + arguments: { message: "Hello" }, + }, + } + + response = @server_with_context.handle(request) + + assert response[:result] + assert_equal "ToolWithOptionalContext: Hello (with context: test_user)", + response[:result][:content][0][:content] + end + + test "tool with optional server_context works when server has no context" do + request = { + jsonrpc: "2.0", + id: 1, + method: "tools/call", + params: { + name: "tool_with_optional_context", + arguments: { message: "Hello" }, + }, + } + + response = @server_without_context.handle(request) + + assert response[:result] + assert_equal "ToolWithOptionalContext: Hello (no context)", + response[:result][:content][0][:content] + end + + test "tool with required server_context receives context" do + request = { + jsonrpc: "2.0", + id: 1, + method: "tools/call", + params: { + name: "tool_with_required_context", + arguments: { message: "Hello" }, + }, + } + + response = @server_with_context.handle(request) + + assert response[:result] + assert_equal "ToolWithRequiredContext: Hello for user test_user", + response[:result][:content][0][:content] + end + + test "tool with required server_context fails when server has no context" do + server_no_context = Server.new( + name: "test_server_no_context", + tools: [ToolWithRequiredContext], + ) + + request = { + jsonrpc: "2.0", + id: 1, + method: "tools/call", + params: { + name: "tool_with_required_context", + arguments: { message: "Hello" }, + }, + } + + response = server_no_context.handle(request) + + assert response[:error] + # The error is wrapped as "Internal error calling tool..." + assert_equal "Internal error", response[:error][:message] + end + + test "call_tool_with_args correctly detects server_context parameter presence" do + # Tool without server_context + refute SimpleToolWithoutContext.method(:call).parameters.any? { |_type, name| name == :server_context } + + # Tool with optional server_context + assert ToolWithOptionalContext.method(:call).parameters.any? { |_type, name| name == :server_context } + + # Tool with required server_context + assert ToolWithRequiredContext.method(:call).parameters.any? { |_type, name| name == :server_context } + end + + test "tools can use splat kwargs to accept any arguments including server_context" do + class FlexibleTool < Tool + tool_name "flexible_tool" + + class << self + def call(**kwargs) + message = kwargs[:message] + context = kwargs[:server_context] + + Tool::Response.new([ + { + type: "text", + content: "FlexibleTool: #{message} (context: #{context ? "present" : "absent"})", + }, + ]) + end + end + end + + server = Server.new( + name: "test_server", + tools: [FlexibleTool], + server_context: { user: "test_user" }, + ) + + request = { + jsonrpc: "2.0", + id: 1, + method: "tools/call", + params: { + name: "flexible_tool", + arguments: { message: "Hello" }, + }, + } + + response = server.handle(request) + + assert response[:result] + assert_equal "FlexibleTool: Hello (context: present)", + response[:result][:content][0][:content] + end + + # Prompt tests + + # Prompt without server_context parameter + class SimplePromptWithoutContext < Prompt + prompt_name "simple_prompt_without_context" + description "A prompt that doesn't use server_context" + arguments [Prompt::Argument.new(name: "message", required: true)] + + class << self + def template(args) + Prompt::Result.new( + messages: [ + Prompt::Message.new( + role: "user", + content: Content::Text.new("SimplePromptWithoutContext: #{args[:message]}"), + ), + ], + ) + end + end + end + + # Prompt with optional server_context parameter + class PromptWithOptionalContext < Prompt + prompt_name "prompt_with_optional_context" + description "A prompt with optional server_context" + arguments [Prompt::Argument.new(name: "message", required: true)] + + class << self + def template(args, server_context: nil) + context_info = server_context ? "with context: #{server_context[:user]}" : "no context" + Prompt::Result.new( + messages: [ + Prompt::Message.new( + role: "user", + content: Content::Text.new("PromptWithOptionalContext: #{args[:message]} (#{context_info})"), + ), + ], + ) + end + end + end + + # Prompt with required server_context parameter + class PromptWithRequiredContext < Prompt + prompt_name "prompt_with_required_context" + description "A prompt that requires server_context" + arguments [Prompt::Argument.new(name: "message", required: true)] + + class << self + def template(args, server_context:) + Prompt::Result.new( + messages: [ + Prompt::Message.new( + role: "user", + content: Content::Text.new( + "PromptWithRequiredContext: #{args[:message]} for user #{server_context[:user]}", + ), + ), + ], + ) + end + end + end + + test "prompt without server_context parameter works when server has context" do + server = Server.new( + name: "test_server", + prompts: [SimplePromptWithoutContext], + server_context: { user: "test_user" }, + ) + + request = { + jsonrpc: "2.0", + id: 1, + method: "prompts/get", + params: { + name: "simple_prompt_without_context", + arguments: { message: "Hello" }, + }, + } + + response = server.handle(request) + + assert response[:result] + assert_equal "SimplePromptWithoutContext: Hello", response[:result][:messages][0][:content][:text] + end + + test "prompt with optional server_context receives context when server has it" do + server = Server.new( + name: "test_server", + prompts: [PromptWithOptionalContext], + server_context: { user: "test_user" }, + ) + + request = { + jsonrpc: "2.0", + id: 1, + method: "prompts/get", + params: { + name: "prompt_with_optional_context", + arguments: { message: "Hello" }, + }, + } + + response = server.handle(request) + + assert response[:result] + assert_equal "PromptWithOptionalContext: Hello (with context: test_user)", + response[:result][:messages][0][:content][:text] + end + + test "prompt with optional server_context works when server has no context" do + server = Server.new( + name: "test_server", + prompts: [PromptWithOptionalContext], + ) + + request = { + jsonrpc: "2.0", + id: 1, + method: "prompts/get", + params: { + name: "prompt_with_optional_context", + arguments: { message: "Hello" }, + }, + } + + response = server.handle(request) + + assert response[:result] + assert_equal "PromptWithOptionalContext: Hello (no context)", + response[:result][:messages][0][:content][:text] + end + + test "prompt with required server_context receives context" do + server = Server.new( + name: "test_server", + prompts: [PromptWithRequiredContext], + server_context: { user: "test_user" }, + ) + + request = { + jsonrpc: "2.0", + id: 1, + method: "prompts/get", + params: { + name: "prompt_with_required_context", + arguments: { message: "Hello" }, + }, + } + + response = server.handle(request) + + assert response[:result] + assert_equal "PromptWithRequiredContext: Hello for user test_user", + response[:result][:messages][0][:content][:text] + end + + test "prompts can use splat kwargs to accept any arguments including server_context" do + class FlexiblePrompt < Prompt + prompt_name "flexible_prompt" + arguments [Prompt::Argument.new(name: "message", required: true)] + + class << self + def template(args, **kwargs) + message = args[:message] + context = kwargs[:server_context] + + Prompt::Result.new( + messages: [ + Prompt::Message.new( + role: "user", + content: Content::Text.new("FlexiblePrompt: #{message} (context: #{context ? "present" : "absent"})"), + ), + ], + ) + end + end + end + + server = Server.new( + name: "test_server", + prompts: [FlexiblePrompt], + server_context: { user: "test_user" }, + ) + + request = { + jsonrpc: "2.0", + id: 1, + method: "prompts/get", + params: { + name: "flexible_prompt", + arguments: { message: "Hello" }, + }, + } + + response = server.handle(request) + + assert response[:result] + assert_equal "FlexiblePrompt: Hello (context: present)", + response[:result][:messages][0][:content][:text] + end + end +end diff --git a/test/mcp/server_test.rb b/test/mcp/server_test.rb index 267f4ca..5ab454d 100644 --- a/test/mcp/server_test.rb +++ b/test/mcp/server_test.rb @@ -193,7 +193,7 @@ class ServerTest < ActiveSupport::TestCase tool_args = { arg: "value" } tool_response = Tool::Response.new([{ result: "success" }]) - @tool.expects(:call).with(arg: "value").returns(tool_response) + @tool.expects(:call).with(arg: "value", server_context: nil).returns(tool_response) request = { jsonrpc: "2.0", @@ -242,7 +242,7 @@ class ServerTest < ActiveSupport::TestCase tool_args = { arg: "value" } tool_response = Tool::Response.new([{ result: "success" }]) - @tool.expects(:call).with(arg: "value").returns(tool_response) + @tool.expects(:call).with(arg: "value", server_context: nil).returns(tool_response) request = JSON.generate({ jsonrpc: "2.0", diff --git a/test/mcp/tool_test.rb b/test/mcp/tool_test.rb index e87c826..1f9d056 100644 --- a/test/mcp/tool_test.rb +++ b/test/mcp/tool_test.rb @@ -244,6 +244,41 @@ def call(message:, server_context: nil) assert_equal response.is_error, false end + class TestToolWithoutServerContext < Tool + tool_name "test_tool_without_server_context" + description "a test tool for testing without server context" + input_schema({ properties: { message: { type: "string" } }, required: ["message"] }) + + class << self + def call(message:) + Tool::Response.new([{ type: "text", content: "OK" }]) + end + end + end + + class TestToolWithoutRequired < Tool + tool_name "test_tool_without_required" + description "a test tool for testing without required server context" + + class << self + def call(message, server_context: nil) + Tool::Response.new([{ type: "text", content: "OK" }]) + end + end + end + + test "tool call without server context" do + tool = TestToolWithoutServerContext + response = tool.call(message: "test") + assert_equal response.content, [{ type: "text", content: "OK" }] + end + + test "tool call with server context and without required" do + tool = TestToolWithoutRequired + response = tool.call("test", server_context: { foo: "bar" }) + assert_equal response.content, [{ type: "text", content: "OK" }] + end + test "input_schema rejects any $ref in schema" do schema_with_ref = { properties: {