> ## Documentation Index
> Fetch the complete documentation index at: https://docs.mcp-use.com/llms.txt
> Use this file to discover all available pages before exploring further.

# Middleware System

> Flexible middleware system for intercepting and processing MCP requests and responses with typed hooks

The mcp-use middleware system provides a powerful, flexible way to intercept and process all MCP (Model Context Protocol) requests and responses. It uses a class-based hook pattern (with an Express-like `next` flow) so you can log, collect metrics, filter/transform requests, cache, rate limit, and more.

## Middleware Process

mcp-use runs middleware in a chain. When a request like `initialize`, `tools/call`, `resources/read`, etc. is made, it flows through your middlewares (in order). In each middleware you can:

1. Analyze the incoming request through the context
2. Modify the request (and its metadata) before execution
3. Execute the next handler via `await call_next(context)`
4. Inspect/transform the result or handle errors

Transport support:

* Fully supported: `stdio`, `http` (streamable HTTP and SSE), `sandbox`
* Not yet supported: `websocket`

## Overview

You implement middleware by subclassing `Middleware` and overriding one or more hooks. Each hook receives:

* `context: MiddlewareContext[T]` - typed request context
* `call_next: NextFunctionT[T, R]` - call to the next middleware or the actual MCP handler

Hooks are typed to the specific MCP method so you get strong IDE assistance when overriding specific operations like `on_call_tool`.

## Quick Start

```python theme={null}
from mcp.types import CallToolRequestParams, CallToolResult
from mcp_use import MCPClient
from mcp_use.middleware import Middleware, MiddlewareContext, NextFunctionT

class CustomMiddleware(Middleware):
    async def on_call_tool(
        self, context: MiddlewareContext[CallToolRequestParams], call_next: NextFunctionT
    ) -> CallToolResult:
        print(f"Calling tool {context.params.name}")
        return await call_next(context)

config = {
    "mcpServers": {
        "playwright": {"command": "npx", "args": ["@playwright/mcp@latest"], "env": {"DISPLAY": ":1"}}
    }
}

# MCPClient automatically prepends a default logging middleware.
# You can add your own middlewares after it.
client = MCPClient(config=config, middleware=[CustomMiddleware()])
```

## Core Types

### `MiddlewareContext[T]`

```python theme={null}
from dataclasses import dataclass, field
from typing import Any, Generic, TypeVar

T = TypeVar("T")

@dataclass
class MiddlewareContext(Generic[T]):
    id: str                 # Unique request ID
    method: str             # JSON-RPC method name (e.g., "tools/call")
    params: T               # Typed parameters for the method (may be None for list APIs)
    connection_id: str      # Connector identifier (e.g., "stdio:...", "http:...")
    timestamp: float        # Request start time
    metadata: dict[str, Any] = field(default_factory=dict)
```

### `NextFunctionT[T, R]`

```python theme={null}
from typing import Protocol

class NextFunctionT(Protocol[T, R]):
    async def __call__(self, context: MiddlewareContext[T]) -> R: ...
```

### `Middleware` base class

You override hooks on the base class. If you only need a single entry point for all requests, override `on_request`.

```python theme={null}
from typing import Any
from mcp_use.middleware import Middleware, MiddlewareContext, NextFunctionT

class MyMiddleware(Middleware):
    async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
        # before
        result = await call_next(context)
        # after
        return result
```

## Hook Reference

Available hooks (override any subset). Types below come from `mcp.types`.

```python theme={null}
from typing import Any
from mcp.types import (
    InitializeRequestParams, InitializeResult,
    CallToolRequestParams, CallToolResult,
    ListToolsRequest, ListToolsResult,
    ListResourcesRequest, ListResourcesResult,
    ReadResourceRequestParams, ReadResourceResult,
    ListPromptsRequest, ListPromptsResult,
    GetPromptRequestParams, GetPromptResult,
)
from mcp_use.middleware import Middleware, MiddlewareContext, NextFunctionT

class Middleware:
    async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any: ...

    async def on_initialize(
        self, context: MiddlewareContext[InitializeRequestParams], call_next: NextFunctionT
    ) -> InitializeResult: ...

    async def on_call_tool(
        self, context: MiddlewareContext[CallToolRequestParams], call_next: NextFunctionT
    ) -> CallToolResult: ...

    async def on_list_tools(
        self, context: MiddlewareContext[ListToolsRequest], call_next: NextFunctionT
    ) -> ListToolsResult: ...

    async def on_list_resources(
        self, context: MiddlewareContext[ListResourcesRequest], call_next: NextFunctionT
    ) -> ListResourcesResult: ...

    async def on_read_resource(
        self, context: MiddlewareContext[ReadResourceRequestParams], call_next: NextFunctionT
    ) -> ReadResourceResult: ...

    async def on_list_prompts(
        self, context: MiddlewareContext[ListPromptsRequest], call_next: NextFunctionT
    ) -> ListPromptsResult: ...

    async def on_get_prompt(
        self, context: MiddlewareContext[GetPromptRequestParams], call_next: NextFunctionT
    ) -> GetPromptResult: ...
```

## Writing Middleware

### Timing Example

```python theme={null}
import time
from typing import Any
from mcp_use.middleware import Middleware, MiddlewareContext, NextFunctionT

class TimingMiddleware(Middleware):
    async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
        start = time.time()
        try:
            result = await call_next(context)
            return result
        finally:
            duration = time.time() - start
            print(f"{context.method} took {duration:.3f}s")
```

## Built-in Middleware

### Default logging

* A default logging middleware is automatically prepended by `MCPClient`.
* It logs each request/response at debug level with timing.
* You don’t need to add it manually; just pass your custom middlewares.

```python theme={null}
from mcp_use import MCPClient
from mcp_use.middleware import Middleware

client = MCPClient(config=config, middleware=[MyCustomMiddleware()])
# Order = [default_logging_middleware, MyCustomMiddleware()]
```

### Metrics

Instantiate and pass to the client. Each middleware exposes a getter on the instance.

```python theme={null}
from mcp_use.middleware import MetricsMiddleware, PerformanceMetricsMiddleware, CombinedAnalyticsMiddleware

metrics_mw = MetricsMiddleware()
perf_mw = PerformanceMetricsMiddleware()
analytics_mw = CombinedAnalyticsMiddleware()

client = MCPClient(config=config, middleware=[metrics_mw, perf_mw, analytics_mw])

# Later, retrieve data
print(metrics_mw.get_metrics())
print(perf_mw.get_performance_metrics())
print(analytics_mw.get_combined_analytics())
```

## Middleware Chain Execution

Middleware executes in the order provided (outermost first):

```python theme={null}
middleware = [
    mw1,  # Executes first (outermost)
    mw2,  # Executes second
    mw3,  # Executes third (innermost)
]
```

Flow:

1. `mw1` starts → calls `await call_next(context)`
2. `mw2` starts → calls `await call_next(context)`
3. `mw3` starts → calls `await call_next(context)`
4. Actual MCP call executes
5. `mw3` resumes with result
6. `mw2` resumes with result
7. `mw1` resumes with result

## Error Handling

Always re-raise unless you’re intentionally transforming errors.

```python theme={null}
from typing import Any
from mcp_use.middleware import Middleware, MiddlewareContext, NextFunctionT

class ErrorAware(Middleware):
    async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
        try:
            return await call_next(context)
        except Exception as e:
            context.metadata["error"] = str(e)
            raise
```

## Best Practices

1. Re-raise exceptions unless you have a clear alternative behavior.
2. Use type hints on hooks for better IDE support.
3. Keep each middleware focused on a single concern.
4. Use `context.metadata` for cross-middleware communication.

```python theme={null}
class Auth(Middleware):
    async def on_request(self, context, call_next):
        context.metadata["user"] = "alice"
        return await call_next(context)

class Audit(Middleware):
    async def on_request(self, context, call_next):
        user = context.metadata.get("user", "unknown")
        print(f"{user} -> {context.method}")
        return await call_next(context)
```

**Note on Modifying `context.params` and Headers**

Middleware receives a typed `context.params` object and may modify it before the request is executed. The runtime guarantees and recommended patterns are:

* **Preferred  -  Mutate fields on `context.params`:**
  * Example: `context.params.arguments["user_id"] = "alice"`
  * These mutations are observed by the final MCP client call and are the most compatible approach.

* **Replacement  -  Reassigning `context.params`:**
  * Example: `context.params = NewParams(...)`
  * This pattern is supported: the middleware system reads `context.params` at call-time so replacements are respected. However, prefer mutation for clarity and to avoid surprises for readers of middleware code.

* **Per-request HTTP headers:**
  * Adding `context.metadata["headers"]` is useful for carrying header-like information through middleware, but it will only be applied to the actual HTTP transport if the connector/transport code explicitly reads and merges those values into the request headers.
  * There is no global automatic mechanism that takes `context.metadata["headers"]` and injects them into every transport unless the connector implements that behavior.

Example  -  add a trace id that downstream middleware or server can observe (does not automatically modify HTTP headers):

```python theme={null}
class AddTraceMiddleware(Middleware):
        async def on_call_tool(self, context: MiddlewareContext[CallToolRequestParams], call_next: NextFunctionT):
                # Mutate params (preferred)
                context.params.arguments.setdefault("meta", {})["trace_id"] = "trace-123"

                # Or store headers-like info in metadata for connector-level handling
                context.metadata.setdefault("headers", {})["X-Trace-Id"] = "trace-123"

                return await call_next(context)
```

If you need middleware to inject actual HTTP headers per request, there are two safe approaches:

1. **Connector support (recommended):** update the connector/transport to read `context.metadata["headers"]` and merge them into the outgoing HTTPX request headers for that call. This is robust and concurrency-safe when implemented correctly.
2. **Mutate request params the server understands:** include header-like fields inside `context.params` (for example inside a tool's `arguments`) and let the server interpret them.

If you'd like, we can add a small connector example showing how to merge `context.metadata["headers"]` into the HTTP request  -  say if you want automatic per-request header injection. Let me know and I will add that example as a follow-up.

## Integration with MCP Clients

### MCPClient

```python theme={null}
from mcp_use import MCPClient
from mcp_use.middleware import MetricsMiddleware

metrics_mw = MetricsMiddleware()

client = MCPClient(
    config={
        "mcpServers": {
            "my_server": {"command": "my-mcp-server", "args": ["--port", "8080"]}
        }
    },
    middleware=[metrics_mw],  # default logging is automatically prepended
)

# You can also add middleware during execution with:
client.add_middleware(CustomMiddleware())
```

### Per-session stacks

```python theme={null}
prod_client = MCPClient(config=prod_config, middleware=[SecurityMw(), MetricsMiddleware()])
dev_client  = MCPClient(config=dev_config)  # default logging only
```

<Note>
  See the examples directory for a complete working example:

  * [`examples/example_middleware.py`](https://github.com/mcp-use/mcp-use/blob/main/examples/example_middleware.py)
</Note>
