|
1 | 1 | import asyncio |
| 2 | +import copy |
2 | 3 | import datetime |
| 4 | +import json |
3 | 5 | import uuid |
4 | 6 | from pathlib import Path |
5 | | -from typing import Optional |
| 7 | +from typing import AsyncGenerator, AsyncIterator, Optional |
6 | 8 |
|
7 | 9 | import structlog |
8 | | -from litellm import ChatCompletionRequest |
| 10 | +from litellm import ChatCompletionRequest, ModelResponse |
| 11 | +from pydantic import BaseModel |
9 | 12 | from sqlalchemy import create_engine, text |
10 | 13 | from sqlalchemy.ext.asyncio import create_async_engine |
11 | 14 |
|
12 | | -from codegate.db.models import Prompt |
| 15 | +from codegate.db.models import Output, Prompt |
13 | 16 |
|
14 | 17 | logger = structlog.get_logger("codegate") |
15 | 18 |
|
@@ -68,64 +71,112 @@ async def init_db(self): |
68 | 71 | finally: |
69 | 72 | await self._async_db_engine.dispose() |
70 | 73 |
|
| 74 | + async def _insert_pydantic_model( |
| 75 | + self, model: BaseModel, sql_insert: text |
| 76 | + ) -> Optional[BaseModel]: |
| 77 | + # There are create method in queries.py automatically generated by sqlc |
| 78 | + # However, the methods are buggy for Pydancti and don't work as expected. |
| 79 | + # Manually writing the SQL query to insert Pydantic models. |
| 80 | + async with self._async_db_engine.begin() as conn: |
| 81 | + result = await conn.execute(sql_insert, model.model_dump()) |
| 82 | + row = result.first() |
| 83 | + if row is None: |
| 84 | + return None |
| 85 | + |
| 86 | + # Get the class of the Pydantic object to create a new object |
| 87 | + model_class = model.__class__ |
| 88 | + return model_class(**row._asdict()) |
| 89 | + |
71 | 90 | async def record_request( |
72 | 91 | self, normalized_request: ChatCompletionRequest, is_fim_request: bool, provider_str: str |
73 | 92 | ) -> Optional[Prompt]: |
74 | | - # Extract system prompt and user prompt from the messages |
75 | | - messages = normalized_request.get("messages", []) |
76 | | - system_prompt = [] |
77 | | - user_prompt = [] |
78 | | - |
79 | | - for msg in messages: |
80 | | - if msg.get("role") == "system": |
81 | | - system_prompt.append(msg.get("content")) |
82 | | - elif msg.get("role") == "user": |
83 | | - user_prompt.append(msg.get("content")) |
84 | | - |
85 | | - # If no user prompt found in messages, try to get from the prompt field |
86 | | - # (for non-chat completions) |
87 | | - if not user_prompt: |
88 | | - prompt = normalized_request.get("prompt") |
89 | | - if prompt: |
90 | | - user_prompt.append(prompt) |
91 | | - |
92 | | - if not user_prompt: |
93 | | - logger.warning("No user prompt found in request.") |
94 | | - return None |
| 93 | + request_str = None |
| 94 | + if isinstance(normalized_request, BaseModel): |
| 95 | + request_str = normalized_request.model_dump_json(exclude_none=True, exclude_unset=True) |
| 96 | + else: |
| 97 | + try: |
| 98 | + request_str = json.dumps(normalized_request) |
| 99 | + except Exception as e: |
| 100 | + logger.error(f"Failed to serialize output: {normalized_request}", error=str(e)) |
| 101 | + |
| 102 | + if request_str is None: |
| 103 | + logger.warning("No request found to record.") |
| 104 | + return |
95 | 105 |
|
96 | 106 | # Create a new prompt record |
97 | 107 | prompt_params = Prompt( |
98 | 108 | id=str(uuid.uuid4()), # Generate a new UUID for the prompt |
99 | 109 | timestamp=datetime.datetime.now(datetime.timezone.utc), |
100 | 110 | provider=provider_str, |
101 | 111 | type="fim" if is_fim_request else "chat", |
102 | | - user_prompt="<|>".join(user_prompt), |
103 | | - system_prompt="<|>".join(system_prompt), |
| 112 | + request=request_str, |
104 | 113 | ) |
105 | | - # There is a `create_prompt` method in queries.py automatically generated by sqlc |
106 | | - # However, the method is is buggy and doesn't work as expected. |
107 | | - # Manually writing the SQL query to insert the prompt record. |
108 | | - async with self._async_db_engine.begin() as conn: |
109 | | - sql = text( |
| 114 | + sql = text( |
| 115 | + """ |
| 116 | + INSERT INTO prompts (id, timestamp, provider, request, type) |
| 117 | + VALUES (:id, :timestamp, :provider, :request, :type) |
| 118 | + RETURNING * |
110 | 119 | """ |
111 | | - INSERT INTO prompts (id, timestamp, provider, system_prompt, user_prompt, type) |
112 | | - VALUES (:id, :timestamp, :provider, :system_prompt, :user_prompt, :type) |
| 120 | + ) |
| 121 | + return await self._insert_pydantic_model(prompt_params, sql) |
| 122 | + |
| 123 | + async def _record_output(self, prompt: Prompt, output_str: str) -> Optional[Output]: |
| 124 | + output_params = Output( |
| 125 | + id=str(uuid.uuid4()), |
| 126 | + prompt_id=prompt.id, |
| 127 | + timestamp=datetime.datetime.now(datetime.timezone.utc), |
| 128 | + output=output_str, |
| 129 | + ) |
| 130 | + sql = text( |
| 131 | + """ |
| 132 | + INSERT INTO outputs (id, prompt_id, timestamp, output) |
| 133 | + VALUES (:id, :prompt_id, :timestamp, :output) |
113 | 134 | RETURNING * |
114 | 135 | """ |
115 | | - ) |
116 | | - result = await conn.execute(sql, prompt_params.model_dump()) |
117 | | - row = result.first() |
118 | | - if row is None: |
119 | | - return None |
| 136 | + ) |
| 137 | + return await self._insert_pydantic_model(output_params, sql) |
| 138 | + |
| 139 | + async def record_output_stream( |
| 140 | + self, prompt: Prompt, model_response: AsyncIterator |
| 141 | + ) -> AsyncGenerator: |
| 142 | + output_chunks = [] |
| 143 | + async for chunk in model_response: |
| 144 | + if isinstance(chunk, BaseModel): |
| 145 | + chunk_to_record = chunk.model_dump(exclude_none=True, exclude_unset=True) |
| 146 | + output_chunks.append(chunk_to_record) |
| 147 | + elif isinstance(chunk, dict): |
| 148 | + output_chunks.append(copy.deepcopy(chunk)) |
| 149 | + else: |
| 150 | + output_chunks.append({"chunk": str(chunk)}) |
| 151 | + yield chunk |
| 152 | + |
| 153 | + if output_chunks: |
| 154 | + # Record the output chunks |
| 155 | + output_str = json.dumps(output_chunks) |
| 156 | + logger.info(f"Recorded chunks: {output_chunks}. Str: {output_str}") |
| 157 | + await self._record_output(prompt, output_str) |
| 158 | + |
| 159 | + async def record_output_non_stream( |
| 160 | + self, prompt: Optional[Prompt], model_response: ModelResponse |
| 161 | + ) -> Optional[Output]: |
| 162 | + if prompt is None: |
| 163 | + logger.warning("No prompt found to record output.") |
| 164 | + return |
| 165 | + |
| 166 | + output_str = None |
| 167 | + if isinstance(model_response, BaseModel): |
| 168 | + output_str = model_response.model_dump_json(exclude_none=True, exclude_unset=True) |
| 169 | + else: |
| 170 | + try: |
| 171 | + output_str = json.dumps(model_response) |
| 172 | + except Exception as e: |
| 173 | + logger.error(f"Failed to serialize output: {model_response}", error=str(e)) |
| 174 | + |
| 175 | + if output_str is None: |
| 176 | + logger.warning("No output found to record.") |
| 177 | + return |
120 | 178 |
|
121 | | - return Prompt( |
122 | | - id=row.id, |
123 | | - timestamp=row.timestamp, |
124 | | - provider=row.provider, |
125 | | - system_prompt=row.system_prompt, |
126 | | - user_prompt=row.user_prompt, |
127 | | - type=row.type, |
128 | | - ) |
| 179 | + return await self._record_output(prompt, output_str) |
129 | 180 |
|
130 | 181 |
|
131 | 182 | def init_db_sync(): |
|
0 commit comments