Skip to content

Commit 83247f6

Browse files
committed
test(models): add integration tests for FallbackModel
1 parent 8062108 commit 83247f6

File tree

1 file changed

+268
-0
lines changed

1 file changed

+268
-0
lines changed
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
"""Integration tests for FallbackModel with real model providers."""
2+
3+
import os
4+
5+
import pytest
6+
7+
from strands import Agent
8+
from strands.models import BedrockModel
9+
from strands.models.anthropic import AnthropicModel
10+
from strands.models.fallback import FallbackModel
11+
from strands.models.openai import OpenAIModel
12+
from tests_integ.models import providers
13+
14+
15+
class TestFallbackModelIntegration:
16+
"""Integration tests for FallbackModel with real model instances."""
17+
18+
@providers.bedrock.mark
19+
@pytest.mark.asyncio
20+
async def test_same_provider_fallback_bedrock(self):
21+
"""Test FallbackModel with two BedrockModel instances."""
22+
# Use different model IDs - opus as primary, haiku as fallback
23+
primary = BedrockModel(model_id="us.anthropic.claude-3-5-sonnet-20241022-v2:0", region_name="us-west-2")
24+
fallback = BedrockModel(model_id="us.anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2")
25+
26+
fallback_model = FallbackModel(
27+
primary=primary,
28+
fallback=fallback,
29+
circuit_failure_threshold=1, # Open circuit quickly for testing
30+
circuit_time_window=60.0,
31+
circuit_cooldown_seconds=5,
32+
)
33+
34+
# Test successful primary model usage
35+
messages = [{"role": "user", "content": [{"text": "Say 'Hello from primary model'"}]}]
36+
37+
events = []
38+
async for event in fallback_model.stream(messages=messages):
39+
events.append(event)
40+
41+
# Should have received events
42+
assert len(events) > 0
43+
44+
# Check that primary was used (fallback_count should be 0)
45+
stats = fallback_model.get_stats()
46+
assert stats["fallback_count"] == 0
47+
assert not stats["using_fallback"]
48+
49+
@pytest.mark.skipif(
50+
"OPENAI_API_KEY" not in os.environ or "AWS_ACCESS_KEY_ID" not in os.environ,
51+
reason="Both OPENAI_API_KEY and AWS credentials required for cross-provider test",
52+
)
53+
@pytest.mark.asyncio
54+
async def test_cross_provider_fallback_openai_bedrock(self):
55+
"""Test FallbackModel with OpenAI primary and Bedrock fallback."""
56+
primary = OpenAIModel(
57+
model_id="gpt-4o",
58+
client_args={
59+
"api_key": os.getenv("OPENAI_API_KEY"),
60+
},
61+
)
62+
fallback = BedrockModel(model_id="us.anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2")
63+
64+
fallback_model = FallbackModel(
65+
primary=primary,
66+
fallback=fallback,
67+
circuit_failure_threshold=2,
68+
circuit_time_window=60.0,
69+
circuit_cooldown_seconds=10,
70+
)
71+
72+
# Test successful cross-provider usage
73+
messages = [{"role": "user", "content": [{"text": "Respond with exactly: 'Cross-provider test successful'"}]}]
74+
75+
events = []
76+
async for event in fallback_model.stream(messages=messages):
77+
events.append(event)
78+
79+
# Should have received events
80+
assert len(events) > 0
81+
82+
# Verify we can get configuration from both models
83+
config = fallback_model.get_config()
84+
assert "primary_config" in config
85+
assert "fallback_model_config" in config
86+
assert "fallback_config" in config
87+
assert "stats" in config
88+
89+
@providers.bedrock.mark
90+
@pytest.mark.asyncio
91+
async def test_agent_integration_with_fallback(self):
92+
"""Test FallbackModel used with Agent class."""
93+
# Create fallback model with two Bedrock instances
94+
primary = BedrockModel(model_id="us.anthropic.claude-3-5-sonnet-20241022-v2:0", region_name="us-west-2")
95+
fallback = BedrockModel(model_id="us.anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2")
96+
97+
fallback_model = FallbackModel(primary=primary, fallback=fallback, track_stats=True)
98+
99+
# Create agent with fallback model
100+
agent = Agent(model=fallback_model, system_prompt="You are a helpful assistant. Keep responses brief.")
101+
102+
# Send test message
103+
response = await agent.invoke_async("What is 2 + 2?")
104+
105+
# Assert response received
106+
assert response is not None
107+
assert response.message is not None
108+
assert len(response.message["content"]) > 0
109+
assert response.message["content"][0]["text"] is not None
110+
111+
# Check that the fallback model was used successfully
112+
stats = fallback_model.get_stats()
113+
assert isinstance(stats, dict)
114+
assert "fallback_count" in stats
115+
assert "primary_failures" in stats
116+
117+
@pytest.mark.skipif(
118+
"ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY required for Anthropic provider test"
119+
)
120+
@pytest.mark.asyncio
121+
async def test_cross_provider_anthropic_bedrock(self):
122+
"""Test FallbackModel with Anthropic primary and Bedrock fallback."""
123+
primary = AnthropicModel(
124+
client_args={
125+
"api_key": os.getenv("ANTHROPIC_API_KEY"),
126+
},
127+
model_id="claude-3-7-sonnet-20250219",
128+
max_tokens=512,
129+
)
130+
fallback = BedrockModel(model_id="us.anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2")
131+
132+
fallback_model = FallbackModel(primary=primary, fallback=fallback)
133+
134+
# Test structured output
135+
from pydantic import BaseModel
136+
137+
class TestResponse(BaseModel):
138+
message: str
139+
number: int
140+
141+
messages = [{"role": "user", "content": [{"text": "Return a message 'test' and number 42"}]}]
142+
143+
events = []
144+
async for event in fallback_model.structured_output(output_model=TestResponse, prompt=messages):
145+
events.append(event)
146+
147+
# Should have received events
148+
assert len(events) > 0
149+
150+
# Check final event has the structured output
151+
final_event = events[-1]
152+
if "output" in final_event:
153+
output = final_event["output"]
154+
assert hasattr(output, "message")
155+
assert hasattr(output, "number")
156+
157+
@providers.bedrock.mark
158+
@pytest.mark.asyncio
159+
async def test_fallback_statistics_tracking(self):
160+
"""Test that statistics are properly tracked during integration tests."""
161+
primary = BedrockModel(model_id="us.anthropic.claude-3-5-sonnet-20241022-v2:0", region_name="us-west-2")
162+
fallback = BedrockModel(model_id="us.anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2")
163+
164+
fallback_model = FallbackModel(primary=primary, fallback=fallback, track_stats=True)
165+
166+
# Make a successful request
167+
messages = [{"role": "user", "content": [{"text": "Say hello"}]}]
168+
169+
events = []
170+
async for event in fallback_model.stream(messages=messages):
171+
events.append(event)
172+
173+
# Check statistics
174+
stats = fallback_model.get_stats()
175+
assert stats["fallback_count"] == 0 # No fallback should have occurred
176+
assert stats["primary_failures"] == 0 # No failures
177+
assert not stats["using_fallback"] # Not using fallback
178+
assert not stats["circuit_open"] # Circuit should be closed
179+
180+
# Test configuration retrieval
181+
config = fallback_model.get_config()
182+
assert config["stats"] is not None
183+
assert config["fallback_config"]["track_stats"] is True
184+
185+
# Test stats reset
186+
fallback_model.reset_stats()
187+
reset_stats = fallback_model.get_stats()
188+
assert reset_stats["fallback_count"] == 0
189+
assert reset_stats["primary_failures"] == 0
190+
191+
@providers.bedrock.mark
192+
@pytest.mark.asyncio
193+
async def test_tool_calling_with_fallback_model(self):
194+
"""Test that tool_specs and tool_choice parameters work with FallbackModel."""
195+
# Create fallback model with two Bedrock instances
196+
primary = BedrockModel(model_id="us.anthropic.claude-3-5-sonnet-20241022-v2:0", region_name="us-west-2")
197+
fallback = BedrockModel(model_id="us.anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2")
198+
199+
fallback_model = FallbackModel(primary=primary, fallback=fallback, track_stats=True)
200+
201+
# Define a simple tool spec
202+
tool_specs = [
203+
{
204+
"name": "get_weather",
205+
"description": "Get weather information for a location",
206+
"inputSchema": {
207+
"json": {
208+
"type": "object",
209+
"properties": {
210+
"location": {"type": "string", "description": "The location to get weather for"}
211+
},
212+
"required": ["location"],
213+
}
214+
},
215+
}
216+
]
217+
218+
tool_choice = {"auto": {}}
219+
220+
# Test message that might trigger tool use
221+
messages = [{"role": "user", "content": [{"text": "What's the weather in Seattle?"}]}]
222+
223+
# Stream with tool parameters
224+
events = []
225+
async for event in fallback_model.stream(messages=messages, tool_specs=tool_specs, tool_choice=tool_choice):
226+
events.append(event)
227+
228+
# Should have received events
229+
assert len(events) > 0
230+
231+
# Verify primary was used (no fallback)
232+
stats = fallback_model.get_stats()
233+
assert stats["fallback_count"] == 0
234+
assert not stats["using_fallback"]
235+
236+
@providers.bedrock.mark
237+
@pytest.mark.asyncio
238+
async def test_tool_calling_with_agent_and_fallback_model(self):
239+
"""Test that FallbackModel works with Agent class when tools are provided."""
240+
# Create fallback model with two Bedrock instances
241+
primary = BedrockModel(model_id="us.anthropic.claude-3-5-sonnet-20241022-v2:0", region_name="us-west-2")
242+
fallback = BedrockModel(model_id="us.anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2")
243+
244+
fallback_model = FallbackModel(primary=primary, fallback=fallback, track_stats=True)
245+
246+
# Create a simple tool using the strands tool decorator
247+
from strands import tool
248+
249+
@tool
250+
def get_current_time(timezone: str = "UTC") -> dict:
251+
"""Get the current time in a specific timezone."""
252+
return {"status": "success", "content": [{"text": f"Current time in {timezone}: 12:00 PM"}]}
253+
254+
# Create agent with fallback model and tool
255+
agent = Agent(model=fallback_model, tools=[get_current_time], system_prompt="You are a helpful assistant.")
256+
257+
# Send test message
258+
response = await agent.invoke_async("What time is it?")
259+
260+
# Assert response received
261+
assert response is not None
262+
assert response.message is not None
263+
assert len(response.message["content"]) > 0
264+
265+
# Verify primary was used (no fallback)
266+
stats = fallback_model.get_stats()
267+
assert stats["fallback_count"] == 0
268+
assert not stats["using_fallback"]

0 commit comments

Comments
 (0)