Skip to content

Commit aedd41f

Browse files
committed
fix: rebase from main and address comments
1 parent 9f3a93a commit aedd41f

File tree

6 files changed

+43
-26
lines changed

6 files changed

+43
-26
lines changed

src/strands/multiagent/graph.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from ..experimental.hooks.multiagent import (
3030
AfterMultiAgentInvocationEvent,
3131
AfterNodeCallEvent,
32+
BeforeNodeCallEvent,
3233
MultiAgentInitializedEvent,
3334
)
3435
from ..hooks import HookProvider, HookRegistry
@@ -409,7 +410,6 @@ def __init__(
409410
reset_on_revisit: bool = False,
410411
session_manager: Optional[SessionManager] = None,
411412
hooks: Optional[list[HookProvider]] = None,
412-
*,
413413
id: str = _DEFAULT_GRAPH_ID,
414414
) -> None:
415415
"""Initialize Graph with execution limits and reset behavior.
@@ -430,7 +430,6 @@ def __init__(
430430

431431
# Validate nodes for duplicate instances
432432
self._validate_graph(nodes)
433-
self.id = id or _DEFAULT_GRAPH_ID
434433

435434
self.nodes = nodes
436435
self.edges = edges
@@ -451,6 +450,7 @@ def __init__(
451450

452451
self._resume_next_nodes: list[GraphNode] = []
453452
self._resume_from_session = False
453+
self.id = id
454454

455455
self.hooks.invoke_callbacks(MultiAgentInitializedEvent(self))
456456

@@ -773,7 +773,9 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
773773

774774
async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
775775
"""Execute a single node and yield TypedEvent objects."""
776-
# Reset the node's state if reset_on_revisit is enabled and it's being revisited
776+
self.hooks.invoke_callbacks(BeforeNodeCallEvent(self, node.node_id, invocation_state))
777+
778+
# Reset the node's state if reset_on_revisit is enabled, and it's being revisited
777779
if self.reset_on_revisit and node in self.state.completed_nodes:
778780
logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id)
779781
node.reset_executor_state()
@@ -914,6 +916,9 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
914916
# Re-raise to stop graph execution (fail-fast behavior)
915917
raise
916918

919+
finally:
920+
self.hooks.invoke_callbacks(AfterNodeCallEvent(self, node.node_id, invocation_state))
921+
917922
def _accumulate_metrics(self, node_result: NodeResult) -> None:
918923
"""Accumulate metrics from a node result."""
919924
self.state.accumulated_usage["inputTokens"] += node_result.accumulated_usage.get("inputTokens", 0)
@@ -1001,13 +1006,12 @@ def _build_result(self) -> GraphResult:
10011006

10021007
def serialize_state(self) -> dict[str, Any]:
10031008
"""Serialize the current graph state to a dictionary."""
1004-
status_str = self.state.status.value
10051009
compute_nodes = self._compute_ready_nodes_for_resume()
10061010
next_nodes = [n.node_id for n in compute_nodes] if compute_nodes else []
10071011
return {
10081012
"type": "graph",
10091013
"id": self.id,
1010-
"status": status_str,
1014+
"status": self.state.status.value,
10111015
"completed_nodes": [n.node_id for n in self.state.completed_nodes],
10121016
"failed_nodes": [n.node_id for n in self.state.failed_nodes],
10131017
"node_results": {k: v.to_dict() for k, v in (self.state.results or {}).items()},
@@ -1020,10 +1024,10 @@ def deserialize_state(self, payload: dict[str, Any]) -> None:
10201024
"""Restore graph state from a session dict and prepare for execution.
10211025
10221026
This method handles two scenarios:
1023-
1. If the persisted status is COMPLETED, FAILED resets all nodes and graph state
1024-
to allow re-execution from the beginning.
1025-
2. Otherwise, restores the persisted state and prepares to resume execution
1026-
from the next ready nodes.
1027+
1. If the graph execution ended (no next_nodes_to_execute, eg: Completed, or Failed with dead end nodes),
1028+
resets all nodes and graph state to allow re-execution from the beginning.
1029+
2. If the graph execution was interrupted mid-execution (has next_nodes_to_execute),
1030+
restores the persisted state and prepares to resume execution from the next ready nodes.
10271031
10281032
Args:
10291033
payload: Dictionary containing persisted state data including status,
@@ -1041,7 +1045,6 @@ def deserialize_state(self, payload: dict[str, Any]) -> None:
10411045
self._from_dict(payload)
10421046
self._resume_from_session = True
10431047

1044-
# Helper functions for serialize and deserialize
10451048
def _compute_ready_nodes_for_resume(self) -> list[GraphNode]:
10461049
if self.state.status == Status.PENDING:
10471050
return []
@@ -1073,7 +1076,9 @@ def _from_dict(self, payload: dict[str, Any]) -> None:
10731076
raise
10741077
self.state.results = results
10751078

1076-
self.state.failed_nodes = set(payload.get("failed_nodes") or [])
1079+
self.state.failed_nodes = set(
1080+
self.nodes[node_id] for node_id in (payload.get("failed_nodes") or []) if node_id in self.nodes
1081+
)
10771082

10781083
# Restore completed nodes from persisted data
10791084
completed_node_ids = payload.get("completed_nodes") or []

src/strands/multiagent/swarm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import logging
1919
import time
2020
from dataclasses import dataclass, field
21-
from typing import Any, AsyncIterator, Callable, Tuple, cast,Optional
21+
from typing import Any, AsyncIterator, Callable, Optional, Tuple, cast
2222

2323
from opentelemetry import trace as trace_api
2424

@@ -28,6 +28,7 @@
2828
from ..experimental.hooks.multiagent import (
2929
AfterMultiAgentInvocationEvent,
3030
AfterNodeCallEvent,
31+
BeforeNodeCallEvent,
3132
MultiAgentInitializedEvent,
3233
)
3334
from ..hooks import HookProvider, HookRegistry
@@ -210,8 +211,8 @@ class Swarm(MultiAgentBase):
210211

211212
def __init__(
212213
self,
213-
id:_DEFAULT_SWARM_ID,
214214
nodes: list[Agent],
215+
id: str = _DEFAULT_SWARM_ID,
215216
*,
216217
entry_point: Agent | None = None,
217218
max_handoffs: int = 20,
@@ -351,7 +352,6 @@ async def stream_async(
351352
self.state.completion_status = Status.EXECUTING
352353
self.state.start_time = time.time()
353354

354-
start_time = time.time()
355355
span = self.tracer.start_multiagent_span(task, "swarm")
356356
with trace_api.use_span(span, end_on_exit=True):
357357
try:
@@ -372,7 +372,7 @@ async def stream_async(
372372
self.state.completion_status = Status.FAILED
373373
raise
374374
finally:
375-
self.state.execution_time = round((time.time() - start_time) * 1000)
375+
self.state.execution_time = round((time.time() - self.state.start_time) * 1000)
376376
self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(self, invocation_state))
377377
self._resume_from_session = False
378378

@@ -685,6 +685,7 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
685685
# TODO: Implement cancellation token to stop _execute_node from continuing
686686
try:
687687
# Execute with timeout wrapper for async generator streaming
688+
self.hooks.invoke_callbacks(BeforeNodeCallEvent(self, current_node.node_id, invocation_state))
688689
node_stream = self._stream_with_timeout(
689690
self._execute_node(current_node, self.state.task, invocation_state),
690691
self.node_timeout,

tests/strands/experimental/hooks/multiagent/test_multi_agent_hooks.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_swarm_complete_hook_lifecycle(swarm, hook_provider):
6767
result = swarm("test task")
6868

6969
length, events = hook_provider.get_events()
70-
assert length == 3
70+
assert length == 4
7171
assert result.status.value == "completed"
7272

7373
events_list = list(events)
@@ -76,20 +76,24 @@ def test_swarm_complete_hook_lifecycle(swarm, hook_provider):
7676
assert isinstance(events_list[0], MultiAgentInitializedEvent)
7777
assert events_list[0].source == swarm
7878

79-
assert isinstance(events_list[1], AfterNodeCallEvent)
79+
assert isinstance(events_list[1], BeforeNodeCallEvent)
8080
assert events_list[1].source == swarm
8181
assert events_list[1].node_id == "agent1"
8282

83-
assert isinstance(events_list[2], AfterMultiAgentInvocationEvent)
83+
assert isinstance(events_list[2], AfterNodeCallEvent)
8484
assert events_list[2].source == swarm
85+
assert events_list[2].node_id == "agent1"
86+
87+
assert isinstance(events_list[3], AfterMultiAgentInvocationEvent)
88+
assert events_list[3].source == swarm
8589

8690

8791
def test_graph_complete_hook_lifecycle(graph, hook_provider):
8892
"""E2E test verifying complete hook lifecycle for Graph."""
8993
result = graph("test task")
9094

9195
length, events = hook_provider.get_events()
92-
assert length == 4
96+
assert length == 6
9397
assert result.status.value == "completed"
9498

9599
events_list = list(events)
@@ -98,13 +102,21 @@ def test_graph_complete_hook_lifecycle(graph, hook_provider):
98102
assert isinstance(events_list[0], MultiAgentInitializedEvent)
99103
assert events_list[0].source == graph
100104

101-
assert isinstance(events_list[1], AfterNodeCallEvent)
105+
assert isinstance(events_list[1], BeforeNodeCallEvent)
102106
assert events_list[1].source == graph
103107
assert events_list[1].node_id == "agent1"
104108

105109
assert isinstance(events_list[2], AfterNodeCallEvent)
106110
assert events_list[2].source == graph
107-
assert events_list[2].node_id == "agent2"
111+
assert events_list[2].node_id == "agent1"
108112

109-
assert isinstance(events_list[3], AfterMultiAgentInvocationEvent)
113+
assert isinstance(events_list[3], BeforeNodeCallEvent)
110114
assert events_list[3].source == graph
115+
assert events_list[3].node_id == "agent2"
116+
117+
assert isinstance(events_list[4], AfterNodeCallEvent)
118+
assert events_list[4].source == graph
119+
assert events_list[4].node_id == "agent2"
120+
121+
assert isinstance(events_list[5], AfterMultiAgentInvocationEvent)
122+
assert events_list[5].source == graph

tests/strands/multiagent/test_graph.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,6 +2010,7 @@ async def test_graph_persisted(mock_strands_tracer, mock_use_span):
20102010
persisted_state = {
20112011
"status": "executing",
20122012
"completed_nodes": [],
2013+
"failed_nodes": [],
20132014
"node_results": {},
20142015
"current_task": "persisted task",
20152016
"execution_order": [],

tests_integ/test_multiagent_graph.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1+
from typing import Any, AsyncIterator
12
from unittest.mock import patch
23
from uuid import uuid4
34

4-
from typing import Any, AsyncIterator
5-
65
import pytest
76

87
from strands import Agent, tool
@@ -14,7 +13,6 @@
1413
BeforeModelCallEvent,
1514
MessageAddedEvent,
1615
)
17-
from strands.multiagent.base import Status
1816
from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult, Status
1917
from strands.multiagent.graph import GraphBuilder
2018
from strands.session.file_session_manager import FileSessionManager

tests_integ/test_multiagent_swarm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
)
1616
from strands.multiagent.base import Status
1717
from strands.multiagent.swarm import Swarm
18+
from strands.session.file_session_manager import FileSessionManager
1819
from strands.types.content import ContentBlock
1920
from tests.fixtures.mock_hook_provider import MockHookProvider
20-
from strands.session.file_session_manager import FileSessionManager
2121

2222

2323
@tool

0 commit comments

Comments
 (0)