Skip to content

Commit 3e6d530

Browse files
committed
fix(tool/decorator): validate ToolContext parameter name to avoid opaque Pydantic error
1 parent 7fbc9dc commit 3e6d530

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

src/strands/tools/decorator.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def my_tool(param1: str, param2: int = 42) -> dict:
5454
TypeVar,
5555
Union,
5656
cast,
57+
get_args,
58+
get_origin,
5759
get_type_hints,
5860
overload,
5961
)
@@ -103,6 +105,35 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) -
103105
doc_str = inspect.getdoc(func) or ""
104106
self.doc = docstring_parser.parse(doc_str)
105107

108+
def _contains_tool_context(tp: Any) -> bool:
109+
"""Return True if the annotation `tp` (possibly Union/Optional) includes ToolContext."""
110+
if tp is None:
111+
return False
112+
origin = get_origin(tp)
113+
if origin is Union:
114+
return any(_contains_tool_context(a) for a in get_args(tp))
115+
# Handle direct ToolContext type
116+
return tp is ToolContext
117+
118+
for param in self.signature.parameters.values():
119+
# Prefer resolved type hints (handles forward refs); fall back to annotation
120+
ann = self.type_hints.get(param.name, param.annotation)
121+
if ann is inspect._empty:
122+
continue
123+
124+
if _contains_tool_context(ann):
125+
# If decorator didn't opt-in to context injection, complain
126+
if self._context_param is None:
127+
raise TypeError(
128+
f"Parameter '{param.name}' is of type 'ToolContext' but '@tool(context=True)' is missing."
129+
)
130+
# If decorator specified a different param name, complain
131+
if param.name != self._context_param:
132+
raise TypeError(
133+
f"Parameter '{param.name}' is of type 'ToolContext' but has the wrong name. "
134+
f"It should be named '{self._context_param}'."
135+
)
136+
106137
# Get parameter descriptions from parsed docstring
107138
self.param_descriptions = {
108139
param.arg_name: param.description or f"Parameter {param.arg_name}" for param in self.doc.params

tests/strands/tools/test_decorator.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,3 +1363,44 @@ async def async_generator() -> AsyncGenerator:
13631363
]
13641364

13651365
assert act_results == exp_results
1366+
1367+
1368+
def test_tool_with_mismatched_tool_context_param_name_raises_error():
1369+
"""Verify that a TypeError is raised for a mismatched tool_context parameter name."""
1370+
with pytest.raises(TypeError) as excinfo:
1371+
1372+
@strands.tool(context=True)
1373+
def my_tool(context: ToolContext):
1374+
pass
1375+
1376+
assert (
1377+
"Parameter 'context' is of type 'ToolContext' but has the wrong name. It should be named 'tool_context'."
1378+
in str(excinfo.value)
1379+
)
1380+
1381+
1382+
def test_tool_with_tool_context_but_no_context_flag_raises_error():
1383+
"""Verify that a TypeError is raised if ToolContext is used without context=True."""
1384+
with pytest.raises(TypeError) as excinfo:
1385+
1386+
@strands.tool
1387+
def my_tool(tool_context: ToolContext):
1388+
pass
1389+
1390+
assert "Parameter 'tool_context' is of type 'ToolContext' but '@tool(context=True)' is missing." in str(
1391+
excinfo.value
1392+
)
1393+
1394+
1395+
def test_tool_with_tool_context_named_custom_context_raises_error_if_mismatched():
1396+
"""Verify that a TypeError is raised when context param name doesn't match the decorator value."""
1397+
with pytest.raises(TypeError) as excinfo:
1398+
1399+
@strands.tool(context="my_context")
1400+
def my_tool(tool_context: ToolContext):
1401+
pass
1402+
1403+
assert (
1404+
"Parameter 'tool_context' is of type 'ToolContext' but has the wrong name. It should be named 'my_context'."
1405+
in str(excinfo.value)
1406+
)

0 commit comments

Comments
 (0)