2929from ..experimental .hooks .multiagent import (
3030 AfterMultiAgentInvocationEvent ,
3131 AfterNodeCallEvent ,
32+ BeforeNodeCallEvent ,
3233 MultiAgentInitializedEvent ,
3334)
3435from ..hooks import HookProvider , HookRegistry
@@ -409,7 +410,6 @@ def __init__(
409410 reset_on_revisit : bool = False ,
410411 session_manager : Optional [SessionManager ] = None ,
411412 hooks : Optional [list [HookProvider ]] = None ,
412- * ,
413413 id : str = _DEFAULT_GRAPH_ID ,
414414 ) -> None :
415415 """Initialize Graph with execution limits and reset behavior.
@@ -430,7 +430,6 @@ def __init__(
430430
431431 # Validate nodes for duplicate instances
432432 self ._validate_graph (nodes )
433- self .id = id or _DEFAULT_GRAPH_ID
434433
435434 self .nodes = nodes
436435 self .edges = edges
@@ -451,6 +450,7 @@ def __init__(
451450
452451 self ._resume_next_nodes : list [GraphNode ] = []
453452 self ._resume_from_session = False
453+ self .id = id
454454
455455 self .hooks .invoke_callbacks (MultiAgentInitializedEvent (self ))
456456
@@ -773,7 +773,9 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
773773
774774 async def _execute_node (self , node : GraphNode , invocation_state : dict [str , Any ]) -> AsyncIterator [Any ]:
775775 """Execute a single node and yield TypedEvent objects."""
776- # Reset the node's state if reset_on_revisit is enabled and it's being revisited
776+ self .hooks .invoke_callbacks (BeforeNodeCallEvent (self , node .node_id , invocation_state ))
777+
778+ # Reset the node's state if reset_on_revisit is enabled, and it's being revisited
777779 if self .reset_on_revisit and node in self .state .completed_nodes :
778780 logger .debug ("node_id=<%s> | resetting node state for revisit" , node .node_id )
779781 node .reset_executor_state ()
@@ -914,6 +916,9 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
914916 # Re-raise to stop graph execution (fail-fast behavior)
915917 raise
916918
919+ finally :
920+ self .hooks .invoke_callbacks (AfterNodeCallEvent (self , node .node_id , invocation_state ))
921+
917922 def _accumulate_metrics (self , node_result : NodeResult ) -> None :
918923 """Accumulate metrics from a node result."""
919924 self .state .accumulated_usage ["inputTokens" ] += node_result .accumulated_usage .get ("inputTokens" , 0 )
@@ -1001,13 +1006,12 @@ def _build_result(self) -> GraphResult:
10011006
10021007 def serialize_state (self ) -> dict [str , Any ]:
10031008 """Serialize the current graph state to a dictionary."""
1004- status_str = self .state .status .value
10051009 compute_nodes = self ._compute_ready_nodes_for_resume ()
10061010 next_nodes = [n .node_id for n in compute_nodes ] if compute_nodes else []
10071011 return {
10081012 "type" : "graph" ,
10091013 "id" : self .id ,
1010- "status" : status_str ,
1014+ "status" : self . state . status . value ,
10111015 "completed_nodes" : [n .node_id for n in self .state .completed_nodes ],
10121016 "failed_nodes" : [n .node_id for n in self .state .failed_nodes ],
10131017 "node_results" : {k : v .to_dict () for k , v in (self .state .results or {}).items ()},
@@ -1020,10 +1024,10 @@ def deserialize_state(self, payload: dict[str, Any]) -> None:
10201024 """Restore graph state from a session dict and prepare for execution.
10211025
10221026 This method handles two scenarios:
1023- 1. If the persisted status is COMPLETED, FAILED resets all nodes and graph state
1024- to allow re-execution from the beginning.
1025- 2. Otherwise, restores the persisted state and prepares to resume execution
1026- from the next ready nodes.
1027+ 1. If the graph execution ended (no next_nodes_to_execute, eg: Completed, or Failed with dead end nodes),
1028+ resets all nodes and graph state to allow re-execution from the beginning.
1029+ 2. If the graph execution was interrupted mid-execution (has next_nodes_to_execute),
1030+ restores the persisted state and prepares to resume execution from the next ready nodes.
10271031
10281032 Args:
10291033 payload: Dictionary containing persisted state data including status,
@@ -1041,7 +1045,6 @@ def deserialize_state(self, payload: dict[str, Any]) -> None:
10411045 self ._from_dict (payload )
10421046 self ._resume_from_session = True
10431047
1044- # Helper functions for serialize and deserialize
10451048 def _compute_ready_nodes_for_resume (self ) -> list [GraphNode ]:
10461049 if self .state .status == Status .PENDING :
10471050 return []
@@ -1073,7 +1076,9 @@ def _from_dict(self, payload: dict[str, Any]) -> None:
10731076 raise
10741077 self .state .results = results
10751078
1076- self .state .failed_nodes = set (payload .get ("failed_nodes" ) or [])
1079+ self .state .failed_nodes = set (
1080+ self .nodes [node_id ] for node_id in (payload .get ("failed_nodes" ) or []) if node_id in self .nodes
1081+ )
10771082
10781083 # Restore completed nodes from persisted data
10791084 completed_node_ids = payload .get ("completed_nodes" ) or []
0 commit comments