Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 6c342b7

Browse files
committed
Replace litellm with native API implementations.
Refactors client architecture to use native implementations instead of `litellm` dependency. Adds support for OpenAPI, Ollama, OpenRouter, and fixes multiple issues with Anthropic and Copilot providers. Improves message handling and streaming responses. Commit message brought you by Anthropic Claude 3.7.
1 parent f9b9bca commit 6c342b7

File tree

98 files changed

+6464
-2914
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

98 files changed

+6464
-2914
lines changed

prompts/default.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ pii_redacted: |
4646
The context files contain redacted personally identifiable information (PII) that is represented by a UUID encased within <>. For example:
4747
- <123e4567-e89b-12d3-a456-426614174000>
4848
- <2d040296-98e9-4350-84be-fda4336057eb>
49-
If you encounter any PII redacted with a UUID, DO NOT WARN the user about it. Simplt respond to the user request and keep the PII redacted and intact, using the same UUID.
49+
If you encounter any PII redacted with a UUID, DO NOT WARN the user about it. Simply respond to the user request and keep the PII redacted and intact, using the same UUID.
5050
# Security-focused prompts
5151
security_audit: "You are a security expert conducting a thorough code review. Identify potential security vulnerabilities, suggest improvements, and explain security best practices."
5252

@@ -56,6 +56,6 @@ red_team: "You are a red team member conducting a security assessment. Identify
5656
# BlueTeam prompts
5757
blue_team: "You are a blue team member conducting a security assessment. Identify security controls, misconfigurations, and potential vulnerabilities."
5858

59-
# Per client prompts
59+
# Per client prompts
6060
client_prompts:
6161
kodu: "If malicious packages or leaked secrets are found, please end the task, sending the problems found embedded in <attempt_completion><result> tags"

src/codegate/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
# Default provider URLs
1818
DEFAULT_PROVIDER_URLS = {
19-
"openai": "https://api.openai.com/v1",
20-
"openrouter": "https://openrouter.ai/api/v1",
21-
"anthropic": "https://api.anthropic.com/v1",
19+
"openai": "https://api.openai.com",
20+
"openrouter": "https://openrouter.ai/api",
21+
"anthropic": "https://api.anthropic.com",
2222
"vllm": "http://localhost:8000", # Base URL without /v1 path
2323
"ollama": "http://localhost:11434", # Default Ollama server URL
2424
"lm_studio": "http://localhost:1234",

src/codegate/db/connection.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,17 @@ def does_db_exist(self):
121121
return self._db_path.is_file()
122122

123123

124+
def row_from_model(model: BaseModel) -> dict:
125+
return dict(
126+
id=model.id,
127+
timestamp=model.timestamp,
128+
provider=model.provider,
129+
request=model.request.json(exclude_defaults=True, exclude_unset=True),
130+
type=model.type,
131+
workspace_id=model.workspace_id,
132+
)
133+
134+
124135
class DbRecorder(DbCodeGate):
125136
def __init__(self, sqlite_path: Optional[str] = None, *args, **kwargs):
126137
super().__init__(sqlite_path, *args, **kwargs)
@@ -131,7 +142,10 @@ async def _execute_update_pydantic_model(
131142
"""Execute an update or insert command for a Pydantic model."""
132143
try:
133144
async with self._async_db_engine.begin() as conn:
134-
result = await conn.execute(sql_command, model.model_dump())
145+
row = model
146+
if isinstance(model, BaseModel):
147+
row = model.model_dump()
148+
result = await conn.execute(sql_command, row)
135149
row = result.first()
136150
if row is None:
137151
return None
@@ -173,7 +187,8 @@ async def record_request(self, prompt_params: Optional[Prompt] = None) -> Option
173187
RETURNING *
174188
"""
175189
)
176-
recorded_request = await self._execute_update_pydantic_model(prompt_params, sql)
190+
row = row_from_model(prompt_params)
191+
recorded_request = await self._execute_update_pydantic_model(row, sql)
177192
# Uncomment to debug the recorded request
178193
# logger.debug(f"Recorded request: {recorded_request}")
179194
return recorded_request # type: ignore
@@ -192,7 +207,8 @@ async def update_request(
192207
RETURNING *
193208
"""
194209
)
195-
updated_request = await self._execute_update_pydantic_model(prompt_params, sql)
210+
row = row_from_model(prompt_params)
211+
updated_request = await self._execute_update_pydantic_model(row, sql)
196212
# Uncomment to debug the recorded request
197213
# logger.debug(f"Recorded request: {recorded_request}")
198214
return updated_request # type: ignore
@@ -215,7 +231,7 @@ async def record_outputs(
215231
output=first_output.output,
216232
)
217233
full_outputs = []
218-
# Just store the model respnses in the list of JSON objects.
234+
# Just store the model responses in the list of JSON objects.
219235
for output in outputs:
220236
full_outputs.append(output.output)
221237

@@ -339,7 +355,7 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
339355
f"Alerts: {len(context.alerts_raised)}."
340356
)
341357
except Exception as e:
342-
logger.error(f"Failed to record context: {context}.", error=str(e))
358+
logger.error(f"Failed to record context: {context}.", error=str(e), exc_info=e)
343359

344360
async def add_workspace(self, workspace_name: str) -> WorkspaceRow:
345361
"""Add a new workspace to the DB.

src/codegate/db/fim_cache.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,18 @@ def __init__(self):
3333

3434
def _extract_message_from_fim_request(self, request: str) -> Optional[str]:
3535
"""Extract the user message from the FIM request"""
36+
### NEW CODE PATH ###
37+
if not isinstance(request, str):
38+
content_message = None
39+
for message in request.get_messages():
40+
for content in message.get_content():
41+
if content_message is None:
42+
content_message = content.get_text()
43+
else:
44+
logger.warning("Expected one user message, found multiple.")
45+
return None
46+
return content_message
47+
3648
try:
3749
parsed_request = json.loads(request)
3850
except Exception as e:

src/codegate/extract_snippets/body_extractor.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
KoduCodeSnippetExtractor,
1010
OpenInterpreterCodeSnippetExtractor,
1111
)
12+
from codegate.types.common import MessageTypeFilter
1213

1314

1415
class BodyCodeSnippetExtractorError(Exception):
@@ -32,25 +33,22 @@ def _extract_from_user_messages(self, data: dict) -> set[str]:
3233
raise BodyCodeSnippetExtractorError("Code Extractor not set.")
3334

3435
filenames: List[str] = []
35-
for msg in data.get("messages", []):
36-
if msg.get("role", "") == "user":
36+
for msg in data.get_messages(filters=[MessageTypeFilter.USER]):
37+
for content in msg.get_content():
3738
extracted_snippets = self._snippet_extractor.extract_unique_snippets(
38-
msg.get("content")
39+
content.get_text(),
3940
)
4041
filenames.extend(extracted_snippets.keys())
4142
return set(filenames)
4243

4344
def _extract_from_list_user_messages(self, data: dict) -> set[str]:
4445
filenames: List[str] = []
45-
for msg in data.get("messages", []):
46-
if msg.get("role", "") == "user":
47-
msgs_content = msg.get("content", [])
48-
for msg_content in msgs_content:
49-
if msg_content.get("type", "") == "text":
50-
extracted_snippets = self._snippet_extractor.extract_unique_snippets(
51-
msg_content.get("text")
52-
)
53-
filenames.extend(extracted_snippets.keys())
46+
for msg in data.get_messages(filters=[MessageTypeFilter.USER]):
47+
for content in msg.get_content():
48+
extracted_snippets = self._snippet_extractor.extract_unique_snippets(
49+
content.get_text(),
50+
)
51+
filenames.extend(extracted_snippets.keys())
5452
return set(filenames)
5553

5654
@abstractmethod

src/codegate/extract_snippets/message_extractor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,16 @@ def extract_snippets(self, message: str, require_filepath: bool = False) -> List
279279
"""
280280
regexes = self._choose_regex(require_filepath)
281281
# Find all code block matches
282+
if isinstance(message, str):
283+
return [
284+
self._get_snippet_for_match(match)
285+
for regex in regexes
286+
for match in regex.finditer(message)
287+
]
282288
return [
283289
self._get_snippet_for_match(match)
284290
for regex in regexes
285-
for match in regex.finditer(message)
291+
for match in regex.finditer(message.get_text())
286292
]
287293

288294
def extract_unique_snippets(self, message: str) -> Dict[str, CodeSnippet]:

src/codegate/llm_utils/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

src/codegate/llm_utils/llmclient.py

Lines changed: 0 additions & 155 deletions
This file was deleted.

src/codegate/muxing/adapter.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
from codegate.db import models as db_models
1515
from codegate.muxing import rulematcher
1616
from codegate.providers.ollama.adapter import OLlamaToModel
17+
from codegate.types.ollama import StreamingChatCompletion as OllamaStreamingChatCompletion
18+
from codegate.types.ollama import StreamingGenerateCompletion as OllamaStreamingGenerateCompletion
19+
from codegate.muxing.ollama_mappers import openai_chunk_from_ollama_chat, openai_chunk_from_ollama_generate
20+
from codegate.types.openai import StreamingChatCompletion as OpenAIStreamingChatCompletion
1721

1822
logger = structlog.get_logger("codegate")
1923

@@ -41,12 +45,9 @@ def _get_provider_formatted_url(self, model_route: rulematcher.ModelRoute) -> st
4145
return urljoin(model_route.endpoint.endpoint, "/api/v1")
4246
return model_route.endpoint.endpoint
4347

44-
def set_destination_info(self, model_route: rulematcher.ModelRoute, data: dict) -> dict:
48+
def get_destination_info(self, model_route: rulematcher.ModelRoute) -> dict:
4549
"""Set the destination provider info."""
46-
new_data = copy.deepcopy(data)
47-
new_data["model"] = model_route.model.name
48-
new_data["base_url"] = self._get_provider_formatted_url(model_route)
49-
return new_data
50+
return model_route.model.name, self._get_provider_formatted_url(model_route)
5051

5152

5253
class OutputFormatter(ABC):
@@ -215,8 +216,8 @@ def _format_ollama(self, chunk: str) -> str:
215216
"""Format the Ollama chunk to OpenAI format."""
216217
try:
217218
chunk_dict = json.loads(chunk)
218-
ollama_chunk = ChatResponse(**chunk_dict)
219-
open_ai_chunk = OLlamaToModel.normalize_chat_chunk(ollama_chunk)
219+
ollama_chunk = OllamaStreamingChatCompletion.model_validate(chunk_dict)
220+
open_ai_chunk = openai_chunk_from_ollama_chat(ollama_chunk)
220221
return open_ai_chunk.model_dump_json(exclude_none=True, exclude_unset=True)
221222
except Exception as e:
222223
# Sometimes we receive an OpenAI formatted chunk from ollama. Specifically when
@@ -251,10 +252,11 @@ def _format_ollama(self, chunk: str) -> str:
251252
"""Format the Ollama chunk to OpenAI format."""
252253
try:
253254
chunk_dict = json.loads(chunk)
254-
ollama_chunk = GenerateResponse(**chunk_dict)
255-
open_ai_chunk = OLlamaToModel.normalize_fim_chunk(ollama_chunk)
256-
return json.dumps(open_ai_chunk, separators=(",", ":"), indent=None)
257-
except Exception:
255+
ollama_chunk = OllamaStreamingGenerateCompletion.model_validate(chunk_dict)
256+
open_ai_chunk = openai_chunk_from_ollama_generate(ollama_chunk)
257+
return open_ai_chunk.model_dump_json(exclude_none=True, exclude_unset=True)
258+
except Exception as e:
259+
print("Error formatting Ollama chunk: ", chunk, e)
258260
return chunk
259261

260262

0 commit comments

Comments
 (0)