diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 8233c4bfe..1e64f5adb 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -50,7 +50,7 @@ from ..tools.executors._executor import ToolExecutor from ..tools.registry import ToolRegistry from ..tools.watcher import ToolWatcher -from ..types._events import InitEventLoopEvent +from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent from ..types.agent import AgentInput from ..types.content import ContentBlock, Message, Messages from ..types.exceptions import ContextWindowOverflowException @@ -576,13 +576,16 @@ async def stream_async( events = self._run_loop(messages, invocation_state=kwargs) async for event in events: - if "callback" in event: - callback_handler(**event["callback"]) - yield event["callback"] + event.prepare(invocation_state=kwargs) + + if event.is_callback_event: + as_dict = event.as_dict() + callback_handler(**as_dict) + yield as_dict result = AgentResult(*event["stop"]) callback_handler(result=result) - yield {"result": result} + yield AgentResultEvent(result=result).as_dict() self._end_agent_trace_span(response=result) @@ -590,9 +593,7 @@ async def stream_async( self._end_agent_trace_span(error=e) raise - async def _run_loop( - self, messages: Messages, invocation_state: dict[str, Any] - ) -> AsyncGenerator[dict[str, Any], None]: + async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: """Execute the agent's event loop with the given message and parameters. Args: @@ -605,7 +606,7 @@ async def _run_loop( self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) try: - yield InitEventLoopEvent(invocation_state) + yield InitEventLoopEvent() for message in messages: self._append_message(message) @@ -616,13 +617,13 @@ async def _run_loop( # Signal from the model provider that the message sent by the user should be redacted, # likely due to a guardrail. if ( - event.get("callback") - and event["callback"].get("event") - and event["callback"]["event"].get("redactContent") - and event["callback"]["event"]["redactContent"].get("redactUserContentMessage") + isinstance(event, ModelStreamChunkEvent) + and event.chunk + and event.chunk.get("redactContent") + and event.chunk["redactContent"].get("redactUserContentMessage") ): self.messages[-1]["content"] = [ - {"text": event["callback"]["event"]["redactContent"]["redactUserContentMessage"]} + {"text": str(event.chunk["redactContent"]["redactUserContentMessage"])} ] if self._session_manager: self._session_manager.redact_latest_message(self.messages[-1], self) @@ -632,7 +633,7 @@ async def _run_loop( self.conversation_manager.apply_management(self) self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) - async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: + async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: """Execute the event loop cycle with retry logic for context window limits. This internal method handles the execution of the event loop cycle and implements diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index a166902eb..a99ecc8a6 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -30,9 +30,11 @@ EventLoopThrottleEvent, ForceStopEvent, ModelMessageEvent, + ModelStopReason, StartEvent, StartEventLoopEvent, ToolResultMessageEvent, + TypedEvent, ) from ..types.content import Message from ..types.exceptions import ( @@ -56,7 +58,7 @@ MAX_DELAY = 240 # 4 minutes -async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: +async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: """Execute a single cycle of the event loop. This core function processes a single conversation turn, handling model inference, tool execution, and error @@ -139,17 +141,9 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> ) try: - # TODO: To maintain backwards compatibility, we need to combine the stream event with invocation_state - # before yielding to the callback handler. This will be revisited when migrating to strongly - # typed events. async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs): - if "callback" in event: - yield { - "callback": { - **event["callback"], - **(invocation_state if "delta" in event["callback"] else {}), - } - } + if not isinstance(event, ModelStopReason): + yield event stop_reason, message, usage, metrics = event["stop"] invocation_state.setdefault("request_state", {}) @@ -198,7 +192,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> time.sleep(current_delay) current_delay = min(current_delay * 2, MAX_DELAY) - yield EventLoopThrottleEvent(delay=current_delay, invocation_state=invocation_state) + yield EventLoopThrottleEvent(delay=current_delay) else: raise e @@ -280,7 +274,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) -async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: +async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: """Make a recursive call to event_loop_cycle with the current state. This function is used when the event loop needs to continue processing after tool execution. @@ -321,7 +315,7 @@ async def _handle_tool_execution( cycle_span: Any, cycle_start_time: float, invocation_state: dict[str, Any], -) -> AsyncGenerator[dict[str, Any], None]: +) -> AsyncGenerator[TypedEvent, None]: """Handles the execution of tools requested by the model during an event loop cycle. Args: diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 9999b77fc..701a3bac0 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -7,15 +7,16 @@ import abc import logging import time -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, AsyncGenerator, cast from opentelemetry import trace as trace_api from ...experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent from ...telemetry.metrics import Trace from ...telemetry.tracer import get_tracer +from ...types._events import ToolResultEvent, ToolStreamEvent, TypedEvent from ...types.content import Message -from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse +from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolResult, ToolUse if TYPE_CHECKING: # pragma: no cover from ...agent import Agent @@ -33,7 +34,7 @@ async def _stream( tool_results: list[ToolResult], invocation_state: dict[str, Any], **kwargs: Any, - ) -> ToolGenerator: + ) -> AsyncGenerator[TypedEvent, None]: """Stream tool events. This method adds additional logic to the stream invocation including: @@ -113,12 +114,12 @@ async def _stream( result=result, ) ) - yield after_event.result + yield ToolResultEvent(after_event.result) tool_results.append(after_event.result) return async for event in selected_tool.stream(tool_use, invocation_state, **kwargs): - yield event + yield ToolStreamEvent(tool_use, event) result = cast(ToolResult, event) @@ -131,7 +132,8 @@ async def _stream( result=result, ) ) - yield after_event.result + + yield ToolResultEvent(after_event.result) tool_results.append(after_event.result) except Exception as e: @@ -151,7 +153,7 @@ async def _stream( exception=e, ) ) - yield after_event.result + yield ToolResultEvent(after_event.result) tool_results.append(after_event.result) @staticmethod @@ -163,7 +165,7 @@ async def _stream_with_trace( cycle_span: Any, invocation_state: dict[str, Any], **kwargs: Any, - ) -> ToolGenerator: + ) -> AsyncGenerator[TypedEvent, None]: """Execute tool with tracing and metrics collection. Args: @@ -190,7 +192,8 @@ async def _stream_with_trace( async for event in ToolExecutor._stream(agent, tool_use, tool_results, invocation_state, **kwargs): yield event - result = cast(ToolResult, event) + result_event = cast(ToolResultEvent, event) + result = result_event.tool_result tool_success = result.get("status") == "success" tool_duration = time.time() - tool_start_time @@ -210,7 +213,7 @@ def _execute( cycle_trace: Trace, cycle_span: Any, invocation_state: dict[str, Any], - ) -> ToolGenerator: + ) -> AsyncGenerator[TypedEvent, None]: """Execute the given tools according to this executor's strategy. Args: diff --git a/src/strands/tools/executors/concurrent.py b/src/strands/tools/executors/concurrent.py index 7d5dd7fe7..767071bae 100644 --- a/src/strands/tools/executors/concurrent.py +++ b/src/strands/tools/executors/concurrent.py @@ -1,12 +1,13 @@ """Concurrent tool executor implementation.""" import asyncio -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, AsyncGenerator from typing_extensions import override from ...telemetry.metrics import Trace -from ...types.tools import ToolGenerator, ToolResult, ToolUse +from ...types._events import TypedEvent +from ...types.tools import ToolResult, ToolUse from ._executor import ToolExecutor if TYPE_CHECKING: # pragma: no cover @@ -25,7 +26,7 @@ async def _execute( cycle_trace: Trace, cycle_span: Any, invocation_state: dict[str, Any], - ) -> ToolGenerator: + ) -> AsyncGenerator[TypedEvent, None]: """Execute tools concurrently. Args: diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py index 55b26f6d3..60e5c7fa7 100644 --- a/src/strands/tools/executors/sequential.py +++ b/src/strands/tools/executors/sequential.py @@ -1,11 +1,12 @@ """Sequential tool executor implementation.""" -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, AsyncGenerator from typing_extensions import override from ...telemetry.metrics import Trace -from ...types.tools import ToolGenerator, ToolResult, ToolUse +from ...types._events import TypedEvent +from ...types.tools import ToolResult, ToolUse from ._executor import ToolExecutor if TYPE_CHECKING: # pragma: no cover @@ -24,7 +25,7 @@ async def _execute( cycle_trace: Trace, cycle_span: Any, invocation_state: dict[str, Any], - ) -> ToolGenerator: + ) -> AsyncGenerator[TypedEvent, None]: """Execute tools sequentially. Args: diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 1bddc5877..cc2330a81 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -5,15 +5,18 @@ agent lifecycle. """ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast + +from typing_extensions import override from ..telemetry import EventLoopMetrics from .content import Message from .event_loop import Metrics, StopReason, Usage from .streaming import ContentBlockDelta, StreamEvent +from .tools import ToolResult, ToolUse if TYPE_CHECKING: - pass + from ..agent import AgentResult class TypedEvent(dict): @@ -27,6 +30,23 @@ def __init__(self, data: dict[str, Any] | None = None) -> None: """ super().__init__(data or {}) + @property + def is_callback_event(self) -> bool: + """True if this event should trigger the callback_handler to fire.""" + return True + + def as_dict(self) -> dict: + """Convert this event to a raw dictionary for emitting purposes.""" + return {**self} + + def prepare(self, invocation_state: dict) -> None: + """Prepare the event for emission by adding invocation state. + + This allows a subset of events to merge with the invocation_state without needing to + pass around the invocation_state throughout the system. + """ + ... + class InitEventLoopEvent(TypedEvent): """Event emitted at the very beginning of agent execution. @@ -38,9 +58,13 @@ class InitEventLoopEvent(TypedEvent): invocation_state: The invocation state passed into the request """ - def __init__(self, invocation_state: dict) -> None: + def __init__(self) -> None: """Initialize the event loop initialization event.""" - super().__init__({"callback": {"init_event_loop": True, **invocation_state}}) + super().__init__({"init_event_loop": True}) + + @override + def prepare(self, invocation_state: dict) -> None: + self.update(invocation_state) class StartEvent(TypedEvent): @@ -55,7 +79,7 @@ class StartEvent(TypedEvent): def __init__(self) -> None: """Initialize the event loop start event.""" - super().__init__({"callback": {"start": True}}) + super().__init__({"start": True}) class StartEventLoopEvent(TypedEvent): @@ -67,7 +91,7 @@ class StartEventLoopEvent(TypedEvent): def __init__(self) -> None: """Initialize the event loop processing start event.""" - super().__init__({"callback": {"start_event_loop": True}}) + super().__init__({"start_event_loop": True}) class ModelStreamChunkEvent(TypedEvent): @@ -79,7 +103,11 @@ def __init__(self, chunk: StreamEvent) -> None: Args: chunk: Incremental streaming data from the model response """ - super().__init__({"callback": {"event": chunk}}) + super().__init__({"event": chunk}) + + @property + def chunk(self) -> StreamEvent: + return cast(StreamEvent, self.get("event")) class ModelStreamEvent(TypedEvent): @@ -97,13 +125,23 @@ def __init__(self, delta_data: dict[str, Any]) -> None: """ super().__init__(delta_data) + @property + def is_callback_event(self) -> bool: + # Only invoke a callback if we're non-empty + return len(self.keys()) > 0 + + @override + def prepare(self, invocation_state: dict) -> None: + if "delta" in self: + self.update(invocation_state) + class ToolUseStreamEvent(ModelStreamEvent): """Event emitted during tool use input streaming.""" def __init__(self, delta: ContentBlockDelta, current_tool_use: dict[str, Any]) -> None: """Initialize with delta and current tool use state.""" - super().__init__({"callback": {"delta": delta, "current_tool_use": current_tool_use}}) + super().__init__({"delta": delta, "current_tool_use": current_tool_use}) class TextStreamEvent(ModelStreamEvent): @@ -111,7 +149,7 @@ class TextStreamEvent(ModelStreamEvent): def __init__(self, delta: ContentBlockDelta, text: str) -> None: """Initialize with delta and text content.""" - super().__init__({"callback": {"data": text, "delta": delta}}) + super().__init__({"data": text, "delta": delta}) class ReasoningTextStreamEvent(ModelStreamEvent): @@ -119,7 +157,7 @@ class ReasoningTextStreamEvent(ModelStreamEvent): def __init__(self, delta: ContentBlockDelta, reasoning_text: str | None) -> None: """Initialize with delta and reasoning text.""" - super().__init__({"callback": {"reasoningText": reasoning_text, "delta": delta, "reasoning": True}}) + super().__init__({"reasoningText": reasoning_text, "delta": delta, "reasoning": True}) class ReasoningSignatureStreamEvent(ModelStreamEvent): @@ -127,7 +165,7 @@ class ReasoningSignatureStreamEvent(ModelStreamEvent): def __init__(self, delta: ContentBlockDelta, reasoning_signature: str | None) -> None: """Initialize with delta and reasoning signature.""" - super().__init__({"callback": {"reasoning_signature": reasoning_signature, "delta": delta, "reasoning": True}}) + super().__init__({"reasoning_signature": reasoning_signature, "delta": delta, "reasoning": True}) class ModelStopReason(TypedEvent): @@ -150,6 +188,11 @@ def __init__( """ super().__init__({"stop": (stop_reason, message, usage, metrics)}) + @property + @override + def is_callback_event(self) -> bool: + return False + class EventLoopStopEvent(TypedEvent): """Event emitted when the agent execution completes normally.""" @@ -171,18 +214,76 @@ def __init__( """ super().__init__({"stop": (stop_reason, message, metrics, request_state)}) + @property + @override + def is_callback_event(self) -> bool: + return False + class EventLoopThrottleEvent(TypedEvent): """Event emitted when the event loop is throttled due to rate limiting.""" - def __init__(self, delay: int, invocation_state: dict[str, Any]) -> None: + def __init__(self, delay: int) -> None: """Initialize with the throttle delay duration. Args: delay: Delay in seconds before the next retry attempt - invocation_state: The invocation state passed into the request """ - super().__init__({"callback": {"event_loop_throttled_delay": delay, **invocation_state}}) + super().__init__({"event_loop_throttled_delay": delay}) + + @override + def prepare(self, invocation_state: dict) -> None: + self.update(invocation_state) + + +class ToolResultEvent(TypedEvent): + """Event emitted when a tool execution completes.""" + + def __init__(self, tool_result: ToolResult) -> None: + """Initialize with the completed tool result. + + Args: + tool_result: Final result from the tool execution + """ + super().__init__({"tool_result": tool_result}) + + @property + def tool_use_id(self) -> str: + """The toolUseId associated with this result.""" + return cast(str, cast(ToolResult, self.get("tool_result")).get("toolUseId")) + + @property + def tool_result(self) -> ToolResult: + """Final result from the completed tool execution.""" + return cast(ToolResult, self.get("tool_result")) + + @property + @override + def is_callback_event(self) -> bool: + return False + + +class ToolStreamEvent(TypedEvent): + """Event emitted when a tool yields sub-events as part of tool execution.""" + + def __init__(self, tool_use: ToolUse, tool_sub_event: Any) -> None: + """Initialize with tool streaming data. + + Args: + tool_use: The tool invocation producing the stream + tool_sub_event: The yielded event from the tool execution + """ + super().__init__({"tool_stream_tool_use": tool_use, "tool_stream_event": tool_sub_event}) + + @property + def tool_use_id(self) -> str: + """The toolUseId associated with this stream.""" + return cast(str, cast(ToolUse, self.get("tool_stream_tool_use")).get("toolUseId")) + + @property + @override + def is_callback_event(self) -> bool: + return False class ModelMessageEvent(TypedEvent): @@ -198,7 +299,7 @@ def __init__(self, message: Message) -> None: Args: message: The response message from the model """ - super().__init__({"callback": {"message": message}}) + super().__init__({"message": message}) class ToolResultMessageEvent(TypedEvent): @@ -215,7 +316,7 @@ def __init__(self, message: Any) -> None: Args: message: Message containing tool results for conversation history """ - super().__init__({"callback": {"message": message}}) + super().__init__({"message": message}) class ForceStopEvent(TypedEvent): @@ -229,10 +330,12 @@ def __init__(self, reason: str | Exception) -> None: """ super().__init__( { - "callback": { - "force_stop": True, - "force_stop_reason": str(reason), - # "force_stop_reason_exception": reason if reason and isinstance(reason, Exception) else MISSING, - } + "force_stop": True, + "force_stop_reason": str(reason), } ) + + +class AgentResultEvent(TypedEvent): + def __init__(self, result: "AgentResult"): + super().__init__({"result": result}) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index a4a8af09a..a8561abe4 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -19,6 +19,7 @@ from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager from strands.telemetry.tracer import serialize +from strands.types._events import EventLoopStopEvent, ModelStreamEvent from strands.types.content import Messages from strands.types.exceptions import ContextWindowOverflowException, EventLoopException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType @@ -406,7 +407,7 @@ async def check_invocation_state(**kwargs): assert invocation_state["agent"] == agent # Return expected values from event_loop_cycle - yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})} + yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {}) mock_event_loop_cycle.side_effect = check_invocation_state @@ -1144,12 +1145,12 @@ async def test_stream_async_returns_all_events(mock_event_loop_cycle, alist): # Define the side effect to simulate callback handler being called multiple times async def test_event_loop(*args, **kwargs): - yield {"callback": {"data": "First chunk"}} - yield {"callback": {"data": "Second chunk"}} - yield {"callback": {"data": "Final chunk", "complete": True}} + yield ModelStreamEvent({"data": "First chunk"}) + yield ModelStreamEvent({"data": "Second chunk"}) + yield ModelStreamEvent({"data": "Final chunk", "complete": True}) # Return expected values from event_loop_cycle - yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})} + yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {}) mock_event_loop_cycle.side_effect = test_event_loop mock_callback = unittest.mock.Mock() @@ -1234,7 +1235,7 @@ async def check_invocation_state(**kwargs): invocation_state = kwargs["invocation_state"] assert invocation_state["some_value"] == "a_value" # Return expected values from event_loop_cycle - yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})} + yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {}) mock_event_loop_cycle.side_effect = check_invocation_state @@ -1366,7 +1367,7 @@ async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_trac mock_get_tracer.return_value = mock_tracer async def test_event_loop(*args, **kwargs): - yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {})} + yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {}) mock_event_loop_cycle.side_effect = test_event_loop diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index c76514ac8..68f9cc5ab 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -486,7 +486,7 @@ async def test_cycle_exception( ] tru_stop_event = None - exp_stop_event = {"callback": {"force_stop": True, "force_stop_reason": "Invalid error presented"}} + exp_stop_event = {"force_stop": True, "force_stop_reason": "Invalid error presented"} with pytest.raises(EventLoopException): stream = strands.event_loop.event_loop.event_loop_cycle( diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index fd9548dae..fdd560b22 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -146,7 +146,7 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) ], ) def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_updated_state, callback_args): - exp_callback_event = {"callback": {**callback_args, "delta": event["delta"]}} if callback_args else {} + exp_callback_event = {**callback_args, "delta": event["delta"]} if callback_args else {} tru_updated_state, tru_callback_event = strands.event_loop.streaming.handle_content_block_delta(event, state) @@ -316,85 +316,71 @@ def test_extract_usage_metrics_with_cache_tokens(): ], [ { - "callback": { - "event": { - "messageStart": { - "role": "assistant", - }, + "event": { + "messageStart": { + "role": "assistant", }, }, }, { - "callback": { - "event": { - "contentBlockStart": { - "start": { - "toolUse": { - "name": "test", - "toolUseId": "123", - }, + "event": { + "contentBlockStart": { + "start": { + "toolUse": { + "name": "test", + "toolUseId": "123", }, }, }, }, }, { - "callback": { - "event": { - "contentBlockDelta": { - "delta": { - "toolUse": { - "input": '{"key": "value"}', - }, + "event": { + "contentBlockDelta": { + "delta": { + "toolUse": { + "input": '{"key": "value"}', }, }, }, }, }, { - "callback": { - "current_tool_use": { - "input": { - "key": "value", - }, - "name": "test", - "toolUseId": "123", + "current_tool_use": { + "input": { + "key": "value", }, - "delta": { - "toolUse": { - "input": '{"key": "value"}', - }, + "name": "test", + "toolUseId": "123", + }, + "delta": { + "toolUse": { + "input": '{"key": "value"}', }, }, }, { - "callback": { - "event": { - "contentBlockStop": {}, - }, + "event": { + "contentBlockStop": {}, }, }, { - "callback": { - "event": { - "messageStop": { - "stopReason": "tool_use", - }, + "event": { + "messageStop": { + "stopReason": "tool_use", }, }, }, { - "callback": { - "event": { - "metadata": { - "metrics": { - "latencyMs": 1, - }, - "usage": { - "inputTokens": 1, - "outputTokens": 1, - "totalTokens": 1, - }, + "event": { + "metadata": { + "metrics": { + "latencyMs": 1, + }, + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, }, }, }, @@ -417,9 +403,7 @@ def test_extract_usage_metrics_with_cache_tokens(): [{}], [ { - "callback": { - "event": {}, - }, + "event": {}, }, { "stop": ( @@ -463,80 +447,64 @@ def test_extract_usage_metrics_with_cache_tokens(): ], [ { - "callback": { - "event": { - "messageStart": { - "role": "assistant", - }, + "event": { + "messageStart": { + "role": "assistant", }, }, }, { - "callback": { - "event": { - "contentBlockStart": { - "start": {}, - }, + "event": { + "contentBlockStart": { + "start": {}, }, }, }, { - "callback": { - "event": { - "contentBlockDelta": { - "delta": { - "text": "Hello!", - }, + "event": { + "contentBlockDelta": { + "delta": { + "text": "Hello!", }, }, }, }, { - "callback": { - "data": "Hello!", - "delta": { - "text": "Hello!", - }, + "data": "Hello!", + "delta": { + "text": "Hello!", }, }, { - "callback": { - "event": { - "contentBlockStop": {}, - }, + "event": { + "contentBlockStop": {}, }, }, { - "callback": { - "event": { - "messageStop": { - "stopReason": "guardrail_intervened", - }, + "event": { + "messageStop": { + "stopReason": "guardrail_intervened", }, }, }, { - "callback": { - "event": { - "redactContent": { - "redactAssistantContentMessage": "REDACTED.", - "redactUserContentMessage": "REDACTED", - }, + "event": { + "redactContent": { + "redactAssistantContentMessage": "REDACTED.", + "redactUserContentMessage": "REDACTED", }, }, }, { - "callback": { - "event": { - "metadata": { - "metrics": { - "latencyMs": 1, - }, - "usage": { - "inputTokens": 1, - "outputTokens": 1, - "totalTokens": 1, - }, + "event": { + "metadata": { + "metrics": { + "latencyMs": 1, + }, + "usage": { + "inputTokens": 1, + "outputTokens": 1, + "totalTokens": 1, }, }, }, @@ -588,29 +556,23 @@ async def test_stream_messages(agenerator, alist): tru_events = await alist(stream) exp_events = [ { - "callback": { - "event": { - "contentBlockDelta": { - "delta": { - "text": "test", - }, + "event": { + "contentBlockDelta": { + "delta": { + "text": "test", }, }, }, }, { - "callback": { - "data": "test", - "delta": { - "text": "test", - }, + "data": "test", + "delta": { + "text": "test", }, }, { - "callback": { - "event": { - "contentBlockStop": {}, - }, + "event": { + "contentBlockStop": {}, }, }, { diff --git a/tests/strands/tools/executors/test_concurrent.py b/tests/strands/tools/executors/test_concurrent.py index 7e0d6c2df..140537add 100644 --- a/tests/strands/tools/executors/test_concurrent.py +++ b/tests/strands/tools/executors/test_concurrent.py @@ -1,6 +1,8 @@ import pytest from strands.tools.executors import ConcurrentToolExecutor +from strands.types._events import ToolResultEvent, ToolStreamEvent +from strands.types.tools import ToolUse @pytest.fixture @@ -12,21 +14,21 @@ def executor(): async def test_concurrent_executor_execute( executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist ): - tool_uses = [ + tool_uses: list[ToolUse] = [ {"name": "weather_tool", "toolUseId": "1", "input": {}}, {"name": "temperature_tool", "toolUseId": "2", "input": {}}, ] stream = executor._execute(agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state) - tru_events = sorted(await alist(stream), key=lambda event: event.get("toolUseId")) + tru_events = sorted(await alist(stream), key=lambda event: event.tool_use_id) exp_events = [ - {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, - {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, - {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}, - {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}, + ToolStreamEvent(tool_uses[0], {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), + ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), + ToolStreamEvent(tool_uses[1], {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), + ToolResultEvent({"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), ] assert tru_events == exp_events tru_results = sorted(tool_results, key=lambda result: result.get("toolUseId")) - exp_results = [exp_events[1], exp_events[3]] + exp_results = [exp_events[1].tool_result, exp_events[3].tool_result] assert tru_results == exp_results diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index edbad3939..56caa950a 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -6,6 +6,8 @@ from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent from strands.telemetry.metrics import Trace from strands.tools.executors._executor import ToolExecutor +from strands.types._events import ToolResultEvent, ToolStreamEvent +from strands.types.tools import ToolUse @pytest.fixture @@ -32,18 +34,18 @@ def tracer(): async def test_executor_stream_yields_result( executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist ): - tool_use = {"name": "weather_tool", "toolUseId": "1", "input": {}} + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} stream = executor._stream(agent, tool_use, tool_results, invocation_state) tru_events = await alist(stream) exp_events = [ - {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, - {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, + ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), + ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ] assert tru_events == exp_events tru_results = tool_results - exp_results = [exp_events[-1]] + exp_results = [exp_events[-1].tool_result] assert tru_results == exp_results tru_hook_events = hook_events @@ -73,11 +75,11 @@ async def test_executor_stream_yields_tool_error( stream = executor._stream(agent, tool_use, tool_results, invocation_state) tru_events = await alist(stream) - exp_events = [{"toolUseId": "1", "status": "error", "content": [{"text": "Error: Tool error"}]}] + exp_events = [ToolResultEvent({"toolUseId": "1", "status": "error", "content": [{"text": "Error: Tool error"}]})] assert tru_events == exp_events tru_results = tool_results - exp_results = [exp_events[-1]] + exp_results = [exp_events[-1].tool_result] assert tru_results == exp_results tru_hook_after_event = hook_events[-1] @@ -98,11 +100,13 @@ async def test_executor_stream_yields_unknown_tool(executor, agent, tool_results stream = executor._stream(agent, tool_use, tool_results, invocation_state) tru_events = await alist(stream) - exp_events = [{"toolUseId": "1", "status": "error", "content": [{"text": "Unknown tool: unknown_tool"}]}] + exp_events = [ + ToolResultEvent({"toolUseId": "1", "status": "error", "content": [{"text": "Unknown tool: unknown_tool"}]}) + ] assert tru_events == exp_events tru_results = tool_results - exp_results = [exp_events[-1]] + exp_results = [exp_events[-1].tool_result] assert tru_results == exp_results tru_hook_after_event = hook_events[-1] @@ -120,18 +124,18 @@ async def test_executor_stream_yields_unknown_tool(executor, agent, tool_results async def test_executor_stream_with_trace( executor, tracer, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist ): - tool_use = {"name": "weather_tool", "toolUseId": "1", "input": {}} + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} stream = executor._stream_with_trace(agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state) tru_events = await alist(stream) exp_events = [ - {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, - {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, + ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), + ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ] assert tru_events == exp_events tru_results = tool_results - exp_results = [exp_events[-1]] + exp_results = [exp_events[-1].tool_result] assert tru_results == exp_results tracer.start_tool_call_span.assert_called_once_with(tool_use, cycle_span) diff --git a/tests/strands/tools/executors/test_sequential.py b/tests/strands/tools/executors/test_sequential.py index d9b32c129..d4e98223e 100644 --- a/tests/strands/tools/executors/test_sequential.py +++ b/tests/strands/tools/executors/test_sequential.py @@ -1,6 +1,7 @@ import pytest from strands.tools.executors import SequentialToolExecutor +from strands.types._events import ToolResultEvent, ToolStreamEvent @pytest.fixture @@ -20,13 +21,13 @@ async def test_sequential_executor_execute( tru_events = await alist(stream) exp_events = [ - {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, - {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}, - {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}, - {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}, + ToolStreamEvent(tool_uses[0], {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), + ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), + ToolStreamEvent(tool_uses[1], {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), + ToolResultEvent({"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), ] assert tru_events == exp_events tru_results = tool_results - exp_results = [exp_events[1], exp_events[2]] + exp_results = [exp_events[1].tool_result, exp_events[3].tool_result] assert tru_results == exp_results