diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 1d437e944..f2eed063c 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -15,13 +15,7 @@ from opentelemetry import trace as trace_api -from ..experimental.hooks import ( - AfterModelInvocationEvent, - BeforeModelInvocationEvent, -) -from ..hooks import ( - MessageAddedEvent, -) +from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, MessageAddedEvent from ..telemetry.metrics import Trace from ..telemetry.tracer import get_tracer from ..tools._validator import validate_and_prepare_tools @@ -133,7 +127,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> ) with trace_api.use_span(model_invoke_span): agent.hooks.invoke_callbacks( - BeforeModelInvocationEvent( + BeforeModelCallEvent( agent=agent, ) ) @@ -149,9 +143,9 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> invocation_state.setdefault("request_state", {}) agent.hooks.invoke_callbacks( - AfterModelInvocationEvent( + AfterModelCallEvent( agent=agent, - stop_response=AfterModelInvocationEvent.ModelStopResponse( + stop_response=AfterModelCallEvent.ModelStopResponse( stop_reason=stop_reason, message=message, ), @@ -170,7 +164,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> tracer.end_span_with_error(model_invoke_span, str(e), e) agent.hooks.invoke_callbacks( - AfterModelInvocationEvent( + AfterModelCallEvent( agent=agent, exception=e, ) diff --git a/src/strands/experimental/hooks/events.py b/src/strands/experimental/hooks/events.py index d03e65d85..d711dd7ed 100644 --- a/src/strands/experimental/hooks/events.py +++ b/src/strands/experimental/hooks/events.py @@ -3,121 +3,19 @@ This module defines the events that are emitted as Agents run through the lifecycle of a request. """ -from dataclasses import dataclass -from typing import Any, Optional - -from ...hooks import HookEvent -from ...types.content import Message -from ...types.streaming import StopReason -from ...types.tools import AgentTool, ToolResult, ToolUse - - -@dataclass -class BeforeToolInvocationEvent(HookEvent): - """Event triggered before a tool is invoked. - - This event is fired just before the agent executes a tool, allowing hook - providers to inspect, modify, or replace the tool that will be executed. - The selected_tool can be modified by hook callbacks to change which tool - gets executed. - - Attributes: - selected_tool: The tool that will be invoked. Can be modified by hooks - to change which tool gets executed. This may be None if tool lookup failed. - tool_use: The tool parameters that will be passed to selected_tool. - invocation_state: Keyword arguments that will be passed to the tool. - """ - - selected_tool: Optional[AgentTool] - tool_use: ToolUse - invocation_state: dict[str, Any] - - def _can_write(self, name: str) -> bool: - return name in ["selected_tool", "tool_use"] - - -@dataclass -class AfterToolInvocationEvent(HookEvent): - """Event triggered after a tool invocation completes. - - This event is fired after the agent has finished executing a tool, - regardless of whether the execution was successful or resulted in an error. - Hook providers can use this event for cleanup, logging, or post-processing. - - Note: This event uses reverse callback ordering, meaning callbacks registered - later will be invoked first during cleanup. - - Attributes: - selected_tool: The tool that was invoked. It may be None if tool lookup failed. - tool_use: The tool parameters that were passed to the tool invoked. - invocation_state: Keyword arguments that were passed to the tool - result: The result of the tool invocation. Either a ToolResult on success - or an Exception if the tool execution failed. - """ - - selected_tool: Optional[AgentTool] - tool_use: ToolUse - invocation_state: dict[str, Any] - result: ToolResult - exception: Optional[Exception] = None - - def _can_write(self, name: str) -> bool: - return name == "result" - - @property - def should_reverse_callbacks(self) -> bool: - """True to invoke callbacks in reverse order.""" - return True - - -@dataclass -class BeforeModelInvocationEvent(HookEvent): - """Event triggered before the model is invoked. - - This event is fired just before the agent calls the model for inference, - allowing hook providers to inspect or modify the messages and configuration - that will be sent to the model. - - Note: This event is not fired for invocations to structured_output. - """ - - pass - - -@dataclass -class AfterModelInvocationEvent(HookEvent): - """Event triggered after the model invocation completes. - - This event is fired after the agent has finished calling the model, - regardless of whether the invocation was successful or resulted in an error. - Hook providers can use this event for cleanup, logging, or post-processing. - - Note: This event uses reverse callback ordering, meaning callbacks registered - later will be invoked first during cleanup. - - Note: This event is not fired for invocations to structured_output. - - Attributes: - stop_response: The model response data if invocation was successful, None if failed. - exception: Exception if the model invocation failed, None if successful. - """ - - @dataclass - class ModelStopResponse: - """Model response data from successful invocation. - - Attributes: - stop_reason: The reason the model stopped generating. - message: The generated message from the model. - """ - - message: Message - stop_reason: StopReason - - stop_response: Optional[ModelStopResponse] = None - exception: Optional[Exception] = None - - @property - def should_reverse_callbacks(self) -> bool: - """True to invoke callbacks in reverse order.""" - return True +import warnings +from typing import TypeAlias + +from ...hooks.events import AfterModelCallEvent, AfterToolCallEvent, BeforeModelCallEvent, BeforeToolCallEvent + +warnings.warn( + "These events have been moved to production with updated names. Use BeforeModelCallEvent, " + "AfterModelCallEvent, BeforeToolCallEvent, and AfterToolCallEvent from strands.hooks instead.", + DeprecationWarning, + stacklevel=2, +) + +BeforeToolInvocationEvent: TypeAlias = BeforeToolCallEvent +AfterToolInvocationEvent: TypeAlias = AfterToolCallEvent +BeforeModelInvocationEvent: TypeAlias = BeforeModelCallEvent +AfterModelInvocationEvent: TypeAlias = AfterModelCallEvent diff --git a/src/strands/hooks/__init__.py b/src/strands/hooks/__init__.py index b98e95a6e..9e0850d32 100644 --- a/src/strands/hooks/__init__.py +++ b/src/strands/hooks/__init__.py @@ -31,8 +31,12 @@ def log_end(self, event: AfterInvocationEvent) -> None: from .events import ( AfterInvocationEvent, + AfterModelCallEvent, + AfterToolCallEvent, AgentInitializedEvent, BeforeInvocationEvent, + BeforeModelCallEvent, + BeforeToolCallEvent, MessageAddedEvent, ) from .registry import HookCallback, HookEvent, HookProvider, HookRegistry @@ -40,6 +44,10 @@ def log_end(self, event: AfterInvocationEvent) -> None: __all__ = [ "AgentInitializedEvent", "BeforeInvocationEvent", + "BeforeToolCallEvent", + "AfterToolCallEvent", + "BeforeModelCallEvent", + "AfterModelCallEvent", "AfterInvocationEvent", "MessageAddedEvent", "HookEvent", diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 42509dc9f..b3b2014f3 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -4,8 +4,11 @@ """ from dataclasses import dataclass +from typing import Any, Optional from ..types.content import Message +from ..types.streaming import StopReason +from ..types.tools import AgentTool, ToolResult, ToolUse from .registry import HookEvent @@ -78,3 +81,114 @@ class MessageAddedEvent(HookEvent): """ message: Message + + +@dataclass +class BeforeToolCallEvent(HookEvent): + """Event triggered before a tool is invoked. + + This event is fired just before the agent executes a tool, allowing hook + providers to inspect, modify, or replace the tool that will be executed. + The selected_tool can be modified by hook callbacks to change which tool + gets executed. + + Attributes: + selected_tool: The tool that will be invoked. Can be modified by hooks + to change which tool gets executed. This may be None if tool lookup failed. + tool_use: The tool parameters that will be passed to selected_tool. + invocation_state: Keyword arguments that will be passed to the tool. + """ + + selected_tool: Optional[AgentTool] + tool_use: ToolUse + invocation_state: dict[str, Any] + + def _can_write(self, name: str) -> bool: + return name in ["selected_tool", "tool_use"] + + +@dataclass +class AfterToolCallEvent(HookEvent): + """Event triggered after a tool invocation completes. + + This event is fired after the agent has finished executing a tool, + regardless of whether the execution was successful or resulted in an error. + Hook providers can use this event for cleanup, logging, or post-processing. + + Note: This event uses reverse callback ordering, meaning callbacks registered + later will be invoked first during cleanup. + + Attributes: + selected_tool: The tool that was invoked. It may be None if tool lookup failed. + tool_use: The tool parameters that were passed to the tool invoked. + invocation_state: Keyword arguments that were passed to the tool + result: The result of the tool invocation. Either a ToolResult on success + or an Exception if the tool execution failed. + """ + + selected_tool: Optional[AgentTool] + tool_use: ToolUse + invocation_state: dict[str, Any] + result: ToolResult + exception: Optional[Exception] = None + + def _can_write(self, name: str) -> bool: + return name == "result" + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True + + +@dataclass +class BeforeModelCallEvent(HookEvent): + """Event triggered before the model is invoked. + + This event is fired just before the agent calls the model for inference, + allowing hook providers to inspect or modify the messages and configuration + that will be sent to the model. + + Note: This event is not fired for invocations to structured_output. + """ + + pass + + +@dataclass +class AfterModelCallEvent(HookEvent): + """Event triggered after the model invocation completes. + + This event is fired after the agent has finished calling the model, + regardless of whether the invocation was successful or resulted in an error. + Hook providers can use this event for cleanup, logging, or post-processing. + + Note: This event uses reverse callback ordering, meaning callbacks registered + later will be invoked first during cleanup. + + Note: This event is not fired for invocations to structured_output. + + Attributes: + stop_response: The model response data if invocation was successful, None if failed. + exception: Exception if the model invocation failed, None if successful. + """ + + @dataclass + class ModelStopResponse: + """Model response data from successful invocation. + + Attributes: + stop_reason: The reason the model stopped generating. + message: The generated message from the model. + """ + + message: Message + stop_reason: StopReason + + stop_response: Optional[ModelStopResponse] = None + exception: Optional[Exception] = None + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True diff --git a/src/strands/hooks/rules.md b/src/strands/hooks/rules.md index a55a71fa3..4d0f571c6 100644 --- a/src/strands/hooks/rules.md +++ b/src/strands/hooks/rules.md @@ -9,6 +9,7 @@ - All hook events have a suffix of `Event` - Paired events follow the naming convention of `Before{Item}Event` and `After{Item}Event` +- Pre actions in the name. i.e. prefer `BeforeToolCallEvent` over `BeforeToolEvent`. ## Paired Events @@ -17,4 +18,4 @@ ## Writable Properties -For events with writable properties, those values are re-read after invoking the hook callbacks and used in subsequent processing. For example, `BeforeToolInvocationEvent.selected_tool` is writable - after invoking the callback for `BeforeToolInvocationEvent`, the `selected_tool` takes effect for the tool call. \ No newline at end of file +For events with writable properties, those values are re-read after invoking the hook callbacks and used in subsequent processing. For example, `BeforeToolEvent.selected_tool` is writable - after invoking the callback for `BeforeToolEvent`, the `selected_tool` takes effect for the tool call. \ No newline at end of file diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 5354991c3..2a75c48f2 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -11,7 +11,7 @@ from opentelemetry import trace as trace_api -from ...experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent +from ...hooks import AfterToolCallEvent, BeforeToolCallEvent from ...telemetry.metrics import Trace from ...telemetry.tracer import get_tracer from ...types._events import ToolResultEvent, ToolStreamEvent, TypedEvent @@ -73,7 +73,7 @@ async def _stream( ) before_event = agent.hooks.invoke_callbacks( - BeforeToolInvocationEvent( + BeforeToolCallEvent( agent=agent, selected_tool=tool_func, tool_use=tool_use, @@ -106,7 +106,7 @@ async def _stream( "content": [{"text": f"Unknown tool: {tool_name}"}], } after_event = agent.hooks.invoke_callbacks( - AfterToolInvocationEvent( + AfterToolCallEvent( agent=agent, selected_tool=selected_tool, tool_use=tool_use, @@ -137,7 +137,7 @@ async def _stream( result = cast(ToolResult, event) after_event = agent.hooks.invoke_callbacks( - AfterToolInvocationEvent( + AfterToolCallEvent( agent=agent, selected_tool=selected_tool, tool_use=tool_use, @@ -157,7 +157,7 @@ async def _stream( "content": [{"text": f"Error: {str(e)}"}], } after_event = agent.hooks.invoke_callbacks( - AfterToolInvocationEvent( + AfterToolCallEvent( agent=agent, selected_tool=selected_tool, tool_use=tool_use, diff --git a/tests/fixtures/mock_hook_provider.py b/tests/fixtures/mock_hook_provider.py index 6bf7b8c77..091f44d06 100644 --- a/tests/fixtures/mock_hook_provider.py +++ b/tests/fixtures/mock_hook_provider.py @@ -1,16 +1,14 @@ from typing import Iterator, Literal, Tuple, Type from strands import Agent -from strands.experimental.hooks import ( - AfterModelInvocationEvent, - AfterToolInvocationEvent, - BeforeModelInvocationEvent, - BeforeToolInvocationEvent, -) from strands.hooks import ( AfterInvocationEvent, + AfterModelCallEvent, + AfterToolCallEvent, AgentInitializedEvent, BeforeInvocationEvent, + BeforeModelCallEvent, + BeforeToolCallEvent, HookEvent, HookProvider, HookRegistry, @@ -25,10 +23,10 @@ def __init__(self, event_types: list[Type] | Literal["all"]): AgentInitializedEvent, BeforeInvocationEvent, AfterInvocationEvent, - AfterToolInvocationEvent, - BeforeToolInvocationEvent, - BeforeModelInvocationEvent, - AfterModelInvocationEvent, + BeforeToolCallEvent, + AfterToolCallEvent, + BeforeModelCallEvent, + AfterModelCallEvent, MessageAddedEvent, ] diff --git a/tests/strands/agent/hooks/__init__.py b/tests/strands/agent/hooks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/hooks/test_events.py b/tests/strands/agent/hooks/test_events.py similarity index 96% rename from tests/strands/experimental/hooks/test_events.py rename to tests/strands/agent/hooks/test_events.py index 231327732..8bbd89c17 100644 --- a/tests/strands/experimental/hooks/test_events.py +++ b/tests/strands/agent/hooks/test_events.py @@ -2,11 +2,12 @@ import pytest -from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent from strands.hooks import ( AfterInvocationEvent, + AfterToolCallEvent, AgentInitializedEvent, BeforeInvocationEvent, + BeforeToolCallEvent, MessageAddedEvent, ) from strands.types.tools import ToolResult, ToolUse @@ -61,7 +62,7 @@ def end_request_event(agent): @pytest.fixture def before_tool_event(agent, tool, tool_use, tool_invocation_state): - return BeforeToolInvocationEvent( + return BeforeToolCallEvent( agent=agent, selected_tool=tool, tool_use=tool_use, @@ -71,7 +72,7 @@ def before_tool_event(agent, tool, tool_use, tool_invocation_state): @pytest.fixture def after_tool_event(agent, tool, tool_use, tool_invocation_state, tool_result): - return AfterToolInvocationEvent( + return AfterToolCallEvent( agent=agent, selected_tool=tool, tool_use=tool_use, diff --git a/tests/strands/experimental/hooks/test_hook_registry.py b/tests/strands/agent/hooks/test_hook_registry.py similarity index 57% rename from tests/strands/experimental/hooks/test_hook_registry.py rename to tests/strands/agent/hooks/test_hook_registry.py index a61c0a1cb..680ded682 100644 --- a/tests/strands/experimental/hooks/test_hook_registry.py +++ b/tests/strands/agent/hooks/test_hook_registry.py @@ -9,20 +9,20 @@ @dataclass -class TestEvent(HookEvent): +class NormalTestEvent(HookEvent): @property def should_reverse_callbacks(self) -> bool: return False @dataclass -class TestAfterEvent(HookEvent): +class AfterTestEvent(HookEvent): @property def should_reverse_callbacks(self) -> bool: return True -class TestHookProvider(HookProvider): +class HookProviderForTests(HookProvider): """Test hook provider for testing hook registry.""" def __init__(self): @@ -38,13 +38,13 @@ def hook_registry(): @pytest.fixture -def test_event(): - return TestEvent(agent=Mock()) +def normal_event(): + return NormalTestEvent(agent=Mock()) @pytest.fixture -def test_after_event(): - return TestAfterEvent(agent=Mock()) +def after_event(): + return AfterTestEvent(agent=Mock()) def test_hook_registry_init(): @@ -53,26 +53,26 @@ def test_hook_registry_init(): assert registry._registered_callbacks == {} -def test_add_callback(hook_registry, test_event): +def test_add_callback(hook_registry, normal_event): """Test that callbacks can be added to the registry.""" callback = unittest.mock.Mock() - hook_registry.add_callback(TestEvent, callback) + hook_registry.add_callback(NormalTestEvent, callback) - assert TestEvent in hook_registry._registered_callbacks - assert callback in hook_registry._registered_callbacks[TestEvent] + assert NormalTestEvent in hook_registry._registered_callbacks + assert callback in hook_registry._registered_callbacks[NormalTestEvent] -def test_add_multiple_callbacks_same_event(hook_registry, test_event): +def test_add_multiple_callbacks_same_event(hook_registry, normal_event): """Test that multiple callbacks can be added for the same event type.""" callback1 = unittest.mock.Mock() callback2 = unittest.mock.Mock() - hook_registry.add_callback(TestEvent, callback1) - hook_registry.add_callback(TestEvent, callback2) + hook_registry.add_callback(NormalTestEvent, callback1) + hook_registry.add_callback(NormalTestEvent, callback2) - assert len(hook_registry._registered_callbacks[TestEvent]) == 2 - assert callback1 in hook_registry._registered_callbacks[TestEvent] - assert callback2 in hook_registry._registered_callbacks[TestEvent] + assert len(hook_registry._registered_callbacks[NormalTestEvent]) == 2 + assert callback1 in hook_registry._registered_callbacks[NormalTestEvent] + assert callback2 in hook_registry._registered_callbacks[NormalTestEvent] def test_add_hook(hook_registry): @@ -83,58 +83,58 @@ def test_add_hook(hook_registry): assert hook_provider.register_hooks.call_count == 1 -def test_get_callbacks_for_normal_event(hook_registry, test_event): +def test_get_callbacks_for_normal_event(hook_registry, normal_event): """Test that get_callbacks_for returns callbacks in the correct order for normal events.""" callback1 = unittest.mock.Mock() callback2 = unittest.mock.Mock() - hook_registry.add_callback(TestEvent, callback1) - hook_registry.add_callback(TestEvent, callback2) + hook_registry.add_callback(NormalTestEvent, callback1) + hook_registry.add_callback(NormalTestEvent, callback2) - callbacks = list(hook_registry.get_callbacks_for(test_event)) + callbacks = list(hook_registry.get_callbacks_for(normal_event)) assert len(callbacks) == 2 assert callbacks[0] == callback1 assert callbacks[1] == callback2 -def test_get_callbacks_for_after_event(hook_registry, test_after_event): +def test_get_callbacks_for_after_event(hook_registry, after_event): """Test that get_callbacks_for returns callbacks in reverse order for after events.""" callback1 = Mock() callback2 = Mock() - hook_registry.add_callback(TestAfterEvent, callback1) - hook_registry.add_callback(TestAfterEvent, callback2) + hook_registry.add_callback(AfterTestEvent, callback1) + hook_registry.add_callback(AfterTestEvent, callback2) - callbacks = list(hook_registry.get_callbacks_for(test_after_event)) + callbacks = list(hook_registry.get_callbacks_for(after_event)) assert len(callbacks) == 2 assert callbacks[0] == callback2 # Reverse order assert callbacks[1] == callback1 # Reverse order -def test_invoke_callbacks(hook_registry, test_event): +def test_invoke_callbacks(hook_registry, normal_event): """Test that invoke_callbacks calls all registered callbacks for an event.""" callback1 = Mock() callback2 = Mock() - hook_registry.add_callback(TestEvent, callback1) - hook_registry.add_callback(TestEvent, callback2) + hook_registry.add_callback(NormalTestEvent, callback1) + hook_registry.add_callback(NormalTestEvent, callback2) - hook_registry.invoke_callbacks(test_event) + hook_registry.invoke_callbacks(normal_event) - callback1.assert_called_once_with(test_event) - callback2.assert_called_once_with(test_event) + callback1.assert_called_once_with(normal_event) + callback2.assert_called_once_with(normal_event) -def test_invoke_callbacks_no_registered_callbacks(hook_registry, test_event): +def test_invoke_callbacks_no_registered_callbacks(hook_registry, normal_event): """Test that invoke_callbacks doesn't fail when there are no registered callbacks.""" # No callbacks registered - hook_registry.invoke_callbacks(test_event) + hook_registry.invoke_callbacks(normal_event) # Test passes if no exception is raised -def test_invoke_callbacks_after_event(hook_registry, test_after_event): +def test_invoke_callbacks_after_event(hook_registry, after_event): """Test that invoke_callbacks calls callbacks in reverse order for after events.""" call_order: List[str] = [] @@ -144,24 +144,24 @@ def callback1(_event): def callback2(_event): call_order.append("callback2") - hook_registry.add_callback(TestAfterEvent, callback1) - hook_registry.add_callback(TestAfterEvent, callback2) + hook_registry.add_callback(AfterTestEvent, callback1) + hook_registry.add_callback(AfterTestEvent, callback2) - hook_registry.invoke_callbacks(test_after_event) + hook_registry.invoke_callbacks(after_event) assert call_order == ["callback2", "callback1"] # Reverse order -def test_has_callbacks(hook_registry, test_event): +def test_has_callbacks(hook_registry, normal_event): """Test that has_callbacks returns correct boolean values.""" # Empty registry should return False assert not hook_registry.has_callbacks() # Registry with callbacks should return True callback = Mock() - hook_registry.add_callback(TestEvent, callback) + hook_registry.add_callback(NormalTestEvent, callback) assert hook_registry.has_callbacks() # Test with multiple event types - hook_registry.add_callback(TestAfterEvent, Mock()) + hook_registry.add_callback(AfterTestEvent, Mock()) assert hook_registry.has_callbacks() diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 9ab008ca2..6c5625e0b 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -5,16 +5,14 @@ import strands from strands import Agent -from strands.experimental.hooks import ( - AfterModelInvocationEvent, - AfterToolInvocationEvent, - BeforeModelInvocationEvent, - BeforeToolInvocationEvent, -) from strands.hooks import ( AfterInvocationEvent, + AfterModelCallEvent, + AfterToolCallEvent, AgentInitializedEvent, BeforeInvocationEvent, + BeforeModelCallEvent, + BeforeToolCallEvent, MessageAddedEvent, ) from strands.types.content import Messages @@ -30,10 +28,10 @@ def hook_provider(): AgentInitializedEvent, BeforeInvocationEvent, AfterInvocationEvent, - AfterToolInvocationEvent, - BeforeToolInvocationEvent, - BeforeModelInvocationEvent, - AfterModelInvocationEvent, + AfterToolCallEvent, + BeforeToolCallEvent, + BeforeModelCallEvent, + AfterModelCallEvent, MessageAddedEvent, ] ) @@ -125,10 +123,10 @@ def test_agent_tool_call(agent, hook_provider, agent_tool): assert length == 6 - assert next(events) == BeforeToolInvocationEvent( + assert next(events) == BeforeToolCallEvent( agent=agent, selected_tool=agent_tool, tool_use=tool_use, invocation_state=ANY ) - assert next(events) == AfterToolInvocationEvent( + assert next(events) == AfterToolCallEvent( agent=agent, selected_tool=agent_tool, tool_use=tool_use, @@ -157,10 +155,10 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u agent=agent, message=agent.messages[0], ) - assert next(events) == BeforeModelInvocationEvent(agent=agent) - assert next(events) == AfterModelInvocationEvent( + assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == AfterModelCallEvent( agent=agent, - stop_response=AfterModelInvocationEvent.ModelStopResponse( + stop_response=AfterModelCallEvent.ModelStopResponse( message={ "content": [{"toolUse": tool_use}], "role": "assistant", @@ -171,10 +169,10 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1]) - assert next(events) == BeforeToolInvocationEvent( + assert next(events) == BeforeToolCallEvent( agent=agent, selected_tool=agent_tool, tool_use=tool_use, invocation_state=ANY ) - assert next(events) == AfterToolInvocationEvent( + assert next(events) == AfterToolCallEvent( agent=agent, selected_tool=agent_tool, tool_use=tool_use, @@ -182,10 +180,10 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"}, ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2]) - assert next(events) == BeforeModelInvocationEvent(agent=agent) - assert next(events) == AfterModelInvocationEvent( + assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == AfterModelCallEvent( agent=agent, - stop_response=AfterModelInvocationEvent.ModelStopResponse( + stop_response=AfterModelCallEvent.ModelStopResponse( message=mock_model.agent_responses[1], stop_reason="end_turn", ), @@ -218,10 +216,10 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m agent=agent, message=agent.messages[0], ) - assert next(events) == BeforeModelInvocationEvent(agent=agent) - assert next(events) == AfterModelInvocationEvent( + assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == AfterModelCallEvent( agent=agent, - stop_response=AfterModelInvocationEvent.ModelStopResponse( + stop_response=AfterModelCallEvent.ModelStopResponse( message={ "content": [{"toolUse": tool_use}], "role": "assistant", @@ -232,10 +230,10 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1]) - assert next(events) == BeforeToolInvocationEvent( + assert next(events) == BeforeToolCallEvent( agent=agent, selected_tool=agent_tool, tool_use=tool_use, invocation_state=ANY ) - assert next(events) == AfterToolInvocationEvent( + assert next(events) == AfterToolCallEvent( agent=agent, selected_tool=agent_tool, tool_use=tool_use, @@ -243,10 +241,10 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"}, ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2]) - assert next(events) == BeforeModelInvocationEvent(agent=agent) - assert next(events) == AfterModelInvocationEvent( + assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == AfterModelCallEvent( agent=agent, - stop_response=AfterModelInvocationEvent.ModelStopResponse( + stop_response=AfterModelCallEvent.ModelStopResponse( message=mock_model.agent_responses[1], stop_reason="end_turn", ), diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 9d9e20863..2b71f3502 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -6,13 +6,12 @@ import strands import strands.telemetry -from strands.experimental.hooks import ( - AfterModelInvocationEvent, - AfterToolInvocationEvent, - BeforeModelInvocationEvent, - BeforeToolInvocationEvent, +from strands.hooks import ( + AfterModelCallEvent, + BeforeModelCallEvent, + HookRegistry, + MessageAddedEvent, ) -from strands.hooks import HookRegistry from strands.telemetry.metrics import EventLoopMetrics from strands.tools.executors import SequentialToolExecutor from strands.tools.registry import ToolRegistry @@ -117,14 +116,7 @@ def hook_registry(): @pytest.fixture def hook_provider(hook_registry): - provider = MockHookProvider( - event_types=[ - BeforeToolInvocationEvent, - AfterToolInvocationEvent, - BeforeModelInvocationEvent, - AfterModelInvocationEvent, - ] - ) + provider = MockHookProvider(event_types="all") hook_registry.add_hook(provider) return provider @@ -842,26 +834,31 @@ async def test_event_loop_cycle_exception_model_hooks(mock_sleep, agent, model, count, events = hook_provider.get_events() - assert count == 8 + assert count == 9 # 1st call - throttled - assert next(events) == BeforeModelInvocationEvent(agent=agent) - assert next(events) == AfterModelInvocationEvent(agent=agent, stop_response=None, exception=exception) + assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) # 2nd call - throttled - assert next(events) == BeforeModelInvocationEvent(agent=agent) - assert next(events) == AfterModelInvocationEvent(agent=agent, stop_response=None, exception=exception) + assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) # 3rd call - throttled - assert next(events) == BeforeModelInvocationEvent(agent=agent) - assert next(events) == AfterModelInvocationEvent(agent=agent, stop_response=None, exception=exception) + assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) # 4th call - successful - assert next(events) == BeforeModelInvocationEvent(agent=agent) - assert next(events) == AfterModelInvocationEvent( + assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == AfterModelCallEvent( agent=agent, - stop_response=AfterModelInvocationEvent.ModelStopResponse( + stop_response=AfterModelCallEvent.ModelStopResponse( message={"content": [{"text": "test text"}], "role": "assistant"}, stop_reason="end_turn" ), exception=None, ) + + # Final message + assert next(events) == MessageAddedEvent( + agent=agent, message={"content": [{"text": "test text"}], "role": "assistant"} + ) diff --git a/tests/strands/experimental/hooks/test_hook_aliases.py b/tests/strands/experimental/hooks/test_hook_aliases.py new file mode 100644 index 000000000..db9cd3783 --- /dev/null +++ b/tests/strands/experimental/hooks/test_hook_aliases.py @@ -0,0 +1,135 @@ +"""Tests to verify that experimental hook aliases work interchangeably with real types. + +This test module ensures that the experimental hook event aliases maintain +backwards compatibility and can be used interchangeably with the actual +hook event types. +""" + +import importlib +import sys +from unittest.mock import Mock + +from strands.experimental.hooks import ( + AfterModelInvocationEvent, + AfterToolInvocationEvent, + BeforeModelInvocationEvent, + BeforeToolInvocationEvent, +) +from strands.hooks import ( + AfterModelCallEvent, + AfterToolCallEvent, + BeforeModelCallEvent, + BeforeToolCallEvent, + HookRegistry, +) + + +def test_experimental_aliases_are_same_types(): + """Verify that experimental aliases are identical to the actual types.""" + assert BeforeToolInvocationEvent is BeforeToolCallEvent + assert AfterToolInvocationEvent is AfterToolCallEvent + assert BeforeModelInvocationEvent is BeforeModelCallEvent + assert AfterModelInvocationEvent is AfterModelCallEvent + + assert BeforeToolCallEvent is BeforeToolInvocationEvent + assert AfterToolCallEvent is AfterToolInvocationEvent + assert BeforeModelCallEvent is BeforeModelInvocationEvent + assert AfterModelCallEvent is AfterModelInvocationEvent + + +def test_before_tool_call_event_type_equality(): + """Verify that BeforeToolInvocationEvent alias has the same type identity.""" + before_tool_event = BeforeToolCallEvent( + agent=Mock(), + selected_tool=Mock(), + tool_use={"name": "test", "toolUseId": "123", "input": {}}, + invocation_state={}, + ) + + assert isinstance(before_tool_event, BeforeToolInvocationEvent) + assert isinstance(before_tool_event, BeforeToolCallEvent) + + +def test_after_tool_call_event_type_equality(): + """Verify that AfterToolInvocationEvent alias has the same type identity.""" + after_tool_event = AfterToolCallEvent( + agent=Mock(), + selected_tool=Mock(), + tool_use={"name": "test", "toolUseId": "123", "input": {}}, + invocation_state={}, + result={"toolUseId": "123", "status": "success", "content": [{"text": "result"}]}, + ) + + assert isinstance(after_tool_event, AfterToolInvocationEvent) + assert isinstance(after_tool_event, AfterToolCallEvent) + + +def test_before_model_call_event_type_equality(): + """Verify that BeforeModelInvocationEvent alias has the same type identity.""" + before_model_event = BeforeModelCallEvent(agent=Mock()) + + assert isinstance(before_model_event, BeforeModelInvocationEvent) + assert isinstance(before_model_event, BeforeModelCallEvent) + + +def test_after_model_call_event_type_equality(): + """Verify that AfterModelInvocationEvent alias has the same type identity.""" + after_model_event = AfterModelCallEvent(agent=Mock()) + + assert isinstance(after_model_event, AfterModelInvocationEvent) + assert isinstance(after_model_event, AfterModelCallEvent) + + +def test_experimental_aliases_in_hook_registry(): + """Verify that experimental aliases work with hook registry callbacks.""" + hook_registry = HookRegistry() + callback_called = False + received_event = None + + def experimental_callback(event: BeforeToolInvocationEvent): + nonlocal callback_called, received_event + callback_called = True + received_event = event + + # Register callback using experimental alias + hook_registry.add_callback(BeforeToolInvocationEvent, experimental_callback) + + # Create event using actual type + test_event = BeforeToolCallEvent( + agent=Mock(), + selected_tool=Mock(), + tool_use={"name": "test", "toolUseId": "123", "input": {}}, + invocation_state={}, + ) + + # Invoke callbacks - should work since alias points to same type + hook_registry.invoke_callbacks(test_event) + + assert callback_called + assert received_event is test_event + + +def test_deprecation_warning_on_import(captured_warnings): + """Verify that importing from experimental module emits deprecation warning.""" + + module = sys.modules.get("strands.experimental.hooks.events") + if module: + importlib.reload(module) + else: + importlib.import_module("strands.experimental.hooks.events") + + assert len(captured_warnings) == 1 + assert issubclass(captured_warnings[0].category, DeprecationWarning) + assert "moved to production with updated names" in str(captured_warnings[0].message) + + +def test_deprecation_warning_on_import_only_for_experimental(captured_warnings): + """Verify that importing from experimental module emits deprecation warning.""" + # Re-import the module to trigger the warning + module = sys.modules.get("strands.hooks") + if module: + importlib.reload(module) + else: + importlib.import_module("strands.hooks") + + assert len(captured_warnings) == 0 diff --git a/tests/strands/tools/executors/conftest.py b/tests/strands/tools/executors/conftest.py index 1576b7578..be90226f6 100644 --- a/tests/strands/tools/executors/conftest.py +++ b/tests/strands/tools/executors/conftest.py @@ -4,8 +4,7 @@ import pytest import strands -from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent -from strands.hooks import HookRegistry +from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent, HookRegistry from strands.tools.registry import ToolRegistry @@ -26,8 +25,8 @@ def callback(event): @pytest.fixture def hook_registry(tool_hook): registry = HookRegistry() - registry.add_callback(BeforeToolInvocationEvent, tool_hook) - registry.add_callback(AfterToolInvocationEvent, tool_hook) + registry.add_callback(BeforeToolCallEvent, tool_hook) + registry.add_callback(AfterToolCallEvent, tool_hook) return registry diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 903a11e5a..3bbedb477 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -4,7 +4,7 @@ import pytest import strands -from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent +from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent from strands.telemetry.metrics import Trace from strands.tools.executors._executor import ToolExecutor from strands.types._events import ToolResultEvent, ToolStreamEvent @@ -50,13 +50,13 @@ async def test_executor_stream_yields_result( tru_hook_events = hook_events exp_hook_events = [ - BeforeToolInvocationEvent( + BeforeToolCallEvent( agent=agent, selected_tool=weather_tool, tool_use=tool_use, invocation_state=invocation_state, ), - AfterToolInvocationEvent( + AfterToolCallEvent( agent=agent, selected_tool=weather_tool, tool_use=tool_use, @@ -153,7 +153,7 @@ async def test_executor_stream_yields_tool_error( assert tru_results == exp_results tru_hook_after_event = hook_events[-1] - exp_hook_after_event = AfterToolInvocationEvent( + exp_hook_after_event = AfterToolCallEvent( agent=agent, selected_tool=exception_tool, tool_use=tool_use, @@ -180,7 +180,7 @@ async def test_executor_stream_yields_unknown_tool(executor, agent, tool_results assert tru_results == exp_results tru_hook_after_event = hook_events[-1] - exp_hook_after_event = AfterToolInvocationEvent( + exp_hook_after_event = AfterToolCallEvent( agent=agent, selected_tool=None, tool_use=tool_use, diff --git a/tests_integ/mcp/test_mcp_client_structured_content_with_hooks.py b/tests_integ/mcp/test_mcp_client_structured_content_with_hooks.py index b671184d9..ef4993b05 100644 --- a/tests_integ/mcp/test_mcp_client_structured_content_with_hooks.py +++ b/tests_integ/mcp/test_mcp_client_structured_content_with_hooks.py @@ -9,8 +9,7 @@ from mcp import StdioServerParameters, stdio_client from strands import Agent -from strands.experimental.hooks import AfterToolInvocationEvent -from strands.hooks import HookProvider, HookRegistry +from strands.hooks import AfterToolCallEvent, HookProvider, HookRegistry from strands.tools.mcp.mcp_client import MCPClient @@ -22,9 +21,9 @@ def __init__(self): def register_hooks(self, registry: HookRegistry) -> None: """Register callback for after tool invocation events.""" - registry.add_callback(AfterToolInvocationEvent, self.on_after_tool_invocation) + registry.add_callback(AfterToolCallEvent, self.on_after_tool_invocation) - def on_after_tool_invocation(self, event: AfterToolInvocationEvent) -> None: + def on_after_tool_invocation(self, event: AfterToolCallEvent) -> None: """Capture structured content tool results.""" if event.tool_use["name"] == "echo_with_structured_content": self.captured_result = event.result diff --git a/tests_integ/test_multiagent_graph.py b/tests_integ/test_multiagent_graph.py index bc9b0ea8b..c2c13c443 100644 --- a/tests_integ/test_multiagent_graph.py +++ b/tests_integ/test_multiagent_graph.py @@ -1,8 +1,14 @@ import pytest from strands import Agent, tool -from strands.experimental.hooks import AfterModelInvocationEvent, BeforeModelInvocationEvent -from strands.hooks import AfterInvocationEvent, AgentInitializedEvent, BeforeInvocationEvent, MessageAddedEvent +from strands.hooks import ( + AfterInvocationEvent, + AfterModelCallEvent, + AgentInitializedEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, + MessageAddedEvent, +) from strands.multiagent.graph import GraphBuilder from strands.types.content import ContentBlock from tests.fixtures.mock_hook_provider import MockHookProvider @@ -204,8 +210,8 @@ async def test_graph_execution_with_image(image_analysis_agent, summary_agent, y AgentInitializedEvent, BeforeInvocationEvent, MessageAddedEvent, - BeforeModelInvocationEvent, - AfterModelInvocationEvent, + BeforeModelCallEvent, + AfterModelCallEvent, MessageAddedEvent, AfterInvocationEvent, ] diff --git a/tests_integ/test_multiagent_swarm.py b/tests_integ/test_multiagent_swarm.py index 76860f687..9a8c79bf8 100644 --- a/tests_integ/test_multiagent_swarm.py +++ b/tests_integ/test_multiagent_swarm.py @@ -1,13 +1,15 @@ import pytest from strands import Agent, tool -from strands.experimental.hooks import ( - AfterModelInvocationEvent, - AfterToolInvocationEvent, - BeforeModelInvocationEvent, - BeforeToolInvocationEvent, +from strands.hooks import ( + AfterInvocationEvent, + AfterModelCallEvent, + AfterToolCallEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, + BeforeToolCallEvent, + MessageAddedEvent, ) -from strands.hooks import AfterInvocationEvent, BeforeInvocationEvent, MessageAddedEvent from strands.multiagent.swarm import Swarm from strands.types.content import ContentBlock from tests.fixtures.mock_hook_provider import MockHookProvider @@ -102,10 +104,10 @@ def test_swarm_execution_with_string(researcher_agent, analyst_agent, writer_age researcher_hooks = hook_provider.extract_for(researcher_agent).event_types_received assert BeforeInvocationEvent in researcher_hooks assert MessageAddedEvent in researcher_hooks - assert BeforeModelInvocationEvent in researcher_hooks - assert BeforeToolInvocationEvent in researcher_hooks - assert AfterToolInvocationEvent in researcher_hooks - assert AfterModelInvocationEvent in researcher_hooks + assert BeforeModelCallEvent in researcher_hooks + assert BeforeToolCallEvent in researcher_hooks + assert AfterToolCallEvent in researcher_hooks + assert AfterModelCallEvent in researcher_hooks assert AfterInvocationEvent in researcher_hooks