diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py index 6e936cf4..f4d3e8ed 100644 --- a/src/codegate/providers/openai/provider.py +++ b/src/codegate/providers/openai/provider.py @@ -18,8 +18,9 @@ class OpenAIProvider(BaseProvider): def __init__( self, pipeline_factory: PipelineFactory, + # Enable receiving other completion handlers from childs, i.e. OpenRouter and LM Studio + completion_handler: LiteLLmShim = LiteLLmShim(stream_generator=sse_stream_generator), ): - completion_handler = LiteLLmShim(stream_generator=sse_stream_generator) super().__init__( OpenAIInputNormalizer(), OpenAIOutputNormalizer(), diff --git a/src/codegate/providers/openrouter/provider.py b/src/codegate/providers/openrouter/provider.py index de65662d..dd934161 100644 --- a/src/codegate/providers/openrouter/provider.py +++ b/src/codegate/providers/openrouter/provider.py @@ -2,12 +2,14 @@ from typing import Dict from fastapi import Header, HTTPException, Request +from litellm import atext_completion from litellm.types.llms.openai import ChatCompletionRequest from codegate.clients.clients import ClientType from codegate.clients.detector import DetectClient from codegate.pipeline.factory import PipelineFactory from codegate.providers.fim_analyzer import FIMAnalyzer +from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator from codegate.providers.normalizer.completion import CompletionNormalizer from codegate.providers.openai import OpenAIProvider @@ -20,15 +22,45 @@ def normalize(self, data: Dict) -> ChatCompletionRequest: return super().normalize(data) def denormalize(self, data: ChatCompletionRequest) -> Dict: - if data.get("had_prompt_before", False): - del data["had_prompt_before"] - - return data + """ + Denormalize a FIM OpenRouter request. Force it to be an accepted atext_completion format. + """ + denormalized_data = super().denormalize(data) + # We are forcing atext_completion which expects to have a "prompt" key in the data + # Forcing it in case is not present + if "prompt" in data: + return denormalized_data + custom_prompt = "" + for msg_dict in denormalized_data.get("messages", []): + content_obj = msg_dict.get("content") + if not content_obj: + continue + if isinstance(content_obj, list): + for content_dict in content_obj: + custom_prompt += ( + content_dict.get("text", "") if isinstance(content_dict, dict) else "" + ) + elif isinstance(content_obj, str): + custom_prompt += content_obj + + # Erase the original "messages" key. Replace it by "prompt" + del denormalized_data["messages"] + denormalized_data["prompt"] = custom_prompt + + return denormalized_data class OpenRouterProvider(OpenAIProvider): def __init__(self, pipeline_factory: PipelineFactory): - super().__init__(pipeline_factory) + super().__init__( + pipeline_factory, + # We get FIM requests in /completions. LiteLLM is forcing /chat/completions + # which returns "choices":[{"delta":{"content":"some text"}}] + # instead of "choices":[{"text":"some text"}] expected by the client (Continue) + completion_handler=LiteLLmShim( + stream_generator=sse_stream_generator, fim_completion_func=atext_completion + ), + ) self._fim_normalizer = OpenRouterNormalizer() @property