diff --git a/src/mcp/server/fastmcp/authorizer.py b/src/mcp/server/fastmcp/authorizer.py new file mode 100644 index 000000000..41a174cf7 --- /dev/null +++ b/src/mcp/server/fastmcp/authorizer.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import abc +from typing import TYPE_CHECKING, Any + +from pydantic import AnyUrl +from starlette.requests import Request + +from mcp.server.lowlevel.server import LifespanResultT +from mcp.server.session import ServerSession + +if TYPE_CHECKING: + from mcp.server.fastmcp.server import Context + + +class Authorizer: + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def permit_get_tool( + self, name: str, context: Context[ServerSession, LifespanResultT, Request] | None = None + ) -> bool: + """Check if the specified tool can be retrieved from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_list_tool( + self, + name: str, + context: Context[ServerSession, LifespanResultT, Request] | None = None, + ) -> bool: + """Check if the specified tool can be listed from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_call_tool( + self, + name: str, + arguments: dict[str, Any], + context: Context[ServerSession, LifespanResultT, Request] | None = None, + ) -> bool: + """Check if the specified tool can be called from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_get_resource( + self, resource: AnyUrl | str, context: Context[ServerSession, LifespanResultT, Request] | None = None + ) -> bool: + """Check if the specified resource can be retrieved from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_create_resource( + self, uri: str, params: dict[str, Any], context: Context[ServerSession, LifespanResultT, Request] | None = None + ) -> bool: + """Check if the specified resource can be created on the associated mcp server""" + return False + + @abc.abstractmethod + def permit_list_resource( + self, resource: AnyUrl | str, context: Context[ServerSession, LifespanResultT, Request] | None = None + ) -> bool: + """Check if the specified resource can be listed from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_list_template( + self, resource: AnyUrl | str, context: Context[ServerSession, LifespanResultT, Request] | None = None + ) -> bool: + """Check if the specified template can be listed from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_get_prompt( + self, name: str, context: Context[ServerSession, LifespanResultT, Request] | None = None + ) -> bool: + """Check if the specified prompt can be retrieved from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_list_prompt( + self, name: str, context: Context[ServerSession, LifespanResultT, Request] | None = None + ) -> bool: + """Check if the specified prompt can be listed from the associated mcp server""" + return False + + @abc.abstractmethod + def permit_render_prompt( + self, + name: str, + arguments: dict[str, Any] | None = None, + context: Context[ServerSession, object, Request] | None = None, + ) -> bool: + """Check if the specified prompt can be rendered from the associated mcp server""" + return False + + +class AllowAllAuthorizer(Authorizer): + def permit_get_tool( + self, name: str, context: Context[ServerSession, LifespanResultT, Request] | None = None + ) -> bool: + return True + + def permit_list_tool( + self, name: str, context: Context[ServerSession, LifespanResultT, Request] | None = None + ) -> bool: + return True + + def permit_call_tool( + self, + name: str, + arguments: dict[str, Any], + context: Context[ServerSession, LifespanResultT, Request] | None = None, + ) -> bool: + return True + + def permit_get_resource( + self, resource: AnyUrl | str, context: Context[ServerSession, LifespanResultT, Request] | None = None + ) -> bool: + return True + + def permit_create_resource( + self, uri: str, params: dict[str, Any], context: Context[ServerSession, LifespanResultT, Request] | None = None + ) -> bool: + return True + + def permit_list_resource( + self, resource: AnyUrl | str, context: Context[ServerSession, LifespanResultT, Request] | None = None + ) -> bool: + return True + + def permit_list_template( + self, resource: AnyUrl | str, context: Context[ServerSession, LifespanResultT, Request] | None = None + ) -> bool: + return True + + def permit_get_prompt( + self, name: str, context: Context[ServerSession, LifespanResultT, Request] | None = None + ) -> bool: + return True + + def permit_list_prompt( + self, name: str, context: Context[ServerSession, LifespanResultT, Request] | None = None + ) -> bool: + return True + + def permit_render_prompt( + self, + name: str, + arguments: dict[str, Any] | None = None, + context: Context[ServerSession, object, Request] | None = None, + ) -> bool: + return True diff --git a/src/mcp/server/fastmcp/prompts/manager.py b/src/mcp/server/fastmcp/prompts/manager.py index 6b01d91cd..5480ae1be 100644 --- a/src/mcp/server/fastmcp/prompts/manager.py +++ b/src/mcp/server/fastmcp/prompts/manager.py @@ -1,9 +1,18 @@ """Prompt management functionality.""" -from typing import Any +from __future__ import annotations as _annotations -from mcp.server.fastmcp.prompts.base import Message, Prompt +from typing import TYPE_CHECKING, Any + +from starlette.requests import Request + +from mcp.server.fastmcp.authorizer import AllowAllAuthorizer, Authorizer +from mcp.server.fastmcp.prompts.base import Prompt from mcp.server.fastmcp.utilities.logging import get_logger +from mcp.server.session import ServerSession + +if TYPE_CHECKING: + from mcp.server.fastmcp.server import Context logger = get_logger(__name__) @@ -11,17 +20,30 @@ class PromptManager: """Manages FastMCP prompts.""" - def __init__(self, warn_on_duplicate_prompts: bool = True): + def __init__( + self, + warn_on_duplicate_prompts: bool = True, + authorizer: Authorizer = AllowAllAuthorizer(), + ): self._prompts: dict[str, Prompt] = {} + self._authorizer = authorizer self.warn_on_duplicate_prompts = warn_on_duplicate_prompts - def get_prompt(self, name: str) -> Prompt | None: + def get_prompt( + self, + name: str, + arguments: dict[str, Any] | None = None, + context: Context[ServerSession, object, Request] | None = None, + ) -> Prompt | None: """Get prompt by name.""" - return self._prompts.get(name) + if self._authorizer.permit_get_prompt(name, context): + return self._prompts.get(name) + else: + return None - def list_prompts(self) -> list[Prompt]: + def list_prompts(self, context: Context[ServerSession, object, Request] | None = None) -> list[Prompt]: """List all registered prompts.""" - return list(self._prompts.values()) + return [prompt for name, prompt in self._prompts.items() if self._authorizer.permit_list_prompt(name, context)] def add_prompt( self, @@ -39,10 +61,17 @@ def add_prompt( self._prompts[prompt.name] = prompt return prompt - async def render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> list[Message]: + async def render_prompt( + self, + name: str, + arguments: dict[str, Any] | None = None, + context: Context[ServerSession, object, Request] | None = None, + ) -> Prompt: """Render a prompt by name with arguments.""" prompt = self.get_prompt(name) if not prompt: raise ValueError(f"Unknown prompt: {name}") - - return await prompt.render(arguments) + if self._authorizer.permit_render_prompt(name, arguments, context): + return prompt + else: + raise ValueError(f"Unknown prompt: {name}") diff --git a/src/mcp/server/fastmcp/resources/resource_manager.py b/src/mcp/server/fastmcp/resources/resource_manager.py index 35e4ec04d..207149418 100644 --- a/src/mcp/server/fastmcp/resources/resource_manager.py +++ b/src/mcp/server/fastmcp/resources/resource_manager.py @@ -1,13 +1,22 @@ """Resource manager functionality.""" +from __future__ import annotations as _annotations + from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any from pydantic import AnyUrl +from starlette.requests import Request +from mcp.server.fastmcp.authorizer import AllowAllAuthorizer, Authorizer from mcp.server.fastmcp.resources.base import Resource from mcp.server.fastmcp.resources.templates import ResourceTemplate from mcp.server.fastmcp.utilities.logging import get_logger +from mcp.server.lowlevel.server import LifespanResultT +from mcp.server.session import ServerSession + +if TYPE_CHECKING: + from mcp.server.fastmcp.server import Context logger = get_logger(__name__) @@ -15,10 +24,15 @@ class ResourceManager: """Manages FastMCP resources.""" - def __init__(self, warn_on_duplicate_resources: bool = True): + def __init__( + self, + warn_on_duplicate_resources: bool = True, + authorizer: Authorizer = AllowAllAuthorizer(), + ): self._resources: dict[str, Resource] = {} self._templates: dict[str, ResourceTemplate] = {} self.warn_on_duplicate_resources = warn_on_duplicate_resources + self._authorizer = authorizer def add_resource(self, resource: Resource) -> Resource: """Add a resource to the manager. @@ -67,31 +81,45 @@ def add_template( self._templates[template.uri_template] = template return template - async def get_resource(self, uri: AnyUrl | str) -> Resource | None: + async def get_resource( + self, uri: AnyUrl | str, context: Context[ServerSession, LifespanResultT, Request] | None = None + ) -> Resource | None: """Get resource by URI, checking concrete resources first, then templates.""" uri_str = str(uri) logger.debug("Getting resource", extra={"uri": uri_str}) # First check concrete resources if resource := self._resources.get(uri_str): - return resource + if self._authorizer.permit_get_resource(uri_str, context): + return resource + else: + raise ValueError(f"Unknown resource: {uri}") # Then check templates for template in self._templates.values(): if params := template.matches(uri_str): try: - return await template.create_resource(uri_str, params) + if self._authorizer.permit_create_resource(uri_str, params): + return await template.create_resource(uri_str, params) + else: + raise ValueError(f"Unknown resource: {uri}") except Exception as e: raise ValueError(f"Error creating resource from template: {e}") raise ValueError(f"Unknown resource: {uri}") - def list_resources(self) -> list[Resource]: + def list_resources(self, context: Context[ServerSession, LifespanResultT, Request] | None = None) -> list[Resource]: """List all registered resources.""" logger.debug("Listing resources", extra={"count": len(self._resources)}) - return list(self._resources.values()) + return [ + resource for uri, resource in self._resources.items() if self._authorizer.permit_list_resource(uri, context) + ] - def list_templates(self) -> list[ResourceTemplate]: + def list_templates( + self, context: Context[ServerSession, LifespanResultT, Request] | None = None + ) -> list[ResourceTemplate]: """List all registered templates.""" logger.debug("Listing templates", extra={"count": len(self._templates)}) - return list(self._templates.values()) + return [ + template for uri, template in self._templates.items() if self._authorizer.permit_list_template(uri, context) + ] diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 924baaa9b..8b025c776 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -26,6 +26,7 @@ from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier from mcp.server.auth.settings import AuthSettings from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation +from mcp.server.fastmcp.authorizer import AllowAllAuthorizer, Authorizer from mcp.server.fastmcp.exceptions import ResourceError from mcp.server.fastmcp.prompts import Prompt, PromptManager from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager @@ -105,6 +106,8 @@ class Settings(BaseSettings, Generic[LifespanResultT]): # Transport security settings (DNS rebinding protection) transport_security: TransportSecuritySettings | None + authorizer: Authorizer | None = None + def lifespan_wrapper( app: FastMCP[LifespanResultT], @@ -145,6 +148,7 @@ def __init__( lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None = None, auth: AuthSettings | None = None, transport_security: TransportSecuritySettings | None = None, + authorizer: Authorizer | None = None, ): self.settings = Settings( debug=debug, @@ -164,6 +168,7 @@ def __init__( lifespan=lifespan, auth=auth, transport_security=transport_security, + authorizer=authorizer, ) self._mcp_server = MCPServer( @@ -173,9 +178,20 @@ def __init__( # We need to create a Lifespan type that is a generic on the server type, like Starlette does. lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore ) - self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) - self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) - self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) + authorizer = self.settings.authorizer or AllowAllAuthorizer() + self._tool_manager = ToolManager( + tools=tools, + warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools, + authorizer=authorizer, + ) + self._resource_manager = ResourceManager( + warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources, + authorizer=authorizer, + ) + self._prompt_manager = PromptManager( + warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts, + authorizer=authorizer, + ) # Validate auth configuration if self.settings.auth is not None: if auth_server_provider and token_verifier: @@ -268,7 +284,8 @@ def _setup_handlers(self) -> None: async def list_tools(self) -> list[MCPTool]: """List all available tools.""" - tools = self._tool_manager.list_tools() + context = self.get_context() + tools = self._tool_manager.list_tools(context) return [ MCPTool( name=info.name, @@ -299,8 +316,8 @@ async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[Cont async def list_resources(self) -> list[MCPResource]: """List all available resources.""" - - resources = self._resource_manager.list_resources() + context = self.get_context() + resources = self._resource_manager.list_resources(context) return [ MCPResource( uri=resource.uri, @@ -313,7 +330,8 @@ async def list_resources(self) -> list[MCPResource]: ] async def list_resource_templates(self) -> list[MCPResourceTemplate]: - templates = self._resource_manager.list_templates() + context = self.get_context() + templates = self._resource_manager.list_templates(context) return [ MCPResourceTemplate( uriTemplate=template.uri_template, @@ -326,8 +344,8 @@ async def list_resource_templates(self) -> list[MCPResourceTemplate]: async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContents]: """Read a resource by URI.""" - - resource = await self._resource_manager.get_resource(uri) + context = self.get_context() + resource = await self._resource_manager.get_resource(uri, context) if not resource: raise ResourceError(f"Unknown resource: {uri}") @@ -956,9 +974,9 @@ def streamable_http_app(self) -> Starlette: lifespan=lambda app: self.session_manager.run(), ) - async def list_prompts(self) -> list[MCPPrompt]: + async def list_prompts(self, context: Context[ServerSession, object, Request] | None = None) -> list[MCPPrompt]: """List all available prompts.""" - prompts = self._prompt_manager.list_prompts() + prompts = self._prompt_manager.list_prompts(context) return [ MCPPrompt( name=prompt.name, @@ -976,13 +994,15 @@ async def list_prompts(self) -> list[MCPPrompt]: for prompt in prompts ] - async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> GetPromptResult: + async def get_prompt( + self, + name: str, + arguments: dict[str, Any] | None = None, + context: Context[ServerSession, object, Request] | None = None, + ) -> GetPromptResult: """Get a prompt by name with arguments.""" try: - prompt = self._prompt_manager.get_prompt(name) - if not prompt: - raise ValueError(f"Unknown prompt: {name}") - + prompt = await self._prompt_manager.render_prompt(name, arguments, context) messages = await prompt.render(arguments) return GetPromptResult( diff --git a/src/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/server/fastmcp/tools/tool_manager.py index bfa8b2382..727368175 100644 --- a/src/mcp/server/fastmcp/tools/tool_manager.py +++ b/src/mcp/server/fastmcp/tools/tool_manager.py @@ -3,15 +3,18 @@ from collections.abc import Callable from typing import TYPE_CHECKING, Any +from starlette.requests import Request + +from mcp.server.fastmcp.authorizer import AllowAllAuthorizer, Authorizer from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.tools.base import Tool from mcp.server.fastmcp.utilities.logging import get_logger -from mcp.shared.context import LifespanContextT, RequestT +from mcp.server.lowlevel.server import LifespanResultT +from mcp.server.session import ServerSession from mcp.types import ToolAnnotations if TYPE_CHECKING: from mcp.server.fastmcp.server import Context - from mcp.server.session import ServerSessionT logger = get_logger(__name__) @@ -24,6 +27,7 @@ def __init__( warn_on_duplicate_tools: bool = True, *, tools: list[Tool] | None = None, + authorizer: Authorizer = AllowAllAuthorizer(), ): self._tools: dict[str, Tool] = {} if tools is not None: @@ -32,15 +36,21 @@ def __init__( logger.warning(f"Tool already exists: {tool.name}") self._tools[tool.name] = tool - self.warn_on_duplicate_tools = warn_on_duplicate_tools + self.warn_on_duplicate_tools = (warn_on_duplicate_tools,) + self._authorizer = authorizer - def get_tool(self, name: str) -> Tool | None: + def get_tool( + self, name: str, context: Context[ServerSession, LifespanResultT, Request] | None = None + ) -> Tool | None: """Get tool by name.""" - return self._tools.get(name) + if self._authorizer.permit_get_tool(name, context): + return self._tools.get(name) + else: + return None - def list_tools(self) -> list[Tool]: + def list_tools(self, context: Context[ServerSession, LifespanResultT, Request] | None = None) -> list[Tool]: """List all registered tools.""" - return list(self._tools.values()) + return [tool for name, tool in self._tools.items() if self._authorizer.permit_list_tool(name, context)] def add_tool( self, @@ -72,12 +82,12 @@ async def call_tool( self, name: str, arguments: dict[str, Any], - context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + context: Context[ServerSession, LifespanResultT, Request] | None = None, convert_result: bool = False, ) -> Any: """Call a tool by name with arguments.""" - tool = self.get_tool(name) - if not tool: + tool = self._tools.get(name) + if not tool or not self._authorizer.permit_call_tool(name, arguments, context): raise ToolError(f"Unknown tool: {name}") return await tool.run(arguments, context=context, convert_result=convert_result) diff --git a/tests/server/fastmcp/prompts/test_manager.py b/tests/server/fastmcp/prompts/test_manager.py index 82b234638..a5af1591d 100644 --- a/tests/server/fastmcp/prompts/test_manager.py +++ b/tests/server/fastmcp/prompts/test_manager.py @@ -71,7 +71,8 @@ def fn() -> str: manager = PromptManager() prompt = Prompt.from_function(fn) manager.add_prompt(prompt) - messages = await manager.render_prompt("fn") + prompt = await manager.render_prompt("fn") + messages = await prompt.render({}) assert messages == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] @pytest.mark.anyio @@ -84,7 +85,8 @@ def fn(name: str) -> str: manager = PromptManager() prompt = Prompt.from_function(fn) manager.add_prompt(prompt) - messages = await manager.render_prompt("fn", arguments={"name": "World"}) + prompt = await manager.render_prompt("fn", arguments={"name": "World"}) + messages = await prompt.render({"name": "World"}) assert messages == [UserMessage(content=TextContent(type="text", text="Hello, World!"))] @pytest.mark.anyio @@ -105,4 +107,5 @@ def fn(name: str) -> str: prompt = Prompt.from_function(fn) manager.add_prompt(prompt) with pytest.raises(ValueError, match="Missing required arguments"): - await manager.render_prompt("fn") + prompt = await manager.render_prompt("fn") + await prompt.render({}) diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index 27e16cc8e..7d19a0fcf 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -7,6 +7,7 @@ from pydantic import BaseModel from mcp.server.fastmcp import Context, FastMCP +from mcp.server.fastmcp.authorizer import Authorizer from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.tools import Tool, ToolManager from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata @@ -173,7 +174,7 @@ def f(x: int) -> int: manager = ToolManager() manager.add_tool(f) - manager.warn_on_duplicate_tools = False + manager.warn_on_duplicate_tools = False # type: ignore with caplog.at_level(logging.WARNING): manager.add_tool(f) assert "Tool already exists: f" not in caplog.text @@ -313,6 +314,30 @@ def name_shrimp(tank: MyShrimpTank, ctx: Context) -> list[str]: ) assert result == ["rex", "gertrude"] + @pytest.mark.anyio + async def test_call_tool_not_permitted(self): + async def double(n: int) -> int: + """Double a number.""" + return n * 2 + + class TestAuthorizer(Authorizer): + allow: bool = True + + def permit_list_tool(self, name, context=None): + return self.allow + + def permit_call_tool(self, name, arguments, context=None): + return self.allow + + authorizer = TestAuthorizer() + manager = ToolManager(authorizer=authorizer) + manager.add_tool(double) + result = await manager.call_tool("double", {"n": 5}) + assert result == 10 + authorizer.allow = False + with pytest.raises(ToolError, match="Unknown tool: double"): + await manager.call_tool("double", {"n": 5}) + class TestToolSchema: @pytest.mark.anyio