Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from ..tools.executors._executor import ToolExecutor
from ..tools.registry import ToolRegistry
from ..tools.watcher import ToolWatcher
from ..types._events import InitEventLoopEvent
from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent
from ..types.agent import AgentInput
from ..types.content import ContentBlock, Message, Messages
from ..types.exceptions import ContextWindowOverflowException
Expand Down Expand Up @@ -576,23 +576,24 @@ async def stream_async(
events = self._run_loop(messages, invocation_state=kwargs)

async for event in events:
if "callback" in event:
callback_handler(**event["callback"])
yield event["callback"]
event.prepare(invocation_state=kwargs)

if event.is_callback_event:
as_dict = event.as_dict()
callback_handler(**as_dict)
yield as_dict

result = AgentResult(*event["stop"])
callback_handler(result=result)
yield {"result": result}
yield AgentResultEvent(result=result).as_dict()

self._end_agent_trace_span(response=result)

except Exception as e:
self._end_agent_trace_span(error=e)
raise

async def _run_loop(
self, messages: Messages, invocation_state: dict[str, Any]
) -> AsyncGenerator[dict[str, Any], None]:
async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
"""Execute the agent's event loop with the given message and parameters.

Args:
Expand All @@ -605,7 +606,7 @@ async def _run_loop(
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))

try:
yield InitEventLoopEvent(invocation_state)
yield InitEventLoopEvent()

for message in messages:
self._append_message(message)
Expand All @@ -616,13 +617,13 @@ async def _run_loop(
# Signal from the model provider that the message sent by the user should be redacted,
# likely due to a guardrail.
if (
event.get("callback")
and event["callback"].get("event")
and event["callback"]["event"].get("redactContent")
and event["callback"]["event"]["redactContent"].get("redactUserContentMessage")
isinstance(event, ModelStreamChunkEvent)
and event.chunk
and event.chunk.get("redactContent")
and event.chunk["redactContent"].get("redactUserContentMessage")
):
self.messages[-1]["content"] = [
{"text": event["callback"]["event"]["redactContent"]["redactUserContentMessage"]}
{"text": str(event.chunk["redactContent"]["redactUserContentMessage"])}
]
if self._session_manager:
self._session_manager.redact_latest_message(self.messages[-1], self)
Expand All @@ -632,7 +633,7 @@ async def _run_loop(
self.conversation_manager.apply_management(self)
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))

async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
"""Execute the event loop cycle with retry logic for context window limits.

This internal method handles the execution of the event loop cycle and implements
Expand Down
22 changes: 8 additions & 14 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
EventLoopThrottleEvent,
ForceStopEvent,
ModelMessageEvent,
ModelStopReason,
StartEvent,
StartEventLoopEvent,
ToolResultMessageEvent,
TypedEvent,
)
from ..types.content import Message
from ..types.exceptions import (
Expand All @@ -56,7 +58,7 @@
MAX_DELAY = 240 # 4 minutes


async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
"""Execute a single cycle of the event loop.

This core function processes a single conversation turn, handling model inference, tool execution, and error
Expand Down Expand Up @@ -139,17 +141,9 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
)

try:
# TODO: To maintain backwards compatibility, we need to combine the stream event with invocation_state
# before yielding to the callback handler. This will be revisited when migrating to strongly
# typed events.
async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs):
if "callback" in event:
yield {
"callback": {
**event["callback"],
**(invocation_state if "delta" in event["callback"] else {}),
}
}
if not isinstance(event, ModelStopReason):
yield event

stop_reason, message, usage, metrics = event["stop"]
invocation_state.setdefault("request_state", {})
Expand Down Expand Up @@ -198,7 +192,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
time.sleep(current_delay)
current_delay = min(current_delay * 2, MAX_DELAY)

yield EventLoopThrottleEvent(delay=current_delay, invocation_state=invocation_state)
yield EventLoopThrottleEvent(delay=current_delay)
else:
raise e

Expand Down Expand Up @@ -280,7 +274,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])


async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
"""Make a recursive call to event_loop_cycle with the current state.

This function is used when the event loop needs to continue processing after tool execution.
Expand Down Expand Up @@ -321,7 +315,7 @@ async def _handle_tool_execution(
cycle_span: Any,
cycle_start_time: float,
invocation_state: dict[str, Any],
) -> AsyncGenerator[dict[str, Any], None]:
) -> AsyncGenerator[TypedEvent, None]:
"""Handles the execution of tools requested by the model during an event loop cycle.

Args:
Expand Down
23 changes: 13 additions & 10 deletions src/strands/tools/executors/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
import abc
import logging
import time
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast

from opentelemetry import trace as trace_api

from ...experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent
from ...telemetry.metrics import Trace
from ...telemetry.tracer import get_tracer
from ...types._events import ToolResultEvent, ToolStreamEvent, TypedEvent
from ...types.content import Message
from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse
from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolResult, ToolUse

if TYPE_CHECKING: # pragma: no cover
from ...agent import Agent
Expand All @@ -33,7 +34,7 @@ async def _stream(
tool_results: list[ToolResult],
invocation_state: dict[str, Any],
**kwargs: Any,
) -> ToolGenerator:
) -> AsyncGenerator[TypedEvent, None]:
"""Stream tool events.

This method adds additional logic to the stream invocation including:
Expand Down Expand Up @@ -113,12 +114,12 @@ async def _stream(
result=result,
)
)
yield after_event.result
yield ToolResultEvent(after_event.result)
tool_results.append(after_event.result)
return

async for event in selected_tool.stream(tool_use, invocation_state, **kwargs):
yield event
yield ToolStreamEvent(tool_use, event)

result = cast(ToolResult, event)

Expand All @@ -131,7 +132,8 @@ async def _stream(
result=result,
)
)
yield after_event.result

yield ToolResultEvent(after_event.result)
tool_results.append(after_event.result)

except Exception as e:
Expand All @@ -151,7 +153,7 @@ async def _stream(
exception=e,
)
)
yield after_event.result
yield ToolResultEvent(after_event.result)
tool_results.append(after_event.result)

@staticmethod
Expand All @@ -163,7 +165,7 @@ async def _stream_with_trace(
cycle_span: Any,
invocation_state: dict[str, Any],
**kwargs: Any,
) -> ToolGenerator:
) -> AsyncGenerator[TypedEvent, None]:
"""Execute tool with tracing and metrics collection.

Args:
Expand All @@ -190,7 +192,8 @@ async def _stream_with_trace(
async for event in ToolExecutor._stream(agent, tool_use, tool_results, invocation_state, **kwargs):
yield event

result = cast(ToolResult, event)
result_event = cast(ToolResultEvent, event)
result = result_event.tool_result

tool_success = result.get("status") == "success"
tool_duration = time.time() - tool_start_time
Expand All @@ -210,7 +213,7 @@ def _execute(
cycle_trace: Trace,
cycle_span: Any,
invocation_state: dict[str, Any],
) -> ToolGenerator:
) -> AsyncGenerator[TypedEvent, None]:
"""Execute the given tools according to this executor's strategy.

Args:
Expand Down
7 changes: 4 additions & 3 deletions src/strands/tools/executors/concurrent.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Concurrent tool executor implementation."""

import asyncio
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, AsyncGenerator

from typing_extensions import override

from ...telemetry.metrics import Trace
from ...types.tools import ToolGenerator, ToolResult, ToolUse
from ...types._events import TypedEvent
from ...types.tools import ToolResult, ToolUse
from ._executor import ToolExecutor

if TYPE_CHECKING: # pragma: no cover
Expand All @@ -25,7 +26,7 @@ async def _execute(
cycle_trace: Trace,
cycle_span: Any,
invocation_state: dict[str, Any],
) -> ToolGenerator:
) -> AsyncGenerator[TypedEvent, None]:
"""Execute tools concurrently.

Args:
Expand Down
7 changes: 4 additions & 3 deletions src/strands/tools/executors/sequential.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Sequential tool executor implementation."""

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, AsyncGenerator

from typing_extensions import override

from ...telemetry.metrics import Trace
from ...types.tools import ToolGenerator, ToolResult, ToolUse
from ...types._events import TypedEvent
from ...types.tools import ToolResult, ToolUse
from ._executor import ToolExecutor

if TYPE_CHECKING: # pragma: no cover
Expand All @@ -24,7 +25,7 @@ async def _execute(
cycle_trace: Trace,
cycle_span: Any,
invocation_state: dict[str, Any],
) -> ToolGenerator:
) -> AsyncGenerator[TypedEvent, None]:
"""Execute tools sequentially.

Args:
Expand Down
Loading
Loading