Skip to content

Commit 6d225cd

Browse files
committed
hooks - before node call - cancel node
1 parent 95ac650 commit 6d225cd

File tree

4 files changed

+60
-7
lines changed

4 files changed

+60
-7
lines changed

src/strands/experimental/hooks/multiagent/events.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,18 @@ class BeforeNodeCallEvent(BaseHookEvent):
3535
source: The multi-agent orchestrator instance
3636
node_id: ID of the node about to execute
3737
invocation_state: Configuration that user passes in
38+
cancel_node: A user defined message that when set, will cancel the node execution with status FAILED.
39+
The message will be emitted under a MultiAgentNodeCancel event. If set to `True`, Strands will cancel the
40+
node using a default cancel message.
3841
"""
3942

4043
source: "MultiAgentBase"
4144
node_id: str
4245
invocation_state: dict[str, Any] | None = None
46+
cancel_node: bool | str = False
47+
48+
def _can_write(self, name: str) -> bool:
49+
return name in ["cancel_node"]
4350

4451

4552
@dataclass

src/strands/multiagent/graph.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from ..telemetry import get_tracer
3939
from ..types._events import (
4040
MultiAgentHandoffEvent,
41+
MultiAgentNodeCancelEvent,
4142
MultiAgentNodeStartEvent,
4243
MultiAgentNodeStopEvent,
4344
MultiAgentNodeStreamEvent,
@@ -776,7 +777,6 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
776777

777778
async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
778779
"""Execute a single node and yield TypedEvent objects."""
779-
await self.hooks.invoke_callbacks_async(BeforeNodeCallEvent(self, node.node_id, invocation_state))
780780

781781
# Reset the node's state if reset_on_revisit is enabled, and it's being revisited
782782
if self.reset_on_revisit and node in self.state.completed_nodes:
@@ -795,6 +795,18 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
795795

796796
start_time = time.time()
797797
try:
798+
before_event, _ = await self.hooks.invoke_callbacks_async(BeforeNodeCallEvent(self, node.node_id, invocation_state))
799+
800+
if before_event.cancel_node:
801+
cancel_message = (
802+
before_event.cancel_node
803+
if isinstance(before_event.cancel_node, str)
804+
else "node cancelled by user"
805+
)
806+
logger.debug("reason=<%s> | cancelling execution", cancel_message)
807+
yield MultiAgentNodeCancelEvent(node.node_id, cancel_message)
808+
raise RuntimeError(cancel_message)
809+
798810
# Build node input from satisfied dependencies
799811
node_input = self._build_node_input(node)
800812

src/strands/multiagent/swarm.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from ..tools.decorator import tool
3939
from ..types._events import (
4040
MultiAgentHandoffEvent,
41+
MultiAgentNodeCancelEvent,
4142
MultiAgentNodeStartEvent,
4243
MultiAgentNodeStopEvent,
4344
MultiAgentNodeStreamEvent,
@@ -680,9 +681,21 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
680681

681682
# TODO: Implement cancellation token to stop _execute_node from continuing
682683
try:
683-
await self.hooks.invoke_callbacks_async(
684+
before_event, _ = await self.hooks.invoke_callbacks_async(
684685
BeforeNodeCallEvent(self, current_node.node_id, invocation_state)
685686
)
687+
688+
if before_event.cancel_node:
689+
cancel_message = (
690+
before_event.cancel_node
691+
if isinstance(before_event.cancel_node, str)
692+
else "node cancelled by user"
693+
)
694+
logger.debug("reason=<%s> | cancelling execution", cancel_message)
695+
yield MultiAgentNodeCancelEvent(current_node.node_id, cancel_message)
696+
self.state.completion_status = Status.FAILED
697+
break
698+
686699
node_stream = self._stream_with_timeout(
687700
self._execute_node(current_node, self.state.task, invocation_state),
688701
self.node_timeout,
@@ -692,6 +705,13 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
692705
yield event
693706

694707
self.state.node_history.append(current_node)
708+
709+
except Exception:
710+
logger.exception("node=<%s> | node execution failed", current_node.node_id)
711+
self.state.completion_status = Status.FAILED
712+
break
713+
714+
finally:
695715
await self.hooks.invoke_callbacks_async(
696716
AfterNodeCallEvent(self, current_node.node_id, invocation_state)
697717
)
@@ -723,11 +743,6 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
723743
self.state.completion_status = Status.COMPLETED
724744
break
725745

726-
except Exception:
727-
logger.exception("node=<%s> | node execution failed", current_node.node_id)
728-
self.state.completion_status = Status.FAILED
729-
break
730-
731746
except Exception:
732747
logger.exception("swarm execution failed")
733748
self.state.completion_status = Status.FAILED

src/strands/types/_events.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,3 +524,22 @@ def __init__(self, node_id: str, agent_event: dict[str, Any]) -> None:
524524
"event": agent_event, # Nest agent event to avoid field conflicts
525525
}
526526
)
527+
528+
529+
class MultiAgentNodeCancelEvent(TypedEvent):
530+
"""Event emitted when a user cancels node execution from their BeforeNodeCallEvent hook."""
531+
532+
def __init__(self, node_id: str, message: str) -> None:
533+
"""Initialize with cancel message.
534+
535+
Args:
536+
node_id: Unique identifier for the node.
537+
message: The node cancellation message.
538+
"""
539+
super().__init__(
540+
{
541+
"type": "multiagent_node_cancel",
542+
"node_id": node_id,
543+
"message": message,
544+
}
545+
)

0 commit comments

Comments
 (0)