Skip to content

Commit df69ea6

Browse files
committed
feat: add InvocationState TypedDict
1 parent ba59f6f commit df69ea6

File tree

10 files changed

+72
-16
lines changed

10 files changed

+72
-16
lines changed

src/strands/agent/agent.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from ..tools.watcher import ToolWatcher
3939
from ..types.content import ContentBlock, Message, Messages
4040
from ..types.exceptions import ContextWindowOverflowException
41+
from ..types.invocation import InvocationState
4142
from ..types.tools import ToolResult, ToolUse
4243
from ..types.traces import AttributeValue
4344
from .agent_result import AgentResult
@@ -523,7 +524,7 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
523524
raise
524525

525526
async def _run_loop(
526-
self, message: Message, invocation_state: dict[str, Any]
527+
self, message: Message, invocation_state: InvocationState
527528
) -> AsyncGenerator[dict[str, Any], None]:
528529
"""Execute the agent's event loop with the given message and parameters.
529530
@@ -563,7 +564,7 @@ async def _run_loop(
563564
self.conversation_manager.apply_management(self)
564565
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
565566

566-
async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
567+
async def _execute_event_loop_cycle(self, invocation_state: InvocationState) -> AsyncGenerator[dict[str, Any], None]:
567568
"""Execute the event loop cycle with retry logic for context window limits.
568569
569570
This internal method handles the execution of the event loop cycle and implements

src/strands/event_loop/event_loop.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
MaxTokensReachedException,
3535
ModelThrottledException,
3636
)
37+
from ..types.invocation import InvocationState
3738
from ..types.streaming import Metrics, StopReason
3839
from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse
3940
from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached
@@ -49,7 +50,7 @@
4950
MAX_DELAY = 240 # 4 minutes
5051

5152

52-
async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
53+
async def event_loop_cycle(agent: "Agent", invocation_state: InvocationState) -> AsyncGenerator[dict[str, Any], None]:
5354
"""Execute a single cycle of the event loop.
5455
5556
This core function processes a single conversation turn, handling model inference, tool execution, and error
@@ -273,7 +274,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
273274
yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])}
274275

275276

276-
async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
277+
async def recurse_event_loop(agent: "Agent", invocation_state: InvocationState) -> AsyncGenerator[dict[str, Any], None]:
277278
"""Make a recursive call to event_loop_cycle with the current state.
278279
279280
This function is used when the event loop needs to continue processing after tool execution.
@@ -306,7 +307,7 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -
306307
recursive_trace.end()
307308

308309

309-
async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str, Any]) -> ToolGenerator:
310+
async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: InvocationState) -> ToolGenerator:
310311
"""Process a tool invocation.
311312
312313
Looks up the tool in the registry and streams it with the provided parameters.
@@ -429,7 +430,7 @@ async def _handle_tool_execution(
429430
cycle_trace: Trace,
430431
cycle_span: Any,
431432
cycle_start_time: float,
432-
invocation_state: dict[str, Any],
433+
invocation_state: InvocationState,
433434
) -> AsyncGenerator[dict[str, Any], None]:
434435
tool_uses: list[ToolUse] = []
435436
tool_results: list[ToolResult] = []

src/strands/experimental/hooks/events.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from ...hooks import HookEvent
1010
from ...types.content import Message
11+
from ...types.invocation import InvocationState
1112
from ...types.streaming import StopReason
1213
from ...types.tools import AgentTool, ToolResult, ToolUse
1314

@@ -30,7 +31,7 @@ class BeforeToolInvocationEvent(HookEvent):
3031

3132
selected_tool: Optional[AgentTool]
3233
tool_use: ToolUse
33-
invocation_state: dict[str, Any]
34+
invocation_state: InvocationState
3435

3536
def _can_write(self, name: str) -> bool:
3637
return name in ["selected_tool", "tool_use"]
@@ -57,7 +58,7 @@ class AfterToolInvocationEvent(HookEvent):
5758

5859
selected_tool: Optional[AgentTool]
5960
tool_use: ToolUse
60-
invocation_state: dict[str, Any]
61+
invocation_state: InvocationState
6162
result: ToolResult
6263
exception: Optional[Exception] = None
6364

src/strands/telemetry/tracer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from opentelemetry.instrumentation.threading import ThreadingInstrumentor
1414
from opentelemetry.trace import Span, StatusCode
1515

16+
from ..types.invocation import InvocationState
17+
1618
from ..agent.agent_result import AgentResult
1719
from ..types.content import ContentBlock, Message, Messages
1820
from ..types.streaming import StopReason, Usage
@@ -343,7 +345,7 @@ def end_tool_call_span(
343345

344346
def start_event_loop_cycle_span(
345347
self,
346-
invocation_state: Any,
348+
invocation_state: InvocationState,
347349
messages: Messages,
348350
parent_span: Optional[Span] = None,
349351
**kwargs: Any,

src/strands/tools/decorator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
6161
from pydantic import BaseModel, Field, create_model
6262
from typing_extensions import override
6363

64+
from ..types.invocation import InvocationState
6465
from ..types.tools import AgentTool, JSONSchema, StrandsContext, ToolGenerator, ToolSpec, ToolUse
6566

6667
logger = logging.getLogger(__name__)
@@ -253,7 +254,7 @@ def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]:
253254
raise ValueError(f"Validation failed for input parameters: {error_msg}") from e
254255

255256
def inject_special_parameters(
256-
self, validated_input: dict[str, Any], tool_use: ToolUse, invocation_state: dict[str, Any]
257+
self, validated_input: dict[str, Any], tool_use: ToolUse, invocation_state: InvocationState
257258
) -> None:
258259
"""Inject special framework-provided parameters into the validated input.
259260
@@ -413,7 +414,7 @@ def tool_type(self) -> str:
413414
return "function"
414415

415416
@override
416-
async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator:
417+
async def stream(self, tool_use: ToolUse, invocation_state: InvocationState, **kwargs: Any) -> ToolGenerator:
417418
"""Stream the tool with a tool use specification.
418419
419420
This method handles tool use streams from a Strands Agent. It validates the input,

src/strands/tools/mcp/mcp_agent_tool.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from mcp.types import Tool as MCPTool
1212
from typing_extensions import override
1313

14+
from ...types.invocation import InvocationState
1415
from ...types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse
1516

1617
if TYPE_CHECKING:
@@ -75,7 +76,7 @@ def tool_type(self) -> str:
7576
return "python"
7677

7778
@override
78-
async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator:
79+
async def stream(self, tool_use: ToolUse, invocation_state: InvocationState, **kwargs: Any) -> ToolGenerator:
7980
"""Stream the MCP tool.
8081
8182
This method delegates the tool stream to the MCP server connection, passing the tool use ID, tool name, and

src/strands/tools/tools.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from typing_extensions import override
1414

15+
from ..types.invocation import InvocationState
1516
from ..types.tools import AgentTool, ToolFunc, ToolGenerator, ToolSpec, ToolUse
1617

1718
logger = logging.getLogger(__name__)
@@ -198,7 +199,7 @@ def tool_type(self) -> str:
198199
return "python"
199200

200201
@override
201-
async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator:
202+
async def stream(self, tool_use: ToolUse, invocation_state: InvocationState, **kwargs: Any) -> ToolGenerator:
202203
"""Stream the Python function with the given tool use request.
203204
204205
Args:
@@ -210,7 +211,7 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
210211
Tool events with the last being the tool result.
211212
"""
212213
if inspect.iscoroutinefunction(self._tool_func):
213-
result = await self._tool_func(tool_use, **invocation_state)
214+
result = await self._tool_func(tool_use, **invocation_state) # this will fail if invocation state and kwargs overlap
214215
else:
215216
result = await asyncio.to_thread(self._tool_func, tool_use, **invocation_state)
216217

src/strands/types/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""SDK type definitions."""
22

33
from .collections import PaginatedList
4+
from .invocation import InvocationState
45

56
__all__ = ["PaginatedList"]

src/strands/types/invocation.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""Types for agent invocation state and context."""
2+
3+
from typing import Any, Dict, TypedDict, TYPE_CHECKING
4+
5+
from opentelemetry.trace import Span
6+
7+
if TYPE_CHECKING:
8+
from ..agent import Agent
9+
from ..telemetry import Trace
10+
11+
class InvocationState(TypedDict, total=False):
12+
"""Type definition for invocation_state used throughout the agent framework.
13+
14+
This TypedDict defines the structure of the invocation_state dictionary that is
15+
passed through the agent's event loop and tool execution pipeline. All fields
16+
are optional since invocation_state is built incrementally during execution.
17+
18+
Core Framework Fields:
19+
agent: The Agent instance executing the invocation (added for backward compatibility).
20+
event_loop_cycle_id: Unique identifier for the current event loop cycle.
21+
request_state: State dictionary maintained across event loop cycles.
22+
event_loop_cycle_trace: Trace object for monitoring the current cycle.
23+
event_loop_cycle_span: Span object for distributed tracing.
24+
event_loop_parent_cycle_id: UUID of the parent cycle for recursive calls. # always uuid? or just string?
25+
event_loop_parent_span: Parent span for tracing hierarchy.
26+
27+
Additional Fields:
28+
Any additional keyword arguments passed during agent invocation or added
29+
by hooks and tools during execution are also included in this state.
30+
"""
31+
32+
# Core agent reference
33+
agent: "Agent" # Forward reference to avoid circular imports
34+
35+
# Event loop cycle management
36+
event_loop_cycle_id: str
37+
event_loop_parent_cycle_id: str
38+
39+
# State management
40+
request_state: Dict[str, Any]
41+
42+
# Tracing and monitoring
43+
event_loop_cycle_trace: "Trace" # "Trace" # Trace object type varies by implementation
44+
event_loop_cycle_span: Span # Span object type varies by implementation
45+
event_loop_parent_span: Span # Parent span for tracing hierarchy

src/strands/types/tools.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from abc import ABC, abstractmethod
99
from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union
1010

11+
from .invocation import InvocationState
12+
1113
from typing_extensions import TypedDict
1214

1315
from .media import DocumentContent, ImageContent
@@ -129,7 +131,7 @@ class StrandsContext(TypedDict, total=False):
129131
"""
130132

131133
tool_use: ToolUse
132-
invocation_state: dict[str, Any]
134+
invocation_state: InvocationState
133135

134136

135137
ToolChoice = Union[
@@ -231,7 +233,7 @@ def supports_hot_reload(self) -> bool:
231233

232234
@abstractmethod
233235
# pragma: no cover
234-
def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator:
236+
def stream(self, tool_use: ToolUse, invocation_state: InvocationState, **kwargs: Any) -> ToolGenerator:
235237
"""Stream tool events and return the final result.
236238
237239
Args:

0 commit comments

Comments
 (0)