|
| 1 | +"""Integration tests for FallbackModel with real model providers.""" |
| 2 | + |
| 3 | +import os |
| 4 | + |
| 5 | +import pytest |
| 6 | + |
| 7 | +from strands import Agent |
| 8 | +from strands.models import BedrockModel |
| 9 | +from strands.models.anthropic import AnthropicModel |
| 10 | +from strands.models.fallback import FallbackModel |
| 11 | +from strands.models.openai import OpenAIModel |
| 12 | +from tests_integ.models import providers |
| 13 | + |
| 14 | + |
| 15 | +class TestFallbackModelIntegration: |
| 16 | + """Integration tests for FallbackModel with real model instances.""" |
| 17 | + |
| 18 | + @providers.bedrock.mark |
| 19 | + @pytest.mark.asyncio |
| 20 | + async def test_same_provider_fallback_bedrock(self): |
| 21 | + """Test FallbackModel with two BedrockModel instances.""" |
| 22 | + # Use different model IDs - opus as primary, haiku as fallback |
| 23 | + primary = BedrockModel(model_id="us.anthropic.claude-3-5-sonnet-20241022-v2:0", region_name="us-west-2") |
| 24 | + fallback = BedrockModel(model_id="us.anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2") |
| 25 | + |
| 26 | + fallback_model = FallbackModel( |
| 27 | + primary=primary, |
| 28 | + fallback=fallback, |
| 29 | + circuit_failure_threshold=1, # Open circuit quickly for testing |
| 30 | + circuit_time_window=60.0, |
| 31 | + circuit_cooldown_seconds=5, |
| 32 | + ) |
| 33 | + |
| 34 | + # Test successful primary model usage |
| 35 | + messages = [{"role": "user", "content": [{"text": "Say 'Hello from primary model'"}]}] |
| 36 | + |
| 37 | + events = [] |
| 38 | + async for event in fallback_model.stream(messages=messages): |
| 39 | + events.append(event) |
| 40 | + |
| 41 | + # Should have received events |
| 42 | + assert len(events) > 0 |
| 43 | + |
| 44 | + # Check that primary was used (fallback_count should be 0) |
| 45 | + stats = fallback_model.get_stats() |
| 46 | + assert stats["fallback_count"] == 0 |
| 47 | + assert not stats["using_fallback"] |
| 48 | + |
| 49 | + @pytest.mark.skipif( |
| 50 | + "OPENAI_API_KEY" not in os.environ or "AWS_ACCESS_KEY_ID" not in os.environ, |
| 51 | + reason="Both OPENAI_API_KEY and AWS credentials required for cross-provider test", |
| 52 | + ) |
| 53 | + @pytest.mark.asyncio |
| 54 | + async def test_cross_provider_fallback_openai_bedrock(self): |
| 55 | + """Test FallbackModel with OpenAI primary and Bedrock fallback.""" |
| 56 | + primary = OpenAIModel( |
| 57 | + model_id="gpt-4o", |
| 58 | + client_args={ |
| 59 | + "api_key": os.getenv("OPENAI_API_KEY"), |
| 60 | + }, |
| 61 | + ) |
| 62 | + fallback = BedrockModel(model_id="us.anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2") |
| 63 | + |
| 64 | + fallback_model = FallbackModel( |
| 65 | + primary=primary, |
| 66 | + fallback=fallback, |
| 67 | + circuit_failure_threshold=2, |
| 68 | + circuit_time_window=60.0, |
| 69 | + circuit_cooldown_seconds=10, |
| 70 | + ) |
| 71 | + |
| 72 | + # Test successful cross-provider usage |
| 73 | + messages = [{"role": "user", "content": [{"text": "Respond with exactly: 'Cross-provider test successful'"}]}] |
| 74 | + |
| 75 | + events = [] |
| 76 | + async for event in fallback_model.stream(messages=messages): |
| 77 | + events.append(event) |
| 78 | + |
| 79 | + # Should have received events |
| 80 | + assert len(events) > 0 |
| 81 | + |
| 82 | + # Verify we can get configuration from both models |
| 83 | + config = fallback_model.get_config() |
| 84 | + assert "primary_config" in config |
| 85 | + assert "fallback_model_config" in config |
| 86 | + assert "fallback_config" in config |
| 87 | + assert "stats" in config |
| 88 | + |
| 89 | + @providers.bedrock.mark |
| 90 | + @pytest.mark.asyncio |
| 91 | + async def test_agent_integration_with_fallback(self): |
| 92 | + """Test FallbackModel used with Agent class.""" |
| 93 | + # Create fallback model with two Bedrock instances |
| 94 | + primary = BedrockModel(model_id="us.anthropic.claude-3-5-sonnet-20241022-v2:0", region_name="us-west-2") |
| 95 | + fallback = BedrockModel(model_id="us.anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2") |
| 96 | + |
| 97 | + fallback_model = FallbackModel(primary=primary, fallback=fallback, track_stats=True) |
| 98 | + |
| 99 | + # Create agent with fallback model |
| 100 | + agent = Agent(model=fallback_model, system_prompt="You are a helpful assistant. Keep responses brief.") |
| 101 | + |
| 102 | + # Send test message |
| 103 | + response = await agent.invoke_async("What is 2 + 2?") |
| 104 | + |
| 105 | + # Assert response received |
| 106 | + assert response is not None |
| 107 | + assert response.message is not None |
| 108 | + assert len(response.message["content"]) > 0 |
| 109 | + assert response.message["content"][0]["text"] is not None |
| 110 | + |
| 111 | + # Check that the fallback model was used successfully |
| 112 | + stats = fallback_model.get_stats() |
| 113 | + assert isinstance(stats, dict) |
| 114 | + assert "fallback_count" in stats |
| 115 | + assert "primary_failures" in stats |
| 116 | + |
| 117 | + @pytest.mark.skipif( |
| 118 | + "ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY required for Anthropic provider test" |
| 119 | + ) |
| 120 | + @pytest.mark.asyncio |
| 121 | + async def test_cross_provider_anthropic_bedrock(self): |
| 122 | + """Test FallbackModel with Anthropic primary and Bedrock fallback.""" |
| 123 | + primary = AnthropicModel( |
| 124 | + client_args={ |
| 125 | + "api_key": os.getenv("ANTHROPIC_API_KEY"), |
| 126 | + }, |
| 127 | + model_id="claude-3-7-sonnet-20250219", |
| 128 | + max_tokens=512, |
| 129 | + ) |
| 130 | + fallback = BedrockModel(model_id="us.anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2") |
| 131 | + |
| 132 | + fallback_model = FallbackModel(primary=primary, fallback=fallback) |
| 133 | + |
| 134 | + # Test structured output |
| 135 | + from pydantic import BaseModel |
| 136 | + |
| 137 | + class TestResponse(BaseModel): |
| 138 | + message: str |
| 139 | + number: int |
| 140 | + |
| 141 | + messages = [{"role": "user", "content": [{"text": "Return a message 'test' and number 42"}]}] |
| 142 | + |
| 143 | + events = [] |
| 144 | + async for event in fallback_model.structured_output(output_model=TestResponse, prompt=messages): |
| 145 | + events.append(event) |
| 146 | + |
| 147 | + # Should have received events |
| 148 | + assert len(events) > 0 |
| 149 | + |
| 150 | + # Check final event has the structured output |
| 151 | + final_event = events[-1] |
| 152 | + if "output" in final_event: |
| 153 | + output = final_event["output"] |
| 154 | + assert hasattr(output, "message") |
| 155 | + assert hasattr(output, "number") |
| 156 | + |
| 157 | + @providers.bedrock.mark |
| 158 | + @pytest.mark.asyncio |
| 159 | + async def test_fallback_statistics_tracking(self): |
| 160 | + """Test that statistics are properly tracked during integration tests.""" |
| 161 | + primary = BedrockModel(model_id="us.anthropic.claude-3-5-sonnet-20241022-v2:0", region_name="us-west-2") |
| 162 | + fallback = BedrockModel(model_id="us.anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2") |
| 163 | + |
| 164 | + fallback_model = FallbackModel(primary=primary, fallback=fallback, track_stats=True) |
| 165 | + |
| 166 | + # Make a successful request |
| 167 | + messages = [{"role": "user", "content": [{"text": "Say hello"}]}] |
| 168 | + |
| 169 | + events = [] |
| 170 | + async for event in fallback_model.stream(messages=messages): |
| 171 | + events.append(event) |
| 172 | + |
| 173 | + # Check statistics |
| 174 | + stats = fallback_model.get_stats() |
| 175 | + assert stats["fallback_count"] == 0 # No fallback should have occurred |
| 176 | + assert stats["primary_failures"] == 0 # No failures |
| 177 | + assert not stats["using_fallback"] # Not using fallback |
| 178 | + assert not stats["circuit_open"] # Circuit should be closed |
| 179 | + |
| 180 | + # Test configuration retrieval |
| 181 | + config = fallback_model.get_config() |
| 182 | + assert config["stats"] is not None |
| 183 | + assert config["fallback_config"]["track_stats"] is True |
| 184 | + |
| 185 | + # Test stats reset |
| 186 | + fallback_model.reset_stats() |
| 187 | + reset_stats = fallback_model.get_stats() |
| 188 | + assert reset_stats["fallback_count"] == 0 |
| 189 | + assert reset_stats["primary_failures"] == 0 |
| 190 | + |
| 191 | + @providers.bedrock.mark |
| 192 | + @pytest.mark.asyncio |
| 193 | + async def test_tool_calling_with_fallback_model(self): |
| 194 | + """Test that tool_specs and tool_choice parameters work with FallbackModel.""" |
| 195 | + # Create fallback model with two Bedrock instances |
| 196 | + primary = BedrockModel(model_id="us.anthropic.claude-3-5-sonnet-20241022-v2:0", region_name="us-west-2") |
| 197 | + fallback = BedrockModel(model_id="us.anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2") |
| 198 | + |
| 199 | + fallback_model = FallbackModel(primary=primary, fallback=fallback, track_stats=True) |
| 200 | + |
| 201 | + # Define a simple tool spec |
| 202 | + tool_specs = [ |
| 203 | + { |
| 204 | + "name": "get_weather", |
| 205 | + "description": "Get weather information for a location", |
| 206 | + "inputSchema": { |
| 207 | + "json": { |
| 208 | + "type": "object", |
| 209 | + "properties": { |
| 210 | + "location": {"type": "string", "description": "The location to get weather for"} |
| 211 | + }, |
| 212 | + "required": ["location"], |
| 213 | + } |
| 214 | + }, |
| 215 | + } |
| 216 | + ] |
| 217 | + |
| 218 | + tool_choice = {"auto": {}} |
| 219 | + |
| 220 | + # Test message that might trigger tool use |
| 221 | + messages = [{"role": "user", "content": [{"text": "What's the weather in Seattle?"}]}] |
| 222 | + |
| 223 | + # Stream with tool parameters |
| 224 | + events = [] |
| 225 | + async for event in fallback_model.stream(messages=messages, tool_specs=tool_specs, tool_choice=tool_choice): |
| 226 | + events.append(event) |
| 227 | + |
| 228 | + # Should have received events |
| 229 | + assert len(events) > 0 |
| 230 | + |
| 231 | + # Verify primary was used (no fallback) |
| 232 | + stats = fallback_model.get_stats() |
| 233 | + assert stats["fallback_count"] == 0 |
| 234 | + assert not stats["using_fallback"] |
| 235 | + |
| 236 | + @providers.bedrock.mark |
| 237 | + @pytest.mark.asyncio |
| 238 | + async def test_tool_calling_with_agent_and_fallback_model(self): |
| 239 | + """Test that FallbackModel works with Agent class when tools are provided.""" |
| 240 | + # Create fallback model with two Bedrock instances |
| 241 | + primary = BedrockModel(model_id="us.anthropic.claude-3-5-sonnet-20241022-v2:0", region_name="us-west-2") |
| 242 | + fallback = BedrockModel(model_id="us.anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2") |
| 243 | + |
| 244 | + fallback_model = FallbackModel(primary=primary, fallback=fallback, track_stats=True) |
| 245 | + |
| 246 | + # Create a simple tool using the strands tool decorator |
| 247 | + from strands import tool |
| 248 | + |
| 249 | + @tool |
| 250 | + def get_current_time(timezone: str = "UTC") -> dict: |
| 251 | + """Get the current time in a specific timezone.""" |
| 252 | + return {"status": "success", "content": [{"text": f"Current time in {timezone}: 12:00 PM"}]} |
| 253 | + |
| 254 | + # Create agent with fallback model and tool |
| 255 | + agent = Agent(model=fallback_model, tools=[get_current_time], system_prompt="You are a helpful assistant.") |
| 256 | + |
| 257 | + # Send test message |
| 258 | + response = await agent.invoke_async("What time is it?") |
| 259 | + |
| 260 | + # Assert response received |
| 261 | + assert response is not None |
| 262 | + assert response.message is not None |
| 263 | + assert len(response.message["content"]) > 0 |
| 264 | + |
| 265 | + # Verify primary was used (no fallback) |
| 266 | + stats = fallback_model.get_stats() |
| 267 | + assert stats["fallback_count"] == 0 |
| 268 | + assert not stats["using_fallback"] |
0 commit comments