Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions src/strands/experimental/steering/handlers/llm/llm_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,13 @@ async def steer(self, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> Steer
)

# Convert LLM decision to steering action
if llm_result.decision == "proceed":
return Proceed(reason=llm_result.reason)
elif llm_result.decision == "guide":
return Guide(reason=llm_result.reason)
elif llm_result.decision == "interrupt":
return Interrupt(reason=llm_result.reason)
else:
logger.warning("decision=<%s> | unknown llm decision, defaulting to proceed", llm_result.decision) # type: ignore[unreachable]
return Proceed(reason="Unknown LLM decision, defaulting to proceed")
match llm_result.decision:
case "proceed":
return Proceed(reason=llm_result.reason)
case "guide":
return Guide(reason=llm_result.reason)
case "interrupt":
return Interrupt(reason=llm_result.reason)
case _:
logger.warning("decision=<%s> | uŹknown llm decision, defaulting to proceed", llm_result.decision) # type: ignore[unreachable]
return Proceed(reason="Unknown LLM decision, defaulting to proceed")
13 changes: 10 additions & 3 deletions tests_integ/steering/test_llm_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ def send_notification(recipient: str, message: str) -> str:
@pytest.mark.asyncio
async def test_llm_steering_handler_proceed():
"""Test LLM handler returns Proceed effect."""
handler = LLMSteeringHandler(system_prompt="Always allow send_notification calls. Return proceed decision.")
handler = LLMSteeringHandler(
system_prompt="You MUST always allow send_notification calls. ALWAYS return proceed decision. "
"Never return guide or interrupt."
)

agent = Agent(tools=[send_notification])
tool_use = {"name": "send_notification", "input": {"recipient": "user", "message": "hello"}}
Expand All @@ -37,7 +40,8 @@ async def test_llm_steering_handler_guide():
"""Test LLM handler returns Guide effect."""
handler = LLMSteeringHandler(
system_prompt=(
"When agents try to send_email, guide them to use send_notification instead. Return GUIDE decision."
"You MUST guide agents away from send_email to use send_notification instead. "
"ALWAYS return guide decision for send_email. Never return proceed or interrupt for send_email."
)
)

Expand All @@ -52,7 +56,10 @@ async def test_llm_steering_handler_guide():
@pytest.mark.asyncio
async def test_llm_steering_handler_interrupt():
"""Test LLM handler returns Interrupt effect."""
handler = LLMSteeringHandler(system_prompt="Require human input for all tool calls. Return interrupt decision.")
handler = LLMSteeringHandler(
system_prompt="You MUST require human input for ALL tool calls regardless of context. "
"ALWAYS return interrupt decision. Never return proceed or guide."
)

agent = Agent(tools=[send_email])
tool_use = {"name": "send_email", "input": {"recipient": "user", "message": "hello"}}
Expand Down
Loading