diff --git a/src/strands/experimental/steering/handlers/llm/llm_handler.py b/src/strands/experimental/steering/handlers/llm/llm_handler.py index b269d4b60..9d9b34911 100644 --- a/src/strands/experimental/steering/handlers/llm/llm_handler.py +++ b/src/strands/experimental/steering/handlers/llm/llm_handler.py @@ -83,12 +83,13 @@ async def steer(self, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> Steer ) # Convert LLM decision to steering action - if llm_result.decision == "proceed": - return Proceed(reason=llm_result.reason) - elif llm_result.decision == "guide": - return Guide(reason=llm_result.reason) - elif llm_result.decision == "interrupt": - return Interrupt(reason=llm_result.reason) - else: - logger.warning("decision=<%s> | unknown llm decision, defaulting to proceed", llm_result.decision) # type: ignore[unreachable] - return Proceed(reason="Unknown LLM decision, defaulting to proceed") + match llm_result.decision: + case "proceed": + return Proceed(reason=llm_result.reason) + case "guide": + return Guide(reason=llm_result.reason) + case "interrupt": + return Interrupt(reason=llm_result.reason) + case _: + logger.warning("decision=<%s> | uŹknown llm decision, defaulting to proceed", llm_result.decision) # type: ignore[unreachable] + return Proceed(reason="Unknown LLM decision, defaulting to proceed") diff --git a/tests_integ/steering/test_llm_handler.py b/tests_integ/steering/test_llm_handler.py index e0cf122d8..8a8cebea2 100644 --- a/tests_integ/steering/test_llm_handler.py +++ b/tests_integ/steering/test_llm_handler.py @@ -22,7 +22,10 @@ def send_notification(recipient: str, message: str) -> str: @pytest.mark.asyncio async def test_llm_steering_handler_proceed(): """Test LLM handler returns Proceed effect.""" - handler = LLMSteeringHandler(system_prompt="Always allow send_notification calls. Return proceed decision.") + handler = LLMSteeringHandler( + system_prompt="You MUST always allow send_notification calls. ALWAYS return proceed decision. " + "Never return guide or interrupt." + ) agent = Agent(tools=[send_notification]) tool_use = {"name": "send_notification", "input": {"recipient": "user", "message": "hello"}} @@ -37,7 +40,8 @@ async def test_llm_steering_handler_guide(): """Test LLM handler returns Guide effect.""" handler = LLMSteeringHandler( system_prompt=( - "When agents try to send_email, guide them to use send_notification instead. Return GUIDE decision." + "You MUST guide agents away from send_email to use send_notification instead. " + "ALWAYS return guide decision for send_email. Never return proceed or interrupt for send_email." ) ) @@ -52,7 +56,10 @@ async def test_llm_steering_handler_guide(): @pytest.mark.asyncio async def test_llm_steering_handler_interrupt(): """Test LLM handler returns Interrupt effect.""" - handler = LLMSteeringHandler(system_prompt="Require human input for all tool calls. Return interrupt decision.") + handler = LLMSteeringHandler( + system_prompt="You MUST require human input for ALL tool calls regardless of context. " + "ALWAYS return interrupt decision. Never return proceed or guide." + ) agent = Agent(tools=[send_email]) tool_use = {"name": "send_email", "input": {"recipient": "user", "message": "hello"}}