Skip to content

Commit bd8b096

Browse files
committed
feat: switch to using MCPToolResult strategy
1 parent a87b1a5 commit bd8b096

File tree

7 files changed

+323
-32
lines changed

7 files changed

+323
-32
lines changed

src/strands/models/bedrock.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import json
88
import logging
99
import os
10-
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
10+
from typing import Any, AsyncGenerator, Callable, Iterable, List, Literal, Optional, Type, TypeVar, Union
1111

1212
import boto3
1313
from botocore.config import Config as BotocoreConfig
@@ -17,10 +17,10 @@
1717

1818
from ..event_loop import streaming
1919
from ..tools import convert_pydantic_to_tool_spec
20-
from ..types.content import Messages
20+
from ..types.content import ContentBlock, Message, Messages
2121
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
2222
from ..types.streaming import StreamEvent
23-
from ..types.tools import ToolSpec
23+
from ..types.tools import ToolResult, ToolSpec
2424
from .model import Model
2525

2626
logger = logging.getLogger(__name__)
@@ -181,7 +181,7 @@ def format_request(
181181
"""
182182
return {
183183
"modelId": self.config["model_id"],
184-
"messages": messages,
184+
"messages": self._clean_tool_result_content_blocks(messages),
185185
"system": [
186186
*([{"text": system_prompt}] if system_prompt else []),
187187
*([{"cachePoint": {"type": self.config["cache_prompt"]}}] if self.config.get("cache_prompt") else []),
@@ -246,6 +246,42 @@ def format_request(
246246
),
247247
}
248248

249+
def _clean_tool_result_content_blocks(self, messages: Messages) -> Messages:
250+
"""Additional fields may be added to ToolResult, like MCPToolResult. These can be useful for retaining
251+
information to be used later in hooks.
252+
253+
However, Bedrock will throw validation exceptions when presented with additional unexpected fields.
254+
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html
255+
"""
256+
257+
cleaned_messages = []
258+
259+
for message in messages:
260+
cleaned_content: List[ContentBlock] = []
261+
262+
for content_block in message["content"]:
263+
if "toolResult" in content_block:
264+
# Create a new content block with only the cleaned toolResult
265+
cleaned_block: ContentBlock = content_block.copy()
266+
tool_result: ToolResult = content_block["toolResult"]
267+
268+
# Keep only the required fields for Bedrock
269+
cleaned_tool_result = ToolResult(
270+
content=tool_result["content"], toolUseId=tool_result["toolUseId"], status=tool_result["status"]
271+
)
272+
273+
cleaned_block["toolResult"] = cleaned_tool_result
274+
cleaned_content.append(cleaned_block)
275+
else:
276+
# Keep other content blocks as-is
277+
cleaned_content.append(content_block)
278+
279+
# Create new message with cleaned content
280+
cleaned_message: Message = Message(content=cleaned_content, role=message["role"])
281+
cleaned_messages.append(cleaned_message)
282+
283+
return cleaned_messages
284+
249285
def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
250286
"""Check if guardrail data contains any blocked policies.
251287

src/strands/tools/mcp/mcp_client.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
from ...types import PaginatedList
2727
from ...types.exceptions import MCPClientInitializationError
2828
from ...types.media import ImageFormat
29-
from ...types.tools import ToolResult, ToolResultContent, ToolResultStatus
29+
from ...types.tools import ToolResultContent, ToolResultStatus
3030
from .mcp_agent_tool import MCPAgentTool
31-
from .mcp_types import MCPTransport
31+
from .mcp_types import MCPToolResult, MCPTransport
3232

3333
logger = logging.getLogger(__name__)
3434

@@ -171,7 +171,7 @@ def call_tool_sync(
171171
name: str,
172172
arguments: dict[str, Any] | None = None,
173173
read_timeout_seconds: timedelta | None = None,
174-
) -> ToolResult:
174+
) -> MCPToolResult:
175175
"""Synchronously calls a tool on the MCP server.
176176
177177
This method calls the asynchronous call_tool method on the MCP session
@@ -186,7 +186,7 @@ def call_tool_sync(
186186
read_timeout_seconds: Optional timeout for the tool call
187187
188188
Returns:
189-
ToolResult: The result of the tool call
189+
MCPToolResult: The result of the tool call
190190
"""
191191
self._log_debug_with_thread("calling MCP tool '%s' synchronously with tool_use_id=%s", name, tool_use_id)
192192
if not self._is_session_active():
@@ -208,13 +208,11 @@ async def call_tool_async(
208208
name: str,
209209
arguments: dict[str, Any] | None = None,
210210
read_timeout_seconds: timedelta | None = None,
211-
) -> ToolResult:
211+
) -> MCPToolResult:
212212
"""Asynchronously calls a tool on the MCP server.
213213
214214
This method calls the asynchronous call_tool method on the MCP session
215-
and converts the result to the ToolResult format. If the MCP tool returns
216-
structured content, it will be included as the last item in the content array
217-
of the returned ToolResult.
215+
and converts the result to the MCPToolResult format.
218216
219217
Args:
220218
tool_use_id: Unique identifier for this tool use
@@ -223,7 +221,7 @@ async def call_tool_async(
223221
read_timeout_seconds: Optional timeout for the tool call
224222
225223
Returns:
226-
ToolResult: The result of the tool call
224+
MCPToolResult: The result of the tool call
227225
"""
228226
self._log_debug_with_thread("calling MCP tool '%s' asynchronously with tool_use_id=%s", name, tool_use_id)
229227
if not self._is_session_active():
@@ -240,27 +238,26 @@ async def _call_tool_async() -> MCPCallToolResult:
240238
logger.exception("tool execution failed")
241239
return self._handle_tool_execution_error(tool_use_id, e)
242240

243-
def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> ToolResult:
241+
def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> MCPToolResult:
244242
"""Create error ToolResult with consistent logging."""
245-
return ToolResult(
243+
return MCPToolResult(
246244
status="error",
247245
toolUseId=tool_use_id,
248246
content=[{"text": f"Tool execution failed: {str(exception)}"}],
249247
)
250248

251-
def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolResult) -> ToolResult:
252-
"""Maps MCP tool result to the agent's ToolResult format.
249+
def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolResult) -> MCPToolResult:
250+
"""Maps MCP tool result to the agent's MCPToolResult format.
253251
254252
This method processes the content from the MCP tool call result and converts it to the format
255-
expected by the agent framework. If structured content is available in the MCP tool result,
256-
it will be appended as the last item in the content array of the returned ToolResult.
253+
expected by the framework.
257254
258255
Args:
259256
tool_use_id: Unique identifier for this tool use
260257
call_tool_result: The result from the MCP tool call
261258
262259
Returns:
263-
ToolResult: The converted tool result
260+
MCPToolResult: The converted tool result
264261
"""
265262
self._log_debug_with_thread("received tool result with %d content items", len(call_tool_result.content))
266263

@@ -270,12 +267,14 @@ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolRes
270267
if (mapped_content := self._map_mcp_content_to_tool_result_content(content)) is not None
271268
]
272269

273-
if call_tool_result.structuredContent:
274-
mapped_content.append({"json": call_tool_result.structuredContent})
275-
276270
status: ToolResultStatus = "error" if call_tool_result.isError else "success"
277271
self._log_debug_with_thread("tool execution completed with status: %s", status)
278-
return ToolResult(status=status, toolUseId=tool_use_id, content=mapped_content)
272+
return MCPToolResult(
273+
status=status,
274+
toolUseId=tool_use_id,
275+
content=mapped_content,
276+
structuredContent=call_tool_result.structuredContent,
277+
)
279278

280279
async def _async_background_thread(self) -> None:
281280
"""Asynchronous method that runs in the background thread to manage the MCP connection.

src/strands/tools/mcp/mcp_types.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
"""Type definitions for MCP integration."""
22

33
from contextlib import AbstractAsyncContextManager
4+
from typing import Any, Dict
45

56
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
67
from mcp.client.streamable_http import GetSessionIdCallback
78
from mcp.shared.memory import MessageStream
89
from mcp.shared.message import SessionMessage
10+
from typing_extensions import NotRequired
11+
12+
from strands.types.tools import ToolResult
913

1014
"""
1115
MCPTransport defines the interface for MCP transport implementations. This abstracts
@@ -41,3 +45,19 @@ async def my_transport_implementation():
4145
MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], GetSessionIdCallback
4246
]
4347
MCPTransport = AbstractAsyncContextManager[MessageStream | _MessageStreamWithGetSessionIdCallback]
48+
49+
50+
class MCPToolResult(ToolResult):
51+
"""Result of an MCP tool execution.
52+
53+
Extends the base ToolResult with MCP-specific structured content support.
54+
The structuredContent field contains optional JSON data returned by MCP tools
55+
that provides structured results beyond the standard text/image/document content.
56+
57+
Attributes:
58+
structuredContent: Optional JSON object containing structured data returned
59+
by the MCP tool. This allows MCP tools to return complex data structures
60+
that can be processed programmatically by agents or other tools.
61+
"""
62+
63+
structuredContent: NotRequired[Dict[str, Any]]

tests/strands/models/test_bedrock.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def model(bedrock_client, model_id):
5151

5252
@pytest.fixture
5353
def messages():
54-
return [{"role": "user", "content": {"text": "test"}}]
54+
return [{"role": "user", "content": [{"text": "test"}]}]
5555

5656

5757
@pytest.fixture
@@ -1202,3 +1202,37 @@ async def test_stream_logging(bedrock_client, model, messages, caplog, alist):
12021202
assert "invoking model" in log_text
12031203
assert "got response from model" in log_text
12041204
assert "finished streaming response from model" in log_text
1205+
1206+
1207+
def test_format_request_cleans_tool_result_content_blocks(model, model_id):
1208+
"""Test that format_request cleans toolResult blocks by removing extra fields."""
1209+
messages = [
1210+
{
1211+
"role": "user",
1212+
"content": [
1213+
{"text": "Hello"},
1214+
{
1215+
"toolResult": {
1216+
"content": [{"text": "Tool output"}],
1217+
"toolUseId": "tool-123",
1218+
"status": "success",
1219+
"extraField": "should be removed",
1220+
"mcpMetadata": {"server": "test"},
1221+
}
1222+
},
1223+
],
1224+
}
1225+
]
1226+
1227+
request = model.format_request(messages)
1228+
1229+
# Verify the request structure
1230+
assert request["modelId"] == model_id
1231+
assert "messages" in request
1232+
1233+
# Verify toolResult only contains allowed fields in the formatted request
1234+
tool_result = request["messages"][0]["content"][1]["toolResult"]
1235+
expected = {"content": [{"text": "Tool output"}], "toolUseId": "tool-123", "status": "success"}
1236+
assert tool_result == expected
1237+
assert "extraField" not in tool_result
1238+
assert "mcpMetadata" not in tool_result

0 commit comments

Comments
 (0)