Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/strands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from . import agent, models, telemetry, types
from .agent.agent import Agent
from .tools.decorator import tool
from .types.tools import StrandsContext

__all__ = ["Agent", "agent", "models", "tool", "types", "telemetry"]
__all__ = ["Agent", "agent", "models", "tool", "types", "telemetry", "StrandsContext"]
11 changes: 7 additions & 4 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from ..tools.watcher import ToolWatcher
from ..types.content import ContentBlock, Message, Messages
from ..types.exceptions import ContextWindowOverflowException
from ..types.invocation import InvocationState
from ..types.tools import ToolResult, ToolUse
from ..types.traces import AttributeValue
from .agent_result import AgentResult
Expand Down Expand Up @@ -138,7 +139,7 @@ def caller(

async def acall() -> ToolResult:
# Pass kwargs as invocation_state
async for event in run_tool(self._agent, tool_use, kwargs):
async for event in run_tool(self._agent, tool_use, cast(InvocationState, kwargs)):
_ = event

return cast(ToolResult, event)
Expand Down Expand Up @@ -506,7 +507,7 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
self.trace_span = self._start_agent_trace_span(message)
with trace_api.use_span(self.trace_span):
try:
events = self._run_loop(message, invocation_state=kwargs)
events = self._run_loop(message, invocation_state=cast(InvocationState, kwargs))
async for event in events:
if "callback" in event:
callback_handler(**event["callback"])
Expand All @@ -523,7 +524,7 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
raise

async def _run_loop(
self, message: Message, invocation_state: dict[str, Any]
self, message: Message, invocation_state: InvocationState
) -> AsyncGenerator[dict[str, Any], None]:
"""Execute the agent's event loop with the given message and parameters.

Expand Down Expand Up @@ -563,7 +564,9 @@ async def _run_loop(
self.conversation_manager.apply_management(self)
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))

async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
async def _execute_event_loop_cycle(
self, invocation_state: InvocationState
) -> AsyncGenerator[dict[str, Any], None]:
"""Execute the event loop cycle with retry logic for context window limits.

This internal method handles the execution of the event loop cycle and implements
Expand Down
9 changes: 5 additions & 4 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
MaxTokensReachedException,
ModelThrottledException,
)
from ..types.invocation import InvocationState
from ..types.streaming import Metrics, StopReason
from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse
from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached
Expand All @@ -49,7 +50,7 @@
MAX_DELAY = 240 # 4 minutes


async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
async def event_loop_cycle(agent: "Agent", invocation_state: InvocationState) -> AsyncGenerator[dict[str, Any], None]:
"""Execute a single cycle of the event loop.

This core function processes a single conversation turn, handling model inference, tool execution, and error
Expand Down Expand Up @@ -273,7 +274,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
yield {"stop": (stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])}


async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
async def recurse_event_loop(agent: "Agent", invocation_state: InvocationState) -> AsyncGenerator[dict[str, Any], None]:
"""Make a recursive call to event_loop_cycle with the current state.

This function is used when the event loop needs to continue processing after tool execution.
Expand Down Expand Up @@ -306,7 +307,7 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -
recursive_trace.end()


async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str, Any]) -> ToolGenerator:
async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: InvocationState) -> ToolGenerator:
"""Process a tool invocation.

Looks up the tool in the registry and streams it with the provided parameters.
Expand Down Expand Up @@ -429,7 +430,7 @@ async def _handle_tool_execution(
cycle_trace: Trace,
cycle_span: Any,
cycle_start_time: float,
invocation_state: dict[str, Any],
invocation_state: InvocationState,
) -> AsyncGenerator[dict[str, Any], None]:
tool_uses: list[ToolUse] = []
tool_results: list[ToolResult] = []
Expand Down
7 changes: 4 additions & 3 deletions src/strands/experimental/hooks/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
"""

from dataclasses import dataclass
from typing import Any, Optional
from typing import Optional

from ...hooks import HookEvent
from ...types.content import Message
from ...types.invocation import InvocationState
from ...types.streaming import StopReason
from ...types.tools import AgentTool, ToolResult, ToolUse

Expand All @@ -30,7 +31,7 @@ class BeforeToolInvocationEvent(HookEvent):

selected_tool: Optional[AgentTool]
tool_use: ToolUse
invocation_state: dict[str, Any]
invocation_state: InvocationState

def _can_write(self, name: str) -> bool:
return name in ["selected_tool", "tool_use"]
Expand All @@ -57,7 +58,7 @@ class AfterToolInvocationEvent(HookEvent):

selected_tool: Optional[AgentTool]
tool_use: ToolUse
invocation_state: dict[str, Any]
invocation_state: InvocationState
result: ToolResult
exception: Optional[Exception] = None

Expand Down
3 changes: 2 additions & 1 deletion src/strands/telemetry/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from ..agent.agent_result import AgentResult
from ..types.content import ContentBlock, Message, Messages
from ..types.invocation import InvocationState
from ..types.streaming import StopReason, Usage
from ..types.tools import ToolResult, ToolUse
from ..types.traces import AttributeValue
Expand Down Expand Up @@ -343,7 +344,7 @@ def end_tool_call_span(

def start_event_loop_cycle_span(
self,
invocation_state: Any,
invocation_state: InvocationState,
messages: Messages,
parent_span: Optional[Span] = None,
**kwargs: Any,
Expand Down
57 changes: 49 additions & 8 deletions src/strands/tools/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def my_tool(param1: str, param2: int = 42) -> dict:
from pydantic import BaseModel, Field, create_model
from typing_extensions import override

from ..types.tools import AgentTool, JSONSchema, ToolGenerator, ToolSpec, ToolUse
from ..types.invocation import InvocationState
from ..types.tools import AgentTool, JSONSchema, StrandsContext, ToolGenerator, ToolSpec, ToolUse

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -113,16 +114,16 @@ def _create_input_model(self) -> Type[BaseModel]:
This method analyzes the function's signature, type hints, and docstring to create a Pydantic model that can
validate input data before passing it to the function.

Special parameters like 'self', 'cls', and 'agent' are excluded from the model.
Special parameters that can be automatically injected are excluded from the model.

Returns:
A Pydantic BaseModel class customized for the function's parameters.
"""
field_definitions: dict[str, Any] = {}

for name, param in self.signature.parameters.items():
# Skip special parameters
if name in ("self", "cls", "agent"):
# Skip parameters that will be automatically injected
if self._is_special_parameter(name):
continue

# Get parameter type and default
Expand Down Expand Up @@ -252,6 +253,47 @@ def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]:
error_msg = str(e)
raise ValueError(f"Validation failed for input parameters: {error_msg}") from e

def inject_special_parameters(
self, validated_input: dict[str, Any], tool_use: ToolUse, invocation_state: InvocationState
) -> None:
"""Inject special framework-provided parameters into the validated input.

This method automatically provides framework-level context to tools that request it
through their function signature.

Args:
validated_input: The validated input parameters (modified in place).
tool_use: The tool use request containing tool invocation details.
invocation_state: Context for the tool invocation, including agent state.
"""
# Inject StrandsContext if requested
if "strands_context" in self.signature.parameters:
strands_context: StrandsContext = {
"tool_use": tool_use,
"invocation_state": invocation_state,
}
validated_input["strands_context"] = strands_context

# Inject agent if requested (backward compatibility)
if "agent" in self.signature.parameters and "agent" in invocation_state:
validated_input["agent"] = invocation_state["agent"]

def _is_special_parameter(self, param_name: str) -> bool:
"""Check if a parameter should be automatically injected by the framework.

Special parameters include:
- Standard Python parameters: self, cls
- Framework-provided context parameters: agent, strands_context

Args:
param_name: The name of the parameter to check.

Returns:
True if the parameter should be excluded from input validation and
automatically injected during tool execution.
"""
return param_name in {"self", "cls", "agent", "strands_context"}


P = ParamSpec("P") # Captures all parameters
R = TypeVar("R") # Return type
Expand Down Expand Up @@ -372,7 +414,7 @@ def tool_type(self) -> str:
return "function"

@override
async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator:
async def stream(self, tool_use: ToolUse, invocation_state: InvocationState, **kwargs: Any) -> ToolGenerator:
"""Stream the tool with a tool use specification.

This method handles tool use streams from a Strands Agent. It validates the input,
Expand Down Expand Up @@ -402,9 +444,8 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
# Validate input against the Pydantic model
validated_input = self._metadata.validate_input(tool_input)

# Pass along the agent if provided and expected by the function
if "agent" in invocation_state and "agent" in self._metadata.signature.parameters:
validated_input["agent"] = invocation_state.get("agent")
# Inject special framework-provided parameters
self._metadata.inject_special_parameters(validated_input, tool_use, invocation_state)

# "Too few arguments" expected, hence the type ignore
if inspect.iscoroutinefunction(self._tool_func):
Expand Down
3 changes: 2 additions & 1 deletion src/strands/tools/mcp/mcp_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mcp.types import Tool as MCPTool
from typing_extensions import override

from ...types.invocation import InvocationState
from ...types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse

if TYPE_CHECKING:
Expand Down Expand Up @@ -75,7 +76,7 @@ def tool_type(self) -> str:
return "python"

@override
async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator:
async def stream(self, tool_use: ToolUse, invocation_state: InvocationState, **kwargs: Any) -> ToolGenerator:
"""Stream the MCP tool.

This method delegates the tool stream to the MCP server connection, passing the tool use ID, tool name, and
Expand Down
7 changes: 5 additions & 2 deletions src/strands/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from typing_extensions import override

from ..types.invocation import InvocationState
from ..types.tools import AgentTool, ToolFunc, ToolGenerator, ToolSpec, ToolUse

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -198,7 +199,7 @@ def tool_type(self) -> str:
return "python"

@override
async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator:
async def stream(self, tool_use: ToolUse, invocation_state: InvocationState, **kwargs: Any) -> ToolGenerator:
"""Stream the Python function with the given tool use request.

Args:
Expand All @@ -210,7 +211,9 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
Tool events with the last being the tool result.
"""
if inspect.iscoroutinefunction(self._tool_func):
result = await self._tool_func(tool_use, **invocation_state)
result = await self._tool_func(
tool_use, **invocation_state
) # this will fail if invocation state and kwargs overlap
else:
result = await asyncio.to_thread(self._tool_func, tool_use, **invocation_state)

Expand Down
62 changes: 62 additions & 0 deletions src/strands/types/invocation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Types for agent invocation state and context."""

from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict
from uuid import UUID

from opentelemetry.trace import Span

if TYPE_CHECKING:
from ..agent import Agent
from ..models.model import Model
from ..telemetry import Trace
from ..types.content import Message
from ..types.tools import ToolConfig


class InvocationState(TypedDict, total=False):
"""Type definition for invocation_state used throughout the agent framework.

This TypedDict defines the structure of the invocation_state dictionary that is
passed through the agent's event loop and tool execution pipeline. All fields
are optional since invocation_state is built incrementally during execution.

Core Framework Fields:
agent: The Agent instance executing the invocation (added for backward compatibility).
event_loop_cycle_id: Unique identifier for the current event loop cycle.
event_loop_parent_cycle_id: id of the parent cycle for recursive calls.
request_state: State dictionary maintained across event loop cycles.
event_loop_cycle_trace: Trace object for monitoring the current cycle.
event_loop_cycle_span: Span object for distributed tracing.
event_loop_parent_span: Parent span for tracing hierarchy.

Agent Context Fields:
model: The model instance used by the agent for inference.
system_prompt: The system prompt used to guide the agent's behavior.
messages: The conversation history as a list of messages.
tool_config: Configuration for tools available to the agent.

Additional Fields:
Any additional keyword arguments passed during agent invocation or added
by hooks and tools during execution are also included in this state.
"""

# Core agent reference
agent: "Agent" # Forward reference to avoid circular imports

# Event loop cycle management
event_loop_cycle_id: UUID
event_loop_parent_cycle_id: UUID

# State management
request_state: Dict[str, Any]

# Tracing and monitoring
event_loop_cycle_trace: "Trace" # "Trace" # Trace object type varies by implementation
event_loop_cycle_span: Span | None # Span object type varies by implementation
event_loop_parent_span: Span | None # Parent span for tracing hierarchy

# Agent context fields
model: "Model" # The model instance used by the agent
system_prompt: Optional[str] # The system prompt for the agent
messages: List["Message"] # The conversation history
tool_config: "ToolConfig" # Configuration for available tools
18 changes: 17 additions & 1 deletion src/strands/types/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from typing_extensions import TypedDict

from .invocation import InvocationState
from .media import DocumentContent, ImageContent

JSONSchema = dict
Expand Down Expand Up @@ -117,6 +118,21 @@ class ToolChoiceTool(TypedDict):
name: str


class StrandsContext(TypedDict, total=False):
"""Context object containing framework-provided data for decorated tools.

This object provides access to framework-level information that may be useful
for tool implementations. All fields are optional to maintain backward compatibility.

Attributes:
tool_use: The complete ToolUse object containing tool invocation details.
invocation_state: Context for the tool invocation, including agent state.
"""

tool_use: ToolUse
invocation_state: InvocationState


ToolChoice = Union[
dict[Literal["auto"], ToolChoiceAuto],
dict[Literal["any"], ToolChoiceAny],
Expand Down Expand Up @@ -216,7 +232,7 @@ def supports_hot_reload(self) -> bool:

@abstractmethod
# pragma: no cover
def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator:
def stream(self, tool_use: ToolUse, invocation_state: InvocationState, **kwargs: Any) -> ToolGenerator:
"""Stream tool events and return the final result.

Args:
Expand Down
Loading
Loading