Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 55 additions & 12 deletions mcpgateway/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@
from mcpgateway.middleware.request_logging_middleware import RequestLoggingMiddleware
from mcpgateway.middleware.security_headers import SecurityHeadersMiddleware
from mcpgateway.middleware.token_scoping import token_scoping_middleware
from mcpgateway.models import InitializeResult, ListResourceTemplatesResult, LogLevel, Root
from mcpgateway.models import InitializeResult
from mcpgateway.models import JSONRPCError as PydanticJSONRPCError
from mcpgateway.models import ListResourceTemplatesResult, LogLevel, Root
from mcpgateway.observability import init_telemetry
from mcpgateway.plugins.framework import PluginError, PluginManager, PluginViolationError
from mcpgateway.routers.well_known import router as well_known_router
Expand Down Expand Up @@ -702,15 +704,16 @@ async def plugin_violation_exception_handler(_request: Request, exc: PluginViola
violation details.

Returns:
JSONResponse: A 403 response with access forbidden.
JSONResponse: A 200 response with error details in JSON-RPC format.

Examples:
>>> from mcpgateway.plugins.framework import PluginViolationError
>>> from mcpgateway.plugins.framework.models import PluginViolation
>>> from fastapi import Request
>>> import asyncio
>>> import json
>>>
>>> # Create a mock integrity error
>>> # Create a plugin violation error
>>> mock_error = PluginViolationError(message="plugin violation",violation = PluginViolation(
... reason="Invalid input",
... description="The input contains prohibited content",
Expand All @@ -719,11 +722,31 @@ async def plugin_violation_exception_handler(_request: Request, exc: PluginViola
... ))
>>> result = asyncio.run(plugin_violation_exception_handler(None, mock_error))
>>> result.status_code
403
200
>>> content = json.loads(result.body.decode())
>>> content["error"]["code"]
-32602
>>> "Plugin Violation:" in content["error"]["message"]
True
>>> content["error"]["data"]["plugin_error_code"]
'PROHIBITED_CONTENT'
"""
policy_violation = exc.violation.model_dump() if exc.violation else {}
message = exc.violation.description if exc.violation else "A plugin violation occurred."
policy_violation["message"] = exc.message
return JSONResponse(status_code=403, content=policy_violation)
status_code = exc.violation.mcp_error_code if exc.violation and exc.violation.mcp_error_code else -32602
violation_details: dict[str, Any] = {}
if exc.violation:
if exc.violation.description:
violation_details["description"] = exc.violation.description
if exc.violation.details:
violation_details["details"] = exc.violation.details
if exc.violation.code:
violation_details["plugin_error_code"] = exc.violation.code
if exc.violation.plugin_name:
violation_details["plugin_name"] = exc.violation.plugin_name
json_rpc_error = PydanticJSONRPCError(code=status_code, message="Plugin Violation: " + message, data=violation_details)
return JSONResponse(status_code=200, content={"error": json_rpc_error.model_dump()})


@app.exception_handler(PluginError)
Expand All @@ -740,15 +763,16 @@ async def plugin_exception_handler(_request: Request, exc: PluginError):
violation details.

Returns:
JSONResponse: A 500 response with internal server error.
JSONResponse: A 200 response with error details in JSON-RPC format.

Examples:
>>> from mcpgateway.plugins.framework import PluginViolationError
>>> from mcpgateway.plugins.framework import PluginError
>>> from mcpgateway.plugins.framework.models import PluginErrorModel
>>> from fastapi import Request
>>> import asyncio
>>> import json
>>>
>>> # Create a mock integrity error
>>> # Create a plugin error
>>> mock_error = PluginError(error = PluginErrorModel(
... message="plugin error",
... code="timeout",
Expand All @@ -757,10 +781,29 @@ async def plugin_exception_handler(_request: Request, exc: PluginError):
... ))
>>> result = asyncio.run(plugin_exception_handler(None, mock_error))
>>> result.status_code
500
"""
error_obj = exc.error.model_dump() if exc.error else {}
return JSONResponse(status_code=500, content=error_obj)
200
>>> content = json.loads(result.body.decode())
>>> content["error"]["code"]
-32603
>>> "Plugin Error:" in content["error"]["message"]
True
>>> content["error"]["data"]["plugin_error_code"]
'timeout'
>>> content["error"]["data"]["plugin_name"]
'abc'
"""
message = exc.error.message if exc.error else "A plugin error occurred."
status_code = exc.error.mcp_error_code if exc.error else -32603
error_details: dict[str, Any] = {}
if exc.error:
if exc.error.details:
error_details["details"] = exc.error.details
if exc.error.code:
error_details["plugin_error_code"] = exc.error.code
if exc.error.plugin_name:
error_details["plugin_name"] = exc.error.plugin_name
json_rpc_error = PydanticJSONRPCError(code=status_code, message="Plugin Error: " + message, data=error_details)
return JSONResponse(status_code=200, content={"error": json_rpc_error.model_dump()})


class DocsAuthMiddleware(BaseHTTPMiddleware):
Expand Down
4 changes: 4 additions & 0 deletions mcpgateway/plugins/framework/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,12 +667,14 @@ class PluginErrorModel(BaseModel):
code (str): an error code.
details: (dict[str, Any]): additional error details.
plugin_name (str): the plugin name.
mcp_error_code ([int]): The MCP error code passed back to the client. Defaults to Internal Error.
"""

message: str
code: Optional[str] = ""
details: Optional[dict[str, Any]] = Field(default_factory=dict)
plugin_name: str
mcp_error_code: int = -32603


class PluginViolation(BaseModel):
Expand All @@ -684,6 +686,7 @@ class PluginViolation(BaseModel):
code (str): a violation code.
details: (dict[str, Any]): additional violation details.
_plugin_name (str): the plugin name, private attribute set by the plugin manager.
mcp_error_code(Optional[int]): A valid mcp error code which will be sent back to the client if plugin enabled.

Examples:
>>> violation = PluginViolation(
Expand All @@ -706,6 +709,7 @@ class PluginViolation(BaseModel):
code: str
details: dict[str, Any]
_plugin_name: str = PrivateAttr(default="")
mcp_error_code: Optional[int] = None

@property
def plugin_name(self) -> str:
Expand Down
218 changes: 218 additions & 0 deletions tests/unit/mcpgateway/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1582,3 +1582,221 @@ def test_jsonpath_modifier_invalid_expressions(sample_people):

with pytest.raises(HTTPException):
jsonpath_modifier(sample_people, "$[*]", mappings={"bad": "$["}) # invalid mapping expr


# ----------------------------------------------------- #
# Plugin Exception Handler Tests #
# ----------------------------------------------------- #
class TestPluginExceptionHandlers:
"""Tests for plugin exception handlers: PluginViolationError and PluginError."""

def test_plugin_violation_exception_handler_with_full_violation(self):
"""Test plugin_violation_exception_handler with complete violation details."""
# Standard
import asyncio

# First-Party
from mcpgateway.main import plugin_violation_exception_handler
from mcpgateway.plugins.framework.errors import PluginViolationError
from mcpgateway.plugins.framework.models import PluginViolation

violation = PluginViolation(
reason="Invalid input",
description="The input contains prohibited content",
code="PROHIBITED_CONTENT",
details={"field": "message", "value": "sensitive_data"},
)
violation._plugin_name = "content_filter"
exc = PluginViolationError(message="Policy violation detected", violation=violation)

result = asyncio.run(plugin_violation_exception_handler(None, exc))

assert result.status_code == 200
content = json.loads(result.body.decode())
assert "error" in content
assert content["error"]["code"] == -32602
assert "Plugin Violation:" in content["error"]["message"]
assert "The input contains prohibited content" in content["error"]["message"]
assert content["error"]["data"]["description"] == "The input contains prohibited content"
assert content["error"]["data"]["details"] == {"field": "message", "value": "sensitive_data"}
assert content["error"]["data"]["plugin_error_code"] == "PROHIBITED_CONTENT"
assert content["error"]["data"]["plugin_name"] == "content_filter"

def test_plugin_violation_exception_handler_with_custom_mcp_error_code(self):
"""Test plugin_violation_exception_handler with custom MCP error code."""
# Standard
import asyncio

# First-Party
from mcpgateway.main import plugin_violation_exception_handler
from mcpgateway.plugins.framework.errors import PluginViolationError
from mcpgateway.plugins.framework.models import PluginViolation

violation = PluginViolation(
reason="Rate limit exceeded",
description="Too many requests from this client",
code="RATE_LIMIT",
details={"requests": 100, "limit": 50},
mcp_error_code=-32000, # Custom error code
)
violation._plugin_name = "rate_limiter"
exc = PluginViolationError(message="Rate limit violation", violation=violation)

result = asyncio.run(plugin_violation_exception_handler(None, exc))

assert result.status_code == 200
content = json.loads(result.body.decode())
assert content["error"]["code"] == -32000
assert "Too many requests from this client" in content["error"]["message"]
assert content["error"]["data"]["plugin_error_code"] == "RATE_LIMIT"
assert content["error"]["data"]["plugin_name"] == "rate_limiter"

def test_plugin_violation_exception_handler_with_minimal_violation(self):
"""Test plugin_violation_exception_handler with minimal violation details."""
# Standard
import asyncio

# First-Party
from mcpgateway.main import plugin_violation_exception_handler
from mcpgateway.plugins.framework.errors import PluginViolationError
from mcpgateway.plugins.framework.models import PluginViolation

violation = PluginViolation(
reason="Violation occurred",
description="Minimal violation",
code="MIN_VIOLATION",
details={},
)
exc = PluginViolationError(message="Minimal violation", violation=violation)

result = asyncio.run(plugin_violation_exception_handler(None, exc))

assert result.status_code == 200
content = json.loads(result.body.decode())
assert content["error"]["code"] == -32602
assert "Minimal violation" in content["error"]["message"]
assert content["error"]["data"]["plugin_error_code"] == "MIN_VIOLATION"

def test_plugin_violation_exception_handler_without_violation_object(self):
"""Test plugin_violation_exception_handler when violation object is None."""
# Standard
import asyncio

# First-Party
from mcpgateway.main import plugin_violation_exception_handler
from mcpgateway.plugins.framework.errors import PluginViolationError

exc = PluginViolationError(message="Generic plugin violation", violation=None)

result = asyncio.run(plugin_violation_exception_handler(None, exc))

assert result.status_code == 200
content = json.loads(result.body.decode())
assert content["error"]["code"] == -32602
assert "A plugin violation occurred" in content["error"]["message"]
assert content["error"]["data"] == {}

def test_plugin_exception_handler_with_full_error(self):
"""Test plugin_exception_handler with complete error details."""
# Standard
import asyncio

# First-Party
from mcpgateway.main import plugin_exception_handler
from mcpgateway.plugins.framework.errors import PluginError
from mcpgateway.plugins.framework.models import PluginErrorModel

error = PluginErrorModel(
message="Plugin execution failed",
code="EXECUTION_ERROR",
plugin_name="data_processor",
details={"error_type": "timeout", "duration": 30},
)
exc = PluginError(error=error)

result = asyncio.run(plugin_exception_handler(None, exc))

assert result.status_code == 200
content = json.loads(result.body.decode())
assert "error" in content
assert content["error"]["code"] == -32603
assert "Plugin Error:" in content["error"]["message"]
assert "Plugin execution failed" in content["error"]["message"]
assert content["error"]["data"]["details"] == {"error_type": "timeout", "duration": 30}
assert content["error"]["data"]["plugin_error_code"] == "EXECUTION_ERROR"
assert content["error"]["data"]["plugin_name"] == "data_processor"

def test_plugin_exception_handler_with_custom_mcp_error_code(self):
"""Test plugin_exception_handler with custom MCP error code."""
# Standard
import asyncio

# First-Party
from mcpgateway.main import plugin_exception_handler
from mcpgateway.plugins.framework.errors import PluginError
from mcpgateway.plugins.framework.models import PluginErrorModel

error = PluginErrorModel(
message="Custom error occurred",
code="CUSTOM_ERROR",
plugin_name="custom_plugin",
details={"context": "test"},
mcp_error_code=-32001, # Custom MCP error code
)
exc = PluginError(error=error)

result = asyncio.run(plugin_exception_handler(None, exc))

assert result.status_code == 200
content = json.loads(result.body.decode())
assert content["error"]["code"] == -32001
assert "Custom error occurred" in content["error"]["message"]
assert content["error"]["data"]["plugin_error_code"] == "CUSTOM_ERROR"

def test_plugin_exception_handler_with_minimal_error(self):
"""Test plugin_exception_handler with minimal error details."""
# Standard
import asyncio

# First-Party
from mcpgateway.main import plugin_exception_handler
from mcpgateway.plugins.framework.errors import PluginError
from mcpgateway.plugins.framework.models import PluginErrorModel

error = PluginErrorModel(message="Minimal error", plugin_name="minimal_plugin")
exc = PluginError(error=error)

result = asyncio.run(plugin_exception_handler(None, exc))

assert result.status_code == 200
content = json.loads(result.body.decode())
assert content["error"]["code"] == -32603
assert "Minimal error" in content["error"]["message"]
assert content["error"]["data"]["plugin_name"] == "minimal_plugin"

def test_plugin_exception_handler_with_empty_code(self):
"""Test plugin_exception_handler when error has empty code field."""
# Standard
import asyncio

# First-Party
from mcpgateway.main import plugin_exception_handler
from mcpgateway.plugins.framework.errors import PluginError
from mcpgateway.plugins.framework.models import PluginErrorModel

error = PluginErrorModel(
message="Error without code",
code="",
plugin_name="test_plugin",
details={"info": "test"},
)
exc = PluginError(error=error)

result = asyncio.run(plugin_exception_handler(None, exc))

assert result.status_code == 200
content = json.loads(result.body.decode())
assert content["error"]["code"] == -32603
assert "Error without code" in content["error"]["message"]
# Empty code should not be included in data
assert "plugin_error_code" not in content["error"]["data"] or content["error"]["data"]["plugin_error_code"] == ""
Loading