diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index bfa83fe20..0651d4521 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -44,6 +44,16 @@ logger = logging.getLogger(__name__) +# Sentinel class and object to distinguish between explicit None and default parameter value +class _DefaultCallbackHandlerSentinel: + """Sentinel class to distinguish between explicit None and default parameter value.""" + + pass + + +_DEFAULT_CALLBACK_HANDLER = _DefaultCallbackHandlerSentinel() + + class Agent: """Core Agent interface. @@ -70,7 +80,7 @@ def __init__(self, agent: "Agent") -> None: # agent tools and thus break their execution. self._agent = agent - def __getattr__(self, name: str) -> Callable: + def __getattr__(self, name: str) -> Callable[..., Any]: """Call tool as a function. This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). @@ -177,7 +187,9 @@ def __init__( messages: Optional[Messages] = None, tools: Optional[List[Union[str, Dict[str, str], Any]]] = None, system_prompt: Optional[str] = None, - callback_handler: Optional[Callable] = PrintingCallbackHandler(), + callback_handler: Optional[ + Union[Callable[..., Any], _DefaultCallbackHandlerSentinel] + ] = _DEFAULT_CALLBACK_HANDLER, conversation_manager: Optional[ConversationManager] = None, max_parallel_tools: int = os.cpu_count() or 1, record_direct_tool_call: bool = True, @@ -204,7 +216,8 @@ def __init__( system_prompt: System prompt to guide model behavior. If None, the model will behave according to its default settings. callback_handler: Callback for processing events as they happen during agent execution. - Defaults to strands.handlers.PrintingCallbackHandler if None. + If not provided (using the default), a new PrintingCallbackHandler instance is created. + If explicitly set to None, null_callback_handler is used. conversation_manager: Manager for conversation history and context window. Defaults to strands.agent.conversation_manager.SlidingWindowConversationManager if None. max_parallel_tools: Maximum number of tools to run in parallel when the model returns multiple tool calls. @@ -222,7 +235,17 @@ def __init__( self.messages = messages if messages is not None else [] self.system_prompt = system_prompt - self.callback_handler = callback_handler or null_callback_handler + + # If not provided, create a new PrintingCallbackHandler instance + # If explicitly set to None, use null_callback_handler + # Otherwise use the passed callback_handler + self.callback_handler: Union[Callable[..., Any], PrintingCallbackHandler] + if isinstance(callback_handler, _DefaultCallbackHandlerSentinel): + self.callback_handler = PrintingCallbackHandler() + elif callback_handler is None: + self.callback_handler = null_callback_handler + else: + self.callback_handler = callback_handler self.conversation_manager = conversation_manager if conversation_manager else SlidingWindowConversationManager() @@ -415,7 +438,7 @@ def target_callback() -> None: thread.join() def _run_loop( - self, prompt: str, kwargs: Any, supplementary_callback_handler: Optional[Callable] = None + self, prompt: str, kwargs: Dict[str, Any], supplementary_callback_handler: Optional[Callable[..., Any]] = None ) -> AgentResult: """Execute the agent's event loop with the given prompt and parameters.""" try: @@ -441,7 +464,7 @@ def _run_loop( finally: self.conversation_manager.apply_management(self) - def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str, Any]) -> AgentResult: + def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs: Dict[str, Any]) -> AgentResult: """Execute the event loop cycle with retry logic for context window limits. This internal method handles the execution of the event loop cycle and implements diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 4a63fa31f..0ea20b642 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -686,6 +686,37 @@ def test_agent_with_callback_handler_none_uses_null_handler(): assert agent.callback_handler == null_callback_handler +def test_agent_callback_handler_not_provided_creates_new_instances(): + """Test that when callback_handler is not provided, new PrintingCallbackHandler instances are created.""" + # Create two agents without providing callback_handler + agent1 = Agent() + agent2 = Agent() + + # Both should have PrintingCallbackHandler instances + assert isinstance(agent1.callback_handler, PrintingCallbackHandler) + assert isinstance(agent2.callback_handler, PrintingCallbackHandler) + + # But they should be different object instances + assert agent1.callback_handler is not agent2.callback_handler + + +def test_agent_callback_handler_explicit_none_uses_null_handler(): + """Test that when callback_handler is explicitly set to None, null_callback_handler is used.""" + agent = Agent(callback_handler=None) + + # Should use null_callback_handler + assert agent.callback_handler is null_callback_handler + + +def test_agent_callback_handler_custom_handler_used(): + """Test that when a custom callback_handler is provided, it is used.""" + custom_handler = unittest.mock.Mock() + agent = Agent(callback_handler=custom_handler) + + # Should use the provided custom handler + assert agent.callback_handler is custom_handler + + @pytest.mark.asyncio async def test_stream_async_returns_all_events(mock_event_loop_cycle): agent = Agent()