Skip to content

Commit 0a86928

Browse files
committed
feat: Implement typed events internally
Step 1/N for implementing typed-events; first just preserve the existing behaviors with no changes to the public api. A follow-up change will update how we invoke callbacks and pass invocation_state around, while this one just adds typed classes for events internally.
1 parent 0fac648 commit 0a86928

File tree

6 files changed

+445
-22
lines changed

6 files changed

+445
-22
lines changed

src/strands/agent/agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from ..tools.executors._executor import ToolExecutor
5151
from ..tools.registry import ToolRegistry
5252
from ..tools.watcher import ToolWatcher
53+
from ..types._events import InitEventLoopEvent
5354
from ..types.agent import AgentInput
5455
from ..types.content import ContentBlock, Message, Messages
5556
from ..types.exceptions import ContextWindowOverflowException
@@ -604,7 +605,7 @@ async def _run_loop(
604605
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
605606

606607
try:
607-
yield {"callback": {"init_event_loop": True, **invocation_state}}
608+
yield InitEventLoopEvent(invocation_state)
608609

609610
for message in messages:
610611
self._append_message(message)

src/strands/event_loop/event_loop.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@
2525
from ..telemetry.metrics import Trace
2626
from ..telemetry.tracer import get_tracer
2727
from ..tools._validator import validate_and_prepare_tools
28+
from ..types._events import (
29+
EventLoopStopEvent,
30+
EventLoopThrottleEvent,
31+
ForceStopEvent,
32+
ModelMessageEvent,
33+
StartEvent,
34+
StartEventLoopEvent,
35+
ToolResultMessageEvent,
36+
)
2837
from ..types.content import Message
2938
from ..types.exceptions import (
3039
ContextWindowOverflowException,
@@ -91,8 +100,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
91100
cycle_start_time, cycle_trace = agent.event_loop_metrics.start_cycle(attributes=attributes)
92101
invocation_state["event_loop_cycle_trace"] = cycle_trace
93102

94-
yield {"callback": {"start": True}}
95-
yield {"callback": {"start_event_loop": True}}
103+
yield StartEvent()
104+
yield StartEventLoopEvent()
96105

97106
# Create tracer span for this event loop cycle
98107
tracer = get_tracer()
@@ -175,7 +184,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
175184

176185
if isinstance(e, ModelThrottledException):
177186
if attempt + 1 == MAX_ATTEMPTS:
178-
yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}}
187+
yield ForceStopEvent(reason=e)
179188
raise e
180189

181190
logger.debug(
@@ -189,7 +198,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
189198
time.sleep(current_delay)
190199
current_delay = min(current_delay * 2, MAX_DELAY)
191200

192-
yield {"callback": {"event_loop_throttled_delay": current_delay, **invocation_state}}
201+
yield EventLoopThrottleEvent(delay=current_delay, invocation_state=invocation_state)
193202
else:
194203
raise e
195204

@@ -201,7 +210,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
201210
# Add the response message to the conversation
202211
agent.messages.append(message)
203212
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message))
204-
yield {"callback": {"message": message}}
213+
yield ModelMessageEvent(message=message)
205214

206215
# Update metrics
207216
agent.event_loop_metrics.update_usage(usage)
@@ -264,11 +273,11 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
264273
tracer.end_span_with_error(cycle_span, str(e), e)
265274

266275
# Handle any other exceptions
267-
yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}}
276+
yield ForceStopEvent(reason=e)
268277
logger.exception("cycle failed")
269278
raise EventLoopException(e, invocation_state["request_state"]) from e
270279

271-
yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])}
280+
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
272281

273282

274283
async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
@@ -295,7 +304,7 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -
295304
recursive_trace = Trace("Recursive call", parent_id=cycle_trace.id)
296305
cycle_trace.add_child(recursive_trace)
297306

298-
yield {"callback": {"start": True}}
307+
yield StartEvent()
299308

300309
events = event_loop_cycle(agent=agent, invocation_state=invocation_state)
301310
async for event in events:
@@ -339,7 +348,7 @@ async def _handle_tool_execution(
339348
validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
340349
tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids]
341350
if not tool_uses:
342-
yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])}
351+
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
343352
return
344353

345354
tool_events = agent.tool_executor._execute(
@@ -358,15 +367,15 @@ async def _handle_tool_execution(
358367

359368
agent.messages.append(tool_result_message)
360369
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message))
361-
yield {"callback": {"message": tool_result_message}}
370+
yield ToolResultMessageEvent(message=message)
362371

363372
if cycle_span:
364373
tracer = get_tracer()
365374
tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message)
366375

367376
if invocation_state["request_state"].get("stop_event_loop", False):
368377
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
369-
yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])}
378+
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
370379
return
371380

372381
events = recurse_event_loop(agent=agent, invocation_state=invocation_state)

src/strands/event_loop/streaming.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55
from typing import Any, AsyncGenerator, AsyncIterable, Optional
66

77
from ..models.model import Model
8+
from ..types._events import (
9+
ModelStopReason,
10+
ModelStreamChunkEvent,
11+
ModelStreamEvent,
12+
ToolUseStreamEvent,
13+
TypedEvent,
14+
)
815
from ..types.content import ContentBlock, Message, Messages
916
from ..types.streaming import (
1017
ContentBlockDeltaEvent,
@@ -105,7 +112,7 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]:
105112

106113
def handle_content_block_delta(
107114
event: ContentBlockDeltaEvent, state: dict[str, Any]
108-
) -> tuple[dict[str, Any], dict[str, Any]]:
115+
) -> tuple[dict[str, Any], ModelStreamEvent]:
109116
"""Handles content block delta updates by appending text, tool input, or reasoning content to the state.
110117
111118
Args:
@@ -117,26 +124,26 @@ def handle_content_block_delta(
117124
"""
118125
delta_content = event["delta"]
119126

120-
callback_event = {}
127+
typed_event: ModelStreamEvent = ModelStreamEvent({})
121128

122129
if "toolUse" in delta_content:
123130
if "input" not in state["current_tool_use"]:
124131
state["current_tool_use"]["input"] = ""
125132

126133
state["current_tool_use"]["input"] += delta_content["toolUse"]["input"]
127-
callback_event["callback"] = {"delta": delta_content, "current_tool_use": state["current_tool_use"]}
134+
typed_event = ToolUseStreamEvent(delta_content, state["current_tool_use"])
128135

129136
elif "text" in delta_content:
130137
state["text"] += delta_content["text"]
131-
callback_event["callback"] = {"data": delta_content["text"], "delta": delta_content}
138+
typed_event["callback"] = {"data": delta_content["text"], "delta": delta_content}
132139

133140
elif "reasoningContent" in delta_content:
134141
if "text" in delta_content["reasoningContent"]:
135142
if "reasoningText" not in state:
136143
state["reasoningText"] = ""
137144

138145
state["reasoningText"] += delta_content["reasoningContent"]["text"]
139-
callback_event["callback"] = {
146+
typed_event["callback"] = {
140147
"reasoningText": delta_content["reasoningContent"]["text"],
141148
"delta": delta_content,
142149
"reasoning": True,
@@ -147,13 +154,13 @@ def handle_content_block_delta(
147154
state["signature"] = ""
148155

149156
state["signature"] += delta_content["reasoningContent"]["signature"]
150-
callback_event["callback"] = {
157+
typed_event["callback"] = {
151158
"reasoning_signature": delta_content["reasoningContent"]["signature"],
152159
"delta": delta_content,
153160
"reasoning": True,
154161
}
155162

156-
return state, callback_event
163+
return state, typed_event
157164

158165

159166
def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]:
@@ -251,7 +258,7 @@ def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]:
251258
return usage, metrics
252259

253260

254-
async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[dict[str, Any], None]:
261+
async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[TypedEvent, None]:
255262
"""Processes the response stream from the API, constructing the final message and extracting usage metrics.
256263
257264
Args:
@@ -274,7 +281,7 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d
274281
metrics: Metrics = Metrics(latencyMs=0)
275282

276283
async for chunk in chunks:
277-
yield {"callback": {"event": chunk}}
284+
yield ModelStreamChunkEvent(chunk=chunk)
278285
if "messageStart" in chunk:
279286
state["message"] = handle_message_start(chunk["messageStart"], state["message"])
280287
elif "contentBlockStart" in chunk:
@@ -291,7 +298,7 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d
291298
elif "redactContent" in chunk:
292299
handle_redact_content(chunk["redactContent"], state)
293300

294-
yield {"stop": (stop_reason, state["message"], usage, metrics)}
301+
yield ModelStopReason(stop_reason=stop_reason, message=state["message"], usage=usage, metrics=metrics)
295302

296303

297304
async def stream_messages(

0 commit comments

Comments
 (0)