Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/mcp/server/mcpserver/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from __future__ import annotations

import functools
import inspect
from collections.abc import Awaitable, Callable, Sequence
from typing import TYPE_CHECKING, Any, Literal

import anyio.to_thread
import pydantic_core
from pydantic import BaseModel, Field, TypeAdapter, validate_call

Expand Down Expand Up @@ -155,10 +157,10 @@ async def render(
# Add context to arguments if needed
call_args = inject_context(self.fn, arguments or {}, context, self.context_kwarg)

# Call function and check if result is a coroutine
result = self.fn(**call_args)
if inspect.iscoroutine(result):
result = await result
if inspect.iscoroutinefunction(self.fn):
result = await self.fn(**call_args)
else:
result = await anyio.to_thread.run_sync(functools.partial(self.fn, **call_args))

# Validate messages
if not isinstance(result, list | tuple):
Expand Down
10 changes: 6 additions & 4 deletions src/mcp/server/mcpserver/resources/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from __future__ import annotations

import functools
import inspect
import re
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from urllib.parse import unquote

import anyio.to_thread
from pydantic import BaseModel, Field, validate_call

from mcp.server.mcpserver.resources.types import FunctionResource, Resource
Expand Down Expand Up @@ -110,10 +112,10 @@ async def create_resource(
# Add context to params if needed
params = inject_context(self.fn, params, context, self.context_kwarg)

# Call function and check if result is a coroutine
result = self.fn(**params)
if inspect.iscoroutine(result):
result = await result
if inspect.iscoroutinefunction(self.fn):
result = await self.fn(**params)
else:
result = await anyio.to_thread.run_sync(functools.partial(self.fn, **params))

return FunctionResource(
uri=uri, # type: ignore
Expand Down
9 changes: 4 additions & 5 deletions src/mcp/server/mcpserver/resources/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,10 @@ class FunctionResource(Resource):
async def read(self) -> str | bytes:
"""Read the resource by calling the wrapped function."""
try:
# Call the function first to see if it returns a coroutine
result = self.fn()
# If it's a coroutine, await it
if inspect.iscoroutine(result):
result = await result
if inspect.iscoroutinefunction(self.fn):
result = await self.fn()
else:
result = await anyio.to_thread.run_sync(self.fn)

if isinstance(result, Resource): # pragma: no cover
return await result.read()
Expand Down
19 changes: 19 additions & 0 deletions tests/server/mcpserver/prompts/test_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import threading
from typing import Any

import pytest
Expand Down Expand Up @@ -190,3 +191,21 @@ async def fn() -> dict[str, Any]:
)
)
]


@pytest.mark.anyio
async def test_sync_fn_runs_in_worker_thread():
"""Sync prompt functions must run in a worker thread, not the event loop."""

main_thread = threading.get_ident()
fn_thread: list[int] = []

def blocking_fn() -> str:
fn_thread.append(threading.get_ident())
return "hello"

prompt = Prompt.from_function(blocking_fn)
messages = await prompt.render(None, Context())

assert messages == [UserMessage(content=TextContent(type="text", text="hello"))]
assert fn_thread[0] != main_thread
52 changes: 52 additions & 0 deletions tests/server/mcpserver/resources/test_function_resources.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import threading

import anyio
import anyio.from_thread
import pytest
from pydantic import BaseModel

Expand Down Expand Up @@ -190,3 +194,51 @@ def get_data() -> str: # pragma: no cover
)

assert resource.meta is None


@pytest.mark.anyio
async def test_sync_fn_runs_in_worker_thread():
"""Sync resource functions must run in a worker thread, not the event loop."""

main_thread = threading.get_ident()
fn_thread: list[int] = []

def blocking_fn() -> str:
fn_thread.append(threading.get_ident())
return "data"

resource = FunctionResource(uri="resource://test", name="test", fn=blocking_fn)
result = await resource.read()

assert result == "data"
assert fn_thread[0] != main_thread


@pytest.mark.anyio
async def test_sync_fn_does_not_block_event_loop():
"""A blocking sync resource function must not stall the event loop.

On regression (sync runs inline), anyio.from_thread.run_sync raises
RuntimeError because there is no worker-thread context, failing fast.
"""
handler_entered = anyio.Event()
release = threading.Event()

def blocking_fn() -> str:
anyio.from_thread.run_sync(handler_entered.set)
release.wait()
return "done"

resource = FunctionResource(uri="resource://test", name="test", fn=blocking_fn)
result: list[str | bytes] = []

async def run() -> None:
result.append(await resource.read())

with anyio.fail_after(5):
async with anyio.create_task_group() as tg:
tg.start_soon(run)
await handler_entered.wait()
release.set()

assert result == ["done"]
20 changes: 20 additions & 0 deletions tests/server/mcpserver/resources/test_resource_template.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import threading
from typing import Any

import pytest
Expand Down Expand Up @@ -310,3 +311,22 @@ def get_item(item_id: str) -> str:
assert resource.meta == metadata
assert resource.meta["category"] == "inventory"
assert resource.meta["cacheable"] is True


@pytest.mark.anyio
async def test_sync_fn_runs_in_worker_thread():
"""Sync template functions must run in a worker thread, not the event loop."""

main_thread = threading.get_ident()
fn_thread: list[int] = []

def blocking_fn(name: str) -> str:
fn_thread.append(threading.get_ident())
return f"hello {name}"

template = ResourceTemplate.from_function(fn=blocking_fn, uri_template="test://{name}")
resource = await template.create_resource("test://world", {"name": "world"}, Context())

assert isinstance(resource, FunctionResource)
assert await resource.read() == "hello world"
assert fn_thread[0] != main_thread
Loading