Skip to content

Commit 784a91f

Browse files
committed
fix: update tests to pass agent in to direct tool call
1 parent 26c081d commit 784a91f

File tree

2 files changed

+31
-23
lines changed

2 files changed

+31
-23
lines changed

src/strands/tools/decorator.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -143,22 +143,6 @@ def _create_input_model(self) -> Type[BaseModel]:
143143
# Handle case with no parameters
144144
return create_model(model_name)
145145

146-
def _is_special_parameter(self, param_name: str) -> bool:
147-
"""Check if a parameter should be automatically injected by the framework.
148-
149-
Special parameters include:
150-
- Standard Python parameters: self, cls
151-
- Framework-provided context parameters: agent, strands_context
152-
153-
Args:
154-
param_name: The name of the parameter to check.
155-
156-
Returns:
157-
True if the parameter should be excluded from input validation and
158-
automatically injected during tool execution.
159-
"""
160-
return param_name in {"self", "cls", "agent", "strands_context"}
161-
162146
def extract_metadata(self) -> ToolSpec:
163147
"""Extract metadata from the function to create a tool specification.
164148
@@ -293,6 +277,22 @@ def inject_special_parameters(
293277
if "agent" in self.signature.parameters and "agent" in invocation_state:
294278
validated_input["agent"] = invocation_state["agent"]
295279

280+
def _is_special_parameter(self, param_name: str) -> bool:
281+
"""Check if a parameter should be automatically injected by the framework.
282+
283+
Special parameters include:
284+
- Standard Python parameters: self, cls
285+
- Framework-provided context parameters: agent, strands_context
286+
287+
Args:
288+
param_name: The name of the parameter to check.
289+
290+
Returns:
291+
True if the parameter should be excluded from input validation and
292+
automatically injected during tool execution.
293+
"""
294+
return param_name in {"self", "cls", "agent", "strands_context"}
295+
296296

297297
P = ParamSpec("P") # Captures all parameters
298298
R = TypeVar("R") # Return type

tests_integ/test_strands_context_integration.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
def tool_with_context(message: str, strands_context: StrandsContext) -> dict:
1616
"""Tool that uses StrandsContext to access tool_use_id."""
1717
tool_use_id = strands_context["tool_use"]["toolUseId"]
18-
return {"status": "success", "content": [{"text": f"Context tool processed '{message}' with ID: {tool_use_id}"}]}
18+
return {
19+
"status": "success",
20+
"content": [{"text": f"Context tool processed '{message}' with ID: {tool_use_id}"}],
21+
}
1922

2023

2124
@tool
@@ -36,9 +39,14 @@ def test_strands_context_integration():
3639
agent = Agent(tools=[tool_with_context, tool_with_agent_and_context])
3740

3841
# Test tool with StrandsContext
39-
result1 = agent.tool.tool_with_context(message="hello world")
40-
assert result1.get("status") == "success"
41-
42-
# Test tool with both agent and StrandsContext
43-
result = agent.tool.tool_with_agent_and_context(message="hello agent")
44-
assert result.get("status") == "success"
42+
result_with_context = agent.tool.tool_with_context(message="hello world")
43+
assert (
44+
"Context tool processed 'hello world' with ID: tooluse_tool_with_context_"
45+
in result_with_context["content"][0]["text"]
46+
)
47+
48+
result_with_agent_and_context = agent.tool.tool_with_agent_and_context(message="hello agent", agent=agent)
49+
assert (
50+
"Agent 'Strands Agents' processed 'hello agent' with ID: tooluse_tool_with_agent_and_context_"
51+
in result_with_agent_and_context["content"][0]["text"]
52+
)

0 commit comments

Comments
 (0)