Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ async def structured_output(
stop_reason, messages, _, _ = event["stop"]

if stop_reason != "tool_use":
raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".")
raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".')

content = messages["content"]
output_response: dict[str, Any] | None = None
Expand Down
2 changes: 1 addition & 1 deletion src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ async def structured_output(
stop_reason, messages, _, _ = event["stop"]

if stop_reason != "tool_use":
raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".")
raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".')

content = messages["content"]
output_response: dict[str, Any] | None = None
Expand Down
14 changes: 13 additions & 1 deletion src/strands/tools/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def _create_input_model(self) -> Type[BaseModel]:
if name in ("self", "cls", "agent"):
continue

# Skip **kwargs parameter
if param.kind == inspect.Parameter.VAR_KEYWORD:
continue

# Get parameter type and default
param_type = self.type_hints.get(name, Any)
default = ... if param.default is inspect.Parameter.empty else param.default
Expand Down Expand Up @@ -402,10 +406,18 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
# Validate input against the Pydantic model
validated_input = self._metadata.validate_input(tool_input)

# Pass along the agent if provided and expected by the function
# To keep backwards-compatibility
# The prefered option is to pass the agent via invocation state (or **kwargs)
if "agent" in invocation_state and "agent" in self._metadata.signature.parameters:
validated_input["agent"] = invocation_state.get("agent")

# Pass invocation_state contents as kwargs if the function has **kwargs
has_kwargs = any(
param.kind == inspect.Parameter.VAR_KEYWORD for param in self._metadata.signature.parameters.values()
)
if has_kwargs:
validated_input.update(invocation_state)

# "Too few arguments" expected, hence the type ignore
if inspect.iscoroutinefunction(self._tool_func):
result = await self._tool_func(**validated_input) # type: ignore
Expand Down
33 changes: 33 additions & 0 deletions tests/strands/agent/test_agent_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,36 @@ def update_state(agent: Agent):

assert agent.state.get("hello") == "world"
assert agent.state.get("foo") == "baz"


def test_agent_state_update_from_tool_using_keyword_param():
@tool
def update_state(**kwargs):
agent = kwargs.get("agent")
assert agent is not None
assert agent.state.get("foo") == "bar"
agent.state.set("hello", "world")
agent.state.set("foo", "baz")

agent_messages: Messages = [
{
"role": "assistant",
"content": [{"toolUse": {"name": "update_state", "toolUseId": "123", "input": {}}}],
},
{"role": "assistant", "content": [{"text": "I invoked a tool!"}]},
]
mocked_model_provider = MockedModelProvider(agent_messages)

agent = Agent(
model=mocked_model_provider,
tools=[update_state],
state={"foo": "bar"},
)

assert agent.state.get("hello") is None
assert agent.state.get("foo") == "bar"

agent("Invoke Mocked!")

assert agent.state.get("hello") == "world"
assert agent.state.get("foo") == "baz"
222 changes: 221 additions & 1 deletion tests/strands/tools/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ async def test_stream(identity_tool, alist):
@pytest.mark.asyncio
async def test_stream_with_agent(alist):
@strands.tool
def identity(a: int, agent: dict = None):
def identity(a: int, **kwargs):
agent = kwargs.get("agent")
return a, agent

stream = identity.stream({"input": {"a": 2}}, {"agent": {"state": 1}})
Expand Down Expand Up @@ -1036,3 +1037,222 @@ def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None]
result = (await alist(stream))[-1]
assert result["status"] == "success"
assert "NoneType: None" in result["content"][0]["text"]


@pytest.mark.asyncio
async def test_tool_with_kwargs_invocation_state(alist):
"""Test that tools with **kwargs receive invocation_state contents."""

@strands.tool
def kwargs_tool(param: str, **kwargs) -> str:
"""Tool with **kwargs that should receive invocation_state.

Args:
param: Regular parameter
"""
agent = kwargs.get("agent", "no_agent")
custom_key = kwargs.get("custom_key", "no_custom_key")
return f"param: {param}, agent: {agent}, custom_key: {custom_key}"

# Test with invocation_state containing agent and custom keys
invocation_state = {"agent": "test_agent", "custom_key": "test_value", "another_key": "another_value"}

tool_use = {"toolUseId": "test-id", "input": {"param": "test_param"}}
stream = kwargs_tool.stream(tool_use, invocation_state)

result = (await alist(stream))[-1]
assert result["status"] == "success"
assert "param: test_param" in result["content"][0]["text"]
assert "agent: test_agent" in result["content"][0]["text"]
assert "custom_key: test_value" in result["content"][0]["text"]


@pytest.mark.asyncio
async def test_tool_with_different_keyword_param_name(alist):
"""Test that tools with **kwargs receive invocation_state contents."""

@strands.tool
def kwargs_tool(param: str, **context) -> str:
"""Tool with different keyword parameter name that should receive invocation_state.

Args:
param: Regular parameter
"""
agent = context.get("agent", "no_agent")
custom_key = context.get("custom_key", "no_custom_key")
return f"param: {param}, agent: {agent}, custom_key: {custom_key}"

# Test with invocation_state containing agent and custom keys
invocation_state = {"agent": "test_agent", "custom_key": "test_value", "another_key": "another_value"}

tool_use = {"toolUseId": "test-id", "input": {"param": "test_param"}}
stream = kwargs_tool.stream(tool_use, invocation_state)

result = (await alist(stream))[-1]
assert result["status"] == "success"
assert "param: test_param" in result["content"][0]["text"]
assert "agent: test_agent" in result["content"][0]["text"]
assert "custom_key: test_value" in result["content"][0]["text"]


@pytest.mark.asyncio
async def test_tool_without_kwargs_no_invocation_state(alist):
"""Test that tools without **kwargs don't receive invocation_state contents."""

@strands.tool
def no_kwargs_tool(param: str) -> str:
"""Tool without **kwargs.

Args:
param: Regular parameter
"""
return f"param: {param}"

# Test with invocation_state - should not be passed to function
invocation_state = {"agent": "test_agent", "custom_key": "test_value"}

tool_use = {"toolUseId": "test-id", "input": {"param": "test_param"}}
stream = no_kwargs_tool.stream(tool_use, invocation_state)

result = (await alist(stream))[-1]
assert result["status"] == "success"
assert result["content"][0]["text"] == "param: test_param"


@pytest.mark.asyncio
async def test_tool_kwargs_with_empty_invocation_state(alist):
"""Test that tools with **kwargs work with empty invocation_state."""

@strands.tool
def kwargs_tool_empty(param: str, **kwargs) -> str:
"""Tool with **kwargs tested with empty invocation_state.

Args:
param: Regular parameter
"""
kwargs_count = len(kwargs)
return f"param: {param}, kwargs_count: {kwargs_count}"

# Test with empty invocation_state
tool_use = {"toolUseId": "test-id", "input": {"param": "test_param"}}
stream = kwargs_tool_empty.stream(tool_use, {})

result = (await alist(stream))[-1]
assert result["status"] == "success"
assert "param: test_param, kwargs_count: 0" in result["content"][0]["text"]


@pytest.mark.asyncio
async def test_tool_kwargs_schema_generation(alist):
"""Test that **kwargs parameters are excluded from schema generation."""

@strands.tool
def kwargs_schema_tool(required_param: str, optional_param: str = "default", **kwargs) -> str:
"""Tool with **kwargs to test schema generation.

Args:
required_param: Required parameter
optional_param: Optional parameter with default
"""
return f"required: {required_param}, optional: {optional_param}"

# Check that schema doesn't include **kwargs
spec = kwargs_schema_tool.tool_spec
schema = spec["inputSchema"]["json"]

# Should have required_param and optional_param, but not kwargs
assert "required_param" in schema["properties"]
assert "optional_param" in schema["properties"]
assert len(schema["properties"]) == 2 # Only these two parameters

# Only required_param should be required
assert schema["required"] == ["required_param"]


@pytest.mark.asyncio
async def test_tool_kwargs_with_mixed_parameters(alist):
"""Test tool with **kwargs alongside regular and optional parameters."""

@strands.tool
def mixed_params_tool(required: str, optional: int = 42, **kwargs) -> str:
"""Tool with mixed parameter types.

Args:
required: Required parameter
optional: Optional parameter with default
"""
agent = kwargs.get("agent", "no_agent")
extra_data = kwargs.get("extra_data", "no_extra")
return f"required: {required}, optional: {optional}, agent: {agent}, extra: {extra_data}"

# Test with all types of parameters
invocation_state = {"agent": "test_agent", "extra_data": "extra_value", "unused_key": "unused_value"}

tool_use = {"toolUseId": "test-id", "input": {"required": "test_req", "optional": 100}}
stream = mixed_params_tool.stream(tool_use, invocation_state)

result = (await alist(stream))[-1]
assert result["status"] == "success"
content = result["content"][0]["text"]
assert "required: test_req" in content
assert "optional: 100" in content
assert "agent: test_agent" in content
assert "extra: extra_value" in content


@pytest.mark.asyncio
async def test_tool_kwargs_async_function(alist):
"""Test that **kwargs work with async tool functions."""

@strands.tool
async def async_kwargs_tool(param: str, **kwargs) -> str:
"""Async tool with **kwargs.

Args:
param: Regular parameter
"""
agent = kwargs.get("agent", "no_agent")
return f"async param: {param}, agent: {agent}"

# Test async function with invocation_state
invocation_state = {"agent": "async_agent"}

tool_use = {"toolUseId": "test-id", "input": {"param": "async_test"}}
stream = async_kwargs_tool.stream(tool_use, invocation_state)

result = (await alist(stream))[-1]
assert result["status"] == "success"
assert "async param: async_test" in result["content"][0]["text"]
assert "agent: async_agent" in result["content"][0]["text"]


@pytest.mark.asyncio
async def test_tool_kwargs_class_method(alist):
"""Test that **kwargs work with class methods."""

class TestClass:
def __init__(self, prefix):
self.prefix = prefix

@strands.tool
def kwargs_method(self, param: str, **kwargs) -> str:
"""Class method with **kwargs.

Args:
param: Test parameter
"""
agent = kwargs.get("agent", "no_agent")
return f"{self.prefix}: param: {param}, agent: {agent}"

# Test class method with invocation_state
instance = TestClass("TestPrefix")
invocation_state = {"agent": "class_agent"}

tool_use = {"toolUseId": "test-id", "input": {"param": "class_test"}}
stream = instance.kwargs_method.stream(tool_use, invocation_state)

result = (await alist(stream))[-1]
assert result["status"] == "success"
content = result["content"][0]["text"]
assert "TestPrefix: param: class_test" in content
assert "agent: class_agent" in content
Loading