diff --git a/src/strands/__init__.py b/src/strands/__init__.py index e9f9e9cd8..ae784a58f 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -3,5 +3,6 @@ from . import agent, models, telemetry, types from .agent.agent import Agent from .tools.decorator import tool +from .types.tools import ToolContext -__all__ = ["Agent", "agent", "models", "tool", "types", "telemetry"] +__all__ = ["Agent", "agent", "models", "tool", "types", "telemetry", "ToolContext"] diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 5ec324b68..75abac9ed 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -61,7 +61,7 @@ def my_tool(param1: str, param2: int = 42) -> dict: from pydantic import BaseModel, Field, create_model from typing_extensions import override -from ..types.tools import AgentTool, JSONSchema, ToolGenerator, ToolSpec, ToolUse +from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolSpec, ToolUse logger = logging.getLogger(__name__) @@ -84,16 +84,18 @@ class FunctionToolMetadata: validate tool usage. """ - def __init__(self, func: Callable[..., Any]) -> None: + def __init__(self, func: Callable[..., Any], context_param: str | None = None) -> None: """Initialize with the function to process. Args: func: The function to extract metadata from. Can be a standalone function or a class method. + context_param: Name of the context parameter to inject, if any. """ self.func = func self.signature = inspect.signature(func) self.type_hints = get_type_hints(func) + self._context_param = context_param # Parse the docstring with docstring_parser doc_str = inspect.getdoc(func) or "" @@ -113,7 +115,7 @@ def _create_input_model(self) -> Type[BaseModel]: This method analyzes the function's signature, type hints, and docstring to create a Pydantic model that can validate input data before passing it to the function. - Special parameters like 'self', 'cls', and 'agent' are excluded from the model. + Special parameters that can be automatically injected are excluded from the model. Returns: A Pydantic BaseModel class customized for the function's parameters. @@ -121,8 +123,8 @@ def _create_input_model(self) -> Type[BaseModel]: field_definitions: dict[str, Any] = {} for name, param in self.signature.parameters.items(): - # Skip special parameters - if name in ("self", "cls", "agent"): + # Skip parameters that will be automatically injected + if self._is_special_parameter(name): continue # Get parameter type and default @@ -252,6 +254,49 @@ def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]: error_msg = str(e) raise ValueError(f"Validation failed for input parameters: {error_msg}") from e + def inject_special_parameters( + self, validated_input: dict[str, Any], tool_use: ToolUse, invocation_state: dict[str, Any] + ) -> None: + """Inject special framework-provided parameters into the validated input. + + This method automatically provides framework-level context to tools that request it + through their function signature. + + Args: + validated_input: The validated input parameters (modified in place). + tool_use: The tool use request containing tool invocation details. + invocation_state: Context for the tool invocation, including agent state. + """ + if self._context_param and self._context_param in self.signature.parameters: + tool_context = ToolContext(tool_use=tool_use, agent=invocation_state["agent"]) + validated_input[self._context_param] = tool_context + + # Inject agent if requested (backward compatibility) + if "agent" in self.signature.parameters and "agent" in invocation_state: + validated_input["agent"] = invocation_state["agent"] + + def _is_special_parameter(self, param_name: str) -> bool: + """Check if a parameter should be automatically injected by the framework or is a standard Python method param. + + Special parameters include: + - Standard Python method parameters: self, cls + - Framework-provided context parameters: agent, and configurable context parameter (defaults to tool_context) + + Args: + param_name: The name of the parameter to check. + + Returns: + True if the parameter should be excluded from input validation and + handled specially during tool execution. + """ + special_params = {"self", "cls", "agent"} + + # Add context parameter if configured + if self._context_param: + special_params.add(self._context_param) + + return param_name in special_params + P = ParamSpec("P") # Captures all parameters R = TypeVar("R") # Return type @@ -402,9 +447,8 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw # Validate input against the Pydantic model validated_input = self._metadata.validate_input(tool_input) - # Pass along the agent if provided and expected by the function - if "agent" in invocation_state and "agent" in self._metadata.signature.parameters: - validated_input["agent"] = invocation_state.get("agent") + # Inject special framework-provided parameters + self._metadata.inject_special_parameters(validated_input, tool_use, invocation_state) # "Too few arguments" expected, hence the type ignore if inspect.iscoroutinefunction(self._tool_func): @@ -474,6 +518,7 @@ def tool( description: Optional[str] = None, inputSchema: Optional[JSONSchema] = None, name: Optional[str] = None, + context: bool | str = False, ) -> Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: ... # Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the # call site, but the actual implementation handles that and it's not representable via the type-system @@ -482,6 +527,7 @@ def tool( # type: ignore description: Optional[str] = None, inputSchema: Optional[JSONSchema] = None, name: Optional[str] = None, + context: bool | str = False, ) -> Union[DecoratedFunctionTool[P, R], Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]]: """Decorator that transforms a Python function into a Strands tool. @@ -507,6 +553,9 @@ def tool( # type: ignore description: Optional custom description to override the function's docstring. inputSchema: Optional custom JSON schema to override the automatically generated schema. name: Optional custom name to override the function's name. + context: When provided, places an object in the designated parameter. If True, the param name + defaults to 'tool_context', or if an override is needed, set context equal to a string to designate + the param name. Returns: An AgentTool that also mimics the original function when invoked @@ -536,15 +585,24 @@ def my_tool(name: str, count: int = 1) -> str: Example with parameters: ```python - @tool(name="custom_tool", description="A tool with a custom name and description") - def my_tool(name: str, count: int = 1) -> str: - return f"Processed {name} {count} times" + @tool(name="custom_tool", description="A tool with a custom name and description", context=True) + def my_tool(name: str, count: int = 1, tool_context: ToolContext) -> str: + tool_id = tool_context["tool_use"]["toolUseId"] + return f"Processed {name} {count} times with tool ID {tool_id}" ``` """ def decorator(f: T) -> "DecoratedFunctionTool[P, R]": + # Resolve context parameter name + if isinstance(context, bool): + context_param = "tool_context" if context else None + else: + context_param = context.strip() + if not context_param: + raise ValueError("Context parameter name cannot be empty") + # Create function tool metadata - tool_meta = FunctionToolMetadata(f) + tool_meta = FunctionToolMetadata(f, context_param) tool_spec = tool_meta.extract_metadata() if name is not None: tool_spec["name"] = name diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index 533e5529c..bb7c874f6 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -6,12 +6,16 @@ """ from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union from typing_extensions import TypedDict from .media import DocumentContent, ImageContent +if TYPE_CHECKING: + from .. import Agent + JSONSchema = dict """Type alias for JSON Schema dictionaries.""" @@ -117,6 +121,27 @@ class ToolChoiceTool(TypedDict): name: str +@dataclass +class ToolContext: + """Context object containing framework-provided data for decorated tools. + + This object provides access to framework-level information that may be useful + for tool implementations. + + Attributes: + tool_use: The complete ToolUse object containing tool invocation details. + agent: The Agent instance executing this tool, providing access to conversation history, + model configuration, and other agent state. + + Note: + This class is intended to be instantiated by the SDK. Direct construction by users + is not supported and may break in future versions as new fields are added. + """ + + tool_use: ToolUse + agent: "Agent" + + ToolChoice = Union[ dict[Literal["auto"], ToolChoiceAuto], dict[Literal["any"], ToolChoiceAny], diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 52a9282e0..246879da7 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -8,7 +8,8 @@ import pytest import strands -from strands.types.tools import ToolUse +from strands import Agent +from strands.types.tools import AgentTool, ToolContext, ToolUse @pytest.fixture(scope="module") @@ -1036,3 +1037,159 @@ def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None] result = (await alist(stream))[-1] assert result["status"] == "success" assert "NoneType: None" in result["content"][0]["text"] + + +async def _run_context_injection_test(context_tool: AgentTool): + """Common test logic for context injection tests.""" + tool: AgentTool = context_tool + generator = tool.stream( + tool_use={ + "toolUseId": "test-id", + "name": "context_tool", + "input": { + "message": "some_message" # note that we do not include agent nor tool context + }, + }, + invocation_state={ + "agent": Agent(name="test_agent"), + }, + ) + tool_results = [value async for value in generator] + + assert len(tool_results) == 1 + tool_result = tool_results[0] + + assert tool_result == { + "status": "success", + "content": [ + {"text": "Tool 'context_tool' (ID: test-id)"}, + {"text": "injected agent 'test_agent' processed: some_message"}, + {"text": "context agent 'test_agent'"} + ], + "toolUseId": "test-id", + } + + +@pytest.mark.asyncio +async def test_tool_context_injection_default(): + """Test that ToolContext is properly injected with default parameter name (tool_context).""" + + @strands.tool(context=True) + def context_tool(message: str, agent: Agent, tool_context: ToolContext) -> dict: + """Tool that uses ToolContext to access tool_use_id.""" + tool_use_id = tool_context.tool_use["toolUseId"] + tool_name = tool_context.tool_use["name"] + agent_from_tool_context = tool_context.agent + + return { + "status": "success", + "content": [ + {"text": f"Tool '{tool_name}' (ID: {tool_use_id})"}, + {"text": f"injected agent '{agent.name}' processed: {message}"}, + {"text": f"context agent '{agent_from_tool_context.name}'"}, + ], + } + + await _run_context_injection_test(context_tool) + + +@pytest.mark.asyncio +async def test_tool_context_injection_custom_name(): + """Test that ToolContext is properly injected with custom parameter name.""" + + @strands.tool(context="custom_context_name") + def context_tool(message: str, agent: Agent, custom_context_name: ToolContext) -> dict: + """Tool that uses ToolContext to access tool_use_id.""" + tool_use_id = custom_context_name.tool_use["toolUseId"] + tool_name = custom_context_name.tool_use["name"] + agent_from_tool_context = custom_context_name.agent + + return { + "status": "success", + "content": [ + {"text": f"Tool '{tool_name}' (ID: {tool_use_id})"}, + {"text": f"injected agent '{agent.name}' processed: {message}"}, + {"text": f"context agent '{agent_from_tool_context.name}'"}, + ], + } + + await _run_context_injection_test(context_tool) + + +@pytest.mark.asyncio +async def test_tool_context_injection_disabled_missing_parameter(): + """Test that when context=False, missing tool_context parameter causes validation error.""" + + @strands.tool(context=False) + def context_tool(message: str, agent: Agent, tool_context: str) -> dict: + """Tool that expects tool_context as a regular string parameter.""" + return { + "status": "success", + "content": [ + {"text": f"Message: {message}"}, + {"text": f"Agent: {agent.name}"}, + {"text": f"Tool context string: {tool_context}"}, + ], + } + + # Verify that missing tool_context parameter causes validation error + tool: AgentTool = context_tool + generator = tool.stream( + tool_use={ + "toolUseId": "test-id", + "name": "context_tool", + "input": { + "message": "some_message" + # Missing tool_context parameter - should cause validation error instead of being auto injected + }, + }, + invocation_state={ + "agent": Agent(name="test_agent"), + }, + ) + tool_results = [value async for value in generator] + + assert len(tool_results) == 1 + tool_result = tool_results[0] + + # Should get a validation error because tool_context is required but not provided + assert tool_result["status"] == "error" + assert "tool_context" in tool_result["content"][0]["text"].lower() + assert "validation" in tool_result["content"][0]["text"].lower() + + +@pytest.mark.asyncio +async def test_tool_context_injection_disabled_string_parameter(): + """Test that when context=False, tool_context can be passed as a string parameter.""" + + @strands.tool(context=False) + def context_tool(message: str, agent: Agent, tool_context: str) -> str: + """Tool that expects tool_context as a regular string parameter.""" + return "success" + + # Verify that providing tool_context as a string works correctly + tool: AgentTool = context_tool + generator = tool.stream( + tool_use={ + "toolUseId": "test-id-2", + "name": "context_tool", + "input": { + "message": "some_message", + "tool_context": "my_custom_context_string" + }, + }, + invocation_state={ + "agent": Agent(name="test_agent"), + }, + ) + tool_results = [value async for value in generator] + + assert len(tool_results) == 1 + tool_result = tool_results[0] + + # Should succeed with the string parameter + assert tool_result == { + "status": "success", + "content": [{"text": "success"}], + "toolUseId": "test-id-2", + } diff --git a/tests_integ/test_tool_context_injection.py b/tests_integ/test_tool_context_injection.py new file mode 100644 index 000000000..3098604f1 --- /dev/null +++ b/tests_integ/test_tool_context_injection.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +""" +Integration test for ToolContext functionality with real agent interactions. +""" + +from strands import Agent, ToolContext, tool +from strands.types.tools import ToolResult + + +@tool(context="custom_context_field") +def good_story(message: str, custom_context_field: ToolContext) -> dict: + """Tool that writes a good story""" + tool_use_id = custom_context_field.tool_use["toolUseId"] + return { + "status": "success", + "content": [{"text": f"Context tool processed with ID: {tool_use_id}"}], + } + + +@tool(context=True) +def bad_story(message: str, tool_context: ToolContext) -> dict: + """Tool that writes a bad story""" + tool_use_id = tool_context.tool_use["toolUseId"] + return { + "status": "success", + "content": [{"text": f"Context tool processed with ID: {tool_use_id}"}], + } + + +def _validate_tool_result_content(agent: Agent): + first_tool_result: ToolResult = [ + block["toolResult"] for message in agent.messages for block in message["content"] if "toolResult" in block + ][0] + + assert first_tool_result["status"] == "success" + assert ( + first_tool_result["content"][0]["text"] == f"Context tool processed with ID: {first_tool_result['toolUseId']}" + ) + + +def test_strands_context_integration_context_true(): + """Test ToolContext functionality with real agent interactions.""" + + agent = Agent(tools=[good_story]) + agent("using a tool, write a good story") + + _validate_tool_result_content(agent) + + +def test_strands_context_integration_context_custom(): + """Test ToolContext functionality with real agent interactions.""" + + agent = Agent(tools=[bad_story]) + agent("using a tool, write a bad story") + + _validate_tool_result_content(agent)