Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,12 @@ async def get_workspace_messages(workspace_name: str) -> List[v1_models.Conversa
raise HTTPException(status_code=500, detail="Internal server error")

try:
prompts_outputs = await dbreader.get_prompts_with_output(ws.id)
conversations, _ = await v1_processing.parse_messages_in_conversations(prompts_outputs)
prompts_with_output_alerts_usage = (
await dbreader.get_prompts_with_output_alerts_usage_by_workspace_id(ws.id)
)
conversations, _ = await v1_processing.parse_messages_in_conversations(
prompts_with_output_alerts_usage
)
return conversations
except Exception:
logger.exception("Error while getting messages")
Expand Down
28 changes: 28 additions & 0 deletions src/codegate/api/v1_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,32 @@ def add_model_token_usage(self, model_token_usage: TokenUsageByModel) -> None:
self.token_usage += model_token_usage.token_usage


class Alert(pydantic.BaseModel):
"""
Represents an alert.
"""

@staticmethod
def from_db_model(db_model: db_models.Alert) -> "Alert":
return Alert(
id=db_model.id,
prompt_id=db_model.prompt_id,
code_snippet=db_model.code_snippet,
trigger_string=db_model.trigger_string,
trigger_type=db_model.trigger_type,
trigger_category=db_model.trigger_category,
timestamp=db_model.timestamp,
)

id: str
prompt_id: str
code_snippet: Optional[CodeSnippet]
trigger_string: Optional[Union[str, dict]]
trigger_type: str
trigger_category: Optional[str]
timestamp: datetime.datetime


class PartialQuestionAnswer(pydantic.BaseModel):
"""
Represents a partial conversation.
Expand All @@ -155,6 +181,7 @@ class PartialQuestionAnswer(pydantic.BaseModel):
partial_questions: PartialQuestions
answer: Optional[ChatMessage]
model_token_usage: TokenUsageByModel
alerts: List[Alert] = []


class Conversation(pydantic.BaseModel):
Expand All @@ -168,6 +195,7 @@ class Conversation(pydantic.BaseModel):
chat_id: str
conversation_timestamp: datetime.datetime
token_usage_agg: Optional[TokenUsageAggregate]
alerts: List[Alert] = []


class AlertConversation(pydantic.BaseModel):
Expand Down
9 changes: 9 additions & 0 deletions src/codegate/api/v1_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import requests
import structlog

from codegate.api import v1_models
from codegate.api.v1_models import (
AlertConversation,
ChatMessage,
Expand Down Expand Up @@ -200,10 +201,15 @@ async def _get_partial_question_answer(
model=model, token_usage=token_usage, provider_type=provider
)

alerts: List[v1_models.Alert] = [
v1_models.Alert.from_db_model(db_alert) for db_alert in row.alerts
]

return PartialQuestionAnswer(
partial_questions=request_message,
answer=output_message,
model_token_usage=model_token_usage,
alerts=alerts,
)


Expand Down Expand Up @@ -367,6 +373,7 @@ async def match_conversations(
for group in grouped_partial_questions:
questions_answers: List[QuestionAnswer] = []
token_usage_agg = TokenUsageAggregate(tokens_by_model={}, token_usage=TokenUsage())
alerts: List[v1_models.Alert] = []
first_partial_qa = None
for partial_question in sorted(group, key=lambda x: x.timestamp):
# Partial questions don't contain the answer, so we need to find the corresponding
Expand All @@ -385,6 +392,7 @@ async def match_conversations(
qa = _get_question_answer_from_partial(selected_partial_qa)
qa.question.message = parse_question_answer(qa.question.message)
questions_answers.append(qa)
alerts.extend(selected_partial_qa.alerts)
token_usage_agg.add_model_token_usage(selected_partial_qa.model_token_usage)

# only add conversation if we have some answers
Expand All @@ -398,6 +406,7 @@ async def match_conversations(
chat_id=first_partial_qa.partial_questions.message_id,
conversation_timestamp=first_partial_qa.partial_questions.timestamp,
token_usage_agg=token_usage_agg,
alerts=alerts,
)
for qa in questions_answers:
map_q_id_to_conversation[qa.question.message_id] = conversation
Expand Down
63 changes: 60 additions & 3 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import uuid
from pathlib import Path
from typing import List, Optional, Type
from typing import Dict, List, Optional, Type

import structlog
from alembic import command as alembic_command
Expand All @@ -19,6 +19,7 @@
Alert,
GetPromptWithOutputsRow,
GetWorkspaceByNameConditions,
IntermediatePromptWithOutputUsageAlerts,
MuxRule,
Output,
Prompt,
Expand Down Expand Up @@ -89,7 +90,6 @@ def does_db_exist(self):


class DbRecorder(DbCodeGate):

def __init__(self, sqlite_path: Optional[str] = None):
super().__init__(sqlite_path)

Expand Down Expand Up @@ -517,7 +517,6 @@ async def add_mux(self, mux: MuxRule) -> MuxRule:


class DbReader(DbCodeGate):

def __init__(self, sqlite_path: Optional[str] = None):
super().__init__(sqlite_path)

Expand Down Expand Up @@ -586,6 +585,64 @@ async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithO
)
return prompts

async def get_prompts_with_output_alerts_usage_by_workspace_id(
self, workspace_id: str
) -> List[GetPromptWithOutputsRow]:
"""
Get all prompts with their outputs, alerts and token usage by workspace_id.
"""

sql = text(
"""
SELECT
p.id as prompt_id, p.timestamp as prompt_timestamp, p.provider, p.request, p.type,
o.id as output_id, o.output, o.timestamp as output_timestamp, o.input_tokens, o.output_tokens, o.input_cost, o.output_cost,
a.id as alert_id, a.code_snippet, a.trigger_string, a.trigger_type, a.trigger_category, a.timestamp as alert_timestamp
FROM prompts p
LEFT JOIN outputs o ON p.id = o.prompt_id
LEFT JOIN alerts a ON p.id = a.prompt_id
WHERE p.workspace_id = :workspace_id
ORDER BY o.timestamp DESC, a.timestamp DESC
""" # noqa: E501
)
conditions = {"workspace_id": workspace_id}
rows = await self._exec_select_conditions_to_pydantic(
IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True
)

prompts_dict: Dict[str, GetPromptWithOutputsRow] = {}
for row in rows:
prompt_id = row.prompt_id
if prompt_id not in prompts_dict:
prompts_dict[prompt_id] = GetPromptWithOutputsRow(
id=row.prompt_id,
timestamp=row.prompt_timestamp,
provider=row.provider,
request=row.request,
type=row.type,
output_id=row.output_id,
output=row.output,
output_timestamp=row.output_timestamp,
input_tokens=row.input_tokens,
output_tokens=row.output_tokens,
input_cost=row.input_cost,
output_cost=row.output_cost,
alerts=[],
)
if row.alert_id:
alert = Alert(
id=row.alert_id,
prompt_id=row.prompt_id,
code_snippet=row.code_snippet,
trigger_string=row.trigger_string,
trigger_type=row.trigger_type,
trigger_category=row.trigger_category,
timestamp=row.alert_timestamp,
)
prompts_dict[prompt_id].alerts.append(alert)

return list(prompts_dict.values())

async def get_alerts_by_workspace(
self, workspace_id: str, trigger_category: Optional[str] = None
) -> List[Alert]:
Expand Down
29 changes: 28 additions & 1 deletion src/codegate/db/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
from enum import Enum
from typing import Annotated, Any, Dict, Optional
from typing import Annotated, Any, Dict, List, Optional

from pydantic import BaseModel, StringConstraints

Expand Down Expand Up @@ -131,6 +131,32 @@ class ProviderType(str, Enum):
openrouter = "openai"


class IntermediatePromptWithOutputUsageAlerts(BaseModel):
"""
An intermediate model to represent the result of a query
for a prompt and related outputs, usage stats & alerts.
"""

prompt_id: Any
prompt_timestamp: Any
provider: Optional[Any]
request: Any
type: Any
output_id: Optional[Any]
output: Optional[Any]
output_timestamp: Optional[Any]
input_tokens: Optional[int]
output_tokens: Optional[int]
input_cost: Optional[float]
output_cost: Optional[float]
alert_id: Optional[Any]
code_snippet: Optional[Any]
trigger_string: Optional[Any]
trigger_type: Optional[Any]
trigger_category: Optional[Any]
alert_timestamp: Optional[Any]


class GetPromptWithOutputsRow(BaseModel):
id: Any
timestamp: Any
Expand All @@ -144,6 +170,7 @@ class GetPromptWithOutputsRow(BaseModel):
output_tokens: Optional[int]
input_cost: Optional[float]
output_cost: Optional[float]
alerts: List[Alert] = []


class WorkspaceWithSessionInfo(BaseModel):
Expand Down