Skip to main content
The mcp-use middleware system provides a powerful, flexible way to intercept and process all MCP (Model Context Protocol) requests and responses. It uses a class-based hook pattern (with an Express-like next flow) so you can log, collect metrics, filter/transform requests, cache, rate limit, and more.

Middleware Process

mcp-use runs middleware in a chain. When a request like initialize, tools/call, resources/read, etc. is made, it flows through your middlewares (in order). In each middleware you can:
  1. Analyze the incoming request through the context
  2. Modify the request (and its metadata) before execution
  3. Execute the next handler via await call_next(context)
  4. Inspect/transform the result or handle errors
Transport support:
  • Fully supported: stdio, http (streamable HTTP and SSE), sandbox
  • Not yet supported: websocket

Overview

You implement middleware by subclassing Middleware and overriding one or more hooks. Each hook receives:
  • context: MiddlewareContext[T] - typed request context
  • call_next: NextFunctionT[T, R] - call to the next middleware or the actual MCP handler
Hooks are typed to the specific MCP method so you get strong IDE assistance when overriding specific operations like on_call_tool.

Quick Start

from mcp.types import CallToolRequestParams, CallToolResult
from mcp_use import MCPClient
from mcp_use.middleware import Middleware, MiddlewareContext, NextFunctionT

class CustomMiddleware(Middleware):
    async def on_call_tool(
        self, context: MiddlewareContext[CallToolRequestParams], call_next: NextFunctionT
    ) -> CallToolResult:
        print(f"Calling tool {context.params.name}")
        return await call_next(context)

config = {
    "mcpServers": {
        "playwright": {"command": "npx", "args": ["@playwright/mcp@latest"], "env": {"DISPLAY": ":1"}}
    }
}

# MCPClient automatically prepends a default logging middleware.
# You can add your own middlewares after it.
client = MCPClient(config=config, middleware=[CustomMiddleware()])

Core Types

MiddlewareContext[T]

from dataclasses import dataclass, field
from typing import Any, Generic, TypeVar

T = TypeVar("T")

@dataclass
class MiddlewareContext(Generic[T]):
    id: str                 # Unique request ID
    method: str             # JSON-RPC method name (e.g., "tools/call")
    params: T               # Typed parameters for the method (may be None for list APIs)
    connection_id: str      # Connector identifier (e.g., "stdio:...", "http:...")
    timestamp: float        # Request start time
    metadata: dict[str, Any] = field(default_factory=dict)

NextFunctionT[T, R]

from typing import Protocol

class NextFunctionT(Protocol[T, R]):
    async def __call__(self, context: MiddlewareContext[T]) -> R: ...

Middleware base class

You override hooks on the base class. If you only need a single entry point for all requests, override on_request.
from typing import Any
from mcp_use.middleware import Middleware, MiddlewareContext, NextFunctionT

class MyMiddleware(Middleware):
    async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
        # before
        result = await call_next(context)
        # after
        return result

Hook Reference

Available hooks (override any subset). Types below come from mcp.types.
from typing import Any
from mcp.types import (
    InitializeRequestParams, InitializeResult,
    CallToolRequestParams, CallToolResult,
    ListToolsRequest, ListToolsResult,
    ListResourcesRequest, ListResourcesResult,
    ReadResourceRequestParams, ReadResourceResult,
    ListPromptsRequest, ListPromptsResult,
    GetPromptRequestParams, GetPromptResult,
)
from mcp_use.middleware import Middleware, MiddlewareContext, NextFunctionT

class Middleware:
    async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any: ...

    async def on_initialize(
        self, context: MiddlewareContext[InitializeRequestParams], call_next: NextFunctionT
    ) -> InitializeResult: ...

    async def on_call_tool(
        self, context: MiddlewareContext[CallToolRequestParams], call_next: NextFunctionT
    ) -> CallToolResult: ...

    async def on_list_tools(
        self, context: MiddlewareContext[ListToolsRequest], call_next: NextFunctionT
    ) -> ListToolsResult: ...

    async def on_list_resources(
        self, context: MiddlewareContext[ListResourcesRequest], call_next: NextFunctionT
    ) -> ListResourcesResult: ...

    async def on_read_resource(
        self, context: MiddlewareContext[ReadResourceRequestParams], call_next: NextFunctionT
    ) -> ReadResourceResult: ...

    async def on_list_prompts(
        self, context: MiddlewareContext[ListPromptsRequest], call_next: NextFunctionT
    ) -> ListPromptsResult: ...

    async def on_get_prompt(
        self, context: MiddlewareContext[GetPromptRequestParams], call_next: NextFunctionT
    ) -> GetPromptResult: ...

Writing Middleware

Timing Example

import time
from typing import Any
from mcp_use.middleware import Middleware, MiddlewareContext, NextFunctionT

class TimingMiddleware(Middleware):
    async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
        start = time.time()
        try:
            result = await call_next(context)
            return result
        finally:
            duration = time.time() - start
            print(f"{context.method} took {duration:.3f}s")

Built-in Middleware

Default logging

  • A default logging middleware is automatically prepended by MCPClient.
  • It logs each request/response at debug level with timing.
  • You don’t need to add it manually; just pass your custom middlewares.
from mcp_use import MCPClient
from mcp_use.middleware import Middleware

client = MCPClient(config=config, middleware=[MyCustomMiddleware()])
# Order = [default_logging_middleware, MyCustomMiddleware()]

Metrics

Instantiate and pass to the client. Each middleware exposes a getter on the instance.
from mcp_use.middleware import MetricsMiddleware, PerformanceMetricsMiddleware, CombinedAnalyticsMiddleware

metrics_mw = MetricsMiddleware()
perf_mw = PerformanceMetricsMiddleware()
analytics_mw = CombinedAnalyticsMiddleware()

client = MCPClient(config=config, middleware=[metrics_mw, perf_mw, analytics_mw])

# Later, retrieve data
print(metrics_mw.get_metrics())
print(perf_mw.get_performance_metrics())
print(analytics_mw.get_combined_analytics())

Middleware Chain Execution

Middleware executes in the order provided (outermost first):
middleware = [
    mw1,  # Executes first (outermost)
    mw2,  # Executes second
    mw3,  # Executes third (innermost)
]
Flow:
  1. mw1 starts → calls await call_next(context)
  2. mw2 starts → calls await call_next(context)
  3. mw3 starts → calls await call_next(context)
  4. Actual MCP call executes
  5. mw3 resumes with result
  6. mw2 resumes with result
  7. mw1 resumes with result

Error Handling

Always re-raise unless you’re intentionally transforming errors.
from typing import Any
from mcp_use.middleware import Middleware, MiddlewareContext, NextFunctionT

class ErrorAware(Middleware):
    async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
        try:
            return await call_next(context)
        except Exception as e:
            context.metadata["error"] = str(e)
            raise

Best Practices

  1. Re-raise exceptions unless you have a clear alternative behavior.
  2. Use type hints on hooks for better IDE support.
  3. Keep each middleware focused on a single concern.
  4. Use context.metadata for cross-middleware communication.
class Auth(Middleware):
    async def on_request(self, context, call_next):
        context.metadata["user"] = "alice"
        return await call_next(context)

class Audit(Middleware):
    async def on_request(self, context, call_next):
        user = context.metadata.get("user", "unknown")
        print(f"{user} -> {context.method}")
        return await call_next(context)

Integration with MCP Clients

MCPClient

from mcp_use import MCPClient
from mcp_use.middleware import MetricsMiddleware

metrics_mw = MetricsMiddleware()

client = MCPClient(
    config={
        "mcpServers": {
            "my_server": {"command": "my-mcp-server", "args": ["--port", "8080"]}
        }
    },
    middleware=[metrics_mw],  # default logging is automatically prepended
)

# You can also add middleware during execution with:
client.add_middleware(CustomMiddleware())

Per-session stacks

prod_client = MCPClient(config=prod_config, middleware=[SecurityMw(), MetricsMiddleware()])
dev_client  = MCPClient(config=dev_config)  # default logging only
See the examples directory for a complete working example:
I