Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def __init__(
self,
nodes: list[Agent],
*,
entry_point: Agent | None = None,
max_handoffs: int = 20,
max_iterations: int = 20,
execution_timeout: float = 900.0,
Expand All @@ -207,6 +208,7 @@ def __init__(

Args:
nodes: List of nodes (e.g. Agent) to include in the swarm
entry_point: Agent to start with. If None, uses the first agent (default: None)
max_handoffs: Maximum handoffs to agents and users (default: 20)
max_iterations: Maximum node executions within the swarm (default: 20)
execution_timeout: Total execution timeout in seconds (default: 900.0)
Expand All @@ -218,6 +220,7 @@ def __init__(
"""
super().__init__()

self.entry_point = entry_point
self.max_handoffs = max_handoffs
self.max_iterations = max_iterations
self.execution_timeout = execution_timeout
Expand Down Expand Up @@ -276,7 +279,11 @@ async def invoke_async(
logger.debug("starting swarm execution")

# Initialize swarm state with configuration
initial_node = next(iter(self.nodes.values())) # First SwarmNode
if self.entry_point:
initial_node = self.nodes[str(self.entry_point.name)]
else:
initial_node = next(iter(self.nodes.values())) # First SwarmNode

self.state = SwarmState(
current_node=initial_node,
task=task,
Expand Down Expand Up @@ -326,9 +333,28 @@ def _setup_swarm(self, nodes: list[Agent]) -> None:

self.nodes[node_id] = SwarmNode(node_id=node_id, executor=node)

# Validate entry point if specified
if self.entry_point is not None:
entry_point_node_id = str(self.entry_point.name)
if (
entry_point_node_id not in self.nodes
or self.nodes[entry_point_node_id].executor is not self.entry_point
):
available_agents = [
f"{node_id} ({type(node.executor).__name__})" for node_id, node in self.nodes.items()
]
raise ValueError(f"Entry point agent not found in swarm nodes. Available agents: {available_agents}")

swarm_nodes = list(self.nodes.values())
logger.debug("nodes=<%s> | initialized swarm with nodes", [node.node_id for node in swarm_nodes])

if self.entry_point:
entry_point_name = getattr(self.entry_point, "name", "unnamed_agent")
logger.debug("entry_point=<%s> | configured entry point", entry_point_name)
else:
first_node = next(iter(self.nodes.keys()))
logger.debug("entry_point=<%s> | using first node as entry point", first_node)

def _validate_swarm(self, nodes: list[Agent]) -> None:
"""Validate swarm structure and nodes."""
# Check for duplicate object instances
Expand Down
76 changes: 76 additions & 0 deletions tests/strands/multiagent/test_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,82 @@ def test_swarm_auto_completion_without_handoff():
no_handoff_agent.invoke_async.assert_called()


def test_swarm_configurable_entry_point():
"""Test swarm with configurable entry point."""
# Create multiple agents
agent1 = create_mock_agent("agent1", "Agent 1 response")
agent2 = create_mock_agent("agent2", "Agent 2 response")
agent3 = create_mock_agent("agent3", "Agent 3 response")

# Create swarm with agent2 as entry point
swarm = Swarm([agent1, agent2, agent3], entry_point=agent2)

# Verify entry point is set correctly
assert swarm.entry_point is agent2

# Execute swarm
result = swarm("Test task")

# Verify agent2 was the first to execute
assert result.status == Status.COMPLETED
assert len(result.node_history) == 1
assert result.node_history[0].node_id == "agent2"


def test_swarm_invalid_entry_point():
"""Test swarm with invalid entry point raises error."""
agent1 = create_mock_agent("agent1", "Agent 1 response")
agent2 = create_mock_agent("agent2", "Agent 2 response")
agent3 = create_mock_agent("agent3", "Agent 3 response") # Not in swarm

# Try to create swarm with agent not in the swarm
with pytest.raises(ValueError, match="Entry point agent not found in swarm nodes"):
Swarm([agent1, agent2], entry_point=agent3)


def test_swarm_default_entry_point():
"""Test swarm uses first agent as default entry point."""
agent1 = create_mock_agent("agent1", "Agent 1 response")
agent2 = create_mock_agent("agent2", "Agent 2 response")

# Create swarm without specifying entry point
swarm = Swarm([agent1, agent2])

# Verify no explicit entry point is set
assert swarm.entry_point is None

# Execute swarm
result = swarm("Test task")

# Verify first agent was used as entry point
assert result.status == Status.COMPLETED
assert len(result.node_history) == 1
assert result.node_history[0].node_id == "agent1"


def test_swarm_duplicate_agent_names():
"""Test swarm rejects agents with duplicate names."""
agent1 = create_mock_agent("duplicate_name", "Agent 1 response")
agent2 = create_mock_agent("duplicate_name", "Agent 2 response")

# Try to create swarm with duplicate names
with pytest.raises(ValueError, match="Node ID 'duplicate_name' is not unique"):
Swarm([agent1, agent2])


def test_swarm_entry_point_same_name_different_object():
"""Test entry point validation with same name but different object."""
agent1 = create_mock_agent("agent1", "Agent 1 response")
agent2 = create_mock_agent("agent2", "Agent 2 response")

# Create a different agent with same name as agent1
different_agent_same_name = create_mock_agent("agent1", "Different agent response")

# Try to use the different agent as entry point
with pytest.raises(ValueError, match="Entry point agent not found in swarm nodes"):
Swarm([agent1, agent2], entry_point=different_agent_same_name)


def test_swarm_validate_unsupported_features():
"""Test Swarm validation for session persistence and callbacks."""
# Test with normal agent (should work)
Expand Down
Loading