diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 7c63c1e89..b62501146 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -33,6 +33,7 @@ from .. import _identifier from .._async import run_async from ..event_loop.event_loop import event_loop_cycle +from ..tools._tool_helpers import generate_missing_tool_result_content if TYPE_CHECKING: from ..experimental.tools import ToolProvider @@ -280,7 +281,7 @@ def __init__( Defaults to None. session_manager: Manager for handling agent sessions including conversation history and state. If provided, enables session-based persistence and state management. - tool_executor: Definition of tool execution stragety (e.g., sequential, concurrent, etc.). + tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). Raises: ValueError: If agent id contains path separators. @@ -816,6 +817,21 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: messages: Messages | None = None if prompt is not None: + # Check if the latest message is toolUse + if len(self.messages) > 0 and any("toolUse" in content for content in self.messages[-1]["content"]): + # Add toolResult message after to have a valid conversation + logger.info( + "Agents latest message is toolUse, appending a toolResult message to have valid conversation." + ) + tool_use_ids = [ + content["toolUse"]["toolUseId"] for content in self.messages[-1]["content"] if "toolUse" in content + ] + self._append_message( + { + "role": "user", + "content": generate_missing_tool_result_content(tool_use_ids), + } + ) if isinstance(prompt, str): # String input - convert to user message messages = [{"role": "user", "content": [{"text": prompt}]}] diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index 86c6044a6..a042452d3 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Optional from ..agent.state import AgentState +from ..tools._tool_helpers import generate_missing_tool_result_content from ..types.content import Message from ..types.exceptions import SessionException from ..types.session import ( @@ -159,6 +160,50 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: # Restore the agents messages array including the optional prepend messages agent.messages = prepend_messages + [session_message.to_message() for session_message in session_messages] + # Fix broken session histories: https://github.com/strands-agents/sdk-python/issues/859 + agent.messages = self._fix_broken_tool_use(agent.messages) + + def _fix_broken_tool_use(self, messages: list[Message]) -> list[Message]: + """Add tool_result after orphaned tool_use messages. + + Before 1.15.0, strands had a bug where they persisted sessions with a potentially broken messages array. + This method retroactively fixes that issue by adding a tool_result outside of session management. After 1.15.0, + this bug is no longer present. + """ + for index, message in enumerate(messages): + # Check all but the latest message in the messages array + # The latest message being orphaned is handled in the agent class + if index + 1 < len(messages): + if any("toolUse" in content for content in message["content"]): + tool_use_ids = [ + content["toolUse"]["toolUseId"] for content in message["content"] if "toolUse" in content + ] + + # Check if there are more messages after the current toolUse message + tool_result_ids = [ + content["toolResult"]["toolUseId"] + for content in messages[index + 1]["content"] + if "toolResult" in content + ] + + missing_tool_use_ids = list(set(tool_use_ids) - set(tool_result_ids)) + # If there area missing tool use ids, that means the messages history is broken + if missing_tool_use_ids: + logger.warning( + "Session message history has an orphaned toolUse with no toolResult. " + "Adding toolResult content blocks to create valid conversation." + ) + # Create the missing toolResult content blocks + missing_content_blocks = generate_missing_tool_result_content(missing_tool_use_ids) + + if tool_result_ids: + # If there were any toolResult ids, that means only some of the content blocks are missing + messages[index + 1]["content"].extend(missing_content_blocks) + else: + # The message following the toolUse was not a toolResult, so lets insert it + messages.insert(index + 1, {"role": "user", "content": missing_content_blocks}) + return messages + def sync_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None: """Serialize and update the multi-agent state into the session repository. diff --git a/src/strands/tools/_tool_helpers.py b/src/strands/tools/_tool_helpers.py index d640f23b8..d023caeec 100644 --- a/src/strands/tools/_tool_helpers.py +++ b/src/strands/tools/_tool_helpers.py @@ -1,6 +1,7 @@ """Helpers for tools.""" -from strands.tools.decorator import tool +from ..tools.decorator import tool +from ..types.content import ContentBlock # https://github.com/strands-agents/sdk-python/issues/998 @@ -13,3 +14,17 @@ def noop_tool() -> None: summarization will fail. As a workaround, we register the no-op tool. """ pass + + +def generate_missing_tool_result_content(tool_use_ids: list[str]) -> list[ContentBlock]: + """Generate ToolResult content blocks for orphaned ToolUse message.""" + return [ + { + "toolResult": { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": "Tool was interrupted."}], + } + } + for tool_use_id in tool_use_ids + ] diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 52840f1a2..6c04c45c4 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -2215,3 +2215,143 @@ def test_redact_user_content(content, expected): agent = Agent() result = agent._redact_user_content(content, "REDACTED") assert result == expected + + +def test_agent_fixes_orphaned_tool_use_on_new_prompt(mock_model, agenerator): + """Test that agent adds toolResult for orphaned toolUse when called with new prompt.""" + mock_model.mock_stream.return_value = agenerator( + [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {"text": ""}}}, + {"contentBlockDelta": {"delta": {"text": "Fixed!"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ) + + # Start with orphaned toolUse message + messages = [ + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "orphaned-123", "name": "tool_decorated", "input": {"random_string": "test"}}} + ], + } + ] + + agent = Agent(model=mock_model, messages=messages) + + # Call with new prompt should fix orphaned toolUse + agent("Continue conversation") + + # Should have added toolResult message + assert len(agent.messages) >= 3 + assert agent.messages[1] == { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "orphaned-123", + "status": "error", + "content": [{"text": "Tool was interrupted."}], + } + } + ], + } + + +def test_agent_fixes_multiple_orphaned_tool_uses(mock_model, agenerator): + """Test that agent handles multiple orphaned toolUse messages.""" + mock_model.mock_stream.return_value = agenerator( + [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {"text": ""}}}, + {"contentBlockDelta": {"delta": {"text": "Fixed multiple!"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ) + + messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "orphaned-123", + "name": "tool_decorated", + "input": {"random_string": "test1"}, + } + }, + { + "toolUse": { + "toolUseId": "orphaned-456", + "name": "tool_decorated", + "input": {"random_string": "test2"}, + } + }, + ], + } + ] + + agent = Agent(model=mock_model, messages=messages) + agent("Continue") + + # Should have toolResult for both toolUse IDs + assert agent.messages[1] == { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "orphaned-123", + "status": "error", + "content": [{"text": "Tool was interrupted."}], + } + }, + { + "toolResult": { + "toolUseId": "orphaned-456", + "status": "error", + "content": [{"text": "Tool was interrupted."}], + } + }, + ], + } + + +def test_agent_skips_fix_for_valid_conversation(mock_model, agenerator): + """Test that agent doesn't modify valid toolUse/toolResult pairs.""" + mock_model.mock_stream.return_value = agenerator( + [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {"text": ""}}}, + {"contentBlockDelta": {"delta": {"text": "No fix needed!"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ) + + # Valid conversation with toolUse followed by toolResult + messages = [ + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "valid-123", "name": "tool_decorated", "input": {"random_string": "test"}}} + ], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "valid-123", "status": "success", "content": [{"text": "result"}]}} + ], + }, + ] + + agent = Agent(model=mock_model, messages=messages) + original_length = len(agent.messages) + + agent("Continue") + + # Should not have added any toolResult messages + # Only the new user message and assistant response should be added + assert len(agent.messages) == original_length + 2 diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index e346f01e0..ed0ec9072 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -233,3 +233,183 @@ def test_initialize_multi_agent_existing(session_manager, mock_multi_agent): # Verify deserialize_state was called with existing state mock_multi_agent.deserialize_state.assert_called_once_with(existing_state) + + +def test_fix_broken_tool_use_adds_missing_tool_results(session_manager): + """Test that _fix_broken_tool_use adds missing toolResult messages.""" + conversation_manager = SlidingWindowConversationManager() + + # Create agent in repository first + session_agent = SessionAgent( + agent_id="existing-agent", + state={"key": "value"}, + conversation_manager_state=conversation_manager.get_state(), + ) + session_manager.session_repository.create_agent("test-session", session_agent) + + broken_messages = [ + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "orphaned-123", "name": "test_tool", "input": {"input": "test"}}}], + }, + {"role": "user", "content": [{"text": "Some other message"}]}, + ] + # Create some session messages + for index, broken_message in enumerate(broken_messages): + broken_session_message = SessionMessage( + message=broken_message, + message_id=index, + ) + session_manager.session_repository.create_message("test-session", "existing-agent", broken_session_message) + + # Initialize agent + agent = Agent(agent_id="existing-agent") + session_manager.initialize(agent) + + fixed_messages = agent.messages + + # Should insert toolResult message between toolUse and other message + assert len(fixed_messages) == 3 + assert "toolResult" in fixed_messages[1]["content"][0] + assert fixed_messages[1]["content"][0]["toolResult"]["toolUseId"] == "orphaned-123" + assert fixed_messages[1]["content"][0]["toolResult"]["status"] == "error" + assert fixed_messages[1]["content"][0]["toolResult"]["content"][0]["text"] == "Tool was interrupted." + + +def test_fix_broken_tool_use_extends_partial_tool_results(session_manager): + """Test fixing messages where some toolResults are missing.""" + conversation_manager = SlidingWindowConversationManager() + # Create agent in repository first + session_agent = SessionAgent( + agent_id="existing-agent", + state={"key": "value"}, + conversation_manager_state=conversation_manager.get_state(), + ) + session_manager.session_repository.create_agent("test-session", session_agent) + + broken_messages = [ + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "complete-123", "name": "test_tool", "input": {"input": "test1"}}}, + {"toolUse": {"toolUseId": "missing-456", "name": "test_tool", "input": {"input": "test2"}}}, + ], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "complete-123", "status": "success", "content": [{"text": "result"}]}} + ], + }, + ] + # Create some session messages + for index, broken_message in enumerate(broken_messages): + broken_session_message = SessionMessage( + message=broken_message, + message_id=index, + ) + session_manager.session_repository.create_message("test-session", "existing-agent", broken_session_message) + + # Initialize agent + agent = Agent(agent_id="existing-agent") + session_manager.initialize(agent) + + fixed_messages = agent.messages + + # Should add missing toolResult to existing message + assert len(fixed_messages) == 2 + assert len(fixed_messages[1]["content"]) == 2 + + tool_use_ids = {tr["toolResult"]["toolUseId"] for tr in fixed_messages[1]["content"]} + assert tool_use_ids == {"complete-123", "missing-456"} + + # Check the added toolResult has correct properties + missing_result = next(tr for tr in fixed_messages[1]["content"] if tr["toolResult"]["toolUseId"] == "missing-456") + assert missing_result["toolResult"]["status"] == "error" + assert missing_result["toolResult"]["content"][0]["text"] == "Tool was interrupted." + + +def test_fix_broken_tool_use_handles_multiple_orphaned_tools(session_manager): + """Test fixing multiple orphaned toolUse messages.""" + + conversation_manager = SlidingWindowConversationManager() + # Create agent in repository first + session_agent = SessionAgent( + agent_id="existing-agent", + state={"key": "value"}, + conversation_manager_state=conversation_manager.get_state(), + ) + session_manager.session_repository.create_agent("test-session", session_agent) + + broken_messages = [ + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "orphaned-123", "name": "test_tool", "input": {"input": "test1"}}}, + {"toolUse": {"toolUseId": "orphaned-456", "name": "test_tool", "input": {"input": "test2"}}}, + ], + }, + {"role": "user", "content": [{"text": "Next message"}]}, + ] + # Create some session messages + for index, broken_message in enumerate(broken_messages): + broken_session_message = SessionMessage( + message=broken_message, + message_id=index, + ) + session_manager.session_repository.create_message("test-session", "existing-agent", broken_session_message) + + # Initialize agent + agent = Agent(agent_id="existing-agent") + session_manager.initialize(agent) + + fixed_messages = agent.messages + + # Should insert message with both toolResults + assert len(fixed_messages) == 3 + assert len(fixed_messages[1]["content"]) == 2 + + tool_use_ids = {tr["toolResult"]["toolUseId"] for tr in fixed_messages[1]["content"]} + assert tool_use_ids == {"orphaned-123", "orphaned-456"} + + +def test_fix_broken_tool_use_ignores_last_message(session_manager): + """Test that orphaned toolUse in the last message is not fixed.""" + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "last-message-123", "name": "test_tool", "input": {"input": "test"}}} + ], + }, + ] + + fixed_messages = session_manager._fix_broken_tool_use(messages) + + # Should remain unchanged since toolUse is in last message + assert fixed_messages == messages + + +def test_fix_broken_tool_use_does_not_change_valid_message(session_manager): + """Test that orphaned toolUse in the last message is not fixed.""" + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "last-message-123", "name": "test_tool", "input": {"input": "test"}}} + ], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "last-message-123", "input": {"input": "test"}, "status": "success"}} + ], + }, + ] + + fixed_messages = session_manager._fix_broken_tool_use(messages) + + # Should remain unchanged since toolUse is in last message + assert fixed_messages == messages