Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies = [
"boto3>=1.26.0,<2.0.0",
"botocore>=1.29.0,<2.0.0",
"docstring_parser>=0.15,<1.0",
"mcp>=1.8.0,<2.0.0",
"mcp>=1.11.0,<2.0.0",
"pydantic>=2.0.0,<3.0.0",
"typing-extensions>=4.13.2,<5.0.0",
"watchdog>=6.0.0,<7.0.0",
Expand Down
53 changes: 50 additions & 3 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

from ..event_loop import streaming
from ..tools import convert_pydantic_to_tool_spec
from ..types.content import Messages
from ..types.content import ContentBlock, Message, Messages
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.streaming import StreamEvent
from ..types.tools import ToolSpec
from ..types.tools import ToolResult, ToolSpec
from .model import Model

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -181,7 +181,7 @@ def format_request(
"""
return {
"modelId": self.config["model_id"],
"messages": messages,
"messages": self._format_bedrock_messages(messages),
"system": [
*([{"text": system_prompt}] if system_prompt else []),
*([{"cachePoint": {"type": self.config["cache_prompt"]}}] if self.config.get("cache_prompt") else []),
Expand Down Expand Up @@ -246,6 +246,53 @@ def format_request(
),
}

def _format_bedrock_messages(self, messages: Messages) -> Messages:
"""Format messages for Bedrock API compatibility.

This function ensures messages conform to Bedrock's expected format by:
- Cleaning tool result content blocks by removing additional fields that may be
useful for retaining information in hooks but would cause Bedrock validation
exceptions when presented with unexpected fields
- Ensuring all message content blocks are properly formatted for the Bedrock API

Args:
messages: List of messages to format

Returns:
Messages formatted for Bedrock API compatibility

Note:
Bedrock will throw validation exceptions when presented with additional
unexpected fields in tool result blocks.
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html
"""
cleaned_messages = []

for message in messages:
cleaned_content: list[ContentBlock] = []

for content_block in message["content"]:
if "toolResult" in content_block:
# Create a new content block with only the cleaned toolResult
tool_result: ToolResult = content_block["toolResult"]

# Keep only the required fields for Bedrock
cleaned_tool_result = ToolResult(
content=tool_result["content"], toolUseId=tool_result["toolUseId"], status=tool_result["status"]
)

cleaned_block: ContentBlock = {"toolResult": cleaned_tool_result}
cleaned_content.append(cleaned_block)
else:
# Keep other content blocks as-is
cleaned_content.append(content_block)

# Create new message with cleaned content
cleaned_message: Message = Message(content=cleaned_content, role=message["role"])
cleaned_messages.append(cleaned_message)

return cleaned_messages

def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
"""Check if guardrail data contains any blocked policies.

Expand Down
49 changes: 36 additions & 13 deletions src/strands/tools/mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from ...types import PaginatedList
from ...types.exceptions import MCPClientInitializationError
from ...types.media import ImageFormat
from ...types.tools import ToolResult, ToolResultContent, ToolResultStatus
from ...types.tools import ToolResultContent, ToolResultStatus
from .mcp_agent_tool import MCPAgentTool
from .mcp_types import MCPTransport
from .mcp_types import MCPToolResult, MCPTransport

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -57,7 +57,8 @@ class MCPClient:
It handles the creation, initialization, and cleanup of MCP connections.

The connection runs in a background thread to avoid blocking the main application thread
while maintaining communication with the MCP service.
while maintaining communication with the MCP service. When structured content is available
from MCP tools, it will be returned as the last item in the content array of the ToolResult.
"""

def __init__(self, transport_callable: Callable[[], MCPTransport]):
Expand Down Expand Up @@ -170,11 +171,13 @@ def call_tool_sync(
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
) -> ToolResult:
) -> MCPToolResult:
"""Synchronously calls a tool on the MCP server.

This method calls the asynchronous call_tool method on the MCP session
and converts the result to the ToolResult format.
and converts the result to the ToolResult format. If the MCP tool returns
structured content, it will be included as the last item in the content array
of the returned ToolResult.

Args:
tool_use_id: Unique identifier for this tool use
Expand All @@ -183,7 +186,7 @@ def call_tool_sync(
read_timeout_seconds: Optional timeout for the tool call

Returns:
ToolResult: The result of the tool call
MCPToolResult: The result of the tool call
"""
self._log_debug_with_thread("calling MCP tool '%s' synchronously with tool_use_id=%s", name, tool_use_id)
if not self._is_session_active():
Expand All @@ -205,11 +208,11 @@ async def call_tool_async(
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
) -> ToolResult:
) -> MCPToolResult:
"""Asynchronously calls a tool on the MCP server.

This method calls the asynchronous call_tool method on the MCP session
and converts the result to the ToolResult format.
and converts the result to the MCPToolResult format.

Args:
tool_use_id: Unique identifier for this tool use
Expand All @@ -218,7 +221,7 @@ async def call_tool_async(
read_timeout_seconds: Optional timeout for the tool call

Returns:
ToolResult: The result of the tool call
MCPToolResult: The result of the tool call
"""
self._log_debug_with_thread("calling MCP tool '%s' asynchronously with tool_use_id=%s", name, tool_use_id)
if not self._is_session_active():
Expand All @@ -235,15 +238,27 @@ async def _call_tool_async() -> MCPCallToolResult:
logger.exception("tool execution failed")
return self._handle_tool_execution_error(tool_use_id, e)

def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> ToolResult:
def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> MCPToolResult:
"""Create error ToolResult with consistent logging."""
return ToolResult(
return MCPToolResult(
status="error",
toolUseId=tool_use_id,
content=[{"text": f"Tool execution failed: {str(exception)}"}],
)

def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolResult) -> ToolResult:
def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolResult) -> MCPToolResult:
"""Maps MCP tool result to the agent's MCPToolResult format.

This method processes the content from the MCP tool call result and converts it to the format
expected by the framework.

Args:
tool_use_id: Unique identifier for this tool use
call_tool_result: The result from the MCP tool call

Returns:
MCPToolResult: The converted tool result
"""
self._log_debug_with_thread("received tool result with %d content items", len(call_tool_result.content))

mapped_content = [
Expand All @@ -254,7 +269,15 @@ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolRes

status: ToolResultStatus = "error" if call_tool_result.isError else "success"
self._log_debug_with_thread("tool execution completed with status: %s", status)
return ToolResult(status=status, toolUseId=tool_use_id, content=mapped_content)
result = MCPToolResult(
status=status,
toolUseId=tool_use_id,
content=mapped_content,
)
if call_tool_result.structuredContent:
result["structuredContent"] = call_tool_result.structuredContent

return result

async def _async_background_thread(self) -> None:
"""Asynchronous method that runs in the background thread to manage the MCP connection.
Expand Down
20 changes: 20 additions & 0 deletions src/strands/tools/mcp/mcp_types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""Type definitions for MCP integration."""

from contextlib import AbstractAsyncContextManager
from typing import Any, Dict

from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.client.streamable_http import GetSessionIdCallback
from mcp.shared.memory import MessageStream
from mcp.shared.message import SessionMessage
from typing_extensions import NotRequired

from strands.types.tools import ToolResult

"""
MCPTransport defines the interface for MCP transport implementations. This abstracts
Expand Down Expand Up @@ -41,3 +45,19 @@ async def my_transport_implementation():
MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], GetSessionIdCallback
]
MCPTransport = AbstractAsyncContextManager[MessageStream | _MessageStreamWithGetSessionIdCallback]


class MCPToolResult(ToolResult):
"""Result of an MCP tool execution.

Extends the base ToolResult with MCP-specific structured content support.
The structuredContent field contains optional JSON data returned by MCP tools
that provides structured results beyond the standard text/image/document content.

Attributes:
structuredContent: Optional JSON object containing structured data returned
by the MCP tool. This allows MCP tools to return complex data structures
that can be processed programmatically by agents or other tools.
"""

structuredContent: NotRequired[Dict[str, Any]]
Loading
Loading