@@ -1092,16 +1092,9 @@ async def test_state_reset_only_with_cycles_enabled():
10921092@pytest .mark .asyncio
10931093async def test_graph_kwargs_passing_agent (mock_strands_tracer , mock_use_span ):
10941094 """Test that kwargs are passed through to underlying Agent nodes."""
1095- # Create a mock agent that captures kwargs
1095+ # Create a mock agent
10961096 kwargs_agent = create_mock_agent ("kwargs_agent" , "Response with kwargs" )
10971097
1098- async def capture_kwargs (* args , ** kwargs ):
1099- # Store kwargs for verification
1100- capture_kwargs .captured_kwargs = kwargs
1101- return kwargs_agent .return_value
1102-
1103- kwargs_agent .invoke_async = MagicMock (side_effect = capture_kwargs )
1104-
11051098 # Create graph
11061099 builder = GraphBuilder ()
11071100 builder .add_node (kwargs_agent , "kwargs_node" )
@@ -1111,28 +1104,19 @@ async def capture_kwargs(*args, **kwargs):
11111104 test_kwargs = {"custom_param" : "test_value" , "another_param" : 42 }
11121105 result = await graph .invoke_async ("Test kwargs passing" , ** test_kwargs )
11131106
1114- # Verify kwargs were passed to agent
1115- assert hasattr (capture_kwargs , "captured_kwargs" )
1116- assert capture_kwargs .captured_kwargs == test_kwargs
1107+ # Verify kwargs were passed to agent using call_args
1108+ kwargs_agent .invoke_async .assert_called_once ()
1109+ call_args , call_kwargs = kwargs_agent .invoke_async .call_args
1110+ assert call_kwargs == test_kwargs
11171111 assert result .status == Status .COMPLETED
11181112
11191113
11201114@pytest .mark .asyncio
11211115async def test_graph_kwargs_passing_multiagent (mock_strands_tracer , mock_use_span ):
11221116 """Test that kwargs are passed through to underlying MultiAgentBase nodes."""
1123- # Create a mock MultiAgentBase that captures kwargs
1117+ # Create a mock MultiAgentBase
11241118 kwargs_multiagent = create_mock_multi_agent ("kwargs_multiagent" , "MultiAgent response with kwargs" )
11251119
1126- # Store the original return value
1127- original_result = kwargs_multiagent .invoke_async .return_value
1128-
1129- async def capture_kwargs (* args , ** kwargs ):
1130- # Store kwargs for verification
1131- capture_kwargs .captured_kwargs = kwargs
1132- return original_result
1133-
1134- kwargs_multiagent .invoke_async = AsyncMock (side_effect = capture_kwargs )
1135-
11361120 # Create graph
11371121 builder = GraphBuilder ()
11381122 builder .add_node (kwargs_multiagent , "multiagent_node" )
@@ -1142,24 +1126,18 @@ async def capture_kwargs(*args, **kwargs):
11421126 test_kwargs = {"custom_param" : "test_value" , "another_param" : 42 }
11431127 result = await graph .invoke_async ("Test kwargs passing to multiagent" , ** test_kwargs )
11441128
1145- # Verify kwargs were passed to multiagent
1146- assert hasattr (capture_kwargs , "captured_kwargs" )
1147- assert capture_kwargs .captured_kwargs == test_kwargs
1129+ # Verify kwargs were passed to multiagent using call_args
1130+ kwargs_multiagent .invoke_async .assert_called_once ()
1131+ call_args , call_kwargs = kwargs_multiagent .invoke_async .call_args
1132+ assert call_kwargs == test_kwargs
11481133 assert result .status == Status .COMPLETED
11491134
11501135
11511136def test_graph_kwargs_passing_sync (mock_strands_tracer , mock_use_span ):
11521137 """Test that kwargs are passed through to underlying nodes in sync execution."""
1153- # Create a mock agent that captures kwargs
1138+ # Create a mock agent
11541139 kwargs_agent = create_mock_agent ("kwargs_agent" , "Response with kwargs" )
11551140
1156- async def capture_kwargs (* args , ** kwargs ):
1157- # Store kwargs for verification
1158- capture_kwargs .captured_kwargs = kwargs
1159- return kwargs_agent .return_value
1160-
1161- kwargs_agent .invoke_async = MagicMock (side_effect = capture_kwargs )
1162-
11631141 # Create graph
11641142 builder = GraphBuilder ()
11651143 builder .add_node (kwargs_agent , "kwargs_node" )
@@ -1169,7 +1147,8 @@ async def capture_kwargs(*args, **kwargs):
11691147 test_kwargs = {"custom_param" : "test_value" , "another_param" : 42 }
11701148 result = graph ("Test kwargs passing sync" , ** test_kwargs )
11711149
1172- # Verify kwargs were passed to agent
1173- assert hasattr (capture_kwargs , "captured_kwargs" )
1174- assert capture_kwargs .captured_kwargs == test_kwargs
1150+ # Verify kwargs were passed to agent using call_args
1151+ kwargs_agent .invoke_async .assert_called_once ()
1152+ call_args , call_kwargs = kwargs_agent .invoke_async .call_args
1153+ assert call_kwargs == test_kwargs
11751154 assert result .status == Status .COMPLETED
0 commit comments