Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/strands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from . import agent, models, telemetry, types
from .agent.agent import Agent
from .tools.decorator import tool
from .types.tools import ToolContext

__all__ = ["Agent", "agent", "models", "tool", "types", "telemetry"]
__all__ = ["Agent", "agent", "models", "tool", "types", "telemetry", "ToolContext"]
82 changes: 70 additions & 12 deletions src/strands/tools/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
from pydantic import BaseModel, Field, create_model
from typing_extensions import override

from ..types.tools import AgentTool, JSONSchema, ToolGenerator, ToolSpec, ToolUse
from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolSpec, ToolUse

logger = logging.getLogger(__name__)

Expand All @@ -84,16 +84,18 @@ class FunctionToolMetadata:
validate tool usage.
"""

def __init__(self, func: Callable[..., Any]) -> None:
def __init__(self, func: Callable[..., Any], context_param: str | None = None) -> None:
"""Initialize with the function to process.

Args:
func: The function to extract metadata from.
Can be a standalone function or a class method.
context_param: Name of the context parameter to inject, if any.
"""
self.func = func
self.signature = inspect.signature(func)
self.type_hints = get_type_hints(func)
self._context_param = context_param

# Parse the docstring with docstring_parser
doc_str = inspect.getdoc(func) or ""
Expand All @@ -113,16 +115,16 @@ def _create_input_model(self) -> Type[BaseModel]:
This method analyzes the function's signature, type hints, and docstring to create a Pydantic model that can
validate input data before passing it to the function.

Special parameters like 'self', 'cls', and 'agent' are excluded from the model.
Special parameters that can be automatically injected are excluded from the model.

Returns:
A Pydantic BaseModel class customized for the function's parameters.
"""
field_definitions: dict[str, Any] = {}

for name, param in self.signature.parameters.items():
# Skip special parameters
if name in ("self", "cls", "agent"):
# Skip parameters that will be automatically injected
if self._is_special_parameter(name):
continue

# Get parameter type and default
Expand Down Expand Up @@ -252,6 +254,49 @@ def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]:
error_msg = str(e)
raise ValueError(f"Validation failed for input parameters: {error_msg}") from e

def inject_special_parameters(
self, validated_input: dict[str, Any], tool_use: ToolUse, invocation_state: dict[str, Any]
) -> None:
"""Inject special framework-provided parameters into the validated input.

This method automatically provides framework-level context to tools that request it
through their function signature.

Args:
validated_input: The validated input parameters (modified in place).
tool_use: The tool use request containing tool invocation details.
invocation_state: Context for the tool invocation, including agent state.
"""
if self._context_param and self._context_param in self.signature.parameters:
tool_context = ToolContext(tool_use=tool_use, agent=invocation_state["agent"])
validated_input[self._context_param] = tool_context

# Inject agent if requested (backward compatibility)
if "agent" in self.signature.parameters and "agent" in invocation_state:
validated_input["agent"] = invocation_state["agent"]

def _is_special_parameter(self, param_name: str) -> bool:
"""Check if a parameter should be automatically injected by the framework or is a standard Python method param.

Special parameters include:
- Standard Python method parameters: self, cls
- Framework-provided context parameters: agent, and configurable context parameter (defaults to tool_context)

Args:
param_name: The name of the parameter to check.

Returns:
True if the parameter should be excluded from input validation and
handled specially during tool execution.
"""
special_params = {"self", "cls", "agent"}

# Add context parameter if configured
if self._context_param:
special_params.add(self._context_param)

return param_name in special_params


P = ParamSpec("P") # Captures all parameters
R = TypeVar("R") # Return type
Expand Down Expand Up @@ -402,9 +447,8 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
# Validate input against the Pydantic model
validated_input = self._metadata.validate_input(tool_input)

# Pass along the agent if provided and expected by the function
if "agent" in invocation_state and "agent" in self._metadata.signature.parameters:
validated_input["agent"] = invocation_state.get("agent")
# Inject special framework-provided parameters
self._metadata.inject_special_parameters(validated_input, tool_use, invocation_state)

# "Too few arguments" expected, hence the type ignore
if inspect.iscoroutinefunction(self._tool_func):
Expand Down Expand Up @@ -474,6 +518,7 @@ def tool(
description: Optional[str] = None,
inputSchema: Optional[JSONSchema] = None,
name: Optional[str] = None,
context: bool | str = False,
) -> Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: ...
# Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the
# call site, but the actual implementation handles that and it's not representable via the type-system
Expand All @@ -482,6 +527,7 @@ def tool( # type: ignore
description: Optional[str] = None,
inputSchema: Optional[JSONSchema] = None,
name: Optional[str] = None,
context: bool | str = False,
) -> Union[DecoratedFunctionTool[P, R], Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]]:
"""Decorator that transforms a Python function into a Strands tool.

Expand All @@ -507,6 +553,9 @@ def tool( # type: ignore
description: Optional custom description to override the function's docstring.
inputSchema: Optional custom JSON schema to override the automatically generated schema.
name: Optional custom name to override the function's name.
context: When provided, places an object in the designated parameter. If True, the param name
defaults to 'tool_context', or if an override is needed, set context equal to a string to designate
the param name.

Returns:
An AgentTool that also mimics the original function when invoked
Expand Down Expand Up @@ -536,15 +585,24 @@ def my_tool(name: str, count: int = 1) -> str:

Example with parameters:
```python
@tool(name="custom_tool", description="A tool with a custom name and description")
def my_tool(name: str, count: int = 1) -> str:
return f"Processed {name} {count} times"
@tool(name="custom_tool", description="A tool with a custom name and description", context=True)
def my_tool(name: str, count: int = 1, tool_context: ToolContext) -> str:
tool_id = tool_context["tool_use"]["toolUseId"]
return f"Processed {name} {count} times with tool ID {tool_id}"
```
"""

def decorator(f: T) -> "DecoratedFunctionTool[P, R]":
# Resolve context parameter name
if isinstance(context, bool):
context_param = "tool_context" if context else None
else:
context_param = context.strip()
if not context_param:
raise ValueError("Context parameter name cannot be empty")

# Create function tool metadata
tool_meta = FunctionToolMetadata(f)
tool_meta = FunctionToolMetadata(f, context_param)
tool_spec = tool_meta.extract_metadata()
if name is not None:
tool_spec["name"] = name
Expand Down
27 changes: 26 additions & 1 deletion src/strands/types/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
"""

from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union

from typing_extensions import TypedDict

from .media import DocumentContent, ImageContent

if TYPE_CHECKING:
from .. import Agent

JSONSchema = dict
"""Type alias for JSON Schema dictionaries."""

Expand Down Expand Up @@ -117,6 +121,27 @@ class ToolChoiceTool(TypedDict):
name: str


@dataclass
class ToolContext:
"""Context object containing framework-provided data for decorated tools.

This object provides access to framework-level information that may be useful
for tool implementations.

Attributes:
tool_use: The complete ToolUse object containing tool invocation details.
agent: The Agent instance executing this tool, providing access to conversation history,
model configuration, and other agent state.

Note:
This class is intended to be instantiated by the SDK. Direct construction by users
is not supported and may break in future versions as new fields are added.
"""

tool_use: ToolUse
agent: "Agent"


ToolChoice = Union[
dict[Literal["auto"], ToolChoiceAuto],
dict[Literal["any"], ToolChoiceAny],
Expand Down
159 changes: 158 additions & 1 deletion tests/strands/tools/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import pytest

import strands
from strands.types.tools import ToolUse
from strands import Agent
from strands.types.tools import AgentTool, ToolContext, ToolUse


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -1036,3 +1037,159 @@ def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None]
result = (await alist(stream))[-1]
assert result["status"] == "success"
assert "NoneType: None" in result["content"][0]["text"]


async def _run_context_injection_test(context_tool: AgentTool):
"""Common test logic for context injection tests."""
tool: AgentTool = context_tool
generator = tool.stream(
tool_use={
"toolUseId": "test-id",
"name": "context_tool",
"input": {
"message": "some_message" # note that we do not include agent nor tool context
},
},
invocation_state={
"agent": Agent(name="test_agent"),
},
)
tool_results = [value async for value in generator]

assert len(tool_results) == 1
tool_result = tool_results[0]

assert tool_result == {
"status": "success",
"content": [
{"text": "Tool 'context_tool' (ID: test-id)"},
{"text": "injected agent 'test_agent' processed: some_message"},
{"text": "context agent 'test_agent'"}
],
"toolUseId": "test-id",
}


@pytest.mark.asyncio
async def test_tool_context_injection_default():
"""Test that ToolContext is properly injected with default parameter name (tool_context)."""

@strands.tool(context=True)
def context_tool(message: str, agent: Agent, tool_context: ToolContext) -> dict:
"""Tool that uses ToolContext to access tool_use_id."""
tool_use_id = tool_context.tool_use["toolUseId"]
tool_name = tool_context.tool_use["name"]
agent_from_tool_context = tool_context.agent

return {
"status": "success",
"content": [
{"text": f"Tool '{tool_name}' (ID: {tool_use_id})"},
{"text": f"injected agent '{agent.name}' processed: {message}"},
{"text": f"context agent '{agent_from_tool_context.name}'"},
],
}

await _run_context_injection_test(context_tool)


@pytest.mark.asyncio
async def test_tool_context_injection_custom_name():
"""Test that ToolContext is properly injected with custom parameter name."""

@strands.tool(context="custom_context_name")
def context_tool(message: str, agent: Agent, custom_context_name: ToolContext) -> dict:
"""Tool that uses ToolContext to access tool_use_id."""
tool_use_id = custom_context_name.tool_use["toolUseId"]
tool_name = custom_context_name.tool_use["name"]
agent_from_tool_context = custom_context_name.agent

return {
"status": "success",
"content": [
{"text": f"Tool '{tool_name}' (ID: {tool_use_id})"},
{"text": f"injected agent '{agent.name}' processed: {message}"},
{"text": f"context agent '{agent_from_tool_context.name}'"},
],
}

await _run_context_injection_test(context_tool)


@pytest.mark.asyncio
async def test_tool_context_injection_disabled_missing_parameter():
"""Test that when context=False, missing tool_context parameter causes validation error."""

@strands.tool(context=False)
def context_tool(message: str, agent: Agent, tool_context: str) -> dict:
"""Tool that expects tool_context as a regular string parameter."""
return {
"status": "success",
"content": [
{"text": f"Message: {message}"},
{"text": f"Agent: {agent.name}"},
{"text": f"Tool context string: {tool_context}"},
],
}

# Verify that missing tool_context parameter causes validation error
tool: AgentTool = context_tool
generator = tool.stream(
tool_use={
"toolUseId": "test-id",
"name": "context_tool",
"input": {
"message": "some_message"
# Missing tool_context parameter - should cause validation error instead of being auto injected
},
},
invocation_state={
"agent": Agent(name="test_agent"),
},
)
tool_results = [value async for value in generator]

assert len(tool_results) == 1
tool_result = tool_results[0]

# Should get a validation error because tool_context is required but not provided
assert tool_result["status"] == "error"
assert "tool_context" in tool_result["content"][0]["text"].lower()
assert "validation" in tool_result["content"][0]["text"].lower()


@pytest.mark.asyncio
async def test_tool_context_injection_disabled_string_parameter():
"""Test that when context=False, tool_context can be passed as a string parameter."""

@strands.tool(context=False)
def context_tool(message: str, agent: Agent, tool_context: str) -> str:
"""Tool that expects tool_context as a regular string parameter."""
return "success"

# Verify that providing tool_context as a string works correctly
tool: AgentTool = context_tool
generator = tool.stream(
tool_use={
"toolUseId": "test-id-2",
"name": "context_tool",
"input": {
"message": "some_message",
"tool_context": "my_custom_context_string"
},
},
invocation_state={
"agent": Agent(name="test_agent"),
},
)
tool_results = [value async for value in generator]

assert len(tool_results) == 1
tool_result = tool_results[0]

# Should succeed with the string parameter
assert tool_result == {
"status": "success",
"content": [{"text": "success"}],
"toolUseId": "test-id-2",
}
Loading
Loading