Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, AsyncIterator, Mapping, Union
from typing import Any, AsyncIterator, Mapping, Type, Union

from pydantic import BaseModel

from .._async import run_async
from ..agent import AgentResult
Expand Down Expand Up @@ -188,20 +190,30 @@ class MultiAgentBase(ABC):

@abstractmethod
async def invoke_async(
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
self,
task: MultiAgentInput,
invocation_state: dict[str, Any] | None = None,
structured_output_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> MultiAgentResult:
"""Invoke asynchronously.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
structured_output_model: Pydantic model to use for structured output from nodes.
Individual nodes may override this with their own default model.
**kwargs: Additional keyword arguments passed to underlying agents.
"""
raise NotImplementedError("invoke_async not implemented")

async def stream_async(
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
self,
task: MultiAgentInput,
invocation_state: dict[str, Any] | None = None,
structured_output_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> AsyncIterator[dict[str, Any]]:
"""Stream events during multi-agent execution.

Expand All @@ -212,6 +224,8 @@ async def stream_async(
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
structured_output_model: Pydantic model to use for structured output from nodes.
Individual nodes may override this with their own default model.
**kwargs: Additional keyword arguments passed to underlying agents.

Yields:
Expand All @@ -222,18 +236,24 @@ async def stream_async(
"""
# Default implementation for backward compatibility
# Execute invoke_async and yield the result as a single event
result = await self.invoke_async(task, invocation_state, **kwargs)
result = await self.invoke_async(task, invocation_state, structured_output_model, **kwargs)
yield {"result": result}

def __call__(
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
self,
task: MultiAgentInput,
invocation_state: dict[str, Any] | None = None,
structured_output_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> MultiAgentResult:
"""Invoke synchronously.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
structured_output_model: Pydantic model to use for structured output from nodes.
Individual nodes may override this with their own default model.
**kwargs: Additional keyword arguments passed to underlying agents.
"""
if invocation_state is None:
Expand All @@ -243,7 +263,7 @@ def __call__(
invocation_state.update(kwargs)
warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2)

return run_async(lambda: self.invoke_async(task, invocation_state))
return run_async(lambda: self.invoke_async(task, invocation_state, structured_output_model))

def serialize_state(self) -> dict[str, Any]:
"""Return a JSON-serializable snapshot of the orchestrator state."""
Expand Down
75 changes: 45 additions & 30 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
import logging
import time
from dataclasses import dataclass, field
from typing import Any, AsyncIterator, Callable, Mapping, Optional, Tuple, cast
from typing import Any, AsyncIterator, Callable, Mapping, Optional, Tuple, Type, cast

from opentelemetry import trace as trace_api
from pydantic import BaseModel

from .._async import run_async
from ..agent import Agent
Expand Down Expand Up @@ -461,24 +462,12 @@ def __init__(

run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))

def __call__(
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> GraphResult:
"""Invoke the graph synchronously.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
**kwargs: Keyword arguments allowing backward compatible future changes.
"""
if invocation_state is None:
invocation_state = {}

return run_async(lambda: self.invoke_async(task, invocation_state))

async def invoke_async(
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
self,
task: MultiAgentInput,
invocation_state: dict[str, Any] | None = None,
structured_output_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> GraphResult:
"""Invoke the graph asynchronously.

Expand All @@ -489,9 +478,10 @@ async def invoke_async(
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
structured_output_model: Pydantic model to use for structured output from nodes.
**kwargs: Keyword arguments allowing backward compatible future changes.
"""
events = self.stream_async(task, invocation_state, **kwargs)
events = self.stream_async(task, invocation_state, structured_output_model, **kwargs)
final_event = None
async for event in events:
final_event = event
Expand All @@ -502,14 +492,19 @@ async def invoke_async(
return cast(GraphResult, final_event["result"])

async def stream_async(
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
self,
task: MultiAgentInput,
invocation_state: dict[str, Any] | None = None,
structured_output_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> AsyncIterator[dict[str, Any]]:
"""Stream events during graph execution.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
structured_output_model: Pydantic model to use for structured output from nodes.
**kwargs: Keyword arguments allowing backward compatible future changes.

Yields:
Expand Down Expand Up @@ -552,7 +547,7 @@ async def stream_async(
self.node_timeout or "None",
)

async for event in self._execute_graph(invocation_state):
async for event in self._execute_graph(invocation_state, structured_output_model):
yield event.as_dict()

# Set final status based on execution results
Expand Down Expand Up @@ -591,7 +586,9 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
# Validate Agent-specific constraints for each node
_validate_node_executor(node.executor)

async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
async def _execute_graph(
self, invocation_state: dict[str, Any], structured_output_model: Type[BaseModel] | None = None
) -> AsyncIterator[Any]:
"""Execute graph and yield TypedEvent objects."""
ready_nodes = self._resume_next_nodes if self._resume_from_session else list(self.entry_points)

Expand All @@ -610,7 +607,7 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato
ready_nodes.clear()

# Execute current batch
async for event in self._execute_nodes_parallel(current_batch, invocation_state):
async for event in self._execute_nodes_parallel(current_batch, invocation_state, structured_output_model):
yield event

# Find newly ready nodes after batch execution
Expand All @@ -634,7 +631,10 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato
ready_nodes.extend(newly_ready)

async def _execute_nodes_parallel(
self, nodes: list["GraphNode"], invocation_state: dict[str, Any]
self,
nodes: list["GraphNode"],
invocation_state: dict[str, Any],
structured_output_model: Type[BaseModel] | None = None,
) -> AsyncIterator[Any]:
"""Execute multiple nodes in parallel and merge their event streams in real-time.

Expand All @@ -644,7 +644,12 @@ async def _execute_nodes_parallel(
event_queue: asyncio.Queue[Any | None | Exception] = asyncio.Queue()

# Start all node streams as independent tasks
tasks = [asyncio.create_task(self._stream_node_to_queue(node, event_queue, invocation_state)) for node in nodes]
tasks = [
asyncio.create_task(
self._stream_node_to_queue(node, event_queue, invocation_state, structured_output_model)
)
for node in nodes
]

try:
# Consume events from the queue as they arrive
Expand Down Expand Up @@ -695,14 +700,15 @@ async def _stream_node_to_queue(
node: GraphNode,
event_queue: asyncio.Queue[Any | None | Exception],
invocation_state: dict[str, Any],
structured_output_model: Type[BaseModel] | None = None,
) -> None:
"""Stream events from a node to the shared queue with optional timeout."""
try:
# Apply timeout to the entire streaming process if configured
if self.node_timeout is not None:

async def stream_node() -> None:
async for event in self._execute_node(node, invocation_state):
async for event in self._execute_node(node, invocation_state, structured_output_model):
await event_queue.put(event)

try:
Expand All @@ -713,7 +719,7 @@ async def stream_node() -> None:
await event_queue.put(timeout_exc)
else:
# No timeout - stream normally
async for event in self._execute_node(node, invocation_state):
async for event in self._execute_node(node, invocation_state, structured_output_model):
await event_queue.put(event)
except Exception as e:
# Send exception through queue for fail-fast behavior
Expand Down Expand Up @@ -780,7 +786,12 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
)
return False

async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
async def _execute_node(
self,
node: GraphNode,
invocation_state: dict[str, Any],
structured_output_model: Type[BaseModel] | None = None,
) -> AsyncIterator[Any]:
"""Execute a single node and yield TypedEvent objects."""
# Reset the node's state if reset_on_revisit is enabled, and it's being revisited
if self.reset_on_revisit and node in self.state.completed_nodes:
Expand Down Expand Up @@ -818,7 +829,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
if isinstance(node.executor, MultiAgentBase):
# For nested multi-agent systems, stream their events and collect result
multi_agent_result = None
async for event in node.executor.stream_async(node_input, invocation_state):
async for event in node.executor.stream_async(node_input, invocation_state, structured_output_model):
# Forward nested multi-agent events with node context
wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event)
yield wrapped_event
Expand All @@ -842,7 +853,11 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
elif isinstance(node.executor, Agent):
# For agents, stream their events and collect result
agent_response = None
async for event in node.executor.stream_async(node_input, invocation_state=invocation_state):
# Use agent's own model if it has one, otherwise use graph-level model
effective_output_model = node.executor._default_structured_output_model or structured_output_model
async for event in node.executor.stream_async(
node_input, invocation_state=invocation_state, structured_output_model=effective_output_model
):
# Forward agent events with node context
wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event)
yield wrapped_event
Expand Down
Loading
Loading