diff --git a/pyproject.toml b/pyproject.toml index 745c80e0c..095a38cb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "boto3>=1.26.0,<2.0.0", "botocore>=1.29.0,<2.0.0", "docstring_parser>=0.15,<1.0", - "mcp>=1.8.0,<2.0.0", + "mcp>=1.11.0,<2.0.0", "pydantic>=2.0.0,<3.0.0", "typing-extensions>=4.13.2,<5.0.0", "watchdog>=6.0.0,<7.0.0", diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 679f1ea3d..8b7ef68d7 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -17,10 +17,10 @@ from ..event_loop import streaming from ..tools import convert_pydantic_to_tool_spec -from ..types.content import Messages +from ..types.content import ContentBlock, Message, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent -from ..types.tools import ToolSpec +from ..types.tools import ToolResult, ToolSpec from .model import Model logger = logging.getLogger(__name__) @@ -181,7 +181,7 @@ def format_request( """ return { "modelId": self.config["model_id"], - "messages": messages, + "messages": self._format_bedrock_messages(messages), "system": [ *([{"text": system_prompt}] if system_prompt else []), *([{"cachePoint": {"type": self.config["cache_prompt"]}}] if self.config.get("cache_prompt") else []), @@ -246,6 +246,53 @@ def format_request( ), } + def _format_bedrock_messages(self, messages: Messages) -> Messages: + """Format messages for Bedrock API compatibility. + + This function ensures messages conform to Bedrock's expected format by: + - Cleaning tool result content blocks by removing additional fields that may be + useful for retaining information in hooks but would cause Bedrock validation + exceptions when presented with unexpected fields + - Ensuring all message content blocks are properly formatted for the Bedrock API + + Args: + messages: List of messages to format + + Returns: + Messages formatted for Bedrock API compatibility + + Note: + Bedrock will throw validation exceptions when presented with additional + unexpected fields in tool result blocks. + https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html + """ + cleaned_messages = [] + + for message in messages: + cleaned_content: list[ContentBlock] = [] + + for content_block in message["content"]: + if "toolResult" in content_block: + # Create a new content block with only the cleaned toolResult + tool_result: ToolResult = content_block["toolResult"] + + # Keep only the required fields for Bedrock + cleaned_tool_result = ToolResult( + content=tool_result["content"], toolUseId=tool_result["toolUseId"], status=tool_result["status"] + ) + + cleaned_block: ContentBlock = {"toolResult": cleaned_tool_result} + cleaned_content.append(cleaned_block) + else: + # Keep other content blocks as-is + cleaned_content.append(content_block) + + # Create new message with cleaned content + cleaned_message: Message = Message(content=cleaned_content, role=message["role"]) + cleaned_messages.append(cleaned_message) + + return cleaned_messages + def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool: """Check if guardrail data contains any blocked policies. diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 4cf4e1f85..784636fd0 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -26,9 +26,9 @@ from ...types import PaginatedList from ...types.exceptions import MCPClientInitializationError from ...types.media import ImageFormat -from ...types.tools import ToolResult, ToolResultContent, ToolResultStatus +from ...types.tools import ToolResultContent, ToolResultStatus from .mcp_agent_tool import MCPAgentTool -from .mcp_types import MCPTransport +from .mcp_types import MCPToolResult, MCPTransport logger = logging.getLogger(__name__) @@ -57,7 +57,8 @@ class MCPClient: It handles the creation, initialization, and cleanup of MCP connections. The connection runs in a background thread to avoid blocking the main application thread - while maintaining communication with the MCP service. + while maintaining communication with the MCP service. When structured content is available + from MCP tools, it will be returned as the last item in the content array of the ToolResult. """ def __init__(self, transport_callable: Callable[[], MCPTransport]): @@ -170,11 +171,13 @@ def call_tool_sync( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, - ) -> ToolResult: + ) -> MCPToolResult: """Synchronously calls a tool on the MCP server. This method calls the asynchronous call_tool method on the MCP session - and converts the result to the ToolResult format. + and converts the result to the ToolResult format. If the MCP tool returns + structured content, it will be included as the last item in the content array + of the returned ToolResult. Args: tool_use_id: Unique identifier for this tool use @@ -183,7 +186,7 @@ def call_tool_sync( read_timeout_seconds: Optional timeout for the tool call Returns: - ToolResult: The result of the tool call + MCPToolResult: The result of the tool call """ self._log_debug_with_thread("calling MCP tool '%s' synchronously with tool_use_id=%s", name, tool_use_id) if not self._is_session_active(): @@ -205,11 +208,11 @@ async def call_tool_async( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, - ) -> ToolResult: + ) -> MCPToolResult: """Asynchronously calls a tool on the MCP server. This method calls the asynchronous call_tool method on the MCP session - and converts the result to the ToolResult format. + and converts the result to the MCPToolResult format. Args: tool_use_id: Unique identifier for this tool use @@ -218,7 +221,7 @@ async def call_tool_async( read_timeout_seconds: Optional timeout for the tool call Returns: - ToolResult: The result of the tool call + MCPToolResult: The result of the tool call """ self._log_debug_with_thread("calling MCP tool '%s' asynchronously with tool_use_id=%s", name, tool_use_id) if not self._is_session_active(): @@ -235,15 +238,27 @@ async def _call_tool_async() -> MCPCallToolResult: logger.exception("tool execution failed") return self._handle_tool_execution_error(tool_use_id, e) - def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> ToolResult: + def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> MCPToolResult: """Create error ToolResult with consistent logging.""" - return ToolResult( + return MCPToolResult( status="error", toolUseId=tool_use_id, content=[{"text": f"Tool execution failed: {str(exception)}"}], ) - def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolResult) -> ToolResult: + def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolResult) -> MCPToolResult: + """Maps MCP tool result to the agent's MCPToolResult format. + + This method processes the content from the MCP tool call result and converts it to the format + expected by the framework. + + Args: + tool_use_id: Unique identifier for this tool use + call_tool_result: The result from the MCP tool call + + Returns: + MCPToolResult: The converted tool result + """ self._log_debug_with_thread("received tool result with %d content items", len(call_tool_result.content)) mapped_content = [ @@ -254,7 +269,15 @@ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolRes status: ToolResultStatus = "error" if call_tool_result.isError else "success" self._log_debug_with_thread("tool execution completed with status: %s", status) - return ToolResult(status=status, toolUseId=tool_use_id, content=mapped_content) + result = MCPToolResult( + status=status, + toolUseId=tool_use_id, + content=mapped_content, + ) + if call_tool_result.structuredContent: + result["structuredContent"] = call_tool_result.structuredContent + + return result async def _async_background_thread(self) -> None: """Asynchronous method that runs in the background thread to manage the MCP connection. diff --git a/src/strands/tools/mcp/mcp_types.py b/src/strands/tools/mcp/mcp_types.py index 30defc585..5fafed5dc 100644 --- a/src/strands/tools/mcp/mcp_types.py +++ b/src/strands/tools/mcp/mcp_types.py @@ -1,11 +1,15 @@ """Type definitions for MCP integration.""" from contextlib import AbstractAsyncContextManager +from typing import Any, Dict from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp.client.streamable_http import GetSessionIdCallback from mcp.shared.memory import MessageStream from mcp.shared.message import SessionMessage +from typing_extensions import NotRequired + +from strands.types.tools import ToolResult """ MCPTransport defines the interface for MCP transport implementations. This abstracts @@ -41,3 +45,19 @@ async def my_transport_implementation(): MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], GetSessionIdCallback ] MCPTransport = AbstractAsyncContextManager[MessageStream | _MessageStreamWithGetSessionIdCallback] + + +class MCPToolResult(ToolResult): + """Result of an MCP tool execution. + + Extends the base ToolResult with MCP-specific structured content support. + The structuredContent field contains optional JSON data returned by MCP tools + that provides structured results beyond the standard text/image/document content. + + Attributes: + structuredContent: Optional JSON object containing structured data returned + by the MCP tool. This allows MCP tools to return complex data structures + that can be processed programmatically by agents or other tools. + """ + + structuredContent: NotRequired[Dict[str, Any]] diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 47e028cb9..0a2846adf 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -13,6 +13,7 @@ from strands.models import BedrockModel from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION from strands.types.exceptions import ModelThrottledException +from strands.types.tools import ToolSpec @pytest.fixture @@ -51,7 +52,7 @@ def model(bedrock_client, model_id): @pytest.fixture def messages(): - return [{"role": "user", "content": {"text": "test"}}] + return [{"role": "user", "content": [{"text": "test"}]}] @pytest.fixture @@ -90,8 +91,12 @@ def inference_config(): @pytest.fixture -def tool_spec(): - return {"t1": 1} +def tool_spec() -> ToolSpec: + return { + "description": "description", + "name": "name", + "inputSchema": {"key": "val"}, + } @pytest.fixture @@ -750,7 +755,7 @@ async def test_stream_output_no_guardrail_redact( @pytest.mark.asyncio -async def test_stream_with_streaming_false(bedrock_client, alist): +async def test_stream_with_streaming_false(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -759,8 +764,7 @@ async def test_stream_with_streaming_false(bedrock_client, alist): # Create model and call stream model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -776,7 +780,7 @@ async def test_stream_with_streaming_false(bedrock_client, alist): @pytest.mark.asyncio -async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist): +async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": { @@ -790,8 +794,7 @@ async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist): # Create model and call stream model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -808,7 +811,7 @@ async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist): @pytest.mark.asyncio -async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist): +async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": { @@ -828,8 +831,7 @@ async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist): # Create model and call stream model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -847,7 +849,7 @@ async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist): @pytest.mark.asyncio -async def test_stream_and_reasoning_no_signature(bedrock_client, alist): +async def test_stream_and_reasoning_no_signature(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": { @@ -867,8 +869,7 @@ async def test_stream_and_reasoning_no_signature(bedrock_client, alist): # Create model and call stream model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -884,7 +885,7 @@ async def test_stream_and_reasoning_no_signature(bedrock_client, alist): @pytest.mark.asyncio -async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client, alist): +async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -895,8 +896,7 @@ async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client # Create model and call stream model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -919,7 +919,7 @@ async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client @pytest.mark.asyncio -async def test_stream_input_guardrails(bedrock_client, alist): +async def test_stream_input_guardrails(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -937,8 +937,7 @@ async def test_stream_input_guardrails(bedrock_client, alist): # Create model and call stream model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -970,7 +969,7 @@ async def test_stream_input_guardrails(bedrock_client, alist): @pytest.mark.asyncio -async def test_stream_output_guardrails(bedrock_client, alist): +async def test_stream_output_guardrails(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -989,8 +988,7 @@ async def test_stream_output_guardrails(bedrock_client, alist): } model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -1024,7 +1022,7 @@ async def test_stream_output_guardrails(bedrock_client, alist): @pytest.mark.asyncio -async def test_stream_output_guardrails_redacts_output(bedrock_client, alist): +async def test_stream_output_guardrails_redacts_output(bedrock_client, alist, messages): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -1043,8 +1041,7 @@ async def test_stream_output_guardrails_redacts_output(bedrock_client, alist): } model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - response = model.stream(request) + response = model.stream(messages) tru_events = await alist(response) exp_events = [ @@ -1101,7 +1098,7 @@ async def test_structured_output(bedrock_client, model, test_output_model_cls, a @pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") @pytest.mark.asyncio -async def test_add_note_on_client_error(bedrock_client, model, alist): +async def test_add_note_on_client_error(bedrock_client, model, alist, messages): """Test that add_note is called on ClientError with region and model ID information.""" # Mock the client error response error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}} @@ -1109,13 +1106,13 @@ async def test_add_note_on_client_error(bedrock_client, model, alist): # Call the stream method which should catch and add notes to the exception with pytest.raises(ClientError) as err: - await alist(model.stream({"modelId": "test-model"})) + await alist(model.stream(messages)) assert err.value.__notes__ == ["└ Bedrock region: us-west-2", "└ Model id: m1"] @pytest.mark.asyncio -async def test_no_add_note_when_not_available(bedrock_client, model, alist): +async def test_no_add_note_when_not_available(bedrock_client, model, alist, messages): """Verify that on any python version (even < 3.11 where add_note is not available, we get the right exception).""" # Mock the client error response error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}} @@ -1123,12 +1120,12 @@ async def test_no_add_note_when_not_available(bedrock_client, model, alist): # Call the stream method which should catch and add notes to the exception with pytest.raises(ClientError): - await alist(model.stream({"modelId": "test-model"})) + await alist(model.stream(messages)) @pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") @pytest.mark.asyncio -async def test_add_note_on_access_denied_exception(bedrock_client, model, alist): +async def test_add_note_on_access_denied_exception(bedrock_client, model, alist, messages): """Test that add_note adds documentation link for AccessDeniedException.""" # Mock the client error response for access denied error_response = { @@ -1142,7 +1139,7 @@ async def test_add_note_on_access_denied_exception(bedrock_client, model, alist) # Call the stream method which should catch and add notes to the exception with pytest.raises(ClientError) as err: - await alist(model.stream({"modelId": "test-model"})) + await alist(model.stream(messages)) assert err.value.__notes__ == [ "└ Bedrock region: us-west-2", @@ -1154,7 +1151,7 @@ async def test_add_note_on_access_denied_exception(bedrock_client, model, alist) @pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") @pytest.mark.asyncio -async def test_add_note_on_validation_exception_throughput(bedrock_client, model, alist): +async def test_add_note_on_validation_exception_throughput(bedrock_client, model, alist, messages): """Test that add_note adds documentation link for ValidationException about on-demand throughput.""" # Mock the client error response for validation exception error_response = { @@ -1170,7 +1167,7 @@ async def test_add_note_on_validation_exception_throughput(bedrock_client, model # Call the stream method which should catch and add notes to the exception with pytest.raises(ClientError) as err: - await alist(model.stream({"modelId": "test-model"})) + await alist(model.stream(messages)) assert err.value.__notes__ == [ "└ Bedrock region: us-west-2", @@ -1202,3 +1199,32 @@ async def test_stream_logging(bedrock_client, model, messages, caplog, alist): assert "invoking model" in log_text assert "got response from model" in log_text assert "finished streaming response from model" in log_text + + +def test_format_request_cleans_tool_result_content_blocks(model, model_id): + """Test that format_request cleans toolResult blocks by removing extra fields.""" + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "content": [{"text": "Tool output"}], + "toolUseId": "tool123", + "status": "success", + "extraField": "should be removed", + "mcpMetadata": {"server": "test"}, + } + }, + ], + } + ] + + formatted_request = model.format_request(messages) + + # Verify toolResult only contains allowed fields in the formatted request + tool_result = formatted_request["messages"][0]["content"][0]["toolResult"] + expected = {"content": [{"text": "Tool output"}], "toolUseId": "tool123", "status": "success"} + assert tool_result == expected + assert "extraField" not in tool_result + assert "mcpMetadata" not in tool_result diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index 6a2fdd00c..3d3792c71 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -8,6 +8,7 @@ from mcp.types import Tool as MCPTool from strands.tools.mcp import MCPClient +from strands.tools.mcp.mcp_types import MCPToolResult from strands.types.exceptions import MCPClientInitializationError @@ -129,6 +130,8 @@ def test_call_tool_sync_status(mock_transport, mock_session, is_error, expected_ assert result["toolUseId"] == "test-123" assert len(result["content"]) == 1 assert result["content"][0]["text"] == "Test message" + # No structured content should be present when not provided by MCP + assert result.get("structuredContent") is None def test_call_tool_sync_session_not_active(): @@ -139,6 +142,31 @@ def test_call_tool_sync_session_not_active(): client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) +def test_call_tool_sync_with_structured_content(mock_transport, mock_session): + """Test that call_tool_sync correctly handles structured content.""" + mock_content = MCPTextContent(type="text", text="Test message") + structured_content = {"result": 42, "status": "completed"} + mock_session.call_tool.return_value = MCPCallToolResult( + isError=False, content=[mock_content], structuredContent=structured_content + ) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) + + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None) + + assert result["status"] == "success" + assert result["toolUseId"] == "test-123" + # Content should only contain the text content, not the structured content + assert len(result["content"]) == 1 + assert result["content"][0]["text"] == "Test message" + # Structured content should be in its own field + assert "structuredContent" in result + assert result["structuredContent"] == structured_content + assert result["structuredContent"]["result"] == 42 + assert result["structuredContent"]["status"] == "completed" + + def test_call_tool_sync_exception(mock_transport, mock_session): """Test that call_tool_sync correctly handles exceptions.""" mock_session.call_tool.side_effect = Exception("Test exception") @@ -312,6 +340,45 @@ def test_enter_with_initialization_exception(mock_transport): client.start() +def test_mcp_tool_result_type(): + """Test that MCPToolResult extends ToolResult correctly.""" + # Test basic ToolResult functionality + result = MCPToolResult(status="success", toolUseId="test-123", content=[{"text": "Test message"}]) + + assert result["status"] == "success" + assert result["toolUseId"] == "test-123" + assert result["content"][0]["text"] == "Test message" + + # Test that structuredContent is optional + assert "structuredContent" not in result or result.get("structuredContent") is None + + # Test with structuredContent + result_with_structured = MCPToolResult( + status="success", toolUseId="test-456", content=[{"text": "Test message"}], structuredContent={"key": "value"} + ) + + assert result_with_structured["structuredContent"] == {"key": "value"} + + +def test_call_tool_sync_without_structured_content(mock_transport, mock_session): + """Test that call_tool_sync works correctly when no structured content is provided.""" + mock_content = MCPTextContent(type="text", text="Test message") + mock_session.call_tool.return_value = MCPCallToolResult( + isError=False, + content=[mock_content], # No structuredContent + ) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) + + assert result["status"] == "success" + assert result["toolUseId"] == "test-123" + assert len(result["content"]) == 1 + assert result["content"][0]["text"] == "Test message" + # structuredContent should be None when not provided by MCP + assert result.get("structuredContent") is None + + def test_exception_when_future_not_running(): """Test exception handling when the future is not running.""" # Create a client.with a mock transport diff --git a/tests_integ/echo_server.py b/tests_integ/echo_server.py index d309607a8..52223792c 100644 --- a/tests_integ/echo_server.py +++ b/tests_integ/echo_server.py @@ -2,7 +2,7 @@ Echo Server for MCP Integration Testing This module implements a simple echo server using the Model Context Protocol (MCP). -It provides a basic tool that echoes back any input string, which is useful for +It provides basic tools that echo back input strings and structured content, which is useful for testing the MCP communication flow and validating that messages are properly transmitted between the client and server. @@ -15,6 +15,8 @@ $ python echo_server.py """ +from typing import Any, Dict + from mcp.server import FastMCP @@ -22,16 +24,22 @@ def start_echo_server(): """ Initialize and start the MCP echo server. - Creates a FastMCP server instance with a single 'echo' tool that returns - any input string back to the caller. The server uses stdio transport + Creates a FastMCP server instance with tools that return + input strings and structured content back to the caller. The server uses stdio transport for communication. + """ mcp = FastMCP("Echo Server") - @mcp.tool(description="Echos response back to the user") + @mcp.tool(description="Echos response back to the user", structured_output=False) def echo(to_echo: str) -> str: return to_echo + # FastMCP automatically constructs structured output schema from method signature + @mcp.tool(description="Echos response back with structured content", structured_output=True) + def echo_with_structured_content(to_echo: str) -> Dict[str, Any]: + return {"echoed": to_echo} + mcp.run(transport="stdio") diff --git a/tests_integ/test_mcp_client.py b/tests_integ/test_mcp_client.py index 9163f625d..ebd4f5896 100644 --- a/tests_integ/test_mcp_client.py +++ b/tests_integ/test_mcp_client.py @@ -1,4 +1,5 @@ import base64 +import json import os import threading import time @@ -87,6 +88,24 @@ def test_mcp_client(): ] ) + tool_use_id = "test-structured-content-123" + result = stdio_mcp_client.call_tool_sync( + tool_use_id=tool_use_id, + name="echo_with_structured_content", + arguments={"to_echo": "STRUCTURED_DATA_TEST"}, + ) + + # With the new MCPToolResult, structured content is in its own field + assert "structuredContent" in result + assert result["structuredContent"]["result"] == {"echoed": "STRUCTURED_DATA_TEST"} + + # Verify the result is an MCPToolResult (at runtime it's just a dict, but type-wise it should be MCPToolResult) + assert result["status"] == "success" + assert result["toolUseId"] == tool_use_id + + assert len(result["content"]) == 1 + assert json.loads(result["content"][0]["text"]) == {"echoed": "STRUCTURED_DATA_TEST"} + def test_can_reuse_mcp_client(): stdio_mcp_client = MCPClient( @@ -103,6 +122,64 @@ def test_can_reuse_mcp_client(): assert any([block["name"] == "echo" for block in tool_use_content_blocks]) +@pytest.mark.asyncio +async def test_mcp_client_async_structured_content(): + """Test that async MCP client calls properly handle structured content. + + This test demonstrates how tools configure structured output: FastMCP automatically + constructs structured output schema from method signature when structured_output=True + is set in the @mcp.tool decorator. The return type annotation defines the structure + that appears in structuredContent field. + """ + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) + ) + + with stdio_mcp_client: + tool_use_id = "test-async-structured-content-456" + result = await stdio_mcp_client.call_tool_async( + tool_use_id=tool_use_id, + name="echo_with_structured_content", + arguments={"to_echo": "ASYNC_STRUCTURED_TEST"}, + ) + + # Verify structured content is in its own field + assert "structuredContent" in result + # "result" nesting is not part of the MCP Structured Content specification, + # but rather a FastMCP implementation detail + assert result["structuredContent"]["result"] == {"echoed": "ASYNC_STRUCTURED_TEST"} + + # Verify basic MCPToolResult structure + assert result["status"] in ["success", "error"] + assert result["toolUseId"] == tool_use_id + + assert len(result["content"]) == 1 + assert json.loads(result["content"][0]["text"]) == {"echoed": "ASYNC_STRUCTURED_TEST"} + + +def test_mcp_client_without_structured_content(): + """Test that MCP client works correctly when tools don't return structured content.""" + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) + ) + + with stdio_mcp_client: + tool_use_id = "test-no-structured-content-789" + result = stdio_mcp_client.call_tool_sync( + tool_use_id=tool_use_id, + name="echo", # This tool doesn't return structured content + arguments={"to_echo": "SIMPLE_ECHO_TEST"}, + ) + + # Verify no structured content when tool doesn't provide it + assert result.get("structuredContent") is None + + # Verify basic result structure + assert result["status"] == "success" + assert result["toolUseId"] == tool_use_id + assert result["content"] == [{"text": "SIMPLE_ECHO_TEST"}] + + @pytest.mark.skipif( condition=os.environ.get("GITHUB_ACTIONS") == "true", reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue", diff --git a/tests_integ/test_mcp_client_structured_content_with_hooks.py b/tests_integ/test_mcp_client_structured_content_with_hooks.py new file mode 100644 index 000000000..ca2468c48 --- /dev/null +++ b/tests_integ/test_mcp_client_structured_content_with_hooks.py @@ -0,0 +1,65 @@ +"""Integration test demonstrating hooks system with MCP client structured content tool. + +This test shows how to use the hooks system to capture and inspect tool invocation +results, specifically testing the echo_with_structured_content tool from echo_server. +""" + +import json + +from mcp import StdioServerParameters, stdio_client + +from strands import Agent +from strands.experimental.hooks import AfterToolInvocationEvent +from strands.hooks import HookProvider, HookRegistry +from strands.tools.mcp.mcp_client import MCPClient + + +class StructuredContentHookProvider(HookProvider): + """Hook provider that captures structured content tool results.""" + + def __init__(self): + self.captured_result = None + + def register_hooks(self, registry: HookRegistry) -> None: + """Register callback for after tool invocation events.""" + registry.add_callback(AfterToolInvocationEvent, self.on_after_tool_invocation) + + def on_after_tool_invocation(self, event: AfterToolInvocationEvent) -> None: + """Capture structured content tool results.""" + if event.tool_use["name"] == "echo_with_structured_content": + self.captured_result = event.result + + +def test_mcp_client_hooks_structured_content(): + """Test using hooks to inspect echo_with_structured_content tool result.""" + # Create hook provider to capture tool result + hook_provider = StructuredContentHookProvider() + + # Set up MCP client for echo server + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"])) + ) + + with stdio_mcp_client: + # Create agent with MCP tools and hook provider + agent = Agent(tools=stdio_mcp_client.list_tools_sync(), hooks=[hook_provider]) + + # Test structured content functionality + test_data = "HOOKS_TEST_DATA" + agent(f"Use the echo_with_structured_content tool to echo: {test_data}") + + # Verify hook captured the tool result + assert hook_provider.captured_result is not None + result = hook_provider.captured_result + + # Verify basic result structure + assert result["status"] == "success" + assert len(result["content"]) == 1 + + # Verify structured content is present and correct + assert "structuredContent" in result + assert result["structuredContent"]["result"] == {"echoed": test_data} + + # Verify text content matches structured content + text_content = json.loads(result["content"][0]["text"]) + assert text_content == {"echoed": test_data}