diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index c1e99f1a2..1e4bf42a6 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -50,6 +50,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: """ self.client_args = client_args or {} self.config = dict(model_config) + self._apply_proxy_prefix() logger.debug("config=<%s> | initializing", self.config) @@ -61,6 +62,7 @@ def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type: **model_config: Configuration overrides. """ self.config.update(model_config) + self._apply_proxy_prefix() @override def get_config(self) -> LiteLLMConfig: @@ -223,3 +225,14 @@ async def structured_output( # If no tool_calls found, raise an error raise ValueError("No tool_calls found in response") + + def _apply_proxy_prefix(self) -> None: + """Apply litellm_proxy/ prefix to model_id when use_litellm_proxy is True. + + This is a workaround for https://github.com/BerriAI/litellm/issues/13454 + where use_litellm_proxy parameter is not honored. + """ + if self.client_args.get("use_litellm_proxy") and "model_id" in self.config: + model_id = self.get_config()["model_id"] + if not model_id.startswith("litellm_proxy/"): + self.config["model_id"] = f"litellm_proxy/{model_id}" diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 44b6df63b..443bd3e8f 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -58,6 +58,39 @@ def test_update_config(model, model_id): assert tru_model_id == exp_model_id +@pytest.mark.parametrize( + "client_args, model_id, expected_model_id", + [ + ({"use_litellm_proxy": True}, "openai/gpt-4", "litellm_proxy/openai/gpt-4"), + ({"use_litellm_proxy": False}, "openai/gpt-4", "openai/gpt-4"), + ({"use_litellm_proxy": None}, "openai/gpt-4", "openai/gpt-4"), + ({}, "openai/gpt-4", "openai/gpt-4"), + (None, "openai/gpt-4", "openai/gpt-4"), + ({"use_litellm_proxy": True}, "litellm_proxy/openai/gpt-4", "litellm_proxy/openai/gpt-4"), + ({"use_litellm_proxy": False}, "litellm_proxy/openai/gpt-4", "litellm_proxy/openai/gpt-4"), + ], +) +def test__init__use_litellm_proxy_prefix(client_args, model_id, expected_model_id): + """Test litellm_proxy prefix behavior for various configurations.""" + model = LiteLLMModel(client_args=client_args, model_id=model_id) + assert model.get_config()["model_id"] == expected_model_id + + +@pytest.mark.parametrize( + "client_args, initial_model_id, new_model_id, expected_model_id", + [ + ({"use_litellm_proxy": True}, "openai/gpt-4", "anthropic/claude-3", "litellm_proxy/anthropic/claude-3"), + ({"use_litellm_proxy": False}, "openai/gpt-4", "anthropic/claude-3", "anthropic/claude-3"), + (None, "openai/gpt-4", "anthropic/claude-3", "anthropic/claude-3"), + ], +) +def test_update_config_proxy_prefix(client_args, initial_model_id, new_model_id, expected_model_id): + """Test that update_config applies proxy prefix correctly.""" + model = LiteLLMModel(client_args=client_args, model_id=initial_model_id) + model.update_config(model_id=new_model_id) + assert model.get_config()["model_id"] == expected_model_id + + @pytest.mark.parametrize( "content, exp_result", [