Skip to content

Commit 8caa9cb

Browse files
committed
fix(tests): use call_args instead of capturing kwargs in multiagent tests
- Replace custom capture_kwargs functions with direct mock verification using call_args - Use existing mock setup from create_mock_agent/create_mock_multi_agent instead of overriding with AsyncMock - Apply consistent pattern across all three kwargs passing tests - Addresses reviewer feedback for cleaner test implementation Fixes strands-agents#816
1 parent 9a895f7 commit 8caa9cb

File tree

1 file changed

+15
-36
lines changed

1 file changed

+15
-36
lines changed

tests/strands/multiagent/test_graph.py

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,16 +1092,9 @@ async def test_state_reset_only_with_cycles_enabled():
10921092
@pytest.mark.asyncio
10931093
async def test_graph_kwargs_passing_agent(mock_strands_tracer, mock_use_span):
10941094
"""Test that kwargs are passed through to underlying Agent nodes."""
1095-
# Create a mock agent that captures kwargs
1095+
# Create a mock agent
10961096
kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs")
10971097

1098-
async def capture_kwargs(*args, **kwargs):
1099-
# Store kwargs for verification
1100-
capture_kwargs.captured_kwargs = kwargs
1101-
return kwargs_agent.return_value
1102-
1103-
kwargs_agent.invoke_async = MagicMock(side_effect=capture_kwargs)
1104-
11051098
# Create graph
11061099
builder = GraphBuilder()
11071100
builder.add_node(kwargs_agent, "kwargs_node")
@@ -1111,28 +1104,19 @@ async def capture_kwargs(*args, **kwargs):
11111104
test_kwargs = {"custom_param": "test_value", "another_param": 42}
11121105
result = await graph.invoke_async("Test kwargs passing", **test_kwargs)
11131106

1114-
# Verify kwargs were passed to agent
1115-
assert hasattr(capture_kwargs, "captured_kwargs")
1116-
assert capture_kwargs.captured_kwargs == test_kwargs
1107+
# Verify kwargs were passed to agent using call_args
1108+
kwargs_agent.invoke_async.assert_called_once()
1109+
call_args, call_kwargs = kwargs_agent.invoke_async.call_args
1110+
assert call_kwargs == test_kwargs
11171111
assert result.status == Status.COMPLETED
11181112

11191113

11201114
@pytest.mark.asyncio
11211115
async def test_graph_kwargs_passing_multiagent(mock_strands_tracer, mock_use_span):
11221116
"""Test that kwargs are passed through to underlying MultiAgentBase nodes."""
1123-
# Create a mock MultiAgentBase that captures kwargs
1117+
# Create a mock MultiAgentBase
11241118
kwargs_multiagent = create_mock_multi_agent("kwargs_multiagent", "MultiAgent response with kwargs")
11251119

1126-
# Store the original return value
1127-
original_result = kwargs_multiagent.invoke_async.return_value
1128-
1129-
async def capture_kwargs(*args, **kwargs):
1130-
# Store kwargs for verification
1131-
capture_kwargs.captured_kwargs = kwargs
1132-
return original_result
1133-
1134-
kwargs_multiagent.invoke_async = AsyncMock(side_effect=capture_kwargs)
1135-
11361120
# Create graph
11371121
builder = GraphBuilder()
11381122
builder.add_node(kwargs_multiagent, "multiagent_node")
@@ -1142,24 +1126,18 @@ async def capture_kwargs(*args, **kwargs):
11421126
test_kwargs = {"custom_param": "test_value", "another_param": 42}
11431127
result = await graph.invoke_async("Test kwargs passing to multiagent", **test_kwargs)
11441128

1145-
# Verify kwargs were passed to multiagent
1146-
assert hasattr(capture_kwargs, "captured_kwargs")
1147-
assert capture_kwargs.captured_kwargs == test_kwargs
1129+
# Verify kwargs were passed to multiagent using call_args
1130+
kwargs_multiagent.invoke_async.assert_called_once()
1131+
call_args, call_kwargs = kwargs_multiagent.invoke_async.call_args
1132+
assert call_kwargs == test_kwargs
11481133
assert result.status == Status.COMPLETED
11491134

11501135

11511136
def test_graph_kwargs_passing_sync(mock_strands_tracer, mock_use_span):
11521137
"""Test that kwargs are passed through to underlying nodes in sync execution."""
1153-
# Create a mock agent that captures kwargs
1138+
# Create a mock agent
11541139
kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs")
11551140

1156-
async def capture_kwargs(*args, **kwargs):
1157-
# Store kwargs for verification
1158-
capture_kwargs.captured_kwargs = kwargs
1159-
return kwargs_agent.return_value
1160-
1161-
kwargs_agent.invoke_async = MagicMock(side_effect=capture_kwargs)
1162-
11631141
# Create graph
11641142
builder = GraphBuilder()
11651143
builder.add_node(kwargs_agent, "kwargs_node")
@@ -1169,7 +1147,8 @@ async def capture_kwargs(*args, **kwargs):
11691147
test_kwargs = {"custom_param": "test_value", "another_param": 42}
11701148
result = graph("Test kwargs passing sync", **test_kwargs)
11711149

1172-
# Verify kwargs were passed to agent
1173-
assert hasattr(capture_kwargs, "captured_kwargs")
1174-
assert capture_kwargs.captured_kwargs == test_kwargs
1150+
# Verify kwargs were passed to agent using call_args
1151+
kwargs_agent.invoke_async.assert_called_once()
1152+
call_args, call_kwargs = kwargs_agent.invoke_async.call_args
1153+
assert call_kwargs == test_kwargs
11751154
assert result.status == Status.COMPLETED

0 commit comments

Comments
 (0)