Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 75 additions & 14 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,8 +681,12 @@ def _stream(
logger.debug("got response from model")
if streaming:
response = self.client.converse_stream(**request)
# Track tool use events to fix stopReason for streaming responses
has_tool_use = False
# Track tool use/result events to fix stopReason for streaming responses
# We need to distinguish server-side tools (already executed) from client-side tools
tool_use_info: dict[str, str] = {} # toolUseId -> type (e.g., "server_tool_use")
tool_result_ids: set[str] = set() # IDs of tools with results
has_client_tools = False

for chunk in response["stream"]:
if (
"metadata" in chunk
Expand All @@ -694,22 +698,40 @@ def _stream(
for event in self._generate_redaction_events():
callback(event)

# Track if we see tool use events
if "contentBlockStart" in chunk and chunk["contentBlockStart"].get("start", {}).get("toolUse"):
has_tool_use = True
# Track tool use events with their types
if "contentBlockStart" in chunk:
tool_use_start = chunk["contentBlockStart"].get("start", {}).get("toolUse")
if tool_use_start:
tool_use_id = tool_use_start.get("toolUseId", "")
tool_type = tool_use_start.get("type", "")
tool_use_info[tool_use_id] = tool_type
# Check if it's a client-side tool (not server_tool_use)
if tool_type != "server_tool_use":
has_client_tools = True

# Track tool result events (for server-side tools that were already executed)
tool_result_start = chunk["contentBlockStart"].get("start", {}).get("toolResult")
if tool_result_start:
tool_result_ids.add(tool_result_start.get("toolUseId", ""))

# Fix stopReason for streaming responses that contain tool use
# BUT: Only override if there are client-side tools without results
if (
has_tool_use
and "messageStop" in chunk
"messageStop" in chunk
and (message_stop := chunk["messageStop"]).get("stopReason") == "end_turn"
):
# Create corrected chunk with tool_use stopReason
modified_chunk = chunk.copy()
modified_chunk["messageStop"] = message_stop.copy()
modified_chunk["messageStop"]["stopReason"] = "tool_use"
logger.warning("Override stop reason from end_turn to tool_use")
callback(modified_chunk)
# Check if we have client-side tools that need execution
needs_execution = has_client_tools and not set(tool_use_info.keys()).issubset(tool_result_ids)

if needs_execution:
# Create corrected chunk with tool_use stopReason
modified_chunk = chunk.copy()
modified_chunk["messageStop"] = message_stop.copy()
modified_chunk["messageStop"]["stopReason"] = "tool_use"
logger.warning("Override stop reason from end_turn to tool_use")
callback(modified_chunk)
else:
callback(chunk)
else:
callback(chunk)

Expand Down Expand Up @@ -771,6 +793,43 @@ def _stream(
callback()
logger.debug("finished streaming response from model")

def _has_client_side_tools_to_execute(self, message_content: list[dict[str, Any]]) -> bool:
"""Check if message contains client-side tools that need execution.

Server-side tools (like nova_grounding) are executed by Bedrock and include
toolResult blocks in the response. We should NOT override stopReason to
"tool_use" for these tools.

Args:
message_content: The content array from Bedrock response.

Returns:
True if there are client-side tools without results, False otherwise.
"""
tool_use_ids = set()
tool_result_ids = set()
has_client_tools = False

for content in message_content:
if "toolUse" in content:
tool_use = content["toolUse"]
tool_use_ids.add(tool_use["toolUseId"])

# Check if it's a server-side tool (Bedrock executes these)
if tool_use.get("type") != "server_tool_use":
has_client_tools = True

elif "toolResult" in content:
# Track which tools already have results
tool_result_ids.add(content["toolResult"]["toolUseId"])

# Only return True if there are client-side tools without results
if not has_client_tools:
return False

# Check if all tool uses have corresponding results
return not tool_use_ids.issubset(tool_result_ids)

def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]:
"""Convert a non-streaming response to the streaming format.

Expand Down Expand Up @@ -851,10 +910,12 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera

# Yield messageStop event
# Fix stopReason for models that return end_turn when they should return tool_use on non-streaming side
# BUT: Don't override for server-side tools (like nova_grounding) that are already executed
current_stop_reason = response["stopReason"]
if current_stop_reason == "end_turn":
message_content = response["output"]["message"]["content"]
if any("toolUse" in content for content in message_content):
# Only override if there are client-side tools that need execution
if self._has_client_side_tools_to_execute(message_content):
current_stop_reason = "tool_use"
logger.warning("Override stop reason from end_turn to tool_use")

Expand Down
167 changes: 167 additions & 0 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2070,3 +2070,170 @@ async def test_stream_backward_compatibility_system_prompt(bedrock_client, model
"system": [{"text": system_prompt}],
}
bedrock_client.converse_stream.assert_called_once_with(**expected_request)


def test_has_client_side_tools_to_execute_with_client_tools(model):
"""Test that client-side tools are correctly identified as needing execution."""
message_content = [
{
"toolUse": {
"toolUseId": "tool-123",
"name": "my_tool",
"input": {"param": "value"},
}
}
]

assert model._has_client_side_tools_to_execute(message_content) is True


def test_has_client_side_tools_to_execute_with_server_tools(model):
"""Test that server-side tools (like nova_grounding) are NOT identified as needing execution."""
message_content = [
{
"toolUse": {
"toolUseId": "tool-123",
"name": "nova_grounding",
"type": "server_tool_use",
"input": {},
}
},
{
"toolResult": {
"toolUseId": "tool-123",
"content": [{"text": "Grounding result"}],
}
},
]

assert model._has_client_side_tools_to_execute(message_content) is False


def test_has_client_side_tools_to_execute_with_mixed_tools(model):
"""Test mixed server and client tools - should return True if client tools need execution."""
message_content = [
# Server-side tool with result
{
"toolUse": {
"toolUseId": "server-tool-123",
"name": "nova_grounding",
"type": "server_tool_use",
"input": {},
}
},
{
"toolResult": {
"toolUseId": "server-tool-123",
"content": [{"text": "Grounding result"}],
}
},
# Client-side tool without result
{
"toolUse": {
"toolUseId": "client-tool-456",
"name": "my_tool",
"input": {"param": "value"},
}
},
]

assert model._has_client_side_tools_to_execute(message_content) is True


def test_has_client_side_tools_to_execute_with_no_tools(model):
"""Test that no tools returns False."""
message_content = [{"text": "Just some text"}]

assert model._has_client_side_tools_to_execute(message_content) is False


@pytest.mark.asyncio
async def test_stream_server_tool_use_does_not_override_stop_reason(bedrock_client, alist, messages):
"""Test that stopReason is NOT overridden for server-side tools like nova_grounding."""
model = BedrockModel(model_id="amazon.nova-premier-v1:0")
model.client = bedrock_client

# Simulate streaming response with server-side tool use and result
bedrock_client.converse_stream.return_value = {
"stream": [
{"messageStart": {"role": "assistant"}},
{
"contentBlockStart": {
"start": {
"toolUse": {
"toolUseId": "tool-123",
"name": "nova_grounding",
"type": "server_tool_use",
}
}
}
},
{"contentBlockDelta": {"delta": {"toolUse": {"input": "{}"}}}},
{"contentBlockStop": {}},
{
"contentBlockStart": {
"start": {
"toolResult": {
"toolUseId": "tool-123",
}
}
}
},
{"contentBlockDelta": {"delta": {"text": "Grounding result"}}},
{"contentBlockStop": {}},
{"contentBlockStart": {"start": {}}},
{"contentBlockDelta": {"delta": {"text": "Final response"}}},
{"contentBlockStop": {}},
{"messageStop": {"stopReason": "end_turn"}},
]
}

events = await alist(model.stream(messages))

# Find the messageStop event
message_stop_event = next(e for e in events if "messageStop" in e)

# Verify stopReason was NOT overridden (should remain end_turn for server-side tools)
assert message_stop_event["messageStop"]["stopReason"] == "end_turn"


@pytest.mark.asyncio
async def test_stream_non_streaming_server_tool_use_does_not_override_stop_reason(bedrock_client, alist, messages):
"""Test that stopReason is NOT overridden for server-side tools in non-streaming mode."""
model = BedrockModel(model_id="amazon.nova-premier-v1:0", streaming=False)
model.client = bedrock_client

bedrock_client.converse.return_value = {
"output": {
"message": {
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": "tool-123",
"name": "nova_grounding",
"type": "server_tool_use",
"input": {},
}
},
{
"toolResult": {
"toolUseId": "tool-123",
"content": [{"text": "Grounding result"}],
}
},
{"text": "Final response based on grounding"},
],
}
},
"stopReason": "end_turn",
"usage": {"inputTokens": 10, "outputTokens": 20},
}

events = await alist(model.stream(messages))

# Find the messageStop event
message_stop_event = next(e for e in events if "messageStop" in e)

# Verify stopReason was NOT overridden (should remain end_turn for server-side tools)
assert message_stop_event["messageStop"]["stopReason"] == "end_turn"