|
34 | 34 | MaxTokensReachedException, |
35 | 35 | ModelThrottledException, |
36 | 36 | ) |
| 37 | +from ..types.invocation import InvocationState |
37 | 38 | from ..types.streaming import Metrics, StopReason |
38 | 39 | from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse |
39 | 40 | from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached |
|
49 | 50 | MAX_DELAY = 240 # 4 minutes |
50 | 51 |
|
51 | 52 |
|
52 | | -async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: |
| 53 | +async def event_loop_cycle(agent: "Agent", invocation_state: InvocationState) -> AsyncGenerator[dict[str, Any], None]: |
53 | 54 | """Execute a single cycle of the event loop. |
54 | 55 |
|
55 | 56 | This core function processes a single conversation turn, handling model inference, tool execution, and error |
@@ -273,7 +274,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> |
273 | 274 | yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])} |
274 | 275 |
|
275 | 276 |
|
276 | | -async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: |
| 277 | +async def recurse_event_loop(agent: "Agent", invocation_state: InvocationState) -> AsyncGenerator[dict[str, Any], None]: |
277 | 278 | """Make a recursive call to event_loop_cycle with the current state. |
278 | 279 |
|
279 | 280 | This function is used when the event loop needs to continue processing after tool execution. |
@@ -306,7 +307,7 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) - |
306 | 307 | recursive_trace.end() |
307 | 308 |
|
308 | 309 |
|
309 | | -async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str, Any]) -> ToolGenerator: |
| 310 | +async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: InvocationState) -> ToolGenerator: |
310 | 311 | """Process a tool invocation. |
311 | 312 |
|
312 | 313 | Looks up the tool in the registry and streams it with the provided parameters. |
@@ -429,7 +430,7 @@ async def _handle_tool_execution( |
429 | 430 | cycle_trace: Trace, |
430 | 431 | cycle_span: Any, |
431 | 432 | cycle_start_time: float, |
432 | | - invocation_state: dict[str, Any], |
| 433 | + invocation_state: InvocationState, |
433 | 434 | ) -> AsyncGenerator[dict[str, Any], None]: |
434 | 435 | tool_uses: list[ToolUse] = [] |
435 | 436 | tool_results: list[ToolResult] = [] |
|
0 commit comments