diff --git a/src/codegate/muxing/adapter.py b/src/codegate/muxing/adapter.py index c2510e90..b000b0ab 100644 --- a/src/codegate/muxing/adapter.py +++ b/src/codegate/muxing/adapter.py @@ -84,14 +84,20 @@ def provider_format_funcs(self) -> Dict[str, Callable]: """ pass + def _clean_chunk(self, chunk: str) -> str: + """Clean the chunk from the "data:" and any extra characters.""" + # Find the first position of 'data:' and add 5 characters to skip 'data:' + start_pos = chunk.find("data:") + 5 + cleaned_chunk = chunk[start_pos:].strip() + return cleaned_chunk + def _format_openai(self, chunk: str) -> str: """ The chunk is already in OpenAI format. To standarize remove the "data:" prefix. This function is used by both chat and FIM formatters """ - cleaned_chunk = chunk.split("data:")[1].strip() - return cleaned_chunk + return self._clean_chunk(chunk) def _format_antropic(self, chunk: str) -> str: """ @@ -99,7 +105,7 @@ def _format_antropic(self, chunk: str) -> str: This function is used by both chat and FIM formatters """ - cleaned_chunk = chunk.split("data:")[1].strip() + cleaned_chunk = self._clean_chunk(chunk) try: # Use `strict=False` to allow the JSON payload to contain # newlines, tabs and other valid characters that might diff --git a/tests/muxing/test_adapter.py b/tests/muxing/test_adapter.py index ba510ef0..802439c1 100644 --- a/tests/muxing/test_adapter.py +++ b/tests/muxing/test_adapter.py @@ -1,7 +1,7 @@ import pytest from codegate.db.models import ProviderType -from codegate.muxing.adapter import BodyAdapter +from codegate.muxing.adapter import BodyAdapter, ChatStreamChunkFormatter class MockedEndpoint: @@ -30,3 +30,35 @@ def test_catch_all(provider_type, endpoint_route, expected_route): model_route = MockedModelRoute(provider_type, endpoint_route) actual_route = body_adapter._get_provider_formatted_url(model_route) assert actual_route == expected_route + + +@pytest.mark.parametrize( + "chunk, expected_cleaned_chunk", + [ + ( + ( + 'event: content_block_delta\ndata:{"type": "content_block_delta", "index": 0, ' + '"delta": {"type": "text_delta", "text": "\n metadata:\n name: trusty"}}' + ), + ( + '{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", ' + '"text": "\n metadata:\n name: trusty"}}' + ), + ), + ( + ( + "event: content_block_delta\n" + 'data:{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", ' + '"text": "v1\nkind: NetworkPolicy\nmetadata:"}}' + ), + ( + '{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text"' + ': "v1\nkind: NetworkPolicy\nmetadata:"}}' + ), + ), + ], +) +def test_clean_chunk(chunk, expected_cleaned_chunk): + formatter = ChatStreamChunkFormatter() + gotten_chunk = formatter._clean_chunk(chunk) + assert gotten_chunk == expected_cleaned_chunk