diff --git a/src/codegate/muxing/models.py b/src/codegate/muxing/models.py
index 9a263731..b26a38e7 100644
--- a/src/codegate/muxing/models.py
+++ b/src/codegate/muxing/models.py
@@ -3,6 +3,8 @@
import pydantic
+from codegate.clients.clients import ClientType
+
class MuxMatcherType(str, Enum):
"""
@@ -11,6 +13,12 @@ class MuxMatcherType(str, Enum):
# Always match this prompt
catch_all = "catch_all"
+ # Match based on the filename. It will match if there is a filename
+ # in the request that matches the matcher either extension or full name (*.py or main.py)
+ filename_match = "filename_match"
+ # Match based on the request type. It will match if the request type
+ # matches the matcher (e.g. FIM or chat)
+ request_type_match = "request_type_match"
class MuxRule(pydantic.BaseModel):
@@ -25,3 +33,14 @@ class MuxRule(pydantic.BaseModel):
# The actual matcher to use. Note that
# this depends on the matcher type.
matcher: Optional[str] = None
+
+
+class ThingToMatchMux(pydantic.BaseModel):
+ """
+ Represents the fields we can use to match a mux rule.
+ """
+
+ body: dict
+ url_request_path: str
+ is_fim_request: bool
+ client_type: ClientType
diff --git a/src/codegate/muxing/router.py b/src/codegate/muxing/router.py
index 5771e59e..c32af579 100644
--- a/src/codegate/muxing/router.py
+++ b/src/codegate/muxing/router.py
@@ -1,14 +1,14 @@
import json
+from typing import Optional
import structlog
from fastapi import APIRouter, HTTPException, Request
-from codegate.clients.clients import ClientType
from codegate.clients.detector import DetectClient
-from codegate.extract_snippets.body_extractor import BodyCodeSnippetExtractorError
-from codegate.extract_snippets.factory import BodyCodeExtractorFactory
+from codegate.muxing import models as mux_models
from codegate.muxing import rulematcher
from codegate.muxing.adapter import BodyAdapter, ResponseAdapter
+from codegate.providers.fim_analyzer import FIMAnalyzer
from codegate.providers.registry import ProviderRegistry
from codegate.workspaces.crud import WorkspaceCrud
@@ -39,40 +39,20 @@ def get_routes(self) -> APIRouter:
def _ensure_path_starts_with_slash(self, path: str) -> str:
return path if path.startswith("/") else f"/{path}"
- def _extract_request_filenames(self, detected_client: ClientType, data: dict) -> set[str]:
+ async def _get_model_route(
+ self, thing_to_match: mux_models.ThingToMatchMux
+ ) -> Optional[rulematcher.ModelRoute]:
"""
- Extract filenames from the request data.
+ Get the model route for the given things_to_match.
"""
- try:
- body_extractor = BodyCodeExtractorFactory.create_snippet_extractor(detected_client)
- return body_extractor.extract_unique_filenames(data)
- except BodyCodeSnippetExtractorError as e:
- logger.error(f"Error extracting filenames from request: {e}")
- return set()
-
- async def _get_model_routes(self, filenames: set[str]) -> list[rulematcher.ModelRoute]:
- """
- Get the model routes for the given filenames.
- """
- model_routes = []
mux_registry = await rulematcher.get_muxing_rules_registry()
try:
- # Try to get a catch_all route
- single_model_route = await mux_registry.get_match_for_active_workspace(
- thing_to_match=None
- )
- model_routes.append(single_model_route)
-
- # Get the model routes for each filename
- for filename in filenames:
- model_route = await mux_registry.get_match_for_active_workspace(
- thing_to_match=filename
- )
- model_routes.append(model_route)
+ # Try to get a model route for the active workspace
+ model_route = await mux_registry.get_match_for_active_workspace(thing_to_match)
+ return model_route
except Exception as e:
logger.error(f"Error getting active workspace muxes: {e}")
raise HTTPException(str(e), status_code=404)
- return model_routes
def _setup_routes(self):
@@ -88,34 +68,45 @@ async def route_to_dest_provider(
1. Get destination provider from DB and active workspace.
2. Map the request body to the destination provider format.
3. Run pipeline. Selecting the correct destination provider.
- 4. Transmit the response back to the client in the correct format.
+ 4. Transmit the response back to the client in OpenAI format.
"""
body = await request.body()
data = json.loads(body)
+ is_fim_request = FIMAnalyzer.is_fim_request(rest_of_path, data)
+
+ # 1. Get destination provider from DB and active workspace.
+ thing_to_match = mux_models.ThingToMatchMux(
+ body=data,
+ url_request_path=rest_of_path,
+ is_fim_request=is_fim_request,
+ client_type=request.state.detected_client,
+ )
+ model_route = await self._get_model_route(thing_to_match)
+ if not model_route:
+ raise HTTPException(
+ "No matching rule found for the active workspace", status_code=404
+ )
- filenames_in_data = self._extract_request_filenames(request.state.detected_client, data)
- logger.info(f"Extracted filenames from request: {filenames_in_data}")
-
- model_routes = await self._get_model_routes(filenames_in_data)
- if not model_routes:
- raise HTTPException("No rule found for the active workspace", status_code=404)
-
- # We still need some logic here to handle the case where we have multiple model routes.
- # For the moment since we match all only pick the first.
- model_route = model_routes[0]
+ logger.info(
+ "Muxing request routed to destination provider",
+ model=model_route.model.name,
+ provider_type=model_route.endpoint.provider_type,
+ provider_name=model_route.endpoint.name,
+ )
- # Parse the input data and map it to the destination provider format
+ # 2. Map the request body to the destination provider format.
rest_of_path = self._ensure_path_starts_with_slash(rest_of_path)
new_data = self._body_adapter.map_body_to_dest(model_route, data)
+
+ # 3. Run pipeline. Selecting the correct destination provider.
provider = self._provider_registry.get_provider(model_route.endpoint.provider_type)
api_key = model_route.auth_material.auth_blob
-
- # Send the request to the destination provider. It will run the pipeline
response = await provider.process_request(
- new_data, api_key, rest_of_path, request.state.detected_client
+ new_data, api_key, is_fim_request, request.state.detected_client
)
- # Format the response to the client always using the OpenAI format
+
+ # 4. Transmit the response back to the client in OpenAI format.
return self._response_adapter.format_response_to_client(
response, model_route.endpoint.provider_type
)
diff --git a/src/codegate/muxing/rulematcher.py b/src/codegate/muxing/rulematcher.py
index e20161c4..fb3c1da2 100644
--- a/src/codegate/muxing/rulematcher.py
+++ b/src/codegate/muxing/rulematcher.py
@@ -1,9 +1,17 @@
import copy
from abc import ABC, abstractmethod
from asyncio import Lock
-from typing import List, Optional
+from typing import Dict, List, Optional
+import structlog
+
+from codegate.clients.clients import ClientType
from codegate.db import models as db_models
+from codegate.extract_snippets.body_extractor import BodyCodeSnippetExtractorError
+from codegate.extract_snippets.factory import BodyCodeExtractorFactory
+from codegate.muxing import models as mux_models
+
+logger = structlog.get_logger("codegate")
_muxrules_sgtn = None
@@ -40,11 +48,12 @@ def __init__(
class MuxingRuleMatcher(ABC):
"""Base class for matching muxing rules."""
- def __init__(self, route: ModelRoute):
+ def __init__(self, route: ModelRoute, matcher_blob: str):
self._route = route
+ self._matcher_blob = matcher_blob
@abstractmethod
- def match(self, thing_to_match) -> bool:
+ def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
"""Return True if the rule matches the thing_to_match."""
pass
@@ -61,12 +70,15 @@ class MuxingMatcherFactory:
def create(mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher:
"""Create a muxing matcher for the given endpoint and model."""
- factory = {
- "catch_all": CatchAllMuxingRuleMatcher,
+ factory: Dict[mux_models.MuxMatcherType, MuxingRuleMatcher] = {
+ mux_models.MuxMatcherType.catch_all: CatchAllMuxingRuleMatcher,
+ mux_models.MuxMatcherType.filename_match: FileMuxingRuleMatcher,
+ mux_models.MuxMatcherType.request_type_match: RequestTypeMuxingRuleMatcher,
}
try:
- return factory[mux_rule.matcher_type](route)
+ # Initialize the MuxingRuleMatcher
+ return factory[mux_rule.matcher_type](route, mux_rule.matcher_blob)
except KeyError:
raise ValueError(f"Unknown matcher type: {mux_rule.matcher_type}")
@@ -74,10 +86,66 @@ def create(mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher:
class CatchAllMuxingRuleMatcher(MuxingRuleMatcher):
"""A catch all muxing rule matcher."""
- def match(self, thing_to_match) -> bool:
+ def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
+ logger.info("Catch all rule matched")
return True
+class FileMuxingRuleMatcher(MuxingRuleMatcher):
+ """A file muxing rule matcher."""
+
+ def _extract_request_filenames(self, detected_client: ClientType, data: dict) -> set[str]:
+ """
+ Extract filenames from the request data.
+ """
+ try:
+ body_extractor = BodyCodeExtractorFactory.create_snippet_extractor(detected_client)
+ return body_extractor.extract_unique_filenames(data)
+ except BodyCodeSnippetExtractorError as e:
+ logger.error(f"Error extracting filenames from request: {e}")
+ return set()
+
+ def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
+ """
+ Retun True if there is a filename in the request that matches the matcher_blob.
+ The matcher_blob is either an extension (e.g. .py) or a filename (e.g. main.py).
+ """
+ # If there is no matcher_blob, we don't match
+ if not self._matcher_blob:
+ return False
+ filenames_to_match = self._extract_request_filenames(
+ thing_to_match.client_type, thing_to_match.body
+ )
+ is_filename_match = any(self._matcher_blob in filename for filename in filenames_to_match)
+ if is_filename_match:
+ logger.info(
+ "Filename rule matched", filenames=filenames_to_match, matcher=self._matcher_blob
+ )
+ return is_filename_match
+
+
+class RequestTypeMuxingRuleMatcher(MuxingRuleMatcher):
+ """A catch all muxing rule matcher."""
+
+ def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
+ """
+ Return True if the request type matches the matcher_blob.
+ The matcher_blob is either "fim" or "chat".
+ """
+ # If there is no matcher_blob, we don't match
+ if not self._matcher_blob:
+ return False
+ incoming_request_type = "fim" if thing_to_match.is_fim_request else "chat"
+ is_request_type_match = self._matcher_blob == incoming_request_type
+ if is_request_type_match:
+ logger.info(
+ "Request type rule matched",
+ matcher=self._matcher_blob,
+ request_type=incoming_request_type,
+ )
+ return is_request_type_match
+
+
class MuxingRulesinWorkspaces:
"""A thread safe dictionary to store the muxing rules in workspaces."""
@@ -111,7 +179,9 @@ async def get_registries(self) -> List[str]:
async with self._lock:
return list(self._ws_rules.keys())
- async def get_match_for_active_workspace(self, thing_to_match) -> Optional[ModelRoute]:
+ async def get_match_for_active_workspace(
+ self, thing_to_match: mux_models.ThingToMatchMux
+ ) -> Optional[ModelRoute]:
"""Get the first match for the given thing_to_match."""
# We iterate over all the rules and return the first match
diff --git a/src/codegate/pipeline/system_prompt/codegate.py b/src/codegate/pipeline/system_prompt/codegate.py
index 0dbf39a8..03520358 100644
--- a/src/codegate/pipeline/system_prompt/codegate.py
+++ b/src/codegate/pipeline/system_prompt/codegate.py
@@ -1,8 +1,8 @@
from typing import Optional
-from codegate.clients.clients import ClientType
from litellm import ChatCompletionRequest, ChatCompletionSystemMessage
+from codegate.clients.clients import ClientType
from codegate.pipeline.base import (
PipelineContext,
PipelineResult,
diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py
index fa4f146f..252a6947 100644
--- a/src/codegate/providers/anthropic/provider.py
+++ b/src/codegate/providers/anthropic/provider.py
@@ -11,6 +11,7 @@
from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer
from codegate.providers.anthropic.completion_handler import AnthropicCompletion
from codegate.providers.base import BaseProvider, ModelFetchError
+from codegate.providers.fim_analyzer import FIMAnalyzer
from codegate.providers.litellmshim import anthropic_stream_generator
@@ -57,10 +58,9 @@ async def process_request(
self,
data: dict,
api_key: str,
- request_url_path: str,
+ is_fim_request: bool,
client_type: ClientType,
):
- is_fim_request = self._is_fim_request(request_url_path, data)
try:
stream = await self.complete(data, api_key, is_fim_request, client_type)
except Exception as e:
@@ -98,10 +98,11 @@ async def create_message(
body = await request.body()
data = json.loads(body)
+ is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data)
return await self.process_request(
data,
x_api_key,
- request.url.path,
+ is_fim_request,
request.state.detected_client,
)
diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py
index 269fd0e5..0c20bab8 100644
--- a/src/codegate/providers/base.py
+++ b/src/codegate/providers/base.py
@@ -79,7 +79,7 @@ async def process_request(
self,
data: dict,
api_key: str,
- request_url_path: str,
+ is_fim_request: bool,
client_type: ClientType,
):
pass
@@ -173,61 +173,6 @@ async def _run_input_pipeline(
return result
- def _is_fim_request_url(self, request_url_path: str) -> bool:
- """
- Checks the request URL to determine if a request is FIM or chat completion.
- Used by: llama.cpp
- """
- # Evaluate first a larger substring.
- if request_url_path.endswith("/chat/completions"):
- return False
-
- # /completions is for OpenAI standard. /api/generate is for ollama.
- if request_url_path.endswith("/completions") or request_url_path.endswith("/api/generate"):
- return True
-
- return False
-
- def _is_fim_request_body(self, data: Dict) -> bool:
- """
- Determine from the raw incoming data if it's a FIM request.
- Used by: OpenAI and Anthropic
- """
- messages = data.get("messages", [])
- if not messages:
- return False
-
- first_message_content = messages[0].get("content")
- if first_message_content is None:
- return False
-
- fim_stop_sequences = ["", "", "", ""]
- if isinstance(first_message_content, str):
- msg_prompt = first_message_content
- elif isinstance(first_message_content, list):
- msg_prompt = first_message_content[0].get("text", "")
- else:
- logger.warning(f"Could not determine if message was FIM from data: {data}")
- return False
- return all([stop_sequence in msg_prompt for stop_sequence in fim_stop_sequences])
-
- def _is_fim_request(self, request_url_path: str, data: Dict) -> bool:
- """
- Determine if the request is FIM by the URL or the data of the request.
- """
- # first check if we are in specific tools to discard FIM
- prompt = data.get("prompt", "")
- tools = ["cline", "kodu", "open interpreter"]
- for tool in tools:
- if tool in prompt.lower():
- # those tools can never be FIM
- return False
- # Avoid more expensive inspection of body by just checking the URL.
- if self._is_fim_request_url(request_url_path):
- return True
-
- return self._is_fim_request_body(data)
-
async def _cleanup_after_streaming(
self, stream: AsyncIterator[ModelResponse], context: PipelineContext
) -> AsyncIterator[ModelResponse]:
diff --git a/src/codegate/providers/fim_analyzer.py b/src/codegate/providers/fim_analyzer.py
new file mode 100644
index 00000000..e0cd090c
--- /dev/null
+++ b/src/codegate/providers/fim_analyzer.py
@@ -0,0 +1,66 @@
+from typing import Dict
+
+import structlog
+
+logger = structlog.get_logger("codegate")
+
+
+class FIMAnalyzer:
+
+ @classmethod
+ def _is_fim_request_url(cls, request_url_path: str) -> bool:
+ """
+ Checks the request URL to determine if a request is FIM or chat completion.
+ Used by: llama.cpp
+ """
+ # Evaluate first a larger substring.
+ if request_url_path.endswith("chat/completions"):
+ return False
+
+ # /completions is for OpenAI standard. /api/generate is for ollama.
+ if request_url_path.endswith("completions") or request_url_path.endswith("api/generate"):
+ return True
+
+ return False
+
+ @classmethod
+ def _is_fim_request_body(cls, data: Dict) -> bool:
+ """
+ Determine from the raw incoming data if it's a FIM request.
+ Used by: OpenAI and Anthropic
+ """
+ messages = data.get("messages", [])
+ if not messages:
+ return False
+
+ first_message_content = messages[0].get("content")
+ if first_message_content is None:
+ return False
+
+ fim_stop_sequences = ["", "", "", ""]
+ if isinstance(first_message_content, str):
+ msg_prompt = first_message_content
+ elif isinstance(first_message_content, list):
+ msg_prompt = first_message_content[0].get("text", "")
+ else:
+ logger.warning(f"Could not determine if message was FIM from data: {data}")
+ return False
+ return all([stop_sequence in msg_prompt for stop_sequence in fim_stop_sequences])
+
+ @classmethod
+ def is_fim_request(cls, request_url_path: str, data: Dict) -> bool:
+ """
+ Determine if the request is FIM by the URL or the data of the request.
+ """
+ # first check if we are in specific tools to discard FIM
+ prompt = data.get("prompt", "")
+ tools = ["cline", "kodu", "open interpreter"]
+ for tool in tools:
+ if tool in prompt.lower():
+ # those tools can never be FIM
+ return False
+ # Avoid more expensive inspection of body by just checking the URL.
+ if cls._is_fim_request_url(request_url_path):
+ return True
+
+ return cls._is_fim_request_body(data)
diff --git a/src/codegate/providers/llamacpp/provider.py b/src/codegate/providers/llamacpp/provider.py
index a57077b4..186fb784 100644
--- a/src/codegate/providers/llamacpp/provider.py
+++ b/src/codegate/providers/llamacpp/provider.py
@@ -10,6 +10,7 @@
from codegate.config import Config
from codegate.pipeline.factory import PipelineFactory
from codegate.providers.base import BaseProvider, ModelFetchError
+from codegate.providers.fim_analyzer import FIMAnalyzer
from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler
from codegate.providers.llamacpp.normalizer import LLamaCppInputNormalizer, LLamaCppOutputNormalizer
@@ -53,10 +54,9 @@ async def process_request(
self,
data: dict,
api_key: str,
- request_url_path: str,
+ is_fim_request: bool,
client_type: ClientType,
):
- is_fim_request = self._is_fim_request(request_url_path, data)
try:
stream = await self.complete(
data, None, is_fim_request=is_fim_request, client_type=client_type
@@ -92,9 +92,10 @@ async def create_completion(
body = await request.body()
data = json.loads(body)
data["base_url"] = Config.get_config().model_base_path
+ is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data)
return await self.process_request(
data,
None,
- request.url.path,
+ is_fim_request,
request.state.detected_client,
)
diff --git a/src/codegate/providers/lm_studio/provider.py b/src/codegate/providers/lm_studio/provider.py
index f96fed7d..d6ab56e2 100644
--- a/src/codegate/providers/lm_studio/provider.py
+++ b/src/codegate/providers/lm_studio/provider.py
@@ -3,8 +3,10 @@
from fastapi import Header, HTTPException, Request
from fastapi.responses import JSONResponse
+from codegate.clients.detector import DetectClient
from codegate.config import Config
from codegate.pipeline.factory import PipelineFactory
+from codegate.providers.fim_analyzer import FIMAnalyzer
from codegate.providers.openai.provider import OpenAIProvider
@@ -40,6 +42,7 @@ async def get_models():
@self.router.post(f"/{self.provider_route_name}/chat/completions")
@self.router.post(f"/{self.provider_route_name}/completions")
@self.router.post(f"/{self.provider_route_name}/v1/chat/completions")
+ @DetectClient()
async def create_completion(
request: Request,
authorization: str = Header(..., description="Bearer token"),
@@ -52,5 +55,7 @@ async def create_completion(
data = json.loads(body)
data["base_url"] = self.lm_studio_url + "/v1/"
-
- return await self.process_request(data, api_key, request.url.path)
+ is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data)
+ return await self.process_request(
+ data, api_key, is_fim_request, request.state.detected_client
+ )
diff --git a/src/codegate/providers/ollama/provider.py b/src/codegate/providers/ollama/provider.py
index cc9809db..4f5cd654 100644
--- a/src/codegate/providers/ollama/provider.py
+++ b/src/codegate/providers/ollama/provider.py
@@ -10,6 +10,7 @@
from codegate.config import Config
from codegate.pipeline.factory import PipelineFactory
from codegate.providers.base import BaseProvider, ModelFetchError
+from codegate.providers.fim_analyzer import FIMAnalyzer
from codegate.providers.ollama.adapter import OllamaInputNormalizer, OllamaOutputNormalizer
from codegate.providers.ollama.completion_handler import OllamaShim
@@ -61,10 +62,9 @@ async def process_request(
self,
data: dict,
api_key: str,
- request_url_path: str,
+ is_fim_request: bool,
client_type: ClientType,
):
- is_fim_request = self._is_fim_request(request_url_path, data)
try:
stream = await self.complete(
data,
@@ -138,10 +138,10 @@ async def create_completion(request: Request):
# `base_url` is used in the providers pipeline to do the packages lookup.
# Force it to be the one that comes in the configuration.
data["base_url"] = self.base_url
-
+ is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data)
return await self.process_request(
data,
None,
- request.url.path,
+ is_fim_request,
request.state.detected_client,
)
diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py
index c74f4e52..6e936cf4 100644
--- a/src/codegate/providers/openai/provider.py
+++ b/src/codegate/providers/openai/provider.py
@@ -9,6 +9,7 @@
from codegate.clients.detector import DetectClient
from codegate.pipeline.factory import PipelineFactory
from codegate.providers.base import BaseProvider, ModelFetchError
+from codegate.providers.fim_analyzer import FIMAnalyzer
from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator
from codegate.providers.openai.adapter import OpenAIInputNormalizer, OpenAIOutputNormalizer
@@ -48,11 +49,9 @@ async def process_request(
self,
data: dict,
api_key: str,
- request_url_path: str,
+ is_fim_request: bool,
client_type: ClientType,
):
- is_fim_request = self._is_fim_request(request_url_path, data)
-
try:
stream = await self.complete(
data,
@@ -93,10 +92,11 @@ async def create_completion(
api_key = authorization.split(" ")[1]
body = await request.body()
data = json.loads(body)
+ is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data)
return await self.process_request(
data,
api_key,
- request.url.path,
+ is_fim_request,
request.state.detected_client,
)
diff --git a/src/codegate/providers/openrouter/provider.py b/src/codegate/providers/openrouter/provider.py
index 0e770a01..c3124282 100644
--- a/src/codegate/providers/openrouter/provider.py
+++ b/src/codegate/providers/openrouter/provider.py
@@ -4,6 +4,7 @@
from codegate.clients.detector import DetectClient
from codegate.pipeline.factory import PipelineFactory
+from codegate.providers.fim_analyzer import FIMAnalyzer
from codegate.providers.openai import OpenAIProvider
@@ -39,9 +40,10 @@ async def create_completion(
if not original_model.startswith("openrouter/"):
data["model"] = f"openrouter/{original_model}"
+ is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data)
return await self.process_request(
data,
api_key,
- request.url.path,
+ is_fim_request,
request.state.detected_client,
)
diff --git a/src/codegate/providers/vllm/provider.py b/src/codegate/providers/vllm/provider.py
index 16f73d6d..bb5d9a02 100644
--- a/src/codegate/providers/vllm/provider.py
+++ b/src/codegate/providers/vllm/provider.py
@@ -11,6 +11,7 @@
from codegate.clients.detector import DetectClient
from codegate.pipeline.factory import PipelineFactory
from codegate.providers.base import BaseProvider, ModelFetchError
+from codegate.providers.fim_analyzer import FIMAnalyzer
from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator
from codegate.providers.vllm.adapter import VLLMInputNormalizer, VLLMOutputNormalizer
@@ -69,10 +70,9 @@ async def process_request(
self,
data: dict,
api_key: str,
- request_url_path: str,
+ is_fim_request: bool,
client_type: ClientType,
):
- is_fim_request = self._is_fim_request(request_url_path, data)
try:
# Pass the potentially None api_key to complete
stream = await self.complete(
@@ -146,10 +146,10 @@ async def create_completion(
# Add the vLLM base URL to the request
base_url = self._get_base_url()
data["base_url"] = base_url
-
+ is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data)
return await self.process_request(
data,
api_key,
- request.url.path,
+ is_fim_request,
request.state.detected_client,
)
diff --git a/tests/muxing/test_rulematcher.py b/tests/muxing/test_rulematcher.py
new file mode 100644
index 00000000..4e489799
--- /dev/null
+++ b/tests/muxing/test_rulematcher.py
@@ -0,0 +1,126 @@
+from unittest.mock import MagicMock
+
+import pytest
+
+from codegate.db import models as db_models
+from codegate.muxing import models as mux_models
+from codegate.muxing import rulematcher
+
+mocked_route_openai = rulematcher.ModelRoute(
+ db_models.ProviderModel(
+ provider_endpoint_id="1", provider_endpoint_name="fake-openai", name="fake-gpt"
+ ),
+ db_models.ProviderEndpoint(
+ id="1",
+ name="fake-openai",
+ description="fake-openai",
+ provider_type="fake-openai",
+ endpoint="http://localhost/openai",
+ auth_type="api_key",
+ ),
+ db_models.ProviderAuthMaterial(
+ provider_endpoint_id="1", auth_type="api_key", auth_blob="fake-api-key"
+ ),
+)
+
+
+@pytest.mark.parametrize(
+ "matcher_blob, thing_to_match",
+ [
+ (None, None),
+ ("fake-matcher-blob", None),
+ (
+ "fake-matcher-blob",
+ mux_models.ThingToMatchMux(
+ body={},
+ url_request_path="/chat/completions",
+ is_fim_request=False,
+ client_type="generic",
+ ),
+ ),
+ ],
+)
+def test_catch_all(matcher_blob, thing_to_match):
+ muxing_rule_matcher = rulematcher.CatchAllMuxingRuleMatcher(mocked_route_openai, matcher_blob)
+ # It should always match
+ assert muxing_rule_matcher.match(thing_to_match) is True
+
+
+@pytest.mark.parametrize(
+ "matcher_blob, filenames_to_match, expected_bool",
+ [
+ (None, [], False), # Empty filenames and no blob
+ (None, ["main.py"], False), # Empty blob
+ (".py", ["main.py"], True), # Extension match
+ ("main.py", ["main.py"], True), # Full name match
+ (".py", ["main.py", "test.py"], True), # Extension match
+ ("main.py", ["main.py", "test.py"], True), # Full name match
+ ("main.py", ["test.py"], False), # Full name no match
+ (".js", ["main.py", "test.py"], False), # Extension no match
+ ],
+)
+def test_file_matcher(matcher_blob, filenames_to_match, expected_bool):
+ muxing_rule_matcher = rulematcher.FileMuxingRuleMatcher(mocked_route_openai, matcher_blob)
+ # We mock the _extract_request_filenames method to return a list of filenames
+ # The logic to get the correct filenames from snippets is tested in /tests/extract_snippets
+ muxing_rule_matcher._extract_request_filenames = MagicMock(return_value=filenames_to_match)
+ mocked_thing_to_match = mux_models.ThingToMatchMux(
+ body={},
+ url_request_path="/chat/completions",
+ is_fim_request=False,
+ client_type="generic",
+ )
+ assert muxing_rule_matcher.match(mocked_thing_to_match) == expected_bool
+
+
+@pytest.mark.parametrize(
+ "matcher_blob, thing_to_match, expected_bool",
+ [
+ (None, None, False), # Empty blob
+ (
+ "fim",
+ mux_models.ThingToMatchMux(
+ body={},
+ url_request_path="/chat/completions",
+ is_fim_request=False,
+ client_type="generic",
+ ),
+ False,
+ ), # No match
+ (
+ "fim",
+ mux_models.ThingToMatchMux(
+ body={},
+ url_request_path="/chat/completions",
+ is_fim_request=True,
+ client_type="generic",
+ ),
+ True,
+ ), # Match
+ (
+ "chat",
+ mux_models.ThingToMatchMux(
+ body={},
+ url_request_path="/chat/completions",
+ is_fim_request=True,
+ client_type="generic",
+ ),
+ False,
+ ), # No match
+ (
+ "chat",
+ mux_models.ThingToMatchMux(
+ body={},
+ url_request_path="/chat/completions",
+ is_fim_request=False,
+ client_type="generic",
+ ),
+ True,
+ ), # Match
+ ],
+)
+def test_request_type(matcher_blob, thing_to_match, expected_bool):
+ muxing_rule_matcher = rulematcher.RequestTypeMuxingRuleMatcher(
+ mocked_route_openai, matcher_blob
+ )
+ assert muxing_rule_matcher.match(thing_to_match) == expected_bool
diff --git a/tests/providers/test_fim_analyzer.py b/tests/providers/test_fim_analyzer.py
new file mode 100644
index 00000000..e2b94b5d
--- /dev/null
+++ b/tests/providers/test_fim_analyzer.py
@@ -0,0 +1,73 @@
+import pytest
+
+from codegate.providers.fim_analyzer import FIMAnalyzer
+
+
+@pytest.mark.parametrize(
+ "url, expected_bool",
+ [
+ ("http://localhost:8989", False),
+ ("http://test.com/chat/completions", False),
+ ("http://localhost:8989/completions", True),
+ ],
+)
+def test_is_fim_request_url(url, expected_bool):
+ assert FIMAnalyzer._is_fim_request_url(url) == expected_bool
+
+
+DATA_CONTENT_STR = {
+ "messages": [
+ {
+ "role": "user",
+ "content": " ",
+ }
+ ]
+}
+DATA_CONTENT_LIST = {
+ "messages": [
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": " "}],
+ }
+ ]
+}
+INVALID_DATA_CONTET = {
+ "messages": [
+ {
+ "role": "user",
+ "content": "http://localhost:8989/completions",
+ }
+ ]
+}
+TOOL_DATA = {
+ "prompt": "cline",
+}
+
+
+@pytest.mark.parametrize(
+ "data, expected_bool",
+ [
+ (DATA_CONTENT_STR, True),
+ (DATA_CONTENT_LIST, True),
+ (INVALID_DATA_CONTET, False),
+ ],
+)
+def test_is_fim_request_body(data, expected_bool):
+ assert FIMAnalyzer._is_fim_request_body(data) == expected_bool
+
+
+@pytest.mark.parametrize(
+ "url, data, expected_bool",
+ [
+ ("http://localhost:8989", DATA_CONTENT_STR, True), # True because of the data
+ (
+ "http://test.com/chat/completions",
+ INVALID_DATA_CONTET,
+ False,
+ ), # False because of the url
+ ("http://localhost:8989/completions", DATA_CONTENT_STR, True), # True because of the url
+ ("http://localhost:8989/completions", TOOL_DATA, False), # False because of the tool data
+ ],
+)
+def test_is_fim_request(url, data, expected_bool):
+ assert FIMAnalyzer.is_fim_request(url, data) == expected_bool
diff --git a/tests/test_provider.py b/tests/test_provider.py
deleted file mode 100644
index fd9558fb..00000000
--- a/tests/test_provider.py
+++ /dev/null
@@ -1,85 +0,0 @@
-from unittest.mock import MagicMock
-
-import pytest
-
-from codegate.providers.base import BaseProvider
-
-
-class MockProvider(BaseProvider):
-
- def __init__(self):
- mocked_input_normalizer = MagicMock()
- mocked_output_normalizer = MagicMock()
- mocked_completion_handler = MagicMock()
- mocked_factory = MagicMock()
- super().__init__(
- mocked_input_normalizer,
- mocked_output_normalizer,
- mocked_completion_handler,
- mocked_factory,
- )
-
- async def process_request(self, data: dict, api_key: str, request_url_path: str):
- return {"message": "test"}
-
- def models(self):
- return []
-
- def _setup_routes(self) -> None:
- pass
-
- @property
- def provider_route_name(self) -> str:
- return "mock-provider"
-
-
-@pytest.mark.parametrize(
- "url, expected_bool",
- [
- ("http://example.com", False),
- ("http://test.com/chat/completions", False),
- ("http://example.com/completions", True),
- ],
-)
-def test_is_fim_request_url(url, expected_bool):
- mock_provider = MockProvider()
- assert mock_provider._is_fim_request_url(url) == expected_bool
-
-
-DATA_CONTENT_STR = {
- "messages": [
- {
- "role": "user",
- "content": " ",
- }
- ]
-}
-DATA_CONTENT_LIST = {
- "messages": [
- {
- "role": "user",
- "content": [{"type": "text", "text": " "}],
- }
- ]
-}
-INVALID_DATA_CONTET = {
- "messages": [
- {
- "role": "user",
- "content": "http://example.com/completions",
- }
- ]
-}
-
-
-@pytest.mark.parametrize(
- "data, expected_bool",
- [
- (DATA_CONTENT_STR, True),
- (DATA_CONTENT_LIST, True),
- (INVALID_DATA_CONTET, False),
- ],
-)
-def test_is_fim_request_body(data, expected_bool):
- mock_provider = MockProvider()
- assert mock_provider._is_fim_request_body(data) == expected_bool