From b8d942f46cd57338611a8ad15bd50ac44634112f Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 5 Sep 2025 14:03:14 -0400 Subject: [PATCH 1/6] fix(models): patch litellm bug to honor passing in use_litellm_proxy as client_args --- src/strands/models/litellm.py | 13 ++++ tests/strands/models/test_litellm.py | 105 +++++++++++++++++++++++++++ 2 files changed, 118 insertions(+) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index c1e99f1a2..24c326211 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -50,6 +50,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: """ self.client_args = client_args or {} self.config = dict(model_config) + self._apply_proxy_prefix() logger.debug("config=<%s> | initializing", self.config) @@ -61,6 +62,18 @@ def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type: **model_config: Configuration overrides. """ self.config.update(model_config) + self._apply_proxy_prefix() + + def _apply_proxy_prefix(self) -> None: + """Apply litellm_proxy/ prefix to model_id when use_litellm_proxy is True. + + This is a workaround for https://github.com/BerriAI/litellm/issues/13454 + where use_litellm_proxy parameter is not honored. + """ + if self.client_args.get("use_litellm_proxy") and "model_id" in self.config: + model_id = self.config["model_id"] + if not model_id.startswith("litellm_proxy/"): + self.config["model_id"] = f"litellm_proxy/{model_id}" @override def get_config(self) -> LiteLLMConfig: diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 44b6df63b..0e154b1ce 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -58,6 +58,39 @@ def test_update_config(model, model_id): assert tru_model_id == exp_model_id +@pytest.mark.parametrize( + "client_args, model_id, expected_model_id", + [ + ({"use_litellm_proxy": True}, "openai/gpt-4", "litellm_proxy/openai/gpt-4"), + ({"use_litellm_proxy": False}, "openai/gpt-4", "openai/gpt-4"), + ({"use_litellm_proxy": None}, "openai/gpt-4", "openai/gpt-4"), + ({}, "openai/gpt-4", "openai/gpt-4"), + (None, "openai/gpt-4", "openai/gpt-4"), + ({"use_litellm_proxy": True}, "litellm_proxy/openai/gpt-4", "litellm_proxy/openai/gpt-4"), + ({"use_litellm_proxy": False}, "litellm_proxy/openai/gpt-4", "litellm_proxy/openai/gpt-4"), + ], +) +def test_use_litellm_proxy_prefix(client_args, model_id, expected_model_id): + """Test litellm_proxy prefix behavior for various configurations.""" + model = LiteLLMModel(client_args=client_args, model_id=model_id) + assert model.get_config()["model_id"] == expected_model_id + + +@pytest.mark.parametrize( + "client_args, initial_model_id, new_model_id, expected_model_id", + [ + ({"use_litellm_proxy": True}, "openai/gpt-4", "anthropic/claude-3", "litellm_proxy/anthropic/claude-3"), + ({"use_litellm_proxy": False}, "openai/gpt-4", "anthropic/claude-3", "anthropic/claude-3"), + (None, "openai/gpt-4", "anthropic/claude-3", "anthropic/claude-3"), + ], +) +def test_update_config_proxy_prefix(client_args, initial_model_id, new_model_id, expected_model_id): + """Test that update_config applies proxy prefix correctly.""" + model = LiteLLMModel(client_args=client_args, model_id=initial_model_id) + model.update_config(model_id=new_model_id) + assert model.get_config()["model_id"] == expected_model_id + + @pytest.mark.parametrize( "content, exp_result", [ @@ -197,6 +230,40 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, litellm_acompletion.assert_called_once_with(**expected_request) +@pytest.mark.parametrize( + "use_litellm_proxy, expected_model_id", + [ + (True, "litellm_proxy/openai/gpt-4"), + (False, "openai/gpt-4"), + (None, "openai/gpt-4"), + ], +) +@pytest.mark.asyncio +async def test_stream_with_proxy(litellm_acompletion, api_key, agenerator, alist, use_litellm_proxy, expected_model_id): + """Test that streaming works correctly with various proxy configurations.""" + client_args = {"api_key": api_key} + if use_litellm_proxy is not None: + client_args["use_litellm_proxy"] = use_litellm_proxy + + model = LiteLLMModel(client_args=client_args, model_id="openai/gpt-4") + + mock_delta = unittest.mock.Mock(content="test response", tool_calls=None, reasoning_content=None) + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_3 = unittest.mock.Mock(usage=None) + + litellm_acompletion.side_effect = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2, mock_event_3]) + ) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + response = model.stream(messages) + await alist(response) + + call_args = litellm_acompletion.call_args[1] + assert call_args["model"] == expected_model_id + + @pytest.mark.asyncio async def test_stream_empty(litellm_acompletion, api_key, model_id, model, agenerator, alist): mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None) @@ -252,3 +319,41 @@ async def test_structured_output(litellm_acompletion, model, test_output_model_c exp_result = {"output": test_output_model_cls(name="John", age=30)} assert tru_result == exp_result + + +@pytest.mark.parametrize( + "use_litellm_proxy, expected_model_id", + [ + (True, "litellm_proxy/openai/gpt-4"), + (False, "openai/gpt-4"), + ], +) +@pytest.mark.asyncio +async def test_structured_output_with_proxy( + litellm_acompletion, api_key, test_output_model_cls, alist, use_litellm_proxy, expected_model_id +): + """Test that structured_output works correctly with various proxy configurations.""" + model = LiteLLMModel( + client_args={"api_key": api_key, "use_litellm_proxy": use_litellm_proxy}, model_id="openai/gpt-4" + ) + + messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + + mock_choice = unittest.mock.Mock() + mock_choice.finish_reason = "tool_calls" + mock_choice.message.content = '{"name": "Jane", "age": 25}' + mock_response = unittest.mock.Mock() + mock_response.choices = [mock_choice] + + litellm_acompletion.side_effect = unittest.mock.AsyncMock(return_value=mock_response) + + with unittest.mock.patch.object(strands.models.litellm, "supports_response_schema", return_value=True): + stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) + tru_result = events[-1] + + exp_result = {"output": test_output_model_cls(name="Jane", age=25)} + assert tru_result == exp_result + + call_args = litellm_acompletion.call_args[1] + assert call_args["model"] == expected_model_id From 42c7179c0d63ca176d4e604d5fdeee900e9a4a5a Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 5 Sep 2025 14:36:52 -0400 Subject: [PATCH 2/6] fix: move private method below public --- src/strands/models/litellm.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 24c326211..4870fae2c 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -64,17 +64,6 @@ def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type: self.config.update(model_config) self._apply_proxy_prefix() - def _apply_proxy_prefix(self) -> None: - """Apply litellm_proxy/ prefix to model_id when use_litellm_proxy is True. - - This is a workaround for https://github.com/BerriAI/litellm/issues/13454 - where use_litellm_proxy parameter is not honored. - """ - if self.client_args.get("use_litellm_proxy") and "model_id" in self.config: - model_id = self.config["model_id"] - if not model_id.startswith("litellm_proxy/"): - self.config["model_id"] = f"litellm_proxy/{model_id}" - @override def get_config(self) -> LiteLLMConfig: """Get the LiteLLM model configuration. @@ -236,3 +225,14 @@ async def structured_output( # If no tool_calls found, raise an error raise ValueError("No tool_calls found in response") + + def _apply_proxy_prefix(self) -> None: + """Apply litellm_proxy/ prefix to model_id when use_litellm_proxy is True. + + This is a workaround for https://github.com/BerriAI/litellm/issues/13454 + where use_litellm_proxy parameter is not honored. + """ + if self.client_args.get("use_litellm_proxy") and "model_id" in self.config: + model_id = self.config["model_id"] + if not model_id.startswith("litellm_proxy/"): + self.config["model_id"] = f"litellm_proxy/{model_id}" From f1c046d82cd74de3c76406fda0f8818666263eee Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 5 Sep 2025 14:39:07 -0400 Subject: [PATCH 3/6] fix: remote redundant tests --- tests/strands/models/test_litellm.py | 66 ---------------------------- 1 file changed, 66 deletions(-) diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 0e154b1ce..c97ccf5ac 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -230,38 +230,7 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, litellm_acompletion.assert_called_once_with(**expected_request) -@pytest.mark.parametrize( - "use_litellm_proxy, expected_model_id", - [ - (True, "litellm_proxy/openai/gpt-4"), - (False, "openai/gpt-4"), - (None, "openai/gpt-4"), - ], -) -@pytest.mark.asyncio -async def test_stream_with_proxy(litellm_acompletion, api_key, agenerator, alist, use_litellm_proxy, expected_model_id): - """Test that streaming works correctly with various proxy configurations.""" - client_args = {"api_key": api_key} - if use_litellm_proxy is not None: - client_args["use_litellm_proxy"] = use_litellm_proxy - - model = LiteLLMModel(client_args=client_args, model_id="openai/gpt-4") - - mock_delta = unittest.mock.Mock(content="test response", tool_calls=None, reasoning_content=None) - mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) - mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) - mock_event_3 = unittest.mock.Mock(usage=None) - litellm_acompletion.side_effect = unittest.mock.AsyncMock( - return_value=agenerator([mock_event_1, mock_event_2, mock_event_3]) - ) - - messages = [{"role": "user", "content": [{"text": "test"}]}] - response = model.stream(messages) - await alist(response) - - call_args = litellm_acompletion.call_args[1] - assert call_args["model"] == expected_model_id @pytest.mark.asyncio @@ -321,39 +290,4 @@ async def test_structured_output(litellm_acompletion, model, test_output_model_c assert tru_result == exp_result -@pytest.mark.parametrize( - "use_litellm_proxy, expected_model_id", - [ - (True, "litellm_proxy/openai/gpt-4"), - (False, "openai/gpt-4"), - ], -) -@pytest.mark.asyncio -async def test_structured_output_with_proxy( - litellm_acompletion, api_key, test_output_model_cls, alist, use_litellm_proxy, expected_model_id -): - """Test that structured_output works correctly with various proxy configurations.""" - model = LiteLLMModel( - client_args={"api_key": api_key, "use_litellm_proxy": use_litellm_proxy}, model_id="openai/gpt-4" - ) - - messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] - - mock_choice = unittest.mock.Mock() - mock_choice.finish_reason = "tool_calls" - mock_choice.message.content = '{"name": "Jane", "age": 25}' - mock_response = unittest.mock.Mock() - mock_response.choices = [mock_choice] - - litellm_acompletion.side_effect = unittest.mock.AsyncMock(return_value=mock_response) - - with unittest.mock.patch.object(strands.models.litellm, "supports_response_schema", return_value=True): - stream = model.structured_output(test_output_model_cls, messages) - events = await alist(stream) - tru_result = events[-1] - - exp_result = {"output": test_output_model_cls(name="Jane", age=25)} - assert tru_result == exp_result - call_args = litellm_acompletion.call_args[1] - assert call_args["model"] == expected_model_id From 6e3b478fb70e6118eeaa4a76d085d21606c69e37 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 5 Sep 2025 14:39:33 -0400 Subject: [PATCH 4/6] formatting --- tests/strands/models/test_litellm.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index c97ccf5ac..80525bad0 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -230,9 +230,6 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, litellm_acompletion.assert_called_once_with(**expected_request) - - - @pytest.mark.asyncio async def test_stream_empty(litellm_acompletion, api_key, model_id, model, agenerator, alist): mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None) @@ -288,6 +285,3 @@ async def test_structured_output(litellm_acompletion, model, test_output_model_c exp_result = {"output": test_output_model_cls(name="John", age=30)} assert tru_result == exp_result - - - From 768244655aeb2678f3ca5f1ed6fe8a0bc3a5d750 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 5 Sep 2025 14:48:05 -0400 Subject: [PATCH 5/6] linting to use get_config() --- src/strands/models/litellm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 4870fae2c..1e4bf42a6 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -233,6 +233,6 @@ def _apply_proxy_prefix(self) -> None: where use_litellm_proxy parameter is not honored. """ if self.client_args.get("use_litellm_proxy") and "model_id" in self.config: - model_id = self.config["model_id"] + model_id = self.get_config()["model_id"] if not model_id.startswith("litellm_proxy/"): self.config["model_id"] = f"litellm_proxy/{model_id}" From 70c636ac11ab123b28751c880375de73a5d60e4c Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 8 Sep 2025 17:37:11 -0400 Subject: [PATCH 6/6] Update tests/strands/models/test_litellm.py --- tests/strands/models/test_litellm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 80525bad0..443bd3e8f 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -70,7 +70,7 @@ def test_update_config(model, model_id): ({"use_litellm_proxy": False}, "litellm_proxy/openai/gpt-4", "litellm_proxy/openai/gpt-4"), ], ) -def test_use_litellm_proxy_prefix(client_args, model_id, expected_model_id): +def test__init__use_litellm_proxy_prefix(client_args, model_id, expected_model_id): """Test litellm_proxy prefix behavior for various configurations.""" model = LiteLLMModel(client_args=client_args, model_id=model_id) assert model.get_config()["model_id"] == expected_model_id