diff --git a/src/strands/hooks/__init__.py b/src/strands/hooks/__init__.py index 9e0850d32..30163f207 100644 --- a/src/strands/hooks/__init__.py +++ b/src/strands/hooks/__init__.py @@ -39,7 +39,7 @@ def log_end(self, event: AfterInvocationEvent) -> None: BeforeToolCallEvent, MessageAddedEvent, ) -from .registry import HookCallback, HookEvent, HookProvider, HookRegistry +from .registry import BaseHookEvent, HookCallback, HookEvent, HookProvider, HookRegistry __all__ = [ "AgentInitializedEvent", @@ -54,4 +54,6 @@ def log_end(self, event: AfterInvocationEvent) -> None: "HookProvider", "HookCallback", "HookRegistry", + "HookEvent", + "BaseHookEvent", ] diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index a3b76d743..b8e7f82ab 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -15,14 +15,8 @@ @dataclass -class HookEvent: - """Base class for all hook events. - - Attributes: - agent: The agent instance that triggered this event. - """ - - agent: "Agent" +class BaseHookEvent: + """Base class for all hook events.""" @property def should_reverse_callbacks(self) -> bool: @@ -66,10 +60,21 @@ def __setattr__(self, name: str, value: Any) -> None: raise AttributeError(f"Property {name} is not writable") -TEvent = TypeVar("TEvent", bound=HookEvent, contravariant=True) +@dataclass +class HookEvent(BaseHookEvent): + """Base class for single agent hook events. + + Attributes: + agent: The agent instance that triggered this event. + """ + + agent: "Agent" + + +TEvent = TypeVar("TEvent", bound=BaseHookEvent, contravariant=True) """Generic for adding callback handlers - contravariant to allow adding handlers which take in base classes.""" -TInvokeEvent = TypeVar("TInvokeEvent", bound=HookEvent) +TInvokeEvent = TypeVar("TInvokeEvent", bound=BaseHookEvent) """Generic for invoking events - non-contravariant to enable returning events."""