Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/strands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from . import agent, models, telemetry, types
from .agent.agent import Agent
from .tools.decorator import tool
from .types.tools import ToolContext

__all__ = ["Agent", "agent", "models", "tool", "types", "telemetry"]
__all__ = ["Agent", "agent", "models", "tool", "types", "telemetry", "ToolContext"]
82 changes: 70 additions & 12 deletions src/strands/tools/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
from pydantic import BaseModel, Field, create_model
from typing_extensions import override

from ..types.tools import AgentTool, JSONSchema, ToolGenerator, ToolSpec, ToolUse
from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolSpec, ToolUse

logger = logging.getLogger(__name__)

Expand All @@ -84,16 +84,18 @@ class FunctionToolMetadata:
validate tool usage.
"""

def __init__(self, func: Callable[..., Any]) -> None:
def __init__(self, func: Callable[..., Any], context_param: str | None = None) -> None:
"""Initialize with the function to process.

Args:
func: The function to extract metadata from.
Can be a standalone function or a class method.
context_param: Name of the context parameter to inject, if any.
"""
self.func = func
self.signature = inspect.signature(func)
self.type_hints = get_type_hints(func)
self._context_param = context_param

# Parse the docstring with docstring_parser
doc_str = inspect.getdoc(func) or ""
Expand All @@ -113,16 +115,16 @@ def _create_input_model(self) -> Type[BaseModel]:
This method analyzes the function's signature, type hints, and docstring to create a Pydantic model that can
validate input data before passing it to the function.

Special parameters like 'self', 'cls', and 'agent' are excluded from the model.
Special parameters that can be automatically injected are excluded from the model.

Returns:
A Pydantic BaseModel class customized for the function's parameters.
"""
field_definitions: dict[str, Any] = {}

for name, param in self.signature.parameters.items():
# Skip special parameters
if name in ("self", "cls", "agent"):
# Skip parameters that will be automatically injected
if self._is_special_parameter(name):
continue

# Get parameter type and default
Expand Down Expand Up @@ -252,6 +254,49 @@ def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]:
error_msg = str(e)
raise ValueError(f"Validation failed for input parameters: {error_msg}") from e

def inject_special_parameters(
self, validated_input: dict[str, Any], tool_use: ToolUse, invocation_state: dict[str, Any]
) -> None:
"""Inject special framework-provided parameters into the validated input.

This method automatically provides framework-level context to tools that request it
through their function signature.

Args:
validated_input: The validated input parameters (modified in place).
tool_use: The tool use request containing tool invocation details.
invocation_state: Context for the tool invocation, including agent state.
"""
if self._context_param and self._context_param in self.signature.parameters:
tool_context = ToolContext(tool_use=tool_use, agent=invocation_state["agent"])
validated_input[self._context_param] = tool_context

# Inject agent if requested (backward compatibility)
if "agent" in self.signature.parameters and "agent" in invocation_state:
validated_input["agent"] = invocation_state["agent"]

def _is_special_parameter(self, param_name: str) -> bool:
"""Check if a parameter should be automatically injected by the framework or is a standard Python method param.

Special parameters include:
- Standard Python method parameters: self, cls
- Framework-provided context parameters: agent, and configurable context parameter (defaults to tool_context)

Args:
param_name: The name of the parameter to check.

Returns:
True if the parameter should be excluded from input validation and
handled specially during tool execution.
"""
special_params = {"self", "cls", "agent"}

# Add context parameter if configured
if self._context_param:
special_params.add(self._context_param)

return param_name in special_params


P = ParamSpec("P") # Captures all parameters
R = TypeVar("R") # Return type
Expand Down Expand Up @@ -402,9 +447,8 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
# Validate input against the Pydantic model
validated_input = self._metadata.validate_input(tool_input)

# Pass along the agent if provided and expected by the function
if "agent" in invocation_state and "agent" in self._metadata.signature.parameters:
validated_input["agent"] = invocation_state.get("agent")
# Inject special framework-provided parameters
self._metadata.inject_special_parameters(validated_input, tool_use, invocation_state)

# "Too few arguments" expected, hence the type ignore
if inspect.iscoroutinefunction(self._tool_func):
Expand Down Expand Up @@ -474,6 +518,7 @@ def tool(
description: Optional[str] = None,
inputSchema: Optional[JSONSchema] = None,
name: Optional[str] = None,
context: bool | str = False,
) -> Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: ...
# Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the
# call site, but the actual implementation handles that and it's not representable via the type-system
Expand All @@ -482,6 +527,7 @@ def tool( # type: ignore
description: Optional[str] = None,
inputSchema: Optional[JSONSchema] = None,
name: Optional[str] = None,
context: bool | str = False,
) -> Union[DecoratedFunctionTool[P, R], Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]]:
"""Decorator that transforms a Python function into a Strands tool.

Expand All @@ -507,6 +553,9 @@ def tool( # type: ignore
description: Optional custom description to override the function's docstring.
inputSchema: Optional custom JSON schema to override the automatically generated schema.
name: Optional custom name to override the function's name.
context: When provided, places a object in the designated parameter. If True, the param name
defaults to 'tool_context', or if an override is needed, set context equal to a string to designate
the param name.

Returns:
An AgentTool that also mimics the original function when invoked
Expand Down Expand Up @@ -536,15 +585,24 @@ def my_tool(name: str, count: int = 1) -> str:

Example with parameters:
```python
@tool(name="custom_tool", description="A tool with a custom name and description")
def my_tool(name: str, count: int = 1) -> str:
return f"Processed {name} {count} times"
@tool(name="custom_tool", description="A tool with a custom name and description", context=True)
def my_tool(name: str, count: int = 1, tool_context: ) -> str:
tool_id = tool_context["tool_use"]["toolUseId"]
return f"Processed {name} {count} times with tool ID {tool_id}"
```
"""

def decorator(f: T) -> "DecoratedFunctionTool[P, R]":
# Resolve context parameter name
if isinstance(context, bool):
context_param = "tool_context" if context else None
else:
context_param = context.strip()
if not context_param:
raise ValueError("Context parameter name cannot be empty")

# Create function tool metadata
tool_meta = FunctionToolMetadata(f)
tool_meta = FunctionToolMetadata(f, context_param)
tool_spec = tool_meta.extract_metadata()
if name is not None:
tool_spec["name"] = name
Expand Down
23 changes: 22 additions & 1 deletion src/strands/types/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
"""

from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union

from typing_extensions import TypedDict

from .media import DocumentContent, ImageContent

if TYPE_CHECKING:
from .. import Agent

JSONSchema = dict
"""Type alias for JSON Schema dictionaries."""

Expand Down Expand Up @@ -117,6 +121,23 @@ class ToolChoiceTool(TypedDict):
name: str


@dataclass
class ToolContext:
"""Context object containing framework-provided data for decorated tools.

This object provides access to framework-level information that may be useful
for tool implementations.

Attributes:
tool_use: The complete ToolUse object containing tool invocation details.
agent: The Agent instance executing this tool, providing access to conversation history,
model configuration, and other agent state.
"""

tool_use: ToolUse
agent: "Agent"


ToolChoice = Union[
dict[Literal["auto"], ToolChoiceAuto],
dict[Literal["any"], ToolChoiceAny],
Expand Down
48 changes: 47 additions & 1 deletion tests/strands/tools/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import pytest

import strands
from strands.types.tools import ToolUse
from strands import Agent
from strands.types.tools import AgentTool, ToolContext, ToolUse


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -1036,3 +1037,48 @@ def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None]
result = (await alist(stream))[-1]
assert result["status"] == "success"
assert "NoneType: None" in result["content"][0]["text"]


@pytest.mark.asyncio
async def test_tool_context_injection(alist):
"""Test that ToolContext is properly injected into tools that request it."""

@strands.tool(context=True)
def context_tool(message: str, agent: Agent, tool_context: ToolContext) -> dict:
"""Tool that uses ToolContext to access tool_use_id."""
tool_use_id = tool_context.tool_use["toolUseId"]
tool_name = tool_context.tool_use["name"]
agent_from_tool_context = tool_context.agent

return {
"status": "success",
"content": [
{"text": f"Tool '{tool_name}' (ID: {tool_use_id})"},
{"text": f"injected agent '{agent.name}' processed: {message}"},
{"text": f"context agent '{agent_from_tool_context.name}'"},
],
}

tool: AgentTool = context_tool
generator = tool.stream(
tool_use={
"toolUseId": "test-id",
"name": "context_tool",
"input": {
"message": "some_message" # note that we do not include agent nor tool context
},
},
invocation_state={
"agent": Agent(name="test_agent"),
},
)
tool_results = [value async for value in generator]

assert len(tool_results) == 1
tool_result = tool_results[0]

assert tool_result["status"] == "success"
assert tool_result["toolUseId"] == "test-id"
assert tool_result["content"][0]["text"] == "Tool 'context_tool' (ID: test-id)"
assert tool_result["content"][1]["text"] == "injected agent 'test_agent' processed: some_message"
assert tool_result["content"][2]["text"] == "context agent 'test_agent'"
56 changes: 56 additions & 0 deletions tests_integ/test_tool_context_injection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/usr/bin/env python3
"""
Integration test for ToolContext functionality with real agent interactions.
"""

from strands import Agent, ToolContext, tool
from strands.types.tools import ToolResult


@tool(context="custom_context_field")
def good_story(message: str, custom_context_field: ToolContext) -> dict:
"""Tool that writes a good story"""
tool_use_id = custom_context_field.tool_use["toolUseId"]
return {
"status": "success",
"content": [{"text": f"Context tool processed with ID: {tool_use_id}"}],
}


@tool(context=True)
def bad_story(message: str, tool_context: ToolContext) -> dict:
"""Tool that writes a bad story"""
tool_use_id = tool_context.tool_use["toolUseId"]
return {
"status": "success",
"content": [{"text": f"Context tool processed with ID: {tool_use_id}"}],
}


def _validate_tool_result_content(agent: Agent):
first_tool_result: ToolResult = [
block["toolResult"] for message in agent.messages for block in message["content"] if "toolResult" in block
][0]

assert first_tool_result["status"] == "success"
assert (
first_tool_result["content"][0]["text"] == f"Context tool processed with ID: {first_tool_result['toolUseId']}"
)


def test_strands_context_integration_context_true():
"""Test ToolContext functionality with real agent interactions."""

agent = Agent(tools=[good_story])
agent("using a tool, write a good story")

_validate_tool_result_content(agent)


def test_strands_context_integration_context_custom():
"""Test ToolContext functionality with real agent interactions."""

agent = Agent(tools=[bad_story])
agent("using a tool, write a bad story")

_validate_tool_result_content(agent)
Loading