diff --git a/src/codegate/muxing/models.py b/src/codegate/muxing/models.py index 9a263731..b26a38e7 100644 --- a/src/codegate/muxing/models.py +++ b/src/codegate/muxing/models.py @@ -3,6 +3,8 @@ import pydantic +from codegate.clients.clients import ClientType + class MuxMatcherType(str, Enum): """ @@ -11,6 +13,12 @@ class MuxMatcherType(str, Enum): # Always match this prompt catch_all = "catch_all" + # Match based on the filename. It will match if there is a filename + # in the request that matches the matcher either extension or full name (*.py or main.py) + filename_match = "filename_match" + # Match based on the request type. It will match if the request type + # matches the matcher (e.g. FIM or chat) + request_type_match = "request_type_match" class MuxRule(pydantic.BaseModel): @@ -25,3 +33,14 @@ class MuxRule(pydantic.BaseModel): # The actual matcher to use. Note that # this depends on the matcher type. matcher: Optional[str] = None + + +class ThingToMatchMux(pydantic.BaseModel): + """ + Represents the fields we can use to match a mux rule. + """ + + body: dict + url_request_path: str + is_fim_request: bool + client_type: ClientType diff --git a/src/codegate/muxing/router.py b/src/codegate/muxing/router.py index 5771e59e..c32af579 100644 --- a/src/codegate/muxing/router.py +++ b/src/codegate/muxing/router.py @@ -1,14 +1,14 @@ import json +from typing import Optional import structlog from fastapi import APIRouter, HTTPException, Request -from codegate.clients.clients import ClientType from codegate.clients.detector import DetectClient -from codegate.extract_snippets.body_extractor import BodyCodeSnippetExtractorError -from codegate.extract_snippets.factory import BodyCodeExtractorFactory +from codegate.muxing import models as mux_models from codegate.muxing import rulematcher from codegate.muxing.adapter import BodyAdapter, ResponseAdapter +from codegate.providers.fim_analyzer import FIMAnalyzer from codegate.providers.registry import ProviderRegistry from codegate.workspaces.crud import WorkspaceCrud @@ -39,40 +39,20 @@ def get_routes(self) -> APIRouter: def _ensure_path_starts_with_slash(self, path: str) -> str: return path if path.startswith("/") else f"/{path}" - def _extract_request_filenames(self, detected_client: ClientType, data: dict) -> set[str]: + async def _get_model_route( + self, thing_to_match: mux_models.ThingToMatchMux + ) -> Optional[rulematcher.ModelRoute]: """ - Extract filenames from the request data. + Get the model route for the given things_to_match. """ - try: - body_extractor = BodyCodeExtractorFactory.create_snippet_extractor(detected_client) - return body_extractor.extract_unique_filenames(data) - except BodyCodeSnippetExtractorError as e: - logger.error(f"Error extracting filenames from request: {e}") - return set() - - async def _get_model_routes(self, filenames: set[str]) -> list[rulematcher.ModelRoute]: - """ - Get the model routes for the given filenames. - """ - model_routes = [] mux_registry = await rulematcher.get_muxing_rules_registry() try: - # Try to get a catch_all route - single_model_route = await mux_registry.get_match_for_active_workspace( - thing_to_match=None - ) - model_routes.append(single_model_route) - - # Get the model routes for each filename - for filename in filenames: - model_route = await mux_registry.get_match_for_active_workspace( - thing_to_match=filename - ) - model_routes.append(model_route) + # Try to get a model route for the active workspace + model_route = await mux_registry.get_match_for_active_workspace(thing_to_match) + return model_route except Exception as e: logger.error(f"Error getting active workspace muxes: {e}") raise HTTPException(str(e), status_code=404) - return model_routes def _setup_routes(self): @@ -88,34 +68,45 @@ async def route_to_dest_provider( 1. Get destination provider from DB and active workspace. 2. Map the request body to the destination provider format. 3. Run pipeline. Selecting the correct destination provider. - 4. Transmit the response back to the client in the correct format. + 4. Transmit the response back to the client in OpenAI format. """ body = await request.body() data = json.loads(body) + is_fim_request = FIMAnalyzer.is_fim_request(rest_of_path, data) + + # 1. Get destination provider from DB and active workspace. + thing_to_match = mux_models.ThingToMatchMux( + body=data, + url_request_path=rest_of_path, + is_fim_request=is_fim_request, + client_type=request.state.detected_client, + ) + model_route = await self._get_model_route(thing_to_match) + if not model_route: + raise HTTPException( + "No matching rule found for the active workspace", status_code=404 + ) - filenames_in_data = self._extract_request_filenames(request.state.detected_client, data) - logger.info(f"Extracted filenames from request: {filenames_in_data}") - - model_routes = await self._get_model_routes(filenames_in_data) - if not model_routes: - raise HTTPException("No rule found for the active workspace", status_code=404) - - # We still need some logic here to handle the case where we have multiple model routes. - # For the moment since we match all only pick the first. - model_route = model_routes[0] + logger.info( + "Muxing request routed to destination provider", + model=model_route.model.name, + provider_type=model_route.endpoint.provider_type, + provider_name=model_route.endpoint.name, + ) - # Parse the input data and map it to the destination provider format + # 2. Map the request body to the destination provider format. rest_of_path = self._ensure_path_starts_with_slash(rest_of_path) new_data = self._body_adapter.map_body_to_dest(model_route, data) + + # 3. Run pipeline. Selecting the correct destination provider. provider = self._provider_registry.get_provider(model_route.endpoint.provider_type) api_key = model_route.auth_material.auth_blob - - # Send the request to the destination provider. It will run the pipeline response = await provider.process_request( - new_data, api_key, rest_of_path, request.state.detected_client + new_data, api_key, is_fim_request, request.state.detected_client ) - # Format the response to the client always using the OpenAI format + + # 4. Transmit the response back to the client in OpenAI format. return self._response_adapter.format_response_to_client( response, model_route.endpoint.provider_type ) diff --git a/src/codegate/muxing/rulematcher.py b/src/codegate/muxing/rulematcher.py index e20161c4..fb3c1da2 100644 --- a/src/codegate/muxing/rulematcher.py +++ b/src/codegate/muxing/rulematcher.py @@ -1,9 +1,17 @@ import copy from abc import ABC, abstractmethod from asyncio import Lock -from typing import List, Optional +from typing import Dict, List, Optional +import structlog + +from codegate.clients.clients import ClientType from codegate.db import models as db_models +from codegate.extract_snippets.body_extractor import BodyCodeSnippetExtractorError +from codegate.extract_snippets.factory import BodyCodeExtractorFactory +from codegate.muxing import models as mux_models + +logger = structlog.get_logger("codegate") _muxrules_sgtn = None @@ -40,11 +48,12 @@ def __init__( class MuxingRuleMatcher(ABC): """Base class for matching muxing rules.""" - def __init__(self, route: ModelRoute): + def __init__(self, route: ModelRoute, matcher_blob: str): self._route = route + self._matcher_blob = matcher_blob @abstractmethod - def match(self, thing_to_match) -> bool: + def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: """Return True if the rule matches the thing_to_match.""" pass @@ -61,12 +70,15 @@ class MuxingMatcherFactory: def create(mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher: """Create a muxing matcher for the given endpoint and model.""" - factory = { - "catch_all": CatchAllMuxingRuleMatcher, + factory: Dict[mux_models.MuxMatcherType, MuxingRuleMatcher] = { + mux_models.MuxMatcherType.catch_all: CatchAllMuxingRuleMatcher, + mux_models.MuxMatcherType.filename_match: FileMuxingRuleMatcher, + mux_models.MuxMatcherType.request_type_match: RequestTypeMuxingRuleMatcher, } try: - return factory[mux_rule.matcher_type](route) + # Initialize the MuxingRuleMatcher + return factory[mux_rule.matcher_type](route, mux_rule.matcher_blob) except KeyError: raise ValueError(f"Unknown matcher type: {mux_rule.matcher_type}") @@ -74,10 +86,66 @@ def create(mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher: class CatchAllMuxingRuleMatcher(MuxingRuleMatcher): """A catch all muxing rule matcher.""" - def match(self, thing_to_match) -> bool: + def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: + logger.info("Catch all rule matched") return True +class FileMuxingRuleMatcher(MuxingRuleMatcher): + """A file muxing rule matcher.""" + + def _extract_request_filenames(self, detected_client: ClientType, data: dict) -> set[str]: + """ + Extract filenames from the request data. + """ + try: + body_extractor = BodyCodeExtractorFactory.create_snippet_extractor(detected_client) + return body_extractor.extract_unique_filenames(data) + except BodyCodeSnippetExtractorError as e: + logger.error(f"Error extracting filenames from request: {e}") + return set() + + def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: + """ + Retun True if there is a filename in the request that matches the matcher_blob. + The matcher_blob is either an extension (e.g. .py) or a filename (e.g. main.py). + """ + # If there is no matcher_blob, we don't match + if not self._matcher_blob: + return False + filenames_to_match = self._extract_request_filenames( + thing_to_match.client_type, thing_to_match.body + ) + is_filename_match = any(self._matcher_blob in filename for filename in filenames_to_match) + if is_filename_match: + logger.info( + "Filename rule matched", filenames=filenames_to_match, matcher=self._matcher_blob + ) + return is_filename_match + + +class RequestTypeMuxingRuleMatcher(MuxingRuleMatcher): + """A catch all muxing rule matcher.""" + + def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: + """ + Return True if the request type matches the matcher_blob. + The matcher_blob is either "fim" or "chat". + """ + # If there is no matcher_blob, we don't match + if not self._matcher_blob: + return False + incoming_request_type = "fim" if thing_to_match.is_fim_request else "chat" + is_request_type_match = self._matcher_blob == incoming_request_type + if is_request_type_match: + logger.info( + "Request type rule matched", + matcher=self._matcher_blob, + request_type=incoming_request_type, + ) + return is_request_type_match + + class MuxingRulesinWorkspaces: """A thread safe dictionary to store the muxing rules in workspaces.""" @@ -111,7 +179,9 @@ async def get_registries(self) -> List[str]: async with self._lock: return list(self._ws_rules.keys()) - async def get_match_for_active_workspace(self, thing_to_match) -> Optional[ModelRoute]: + async def get_match_for_active_workspace( + self, thing_to_match: mux_models.ThingToMatchMux + ) -> Optional[ModelRoute]: """Get the first match for the given thing_to_match.""" # We iterate over all the rules and return the first match diff --git a/src/codegate/pipeline/system_prompt/codegate.py b/src/codegate/pipeline/system_prompt/codegate.py index 0dbf39a8..03520358 100644 --- a/src/codegate/pipeline/system_prompt/codegate.py +++ b/src/codegate/pipeline/system_prompt/codegate.py @@ -1,8 +1,8 @@ from typing import Optional -from codegate.clients.clients import ClientType from litellm import ChatCompletionRequest, ChatCompletionSystemMessage +from codegate.clients.clients import ClientType from codegate.pipeline.base import ( PipelineContext, PipelineResult, diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py index fa4f146f..252a6947 100644 --- a/src/codegate/providers/anthropic/provider.py +++ b/src/codegate/providers/anthropic/provider.py @@ -11,6 +11,7 @@ from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer from codegate.providers.anthropic.completion_handler import AnthropicCompletion from codegate.providers.base import BaseProvider, ModelFetchError +from codegate.providers.fim_analyzer import FIMAnalyzer from codegate.providers.litellmshim import anthropic_stream_generator @@ -57,10 +58,9 @@ async def process_request( self, data: dict, api_key: str, - request_url_path: str, + is_fim_request: bool, client_type: ClientType, ): - is_fim_request = self._is_fim_request(request_url_path, data) try: stream = await self.complete(data, api_key, is_fim_request, client_type) except Exception as e: @@ -98,10 +98,11 @@ async def create_message( body = await request.body() data = json.loads(body) + is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data) return await self.process_request( data, x_api_key, - request.url.path, + is_fim_request, request.state.detected_client, ) diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 269fd0e5..0c20bab8 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -79,7 +79,7 @@ async def process_request( self, data: dict, api_key: str, - request_url_path: str, + is_fim_request: bool, client_type: ClientType, ): pass @@ -173,61 +173,6 @@ async def _run_input_pipeline( return result - def _is_fim_request_url(self, request_url_path: str) -> bool: - """ - Checks the request URL to determine if a request is FIM or chat completion. - Used by: llama.cpp - """ - # Evaluate first a larger substring. - if request_url_path.endswith("/chat/completions"): - return False - - # /completions is for OpenAI standard. /api/generate is for ollama. - if request_url_path.endswith("/completions") or request_url_path.endswith("/api/generate"): - return True - - return False - - def _is_fim_request_body(self, data: Dict) -> bool: - """ - Determine from the raw incoming data if it's a FIM request. - Used by: OpenAI and Anthropic - """ - messages = data.get("messages", []) - if not messages: - return False - - first_message_content = messages[0].get("content") - if first_message_content is None: - return False - - fim_stop_sequences = ["", "", "", ""] - if isinstance(first_message_content, str): - msg_prompt = first_message_content - elif isinstance(first_message_content, list): - msg_prompt = first_message_content[0].get("text", "") - else: - logger.warning(f"Could not determine if message was FIM from data: {data}") - return False - return all([stop_sequence in msg_prompt for stop_sequence in fim_stop_sequences]) - - def _is_fim_request(self, request_url_path: str, data: Dict) -> bool: - """ - Determine if the request is FIM by the URL or the data of the request. - """ - # first check if we are in specific tools to discard FIM - prompt = data.get("prompt", "") - tools = ["cline", "kodu", "open interpreter"] - for tool in tools: - if tool in prompt.lower(): - # those tools can never be FIM - return False - # Avoid more expensive inspection of body by just checking the URL. - if self._is_fim_request_url(request_url_path): - return True - - return self._is_fim_request_body(data) - async def _cleanup_after_streaming( self, stream: AsyncIterator[ModelResponse], context: PipelineContext ) -> AsyncIterator[ModelResponse]: diff --git a/src/codegate/providers/fim_analyzer.py b/src/codegate/providers/fim_analyzer.py new file mode 100644 index 00000000..e0cd090c --- /dev/null +++ b/src/codegate/providers/fim_analyzer.py @@ -0,0 +1,66 @@ +from typing import Dict + +import structlog + +logger = structlog.get_logger("codegate") + + +class FIMAnalyzer: + + @classmethod + def _is_fim_request_url(cls, request_url_path: str) -> bool: + """ + Checks the request URL to determine if a request is FIM or chat completion. + Used by: llama.cpp + """ + # Evaluate first a larger substring. + if request_url_path.endswith("chat/completions"): + return False + + # /completions is for OpenAI standard. /api/generate is for ollama. + if request_url_path.endswith("completions") or request_url_path.endswith("api/generate"): + return True + + return False + + @classmethod + def _is_fim_request_body(cls, data: Dict) -> bool: + """ + Determine from the raw incoming data if it's a FIM request. + Used by: OpenAI and Anthropic + """ + messages = data.get("messages", []) + if not messages: + return False + + first_message_content = messages[0].get("content") + if first_message_content is None: + return False + + fim_stop_sequences = ["", "", "", ""] + if isinstance(first_message_content, str): + msg_prompt = first_message_content + elif isinstance(first_message_content, list): + msg_prompt = first_message_content[0].get("text", "") + else: + logger.warning(f"Could not determine if message was FIM from data: {data}") + return False + return all([stop_sequence in msg_prompt for stop_sequence in fim_stop_sequences]) + + @classmethod + def is_fim_request(cls, request_url_path: str, data: Dict) -> bool: + """ + Determine if the request is FIM by the URL or the data of the request. + """ + # first check if we are in specific tools to discard FIM + prompt = data.get("prompt", "") + tools = ["cline", "kodu", "open interpreter"] + for tool in tools: + if tool in prompt.lower(): + # those tools can never be FIM + return False + # Avoid more expensive inspection of body by just checking the URL. + if cls._is_fim_request_url(request_url_path): + return True + + return cls._is_fim_request_body(data) diff --git a/src/codegate/providers/llamacpp/provider.py b/src/codegate/providers/llamacpp/provider.py index a57077b4..186fb784 100644 --- a/src/codegate/providers/llamacpp/provider.py +++ b/src/codegate/providers/llamacpp/provider.py @@ -10,6 +10,7 @@ from codegate.config import Config from codegate.pipeline.factory import PipelineFactory from codegate.providers.base import BaseProvider, ModelFetchError +from codegate.providers.fim_analyzer import FIMAnalyzer from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler from codegate.providers.llamacpp.normalizer import LLamaCppInputNormalizer, LLamaCppOutputNormalizer @@ -53,10 +54,9 @@ async def process_request( self, data: dict, api_key: str, - request_url_path: str, + is_fim_request: bool, client_type: ClientType, ): - is_fim_request = self._is_fim_request(request_url_path, data) try: stream = await self.complete( data, None, is_fim_request=is_fim_request, client_type=client_type @@ -92,9 +92,10 @@ async def create_completion( body = await request.body() data = json.loads(body) data["base_url"] = Config.get_config().model_base_path + is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data) return await self.process_request( data, None, - request.url.path, + is_fim_request, request.state.detected_client, ) diff --git a/src/codegate/providers/lm_studio/provider.py b/src/codegate/providers/lm_studio/provider.py index f96fed7d..d6ab56e2 100644 --- a/src/codegate/providers/lm_studio/provider.py +++ b/src/codegate/providers/lm_studio/provider.py @@ -3,8 +3,10 @@ from fastapi import Header, HTTPException, Request from fastapi.responses import JSONResponse +from codegate.clients.detector import DetectClient from codegate.config import Config from codegate.pipeline.factory import PipelineFactory +from codegate.providers.fim_analyzer import FIMAnalyzer from codegate.providers.openai.provider import OpenAIProvider @@ -40,6 +42,7 @@ async def get_models(): @self.router.post(f"/{self.provider_route_name}/chat/completions") @self.router.post(f"/{self.provider_route_name}/completions") @self.router.post(f"/{self.provider_route_name}/v1/chat/completions") + @DetectClient() async def create_completion( request: Request, authorization: str = Header(..., description="Bearer token"), @@ -52,5 +55,7 @@ async def create_completion( data = json.loads(body) data["base_url"] = self.lm_studio_url + "/v1/" - - return await self.process_request(data, api_key, request.url.path) + is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data) + return await self.process_request( + data, api_key, is_fim_request, request.state.detected_client + ) diff --git a/src/codegate/providers/ollama/provider.py b/src/codegate/providers/ollama/provider.py index cc9809db..4f5cd654 100644 --- a/src/codegate/providers/ollama/provider.py +++ b/src/codegate/providers/ollama/provider.py @@ -10,6 +10,7 @@ from codegate.config import Config from codegate.pipeline.factory import PipelineFactory from codegate.providers.base import BaseProvider, ModelFetchError +from codegate.providers.fim_analyzer import FIMAnalyzer from codegate.providers.ollama.adapter import OllamaInputNormalizer, OllamaOutputNormalizer from codegate.providers.ollama.completion_handler import OllamaShim @@ -61,10 +62,9 @@ async def process_request( self, data: dict, api_key: str, - request_url_path: str, + is_fim_request: bool, client_type: ClientType, ): - is_fim_request = self._is_fim_request(request_url_path, data) try: stream = await self.complete( data, @@ -138,10 +138,10 @@ async def create_completion(request: Request): # `base_url` is used in the providers pipeline to do the packages lookup. # Force it to be the one that comes in the configuration. data["base_url"] = self.base_url - + is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data) return await self.process_request( data, None, - request.url.path, + is_fim_request, request.state.detected_client, ) diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py index c74f4e52..6e936cf4 100644 --- a/src/codegate/providers/openai/provider.py +++ b/src/codegate/providers/openai/provider.py @@ -9,6 +9,7 @@ from codegate.clients.detector import DetectClient from codegate.pipeline.factory import PipelineFactory from codegate.providers.base import BaseProvider, ModelFetchError +from codegate.providers.fim_analyzer import FIMAnalyzer from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator from codegate.providers.openai.adapter import OpenAIInputNormalizer, OpenAIOutputNormalizer @@ -48,11 +49,9 @@ async def process_request( self, data: dict, api_key: str, - request_url_path: str, + is_fim_request: bool, client_type: ClientType, ): - is_fim_request = self._is_fim_request(request_url_path, data) - try: stream = await self.complete( data, @@ -93,10 +92,11 @@ async def create_completion( api_key = authorization.split(" ")[1] body = await request.body() data = json.loads(body) + is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data) return await self.process_request( data, api_key, - request.url.path, + is_fim_request, request.state.detected_client, ) diff --git a/src/codegate/providers/openrouter/provider.py b/src/codegate/providers/openrouter/provider.py index 0e770a01..c3124282 100644 --- a/src/codegate/providers/openrouter/provider.py +++ b/src/codegate/providers/openrouter/provider.py @@ -4,6 +4,7 @@ from codegate.clients.detector import DetectClient from codegate.pipeline.factory import PipelineFactory +from codegate.providers.fim_analyzer import FIMAnalyzer from codegate.providers.openai import OpenAIProvider @@ -39,9 +40,10 @@ async def create_completion( if not original_model.startswith("openrouter/"): data["model"] = f"openrouter/{original_model}" + is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data) return await self.process_request( data, api_key, - request.url.path, + is_fim_request, request.state.detected_client, ) diff --git a/src/codegate/providers/vllm/provider.py b/src/codegate/providers/vllm/provider.py index 16f73d6d..bb5d9a02 100644 --- a/src/codegate/providers/vllm/provider.py +++ b/src/codegate/providers/vllm/provider.py @@ -11,6 +11,7 @@ from codegate.clients.detector import DetectClient from codegate.pipeline.factory import PipelineFactory from codegate.providers.base import BaseProvider, ModelFetchError +from codegate.providers.fim_analyzer import FIMAnalyzer from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator from codegate.providers.vllm.adapter import VLLMInputNormalizer, VLLMOutputNormalizer @@ -69,10 +70,9 @@ async def process_request( self, data: dict, api_key: str, - request_url_path: str, + is_fim_request: bool, client_type: ClientType, ): - is_fim_request = self._is_fim_request(request_url_path, data) try: # Pass the potentially None api_key to complete stream = await self.complete( @@ -146,10 +146,10 @@ async def create_completion( # Add the vLLM base URL to the request base_url = self._get_base_url() data["base_url"] = base_url - + is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data) return await self.process_request( data, api_key, - request.url.path, + is_fim_request, request.state.detected_client, ) diff --git a/tests/muxing/test_rulematcher.py b/tests/muxing/test_rulematcher.py new file mode 100644 index 00000000..4e489799 --- /dev/null +++ b/tests/muxing/test_rulematcher.py @@ -0,0 +1,126 @@ +from unittest.mock import MagicMock + +import pytest + +from codegate.db import models as db_models +from codegate.muxing import models as mux_models +from codegate.muxing import rulematcher + +mocked_route_openai = rulematcher.ModelRoute( + db_models.ProviderModel( + provider_endpoint_id="1", provider_endpoint_name="fake-openai", name="fake-gpt" + ), + db_models.ProviderEndpoint( + id="1", + name="fake-openai", + description="fake-openai", + provider_type="fake-openai", + endpoint="http://localhost/openai", + auth_type="api_key", + ), + db_models.ProviderAuthMaterial( + provider_endpoint_id="1", auth_type="api_key", auth_blob="fake-api-key" + ), +) + + +@pytest.mark.parametrize( + "matcher_blob, thing_to_match", + [ + (None, None), + ("fake-matcher-blob", None), + ( + "fake-matcher-blob", + mux_models.ThingToMatchMux( + body={}, + url_request_path="/chat/completions", + is_fim_request=False, + client_type="generic", + ), + ), + ], +) +def test_catch_all(matcher_blob, thing_to_match): + muxing_rule_matcher = rulematcher.CatchAllMuxingRuleMatcher(mocked_route_openai, matcher_blob) + # It should always match + assert muxing_rule_matcher.match(thing_to_match) is True + + +@pytest.mark.parametrize( + "matcher_blob, filenames_to_match, expected_bool", + [ + (None, [], False), # Empty filenames and no blob + (None, ["main.py"], False), # Empty blob + (".py", ["main.py"], True), # Extension match + ("main.py", ["main.py"], True), # Full name match + (".py", ["main.py", "test.py"], True), # Extension match + ("main.py", ["main.py", "test.py"], True), # Full name match + ("main.py", ["test.py"], False), # Full name no match + (".js", ["main.py", "test.py"], False), # Extension no match + ], +) +def test_file_matcher(matcher_blob, filenames_to_match, expected_bool): + muxing_rule_matcher = rulematcher.FileMuxingRuleMatcher(mocked_route_openai, matcher_blob) + # We mock the _extract_request_filenames method to return a list of filenames + # The logic to get the correct filenames from snippets is tested in /tests/extract_snippets + muxing_rule_matcher._extract_request_filenames = MagicMock(return_value=filenames_to_match) + mocked_thing_to_match = mux_models.ThingToMatchMux( + body={}, + url_request_path="/chat/completions", + is_fim_request=False, + client_type="generic", + ) + assert muxing_rule_matcher.match(mocked_thing_to_match) == expected_bool + + +@pytest.mark.parametrize( + "matcher_blob, thing_to_match, expected_bool", + [ + (None, None, False), # Empty blob + ( + "fim", + mux_models.ThingToMatchMux( + body={}, + url_request_path="/chat/completions", + is_fim_request=False, + client_type="generic", + ), + False, + ), # No match + ( + "fim", + mux_models.ThingToMatchMux( + body={}, + url_request_path="/chat/completions", + is_fim_request=True, + client_type="generic", + ), + True, + ), # Match + ( + "chat", + mux_models.ThingToMatchMux( + body={}, + url_request_path="/chat/completions", + is_fim_request=True, + client_type="generic", + ), + False, + ), # No match + ( + "chat", + mux_models.ThingToMatchMux( + body={}, + url_request_path="/chat/completions", + is_fim_request=False, + client_type="generic", + ), + True, + ), # Match + ], +) +def test_request_type(matcher_blob, thing_to_match, expected_bool): + muxing_rule_matcher = rulematcher.RequestTypeMuxingRuleMatcher( + mocked_route_openai, matcher_blob + ) + assert muxing_rule_matcher.match(thing_to_match) == expected_bool diff --git a/tests/providers/test_fim_analyzer.py b/tests/providers/test_fim_analyzer.py new file mode 100644 index 00000000..e2b94b5d --- /dev/null +++ b/tests/providers/test_fim_analyzer.py @@ -0,0 +1,73 @@ +import pytest + +from codegate.providers.fim_analyzer import FIMAnalyzer + + +@pytest.mark.parametrize( + "url, expected_bool", + [ + ("http://localhost:8989", False), + ("http://test.com/chat/completions", False), + ("http://localhost:8989/completions", True), + ], +) +def test_is_fim_request_url(url, expected_bool): + assert FIMAnalyzer._is_fim_request_url(url) == expected_bool + + +DATA_CONTENT_STR = { + "messages": [ + { + "role": "user", + "content": " ", + } + ] +} +DATA_CONTENT_LIST = { + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": " "}], + } + ] +} +INVALID_DATA_CONTET = { + "messages": [ + { + "role": "user", + "content": "http://localhost:8989/completions", + } + ] +} +TOOL_DATA = { + "prompt": "cline", +} + + +@pytest.mark.parametrize( + "data, expected_bool", + [ + (DATA_CONTENT_STR, True), + (DATA_CONTENT_LIST, True), + (INVALID_DATA_CONTET, False), + ], +) +def test_is_fim_request_body(data, expected_bool): + assert FIMAnalyzer._is_fim_request_body(data) == expected_bool + + +@pytest.mark.parametrize( + "url, data, expected_bool", + [ + ("http://localhost:8989", DATA_CONTENT_STR, True), # True because of the data + ( + "http://test.com/chat/completions", + INVALID_DATA_CONTET, + False, + ), # False because of the url + ("http://localhost:8989/completions", DATA_CONTENT_STR, True), # True because of the url + ("http://localhost:8989/completions", TOOL_DATA, False), # False because of the tool data + ], +) +def test_is_fim_request(url, data, expected_bool): + assert FIMAnalyzer.is_fim_request(url, data) == expected_bool diff --git a/tests/test_provider.py b/tests/test_provider.py deleted file mode 100644 index fd9558fb..00000000 --- a/tests/test_provider.py +++ /dev/null @@ -1,85 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from codegate.providers.base import BaseProvider - - -class MockProvider(BaseProvider): - - def __init__(self): - mocked_input_normalizer = MagicMock() - mocked_output_normalizer = MagicMock() - mocked_completion_handler = MagicMock() - mocked_factory = MagicMock() - super().__init__( - mocked_input_normalizer, - mocked_output_normalizer, - mocked_completion_handler, - mocked_factory, - ) - - async def process_request(self, data: dict, api_key: str, request_url_path: str): - return {"message": "test"} - - def models(self): - return [] - - def _setup_routes(self) -> None: - pass - - @property - def provider_route_name(self) -> str: - return "mock-provider" - - -@pytest.mark.parametrize( - "url, expected_bool", - [ - ("http://example.com", False), - ("http://test.com/chat/completions", False), - ("http://example.com/completions", True), - ], -) -def test_is_fim_request_url(url, expected_bool): - mock_provider = MockProvider() - assert mock_provider._is_fim_request_url(url) == expected_bool - - -DATA_CONTENT_STR = { - "messages": [ - { - "role": "user", - "content": " ", - } - ] -} -DATA_CONTENT_LIST = { - "messages": [ - { - "role": "user", - "content": [{"type": "text", "text": " "}], - } - ] -} -INVALID_DATA_CONTET = { - "messages": [ - { - "role": "user", - "content": "http://example.com/completions", - } - ] -} - - -@pytest.mark.parametrize( - "data, expected_bool", - [ - (DATA_CONTENT_STR, True), - (DATA_CONTENT_LIST, True), - (INVALID_DATA_CONTET, False), - ], -) -def test_is_fim_request_body(data, expected_bool): - mock_provider = MockProvider() - assert mock_provider._is_fim_request_body(data) == expected_bool