diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 1c2302c28..620fa5e24 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -196,6 +196,7 @@ def __init__( self, nodes: list[Agent], *, + entry_point: Agent | None = None, max_handoffs: int = 20, max_iterations: int = 20, execution_timeout: float = 900.0, @@ -207,6 +208,7 @@ def __init__( Args: nodes: List of nodes (e.g. Agent) to include in the swarm + entry_point: Agent to start with. If None, uses the first agent (default: None) max_handoffs: Maximum handoffs to agents and users (default: 20) max_iterations: Maximum node executions within the swarm (default: 20) execution_timeout: Total execution timeout in seconds (default: 900.0) @@ -218,6 +220,7 @@ def __init__( """ super().__init__() + self.entry_point = entry_point self.max_handoffs = max_handoffs self.max_iterations = max_iterations self.execution_timeout = execution_timeout @@ -276,7 +279,11 @@ async def invoke_async( logger.debug("starting swarm execution") # Initialize swarm state with configuration - initial_node = next(iter(self.nodes.values())) # First SwarmNode + if self.entry_point: + initial_node = self.nodes[str(self.entry_point.name)] + else: + initial_node = next(iter(self.nodes.values())) # First SwarmNode + self.state = SwarmState( current_node=initial_node, task=task, @@ -326,9 +333,28 @@ def _setup_swarm(self, nodes: list[Agent]) -> None: self.nodes[node_id] = SwarmNode(node_id=node_id, executor=node) + # Validate entry point if specified + if self.entry_point is not None: + entry_point_node_id = str(self.entry_point.name) + if ( + entry_point_node_id not in self.nodes + or self.nodes[entry_point_node_id].executor is not self.entry_point + ): + available_agents = [ + f"{node_id} ({type(node.executor).__name__})" for node_id, node in self.nodes.items() + ] + raise ValueError(f"Entry point agent not found in swarm nodes. Available agents: {available_agents}") + swarm_nodes = list(self.nodes.values()) logger.debug("nodes=<%s> | initialized swarm with nodes", [node.node_id for node in swarm_nodes]) + if self.entry_point: + entry_point_name = getattr(self.entry_point, "name", "unnamed_agent") + logger.debug("entry_point=<%s> | configured entry point", entry_point_name) + else: + first_node = next(iter(self.nodes.keys())) + logger.debug("entry_point=<%s> | using first node as entry point", first_node) + def _validate_swarm(self, nodes: list[Agent]) -> None: """Validate swarm structure and nodes.""" # Check for duplicate object instances diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index be463c7fd..7d3e69695 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -451,6 +451,82 @@ def test_swarm_auto_completion_without_handoff(): no_handoff_agent.invoke_async.assert_called() +def test_swarm_configurable_entry_point(): + """Test swarm with configurable entry point.""" + # Create multiple agents + agent1 = create_mock_agent("agent1", "Agent 1 response") + agent2 = create_mock_agent("agent2", "Agent 2 response") + agent3 = create_mock_agent("agent3", "Agent 3 response") + + # Create swarm with agent2 as entry point + swarm = Swarm([agent1, agent2, agent3], entry_point=agent2) + + # Verify entry point is set correctly + assert swarm.entry_point is agent2 + + # Execute swarm + result = swarm("Test task") + + # Verify agent2 was the first to execute + assert result.status == Status.COMPLETED + assert len(result.node_history) == 1 + assert result.node_history[0].node_id == "agent2" + + +def test_swarm_invalid_entry_point(): + """Test swarm with invalid entry point raises error.""" + agent1 = create_mock_agent("agent1", "Agent 1 response") + agent2 = create_mock_agent("agent2", "Agent 2 response") + agent3 = create_mock_agent("agent3", "Agent 3 response") # Not in swarm + + # Try to create swarm with agent not in the swarm + with pytest.raises(ValueError, match="Entry point agent not found in swarm nodes"): + Swarm([agent1, agent2], entry_point=agent3) + + +def test_swarm_default_entry_point(): + """Test swarm uses first agent as default entry point.""" + agent1 = create_mock_agent("agent1", "Agent 1 response") + agent2 = create_mock_agent("agent2", "Agent 2 response") + + # Create swarm without specifying entry point + swarm = Swarm([agent1, agent2]) + + # Verify no explicit entry point is set + assert swarm.entry_point is None + + # Execute swarm + result = swarm("Test task") + + # Verify first agent was used as entry point + assert result.status == Status.COMPLETED + assert len(result.node_history) == 1 + assert result.node_history[0].node_id == "agent1" + + +def test_swarm_duplicate_agent_names(): + """Test swarm rejects agents with duplicate names.""" + agent1 = create_mock_agent("duplicate_name", "Agent 1 response") + agent2 = create_mock_agent("duplicate_name", "Agent 2 response") + + # Try to create swarm with duplicate names + with pytest.raises(ValueError, match="Node ID 'duplicate_name' is not unique"): + Swarm([agent1, agent2]) + + +def test_swarm_entry_point_same_name_different_object(): + """Test entry point validation with same name but different object.""" + agent1 = create_mock_agent("agent1", "Agent 1 response") + agent2 = create_mock_agent("agent2", "Agent 2 response") + + # Create a different agent with same name as agent1 + different_agent_same_name = create_mock_agent("agent1", "Different agent response") + + # Try to use the different agent as entry point + with pytest.raises(ValueError, match="Entry point agent not found in swarm nodes"): + Swarm([agent1, agent2], entry_point=different_agent_same_name) + + def test_swarm_validate_unsupported_features(): """Test Swarm validation for session persistence and callbacks.""" # Test with normal agent (should work)