diff --git a/src/codegate/pipeline/output.py b/src/codegate/pipeline/output.py index 608c36de0..266485c54 100644 --- a/src/codegate/pipeline/output.py +++ b/src/codegate/pipeline/output.py @@ -153,6 +153,8 @@ async def process_stream( step_result = await step.process_chunk( c, self._context, self._input_context ) + if not step_result: + break processed_chunks.extend(step_result) current_chunks = processed_chunks diff --git a/src/codegate/pipeline/pii/pii.py b/src/codegate/pipeline/pii/pii.py index d7f33d670..4dd7d5db9 100644 --- a/src/codegate/pipeline/pii/pii.py +++ b/src/codegate/pipeline/pii/pii.py @@ -20,6 +20,38 @@ logger = structlog.get_logger("codegate") +def can_be_uuid(buffer): + """ + This is a way to check if a buffer can be a UUID. It aims to return as soon as possible + meaning that we buffer as little as possible. This is important for performance reasons + but also to make sure other steps don't wait too long as we don't buffer more than we need to. + """ + # UUID structure: 8-4-4-4-12 hex digits + # Expected positions of hyphens + hyphen_positions = {8, 13, 18, 23} + + # Maximum length of a UUID + max_uuid_length = 36 + + if buffer == "": + return True + + # If buffer is longer than a UUID, it can't be a UUID + if len(buffer) > max_uuid_length: + return False + + for i, char in enumerate(buffer): + # Check if hyphens are in the right positions + if i in hyphen_positions: + if char != "-": + return False + # Check if non-hyphen positions contain hex digits + elif not (char.isdigit() or char.lower() in "abcdef"): + return False + + return True + + class CodegatePii(PipelineStep): """ CodegatePii is a pipeline step that handles the detection and redaction of PII @@ -278,8 +310,13 @@ async def process_chunk( # noqa: C901 end_idx = content.find(self.marker_end, start_idx + 1) if end_idx == -1: - # Incomplete marker, buffer the rest - context.prefix_buffer = content[current_pos:] + # Incomplete marker, buffer the rest only if it can be a UUID + if start_idx + 1 < len(content) and not can_be_uuid(content[start_idx + 1 :]): + # the buffer can't be a UUID, so we can't process it, just return + result.append(content[current_pos:]) + else: + # this can still be a UUID + context.prefix_buffer = content[current_pos:] break # Add text before marker diff --git a/tests/pipeline/pii/test_pi.py b/tests/pipeline/pii/test_pi.py index 06d2881fe..6ced039a8 100644 --- a/tests/pipeline/pii/test_pi.py +++ b/tests/pipeline/pii/test_pi.py @@ -120,6 +120,52 @@ async def test_process_chunk_with_uuid(self, unredaction_step): result = await unredaction_step.process_chunk(chunk, context, input_context) assert result[0].choices[0].delta.content == "Text with test@example.com" + @pytest.mark.asyncio + async def test_detect_not_an_uuid(self, unredaction_step): + chunk1 = ModelResponse( + id="test", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta(content="#"), + logprobs=None, + ) + ], + created=1234567890, + model="test-model", + object="chat.completion.chunk", + ) + chunk2 = ModelResponse( + id="test", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta(content=" filepath"), + logprobs=None, + ) + ], + created=1234567890, + model="test-model", + object="chat.completion.chunk", + ) + + context = OutputPipelineContext() + manager = SensitiveDataManager() + sensitive = PipelineSensitiveData(manager=manager, session_id="session-id") + input_context = PipelineContext(sensitive=sensitive) + + # Mock PII manager in input context + mock_sensitive_data_manager = MagicMock() + mock_sensitive_data_manager.get_original_value = MagicMock(return_value="test@example.com") + input_context.metadata["sensitive_data_manager"] = mock_sensitive_data_manager + + result = await unredaction_step.process_chunk(chunk1, context, input_context) + assert not result + result = await unredaction_step.process_chunk(chunk2, context, input_context) + assert result[0].choices[0].delta.content == "# filepath" + class TestPiiRedactionNotifier: @pytest.fixture