diff --git a/scripts/import_packages.py b/scripts/import_packages.py index c66a271f..f980d17d 100644 --- a/scripts/import_packages.py +++ b/scripts/import_packages.py @@ -13,8 +13,7 @@ class PackageImporter: def __init__(self): self.client = weaviate.WeaviateClient( embedded_options=EmbeddedOptions( - persistence_data_path="./weaviate_data", - grpc_port=50052 + persistence_data_path="./weaviate_data", grpc_port=50052 ) ) self.json_files = [ @@ -46,13 +45,13 @@ def generate_vector_string(self, package): "npm": "JavaScript package available on NPM", "go": "Go package", "crates": "Rust package available on Crates", - "java": "Java package" + "java": "Java package", } status_messages = { "archived": "However, this package is found to be archived and no longer maintained.", "deprecated": "However, this package is found to be deprecated and no longer " "recommended for use.", - "malicious": "However, this package is found to be malicious." + "malicious": "However, this package is found to be malicious.", } vector_str += f" is a {type_map.get(package['type'], 'unknown type')} " package_url = f"https://trustypkg.dev/{package['type']}/{package['name']}" @@ -75,8 +74,9 @@ async def add_data(self): packages_dict = { f"{package.properties['name']}/{package.properties['type']}": { "status": package.properties["status"], - "description": package.properties["description"] - } for package in existing_packages + "description": package.properties["description"], + } + for package in existing_packages } for json_file in self.json_files: @@ -85,12 +85,12 @@ async def add_data(self): packages_to_insert = [] for line in f: package = json.loads(line) - package["status"] = json_file.split('/')[-1].split('.')[0] + package["status"] = json_file.split("/")[-1].split(".")[0] key = f"{package['name']}/{package['type']}" if key in packages_dict and packages_dict[key] == { "status": package["status"], - "description": package["description"] + "description": package["description"], }: print("Package already exists", key) continue @@ -102,8 +102,9 @@ async def add_data(self): # Synchronous batch insert after preparing all data with collection.batch.dynamic() as batch: for package, vector in packages_to_insert: - batch.add_object(properties=package, vector=vector, - uuid=generate_uuid5(package)) + batch.add_object( + properties=package, vector=vector, uuid=generate_uuid5(package) + ) async def run_import(self): self.setup_schema() diff --git a/src/codegate/__init__.py b/src/codegate/__init__.py index 15535fbb..042fba14 100644 --- a/src/codegate/__init__.py +++ b/src/codegate/__init__.py @@ -1,10 +1,10 @@ """Codegate - A Generative AI security gateway.""" -from importlib import metadata import logging as python_logging +from importlib import metadata +from codegate.codegate_logging import LogFormat, LogLevel, setup_logging from codegate.config import Config -from codegate.codegate_logging import setup_logging, LogFormat, LogLevel from codegate.exceptions import ConfigurationError try: diff --git a/src/codegate/cli.py b/src/codegate/cli.py index acd28530..8688947b 100644 --- a/src/codegate/cli.py +++ b/src/codegate/cli.py @@ -6,8 +6,8 @@ import click +from codegate.codegate_logging import LogFormat, LogLevel, setup_logging from codegate.config import Config, ConfigurationError -from codegate.codegate_logging import setup_logging, LogFormat, LogLevel from codegate.server import init_app diff --git a/src/codegate/codegate_logging.py b/src/codegate/codegate_logging.py index 9656eadc..a57a1579 100644 --- a/src/codegate/codegate_logging.py +++ b/src/codegate/codegate_logging.py @@ -1,8 +1,8 @@ import datetime -from enum import Enum import json import logging import sys +from enum import Enum from typing import Any, Optional diff --git a/src/codegate/config.py b/src/codegate/config.py index 3d39134c..e63e5fc1 100644 --- a/src/codegate/config.py +++ b/src/codegate/config.py @@ -7,7 +7,7 @@ import yaml -from codegate.codegate_logging import setup_logging, LogFormat, LogLevel +from codegate.codegate_logging import LogFormat, LogLevel, setup_logging from codegate.exceptions import ConfigurationError from codegate.prompts import PromptConfig @@ -52,9 +52,7 @@ def __post_init__(self) -> None: @staticmethod def _load_default_prompts() -> PromptConfig: """Load default prompts from prompts/default.yaml.""" - default_prompts_path = ( - Path(__file__).parent.parent.parent / "prompts" / "default.yaml" - ) + default_prompts_path = Path(__file__).parent.parent.parent / "prompts" / "default.yaml" try: return PromptConfig.from_file(default_prompts_path) except Exception as e: diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index d0e77602..b8875dda 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -14,6 +14,7 @@ class CodeSnippet: language: The programming language identifier (e.g., 'python', 'javascript') code: The actual code content """ + language: str code: str @@ -24,6 +25,7 @@ def __post_init__(self): raise ValueError("Code must not be empty") self.language = self.language.strip().lower() + @dataclass class PipelineContext: code_snippets: List[CodeSnippet] = field(default_factory=list) @@ -35,13 +37,16 @@ def add_code_snippet(self, snippet: CodeSnippet): def get_snippets_by_language(self, language: str) -> List[CodeSnippet]: return [s for s in self.code_snippets if s.language.lower() == language.lower()] + @dataclass class PipelineResponse: """Response generated by a pipeline step""" + content: str step_name: str # The name of the pipeline step that generated this response model: str # Taken from the original request's model field + @dataclass class PipelineResult: """ @@ -49,6 +54,7 @@ class PipelineResult: Either contains a modified request to continue processing, or a response to return to the client. """ + request: Optional[ChatCompletionRequest] = None response: Optional[PipelineResponse] = None error_message: Optional[str] = None @@ -79,8 +85,8 @@ def name(self) -> str: @staticmethod def get_last_user_message( - request: ChatCompletionRequest, - ) -> Optional[tuple[str, int]]: + request: ChatCompletionRequest, + ) -> Optional[tuple[str, int]]: """ Get the last user message and its index from the request. @@ -122,9 +128,7 @@ def get_last_user_message( @abstractmethod async def process( - self, - request: ChatCompletionRequest, - context: PipelineContext + self, request: ChatCompletionRequest, context: PipelineContext ) -> PipelineResult: """Process a request and return either modified request or response stream""" pass @@ -135,8 +139,8 @@ def __init__(self, pipeline_steps: List[PipelineStep]): self.pipeline_steps = pipeline_steps async def process_request( - self, - request: ChatCompletionRequest, + self, + request: ChatCompletionRequest, ) -> PipelineResult: """ Process a request through all pipeline steps diff --git a/src/codegate/pipeline/version/version.py b/src/codegate/pipeline/version/version.py index 314c831f..9f809ace 100644 --- a/src/codegate/pipeline/version/version.py +++ b/src/codegate/pipeline/version/version.py @@ -23,9 +23,7 @@ def name(self) -> str: return "codegate-version" async def process( - self, - request: ChatCompletionRequest, - context: PipelineContext + self, request: ChatCompletionRequest, context: PipelineContext ) -> PipelineResult: """ Checks if the last user message contains "codegate-version" and diff --git a/src/codegate/prompts.py b/src/codegate/prompts.py index a656155d..63405a08 100644 --- a/src/codegate/prompts.py +++ b/src/codegate/prompts.py @@ -44,9 +44,7 @@ def from_file(cls, prompt_path: Union[str, Path]) -> "PromptConfig": # Validate all values are strings for key, value in prompt_data.items(): if not isinstance(value, str): - raise ConfigurationError( - f"Prompt '{key}' must be a string, got {type(value)}" - ) + raise ConfigurationError(f"Prompt '{key}' must be a string, got {type(value)}") return cls(prompts=prompt_data) except yaml.YAMLError as e: diff --git a/src/codegate/providers/anthropic/adapter.py b/src/codegate/providers/anthropic/adapter.py index ee9221d4..6f89bd4a 100644 --- a/src/codegate/providers/anthropic/adapter.py +++ b/src/codegate/providers/anthropic/adapter.py @@ -1,46 +1,30 @@ -from typing import Any, Dict, Optional - -from litellm import AdapterCompletionStreamWrapper, ChatCompletionRequest, ModelResponse from litellm.adapters.anthropic_adapter import ( AnthropicAdapter as LitellmAnthropicAdapter, ) -from litellm.types.llms.anthropic import AnthropicResponse -from codegate.providers.base import StreamGenerator -from codegate.providers.litellmshim import anthropic_stream_generator, BaseAdapter +from codegate.providers.litellmshim.adapter import ( + LiteLLMAdapterInputNormalizer, + LiteLLMAdapterOutputNormalizer, +) -class AnthropicAdapter(BaseAdapter): +class AnthropicInputNormalizer(LiteLLMAdapterInputNormalizer): """ LiteLLM's adapter class interface is used to translate between the Anthropic data format and the underlying model. The AnthropicAdapter class contains the actual implementation of the interface methods, we just forward the calls to it. """ - def __init__(self, stream_generator: StreamGenerator = anthropic_stream_generator): - self.litellm_anthropic_adapter = LitellmAnthropicAdapter() - super().__init__(stream_generator) + def __init__(self): + super().__init__(LitellmAnthropicAdapter()) - def translate_completion_input_params( - self, - completion_request: Dict, - ) -> Optional[ChatCompletionRequest]: - return self.litellm_anthropic_adapter.translate_completion_input_params( - completion_request - ) - def translate_completion_output_params( - self, response: ModelResponse - ) -> Optional[AnthropicResponse]: - return self.litellm_anthropic_adapter.translate_completion_output_params( - response - ) +class AnthropicOutputNormalizer(LiteLLMAdapterOutputNormalizer): + """ + LiteLLM's adapter class interface is used to translate between the Anthropic data + format and the underlying model. The AnthropicAdapter class contains the actual + implementation of the interface methods, we just forward the calls to it. + """ - def translate_completion_output_params_streaming( - self, completion_stream: Any - ) -> AdapterCompletionStreamWrapper | None: - return ( - self.litellm_anthropic_adapter.translate_completion_output_params_streaming( - completion_stream - ) - ) + def __init__(self): + super().__init__(LitellmAnthropicAdapter()) diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py index 1b39ee07..a16c5921 100644 --- a/src/codegate/providers/anthropic/provider.py +++ b/src/codegate/providers/anthropic/provider.py @@ -2,16 +2,20 @@ from fastapi import Header, HTTPException, Request +from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer from codegate.providers.base import BaseProvider -from codegate.providers.litellmshim import LiteLLmShim -from codegate.providers.anthropic.adapter import AnthropicAdapter +from codegate.providers.litellmshim import LiteLLmShim, anthropic_stream_generator class AnthropicProvider(BaseProvider): def __init__(self, pipeline_processor=None): - adapter = AnthropicAdapter() - completion_handler = LiteLLmShim(adapter) - super().__init__(completion_handler, pipeline_processor) + completion_handler = LiteLLmShim(stream_generator=anthropic_stream_generator) + super().__init__( + AnthropicInputNormalizer(), + AnthropicOutputNormalizer(), + completion_handler, + pipeline_processor, + ) @property def provider_route_name(self) -> str: diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 3fe29cc8..940d93b2 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -3,14 +3,16 @@ from fastapi import APIRouter from litellm import ModelResponse +from litellm.types.llms.openai import ChatCompletionRequest +from codegate.pipeline.base import PipelineResult, SequentialPipelineProcessor from codegate.providers.completion.base import BaseCompletionHandler from codegate.providers.formatting.input_pipeline import PipelineResponseFormatter - -from ..pipeline.base import SequentialPipelineProcessor +from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer StreamGenerator = Callable[[AsyncIterator[Any]], AsyncIterator[str]] + class BaseProvider(ABC): """ The provider class is responsible for defining the API routes and @@ -19,14 +21,19 @@ class BaseProvider(ABC): def __init__( self, + input_normalizer: ModelInputNormalizer, + output_normalizer: ModelOutputNormalizer, completion_handler: BaseCompletionHandler, - pipeline_processor: Optional[SequentialPipelineProcessor] = None + pipeline_processor: Optional[SequentialPipelineProcessor] = None, ): self.router = APIRouter() self._completion_handler = completion_handler + self._input_normalizer = input_normalizer + self._output_normalizer = output_normalizer self._pipeline_processor = pipeline_processor - self._pipeline_response_formatter = \ - PipelineResponseFormatter(completion_handler) + + self._pipeline_response_formatter = PipelineResponseFormatter(output_normalizer) + self._setup_routes() @abstractmethod @@ -38,9 +45,26 @@ def _setup_routes(self) -> None: def provider_route_name(self) -> str: pass + async def _run_input_pipeline( + self, + normalized_request: ChatCompletionRequest, + ) -> PipelineResult: + if self._pipeline_processor is None: + return PipelineResult(request=normalized_request) + + result = await self._pipeline_processor.process_request(normalized_request) + + # TODO(jakub): handle this by returning a message to the client + if result.error_message: + raise Exception(result.error_message) + + return result + async def complete( - self, data: Dict, api_key: str, - ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: + self, + data: Dict, + api_key: Optional[str], + ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """ Main completion flow with pipeline integration @@ -52,31 +76,27 @@ async def complete( - Execute the completion and translate the response back to the provider-specific format """ - completion_request = self._completion_handler.translate_request(data, api_key) + normalized_request = self._input_normalizer.normalize(data) streaming = data.get("stream", False) - if self._pipeline_processor is not None: - result = await self._pipeline_processor.process_request(completion_request) + input_pipeline_result = await self._run_input_pipeline(normalized_request) + if input_pipeline_result.response: + return self._pipeline_response_formatter.handle_pipeline_response( + input_pipeline_result.response, streaming + ) - if result.error_message: - raise Exception(result.error_message) - - if result.response: - return self._pipeline_response_formatter.handle_pipeline_response( - result.response, streaming) - - completion_request = result.request + provider_request = self._input_normalizer.denormalize(input_pipeline_result.request) # Execute the completion and translate the response # This gives us either a single response or a stream of responses # based on the streaming flag - raw_response = await self._completion_handler.execute_completion( - completion_request, - stream=streaming + model_response = await self._completion_handler.execute_completion( + provider_request, api_key=api_key, stream=streaming ) + if not streaming: - return self._completion_handler.translate_response(raw_response) - return self._completion_handler.translate_streaming_response(raw_response) + return self._output_normalizer.denormalize(model_response) + return self._output_normalizer.denormalize_streaming(model_response) def get_routes(self) -> APIRouter: return self.router diff --git a/src/codegate/providers/completion/__init__.py b/src/codegate/providers/completion/__init__.py index e69de29b..80a0fefd 100644 --- a/src/codegate/providers/completion/__init__.py +++ b/src/codegate/providers/completion/__init__.py @@ -0,0 +1,5 @@ +from codegate.providers.completion.base import BaseCompletionHandler + +__all__ = [ + "BaseCompletionHandler", +] diff --git a/src/codegate/providers/completion/base.py b/src/codegate/providers/completion/base.py index 906e20d4..2bba9bc2 100644 --- a/src/codegate/providers/completion/base.py +++ b/src/codegate/providers/completion/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, AsyncIterator, Dict, Union +from typing import Any, AsyncIterator, Optional, Union from fastapi.responses import StreamingResponse from litellm import ChatCompletionRequest, ModelResponse @@ -11,39 +11,16 @@ class BaseCompletionHandler(ABC): and creating the streaming response. """ - @abstractmethod - def translate_request(self, data: Dict, api_key: str) -> ChatCompletionRequest: - """Convert raw request data into a ChatCompletionRequest""" - pass - @abstractmethod async def execute_completion( - self, - request: ChatCompletionRequest, - stream: bool = False + self, + request: ChatCompletionRequest, + api_key: Optional[str], + stream: bool = False, # TODO: remove this param? ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """Execute the completion request""" pass @abstractmethod - def create_streaming_response( - self, stream: AsyncIterator[Any] - ) -> StreamingResponse: - pass - - @abstractmethod - def translate_response( - self, - response: ModelResponse, - ) -> ModelResponse: - """Convert pipeline response to provider-specific format""" + def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse: pass - - @abstractmethod - def translate_streaming_response( - self, - response: AsyncIterator[ModelResponse], - ) -> AsyncIterator[ModelResponse]: - """Convert pipeline response to provider-specific format""" - pass - diff --git a/src/codegate/providers/formatting/__init__.py b/src/codegate/providers/formatting/__init__.py index e69de29b..13ba54a4 100644 --- a/src/codegate/providers/formatting/__init__.py +++ b/src/codegate/providers/formatting/__init__.py @@ -0,0 +1,5 @@ +from codegate.providers.formatting.input_pipeline import PipelineResponseFormatter + +__all__ = [ + "PipelineResponseFormatter", +] diff --git a/src/codegate/providers/formatting/input_pipeline.py b/src/codegate/providers/formatting/input_pipeline.py index 6cf54a8d..01d5ef7a 100644 --- a/src/codegate/providers/formatting/input_pipeline.py +++ b/src/codegate/providers/formatting/input_pipeline.py @@ -5,7 +5,7 @@ from litellm.types.utils import Delta, StreamingChoices from codegate.pipeline.base import PipelineResponse -from codegate.providers.completion.base import BaseCompletionHandler +from codegate.providers.normalizer.base import ModelOutputNormalizer def _create_stream_end_response(original_response: ModelResponse) -> ModelResponse: @@ -14,24 +14,21 @@ def _create_stream_end_response(original_response: ModelResponse) -> ModelRespon id=original_response.id, choices=[ StreamingChoices( - finish_reason="stop", - index=0, - delta=Delta( - content="", - role=None - ), - logprobs=None + finish_reason="stop", index=0, delta=Delta(content="", role=None), logprobs=None ) ], created=original_response.created, model=original_response.model, - object="chat.completion.chunk" + object="chat.completion.chunk", ) def _create_model_response( - content: str, step_name: str, model: str, streaming: bool, - ) -> ModelResponse: + content: str, + step_name: str, + model: str, + streaming: bool, +) -> ModelResponse: """ Create a ModelResponse in either streaming or non-streaming format This is required because the ModelResponse format is different for streaming @@ -47,33 +44,28 @@ def _create_model_response( StreamingChoices( finish_reason=None, index=0, - delta=Delta( - content=content, - role="assistant" - ), - logprobs=None + delta=Delta(content=content, role="assistant"), + logprobs=None, ) ], created=created, model=model, - object="chat.completion.chunk" + object="chat.completion.chunk", ) else: return ModelResponse( id=response_id, - choices=[{ - "text": content, - "index": 0, - "finish_reason": None - }], + choices=[{"text": content, "index": 0, "finish_reason": None}], created=created, - model=model + model=model, ) async def _convert_to_stream( - content: str, step_name: str, model: str, - ) -> AsyncIterator[ModelResponse]: + content: str, + step_name: str, + model: str, +) -> AsyncIterator[ModelResponse]: """ Converts a single completion response, provided by our pipeline as a shortcut to a streaming response. The streaming response has two chunks: the first @@ -87,15 +79,14 @@ async def _convert_to_stream( class PipelineResponseFormatter: - def __init__(self, - completion_handler: BaseCompletionHandler, - ): - self._completion_handler = completion_handler + def __init__( + self, + output_normalizer: ModelOutputNormalizer, + ): + self._output_normalizer = output_normalizer def handle_pipeline_response( - self, - pipeline_response: PipelineResponse, - streaming: bool + self, pipeline_response: PipelineResponse, streaming: bool ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """ Convert pipeline response to appropriate format based on streaming flag @@ -109,21 +100,16 @@ def handle_pipeline_response( pipeline_response.content, pipeline_response.step_name, pipeline_response.model, - streaming=streaming + streaming=streaming, ) if not streaming: # If we're not streaming, we just return the response translated # to the provider-specific format - return self._completion_handler.translate_response(model_response) + return self._output_normalizer.denormalize(model_response) # If we're streaming, we need to convert the response to a stream first # then feed the stream into the completion handler's conversion method model_response_stream = _convert_to_stream( - pipeline_response.content, - pipeline_response.step_name, - pipeline_response.model + pipeline_response.content, pipeline_response.step_name, pipeline_response.model ) - return self._completion_handler.translate_streaming_response( - model_response_stream - ) - + return self._output_normalizer.denormalize_streaming(model_response_stream) diff --git a/src/codegate/providers/litellmshim/__init__.py b/src/codegate/providers/litellmshim/__init__.py index ab470e3c..b2561059 100644 --- a/src/codegate/providers/litellmshim/__init__.py +++ b/src/codegate/providers/litellmshim/__init__.py @@ -1,13 +1,13 @@ from codegate.providers.litellmshim.adapter import BaseAdapter from codegate.providers.litellmshim.generators import ( - anthropic_stream_generator, sse_stream_generator, llamacpp_stream_generator + anthropic_stream_generator, + sse_stream_generator, ) from codegate.providers.litellmshim.litellmshim import LiteLLmShim __all__ = [ "sse_stream_generator", "anthropic_stream_generator", - "llamacpp_stream_generator", "LiteLLmShim", "BaseAdapter", ] diff --git a/src/codegate/providers/litellmshim/adapter.py b/src/codegate/providers/litellmshim/adapter.py index b1c349f0..c0b1a6a9 100644 --- a/src/codegate/providers/litellmshim/adapter.py +++ b/src/codegate/providers/litellmshim/adapter.py @@ -1,9 +1,10 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, Optional +from typing import Any, AsyncIterable, AsyncIterator, Dict, Iterable, Iterator, Optional, Union from litellm import ChatCompletionRequest, ModelResponse from codegate.providers.base import StreamGenerator +from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer class BaseAdapter(ABC): @@ -22,9 +23,7 @@ def __init__(self, stream_generator: StreamGenerator): self.stream_generator = stream_generator @abstractmethod - def translate_completion_input_params( - self, kwargs: Dict - ) -> Optional[ChatCompletionRequest]: + def translate_completion_input_params(self, kwargs: Dict) -> Optional[ChatCompletionRequest]: """Convert input parameters to LiteLLM's ChatCompletionRequest format""" pass @@ -34,11 +33,67 @@ def translate_completion_output_params(self, response: ModelResponse) -> Any: pass @abstractmethod - def translate_completion_output_params_streaming( - self, completion_stream: Any - ) -> Any: + def translate_completion_output_params_streaming(self, completion_stream: Any) -> Any: """ Convert streaming response from LiteLLM format to a format that can be passed to a stream generator and to the client. """ pass + + +class LiteLLMAdapterInputNormalizer(ModelInputNormalizer): + def __init__(self, adapter: BaseAdapter): + self._adapter = adapter + + def normalize(self, data: Dict) -> ChatCompletionRequest: + """ + Uses an LiteLLM adapter to translate the request data from the native + LLM format to the OpenAI API format used by LiteLLM internally. + """ + return self._adapter.translate_completion_input_params(data) + + def denormalize(self, data: ChatCompletionRequest) -> Dict: + """ + For LiteLLM, we don't have to de-normalize as the input format is + always ChatCompletionRequest which is a TypedDict which is a Dict + """ + return data + + +class LiteLLMAdapterOutputNormalizer(ModelOutputNormalizer): + def __init__(self, adapter: BaseAdapter): + self._adapter = adapter + + def normalize_streaming( + self, + model_reply: Union[AsyncIterable[Any], Iterable[Any]], + ) -> Union[AsyncIterator[ModelResponse], Iterator[ModelResponse]]: + """ + Normalize the output stream. This is a pass-through for liteLLM output normalizer + as the liteLLM output is already in the normalized format. + """ + return model_reply + + def normalize(self, model_reply: Any) -> ModelResponse: + """ + Normalize the output data. This is a pass-through for liteLLM output normalizer + as the liteLLM output is already in the normalized format. + """ + return model_reply + + def denormalize(self, normalized_reply: ModelResponse) -> Any: + """ + Denormalize the output data from the completion function to the format + expected by the client + """ + return self._adapter.translate_completion_output_params(normalized_reply) + + def denormalize_streaming( + self, + normalized_reply: Union[AsyncIterable[ModelResponse], Iterable[ModelResponse]], + ) -> Union[AsyncIterator[Any], Iterator[Any]]: + """ + Denormalize the output stream from the completion function to the format + expected by the client + """ + return self._adapter.translate_completion_output_params_streaming(normalized_reply) diff --git a/src/codegate/providers/litellmshim/generators.py b/src/codegate/providers/litellmshim/generators.py index c9ad8fc8..306f1900 100644 --- a/src/codegate/providers/litellmshim/generators.py +++ b/src/codegate/providers/litellmshim/generators.py @@ -1,6 +1,5 @@ import json -from typing import Any, AsyncIterator, Iterator -import asyncio +from typing import Any, AsyncIterator from pydantic import BaseModel @@ -38,20 +37,3 @@ async def anthropic_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterato yield f"event: {event_type}\ndata:{str(e)}\n\n" except Exception as e: yield f"data: {str(e)}\n\n" - - -async def llamacpp_stream_generator(stream: Iterator[Any]) -> AsyncIterator[str]: - """OpenAI-style SSE format""" - try: - for chunk in stream: - if hasattr(chunk, "model_dump_json"): - chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True) - try: - yield f"data:{json.dumps(chunk)}\n\n" - await asyncio.sleep(0) - except Exception as e: - yield f"data:{str(e)}\n\n" - except Exception as e: - yield f"data: {str(e)}\n\n" - finally: - yield "data: [DONE]\n\n" diff --git a/src/codegate/providers/litellmshim/litellmshim.py b/src/codegate/providers/litellmshim/litellmshim.py index 53f08443..1b4dcdf5 100644 --- a/src/codegate/providers/litellmshim/litellmshim.py +++ b/src/codegate/providers/litellmshim/litellmshim.py @@ -1,10 +1,9 @@ -from typing import Any, AsyncIterator, Dict, Union +from typing import Any, AsyncIterator, Optional, Union from fastapi.responses import StreamingResponse from litellm import ChatCompletionRequest, ModelResponse, acompletion -from codegate.providers.base import BaseCompletionHandler -from codegate.providers.litellmshim.adapter import BaseAdapter +from codegate.providers.base import BaseCompletionHandler, StreamGenerator class LiteLLmShim(BaseCompletionHandler): @@ -14,60 +13,29 @@ class LiteLLmShim(BaseCompletionHandler): LiteLLM API. """ - def __init__(self, adapter: BaseAdapter, completion_func=acompletion): - self._adapter = adapter + def __init__(self, stream_generator: StreamGenerator, completion_func=acompletion): + self._stream_generator = stream_generator self._completion_func = completion_func - def translate_request(self, data: Dict, api_key: str) -> ChatCompletionRequest: - """ - Uses the configured adapter to translate the request data from the native - LLM API format to the OpenAI API format used by LiteLLM internally. - - The OpenAPI format is also what our pipeline expects. - """ - data["api_key"] = api_key - completion_request = self._adapter.translate_completion_input_params(data) - if completion_request is None: - raise Exception("Couldn't translate the request") - return completion_request - - def translate_streaming_response( - self, - response: AsyncIterator[ModelResponse], - ) -> AsyncIterator[ModelResponse]: - """ - Convert pipeline or completion response to provider-specific stream - """ - return self._adapter.translate_completion_output_params_streaming(response) - - def translate_response( - self, - response: ModelResponse, - ) -> ModelResponse: - """ - Convert pipeline or completion response to provider-specific format - """ - return self._adapter.translate_completion_output_params(response) - async def execute_completion( self, request: ChatCompletionRequest, - stream: bool = False + api_key: Optional[str], + stream: bool = False, ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """ Execute the completion request with LiteLLM's API """ + request["api_key"] = api_key return await self._completion_func(**request) - def create_streaming_response( - self, stream: AsyncIterator[Any] - ) -> StreamingResponse: + def create_streaming_response(self, stream: AsyncIterator[Any]) -> StreamingResponse: """ Create a streaming response from a stream generator. The StreamingResponse is the format that FastAPI expects for streaming responses. """ return StreamingResponse( - self._adapter.stream_generator(stream), + self._stream_generator(stream), headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", diff --git a/src/codegate/providers/llamacpp/adapter.py b/src/codegate/providers/llamacpp/adapter.py deleted file mode 100644 index b6f9d394..00000000 --- a/src/codegate/providers/llamacpp/adapter.py +++ /dev/null @@ -1,32 +0,0 @@ -from typing import Any, AsyncIterator, Dict, Optional - -from litellm import ChatCompletionRequest, ModelResponse - -from codegate.providers.base import StreamGenerator -from codegate.providers.litellmshim import llamacpp_stream_generator, BaseAdapter - - -class LlamaCppAdapter(BaseAdapter): - """ - This is just a wrapper around LiteLLM's adapter class interface that passes - through the input and output as-is - LiteLLM's API expects OpenAI's API - format. - """ - def __init__(self, stream_generator: StreamGenerator = llamacpp_stream_generator): - super().__init__(stream_generator) - - def translate_completion_input_params( - self, kwargs: Dict - ) -> Optional[ChatCompletionRequest]: - try: - return ChatCompletionRequest(**kwargs) - except Exception as e: - raise ValueError(f"Invalid completion parameters: {str(e)}") - - def translate_completion_output_params(self, response: ModelResponse) -> Any: - return response - - def translate_completion_output_params_streaming( - self, completion_stream: AsyncIterator[ModelResponse] - ) -> AsyncIterator[ModelResponse]: - return completion_stream diff --git a/src/codegate/providers/llamacpp/completion_handler.py b/src/codegate/providers/llamacpp/completion_handler.py index 822947eb..40046975 100644 --- a/src/codegate/providers/llamacpp/completion_handler.py +++ b/src/codegate/providers/llamacpp/completion_handler.py @@ -1,49 +1,38 @@ -from typing import Any, AsyncIterator, Dict, Union +import json +import asyncio +from typing import Any, AsyncIterator, Iterator, Optional, Union from fastapi.responses import StreamingResponse from litellm import ChatCompletionRequest, ModelResponse -from codegate.providers.base import BaseCompletionHandler -from codegate.providers.llamacpp.adapter import BaseAdapter -from codegate.inference.inference_engine import LlamaCppInferenceEngine from codegate.config import Config +from codegate.inference.inference_engine import LlamaCppInferenceEngine +from codegate.providers.base import BaseCompletionHandler -class LlamaCppCompletionHandler(BaseCompletionHandler): - def __init__(self, adapter: BaseAdapter): - self._adapter = adapter - self.inference_engine = LlamaCppInferenceEngine() - - def translate_request(self, data: Dict, api_key: str) -> ChatCompletionRequest: - completion_request = self._adapter.translate_completion_input_params( - data) - if completion_request is None: - raise Exception("Couldn't translate the request") - - return ChatCompletionRequest(**completion_request) +async def llamacpp_stream_generator(stream: Iterator[Any]) -> AsyncIterator[str]: + """OpenAI-style SSE format""" + try: + for chunk in stream: + if hasattr(chunk, "model_dump_json"): + chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True) + try: + yield f"data:{json.dumps(chunk)}\n\n" + await asyncio.sleep(0) + except Exception as e: + yield f"data:{str(e)}\n\n" + except Exception as e: + yield f"data: {str(e)}\n\n" + finally: + yield "data: [DONE]\n\n" - def translate_streaming_response( - self, - response: AsyncIterator[ModelResponse], - ) -> AsyncIterator[ModelResponse]: - """ - Convert pipeline or completion response to provider-specific stream - """ - return self._adapter.translate_completion_output_params_streaming(response) - def translate_response( - self, - response: ModelResponse, - ) -> ModelResponse: - """ - Convert pipeline or completion response to provider-specific format - """ - return self._adapter.translate_completion_output_params(response) +class LlamaCppCompletionHandler(BaseCompletionHandler): + def __init__(self): + self.inference_engine = LlamaCppInferenceEngine() async def execute_completion( - self, - request: ChatCompletionRequest, - stream: bool = False + self, request: ChatCompletionRequest, api_key: Optional[str], stream: bool = False ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: """ Execute the completion request with inference engine API @@ -62,15 +51,13 @@ async def execute_completion( **request) return response - def create_streaming_response( - self, stream: AsyncIterator[Any] - ) -> StreamingResponse: + def create_streaming_response(self, stream: Iterator[Any]) -> StreamingResponse: """ Create a streaming response from a stream generator. The StreamingResponse is the format that FastAPI expects for streaming responses. """ return StreamingResponse( - self._adapter.stream_generator(stream), + llamacpp_stream_generator(stream), headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", diff --git a/src/codegate/providers/llamacpp/provider.py b/src/codegate/providers/llamacpp/provider.py index a3227085..26291cdc 100644 --- a/src/codegate/providers/llamacpp/provider.py +++ b/src/codegate/providers/llamacpp/provider.py @@ -4,14 +4,18 @@ from codegate.providers.base import BaseProvider from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler -from codegate.providers.llamacpp.adapter import LlamaCppAdapter +from codegate.providers.llamacpp.normalizer import LLamaCppInputNormalizer, LLamaCppOutputNormalizer class LlamaCppProvider(BaseProvider): def __init__(self, pipeline_processor=None): - adapter = LlamaCppAdapter() - completion_handler = LlamaCppCompletionHandler(adapter) - super().__init__(completion_handler, pipeline_processor) + completion_handler = LlamaCppCompletionHandler() + super().__init__( + LLamaCppInputNormalizer(), + LLamaCppOutputNormalizer(), + completion_handler, + pipeline_processor, + ) @property def provider_route_name(self) -> str: @@ -30,5 +34,5 @@ async def create_completion( body = await request.body() data = json.loads(body) - stream = await self.complete(data, None) + stream = await self.complete(data, api_key=None) return self._completion_handler.create_streaming_response(stream) diff --git a/src/codegate/providers/normalizer/__init__.py b/src/codegate/providers/normalizer/__init__.py new file mode 100644 index 00000000..6d5ba244 --- /dev/null +++ b/src/codegate/providers/normalizer/__init__.py @@ -0,0 +1,6 @@ +from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer + +__all__ = [ + "ModelInputNormalizer", + "ModelOutputNormalizer", +] diff --git a/src/codegate/providers/normalizer/base.py b/src/codegate/providers/normalizer/base.py new file mode 100644 index 00000000..625842c9 --- /dev/null +++ b/src/codegate/providers/normalizer/base.py @@ -0,0 +1,58 @@ +from abc import ABC, abstractmethod +from typing import Any, AsyncIterable, AsyncIterator, Dict, Iterable, Iterator, Union + +from litellm import ChatCompletionRequest, ModelResponse + + +class ModelInputNormalizer(ABC): + """ + The normalizer class is responsible for normalizing the input data + before it is passed to the pipeline. It converts the input data (raw request) + to the format expected by the pipeline. + """ + + @abstractmethod + def normalize(self, data: Dict) -> ChatCompletionRequest: + """Normalize the input data""" + pass + + @abstractmethod + def denormalize(self, data: ChatCompletionRequest) -> Dict: + """Denormalize the input data""" + pass + + +class ModelOutputNormalizer(ABC): + """ + The output normalizer class is responsible for normalizing the output data + from a model to the format expected by the output pipeline. + + The normalize methods are not implemented yet - they will be when we get + around to implementing output pipelines. + """ + + @abstractmethod + def normalize_streaming( + self, + model_reply: Union[AsyncIterable[Any], Iterable[Any]], + ) -> Union[AsyncIterator[ModelResponse], Iterator[ModelResponse]]: + """Normalize the output data""" + pass + + @abstractmethod + def normalize(self, model_reply: Any) -> ModelResponse: + """Normalize the output data""" + pass + + @abstractmethod + def denormalize(self, normalized_reply: ModelResponse) -> Any: + """Denormalize the output data""" + pass + + @abstractmethod + def denormalize_streaming( + self, + normalized_reply: Union[AsyncIterable[ModelResponse], Iterable[ModelResponse]], + ) -> Union[AsyncIterator[Any], Iterator[Any]]: + """Denormalize the output data""" + pass diff --git a/src/codegate/providers/openai/adapter.py b/src/codegate/providers/openai/adapter.py index c7f9b6a6..b5f4565a 100644 --- a/src/codegate/providers/openai/adapter.py +++ b/src/codegate/providers/openai/adapter.py @@ -1,33 +1,57 @@ -from typing import Any, AsyncIterator, Dict, Optional - -from litellm import ChatCompletionRequest, ModelResponse - -from codegate.providers.base import StreamGenerator -from codegate.providers.litellmshim import sse_stream_generator, BaseAdapter - - -class OpenAIAdapter(BaseAdapter): - """ - This is just a wrapper around LiteLLM's adapter class interface that passes - through the input and output as-is - LiteLLM's API expects OpenAI's API - format. - """ - - def __init__(self, stream_generator: StreamGenerator = sse_stream_generator): - super().__init__(stream_generator) - - def translate_completion_input_params( - self, kwargs: Dict - ) -> Optional[ChatCompletionRequest]: - try: - return ChatCompletionRequest(**kwargs) - except Exception as e: - raise ValueError(f"Invalid completion parameters: {str(e)}") - - def translate_completion_output_params(self, response: ModelResponse) -> Any: - return response - - def translate_completion_output_params_streaming( - self, completion_stream: AsyncIterator[ModelResponse] - ) -> AsyncIterator[ModelResponse]: - return completion_stream +from typing import Any, Dict + +from litellm import ChatCompletionRequest + +from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer + + +class OpenAIInputNormalizer(ModelInputNormalizer): + def __init__(self): + super().__init__() + + def normalize(self, data: Dict) -> ChatCompletionRequest: + """ + No normalizing needed, already OpenAI format + """ + return ChatCompletionRequest(**data) + + def denormalize(self, data: ChatCompletionRequest) -> Dict: + """ + No denormalizing needed, already OpenAI format + """ + return data + + +class OpenAIOutputNormalizer(ModelOutputNormalizer): + def __init__(self): + super().__init__() + + def normalize_streaming( + self, + model_reply: Any, + ) -> Any: + """ + No normalizing needed, already OpenAI format + """ + return model_reply + + def normalize(self, model_reply: Any) -> Any: + """ + No normalizing needed, already OpenAI format + """ + return model_reply + + def denormalize(self, normalized_reply: Any) -> Any: + """ + No denormalizing needed, already OpenAI format + """ + return normalized_reply + + def denormalize_streaming( + self, + normalized_reply: Any, + ) -> Any: + """ + No denormalizing needed, already OpenAI format + """ + return normalized_reply diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py index 16167c95..6d1e6c1d 100644 --- a/src/codegate/providers/openai/provider.py +++ b/src/codegate/providers/openai/provider.py @@ -3,15 +3,19 @@ from fastapi import Header, HTTPException, Request from codegate.providers.base import BaseProvider -from codegate.providers.litellmshim import LiteLLmShim -from codegate.providers.openai.adapter import OpenAIAdapter +from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator +from codegate.providers.openai.adapter import OpenAIInputNormalizer, OpenAIOutputNormalizer class OpenAIProvider(BaseProvider): def __init__(self, pipeline_processor=None): - adapter = OpenAIAdapter() - completion_handler = LiteLLmShim(adapter) - super().__init__(completion_handler, pipeline_processor) + completion_handler = LiteLLmShim(stream_generator=sse_stream_generator) + super().__init__( + OpenAIInputNormalizer(), + OpenAIOutputNormalizer(), + completion_handler, + pipeline_processor, + ) @property def provider_route_name(self) -> str: @@ -30,9 +34,7 @@ async def create_completion( authorization: str = Header(..., description="Bearer token"), ): if not authorization.startswith("Bearer "): - raise HTTPException( - status_code=401, detail="Invalid authorization header" - ) + raise HTTPException(status_code=401, detail="Invalid authorization header") api_key = authorization.split(" ")[1] body = await request.body() diff --git a/src/codegate/server.py b/src/codegate/server.py index 0db158f7..359425a2 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, FastAPI from codegate import __description__, __version__ -from codegate.pipeline.base import SequentialPipelineProcessor, PipelineStep +from codegate.pipeline.base import PipelineStep, SequentialPipelineProcessor from codegate.pipeline.version.version import CodegateVersion from codegate.providers.anthropic.provider import AnthropicProvider from codegate.providers.llamacpp.provider import LlamaCppProvider diff --git a/tests/providers/anthropic/test_adapter.py b/tests/providers/anthropic/test_adapter.py index 8eab6667..9bb81e54 100644 --- a/tests/providers/anthropic/test_adapter.py +++ b/tests/providers/anthropic/test_adapter.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, Dict, List, Union +from typing import List, Union import pytest from litellm import ModelResponse @@ -12,24 +12,24 @@ ) from litellm.types.utils import Delta, StreamingChoices -from codegate.providers.anthropic.adapter import AnthropicAdapter +from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer @pytest.fixture -def adapter(): - return AnthropicAdapter() +def input_normalizer(): + return AnthropicInputNormalizer() -def test_translate_completion_input_params(adapter): +def test_normalize_anthropic_input(input_normalizer): # Test input data completion_request = { "model": "claude-3-haiku-20240307", + "system": "You are an expert code reviewer", "max_tokens": 1024, "stream": True, "messages": [ { "role": "user", - "system": "You are an expert code reviewer", "content": [{"type": "text", "text": "Review this code"}], } ], @@ -37,19 +37,25 @@ def test_translate_completion_input_params(adapter): expected = { "max_tokens": 1024, "messages": [ - {"content": [{"text": "Review this code", "type": "text"}], "role": "user"} + {"content": "You are an expert code reviewer", "role": "system"}, + {"content": [{"text": "Review this code", "type": "text"}], "role": "user"}, ], "model": "claude-3-haiku-20240307", "stream": True, } # Get translation - result = adapter.translate_completion_input_params(completion_request) + result = input_normalizer.normalize(completion_request) assert result == expected +@pytest.fixture +def output_normalizer(): + return AnthropicOutputNormalizer() + + @pytest.mark.asyncio -async def test_translate_completion_output_params_streaming(adapter): +async def test_normalize_anthropic_output_stream(output_normalizer): # Test stream data async def mock_stream(): messages = [ @@ -129,7 +135,7 @@ async def mock_stream(): dict(type="message_stop"), ] - stream = adapter.translate_completion_output_params_streaming(mock_stream()) + stream = output_normalizer.denormalize_streaming(mock_stream()) assert isinstance(stream, AnthropicStreamWrapper) # just so that we can zip over the expected chunks @@ -139,20 +145,3 @@ async def mock_stream(): for chunk, expected_chunk in zip(stream_list, expected): assert chunk == expected_chunk - - -def test_stream_generator_initialization(adapter): - # Verify the default stream generator is set - from codegate.providers.litellmshim import anthropic_stream_generator - - assert adapter.stream_generator == anthropic_stream_generator - - -def test_custom_stream_generator(): - # Test that we can inject a custom stream generator - async def custom_generator(stream: AsyncIterator[Dict]) -> AsyncIterator[str]: - async for chunk in stream: - yield "custom: " + str(chunk) - - adapter = AnthropicAdapter(stream_generator=custom_generator) - assert adapter.stream_generator == custom_generator diff --git a/tests/providers/litellmshim/test_litellmshim.py b/tests/providers/litellmshim/test_litellmshim.py index 0e524220..73889a34 100644 --- a/tests/providers/litellmshim/test_litellmshim.py +++ b/tests/providers/litellmshim/test_litellmshim.py @@ -5,7 +5,7 @@ from fastapi.responses import StreamingResponse from litellm import ChatCompletionRequest, ModelResponse -from codegate.providers.litellmshim import BaseAdapter, LiteLLmShim +from codegate.providers.litellmshim import BaseAdapter, LiteLLmShim, sse_stream_generator class MockAdapter(BaseAdapter): @@ -38,24 +38,16 @@ async def modified_stream(): return modified_stream() -@pytest.fixture -def mock_adapter(): - return MockAdapter() - - -@pytest.fixture -def litellm_shim(mock_adapter): - return LiteLLmShim(mock_adapter) - - @pytest.mark.asyncio -async def test_complete_non_streaming(litellm_shim, mock_adapter): +async def test_complete_non_streaming(): # Mock response mock_response = ModelResponse(id="123", choices=[{"text": "test response"}]) mock_completion = AsyncMock(return_value=mock_response) # Create shim with mocked completion - litellm_shim = LiteLLmShim(mock_adapter, completion_func=mock_completion) + litellm_shim = LiteLLmShim( + stream_generator=sse_stream_generator, completion_func=mock_completion + ) # Test data data = { @@ -64,7 +56,7 @@ async def test_complete_non_streaming(litellm_shim, mock_adapter): } # Execute - result = await litellm_shim.execute_completion(data) + result = await litellm_shim.execute_completion(data, api_key=None) # Verify assert result == mock_response @@ -81,8 +73,9 @@ async def mock_stream() -> AsyncIterator[ModelResponse]: yield ModelResponse(id="123", choices=[{"text": "chunk2"}]) mock_completion = AsyncMock(return_value=mock_stream()) - mock_adapter = MockAdapter() - litellm_shim = LiteLLmShim(mock_adapter, completion_func=mock_completion) + litellm_shim = LiteLLmShim( + stream_generator=sse_stream_generator, completion_func=mock_completion + ) # Test data data = { @@ -92,7 +85,9 @@ async def mock_stream() -> AsyncIterator[ModelResponse]: } # Execute - result_stream = await litellm_shim.execute_completion(data) + result_stream = await litellm_shim.execute_completion( + ChatCompletionRequest(**data), api_key=None + ) # Verify stream contents and adapter processing chunks = [] @@ -112,7 +107,7 @@ async def mock_stream() -> AsyncIterator[ModelResponse]: @pytest.mark.asyncio -async def test_create_streaming_response(litellm_shim): +async def test_create_streaming_response(): # Create a simple async generator that we know works async def mock_stream_gen(): for msg in ["Hello", "World"]: @@ -121,6 +116,7 @@ async def mock_stream_gen(): # Create and verify the generator generator = mock_stream_gen() + litellm_shim = LiteLLmShim(stream_generator=sse_stream_generator) response = litellm_shim.create_streaming_response(generator) # Verify response metadata @@ -128,4 +124,4 @@ async def mock_stream_gen(): assert response.status_code == 200 assert response.headers["Cache-Control"] == "no-cache" assert response.headers["Connection"] == "keep-alive" - assert response.headers["Transfer-Encoding"] == "chunked" \ No newline at end of file + assert response.headers["Transfer-Encoding"] == "chunked" diff --git a/tests/providers/test_registry.py b/tests/providers/test_registry.py index 209a3651..8c957f13 100644 --- a/tests/providers/test_registry.py +++ b/tests/providers/test_registry.py @@ -1,10 +1,21 @@ -from typing import Any, AsyncIterator, Dict +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Dict, + Iterable, + Iterator, + Optional, + Union, +) import pytest from fastapi import FastAPI from fastapi.responses import StreamingResponse +from litellm import ChatCompletionRequest, ModelResponse from codegate.providers.base import BaseCompletionHandler, BaseProvider +from codegate.providers.normalizer import ModelInputNormalizer, ModelOutputNormalizer from codegate.providers.registry import ProviderRegistry @@ -26,7 +37,8 @@ def translate_streaming_response( def execute_completion( self, - request: Any, + request: ChatCompletionRequest, + api_key: Optional[str], stream: bool = False, ) -> Any: pass @@ -38,11 +50,45 @@ def create_streaming_response( return StreamingResponse(stream) +class MockInputNormalizer(ModelInputNormalizer): + def normalize(self, data: Dict) -> Dict: + return data + + def denormalize(self, data: Dict) -> Dict: + return data + + +class MockOutputNormalizer(ModelOutputNormalizer): + def normalize_streaming( + self, + model_reply: Union[AsyncIterable[Any], Iterable[Any]], + ) -> Union[AsyncIterator[ModelResponse], Iterator[ModelResponse]]: + pass + + def normalize(self, model_reply: Any) -> ModelResponse: + pass + + def denormalize(self, normalized_reply: ModelResponse) -> Any: + pass + + def denormalize_streaming( + self, + normalized_reply: Union[AsyncIterable[ModelResponse], Iterable[ModelResponse]], + ) -> Union[AsyncIterator[Any], Iterator[Any]]: + pass + + class MockProvider(BaseProvider): + def __init__( + self, + ): + super().__init__( + MockInputNormalizer(), MockOutputNormalizer(), MockCompletionHandler(), None + ) @property def provider_route_name(self) -> str: - return 'mock_provider' + return "mock_provider" def _setup_routes(self) -> None: @self.router.get(f"/{self.provider_route_name}/test") @@ -65,24 +111,24 @@ def registry(app): return ProviderRegistry(app) -def test_add_provider(registry, mock_completion_handler): - provider = MockProvider(mock_completion_handler) +def test_add_provider(registry): + provider = MockProvider() registry.add_provider("test", provider) assert "test" in registry.providers assert registry.providers["test"] == provider -def test_get_provider(registry, mock_completion_handler): - provider = MockProvider(mock_completion_handler) +def test_get_provider(registry): + provider = MockProvider() registry.add_provider("test", provider) assert registry.get_provider("test") == provider assert registry.get_provider("nonexistent") is None -def test_provider_routes_added(app, registry, mock_completion_handler): - provider = MockProvider(mock_completion_handler) +def test_provider_routes_added(app, registry): + provider = MockProvider() registry.add_provider("test", provider) routes = [route for route in app.routes if route.path == "/mock_provider/test"] diff --git a/tests/test_cli_prompts.py b/tests/test_cli_prompts.py index 88c743f6..2b5029a8 100644 --- a/tests/test_cli_prompts.py +++ b/tests/test_cli_prompts.py @@ -72,9 +72,7 @@ def test_serve_with_prompts(temp_prompts_file): """Test the serve command with prompts file.""" runner = CliRunner() # Use --help to avoid actually starting the server - result = runner.invoke( - cli, ["serve", "--prompts", str(temp_prompts_file), "--help"] - ) + result = runner.invoke(cli, ["serve", "--prompts", str(temp_prompts_file), "--help"]) assert result.exit_code == 0 assert "Path to YAML prompts file" in result.output diff --git a/tests/test_logging.py b/tests/test_logging.py index 97f906b8..d2160de9 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -2,7 +2,13 @@ import logging from io import StringIO -from codegate.codegate_logging import JSONFormatter, TextFormatter, setup_logging, LogFormat, LogLevel +from codegate.codegate_logging import ( + JSONFormatter, + LogFormat, + LogLevel, + TextFormatter, + setup_logging, +) def test_json_formatter(): diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 5fef36b0..e28863f0 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -147,9 +147,7 @@ def test_environment_variable_override(temp_env_prompts_file, monkeypatch): assert config.prompts.another_env == "Another environment prompt" -def test_cli_override_takes_precedence( - temp_prompts_file, temp_env_prompts_file, monkeypatch -): +def test_cli_override_takes_precedence(temp_prompts_file, temp_env_prompts_file, monkeypatch): """Test that CLI prompts override config and environment.""" # Set environment variable monkeypatch.setenv("CODEGATE_PROMPTS_FILE", str(temp_env_prompts_file))