Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions chat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,11 @@ class Message(models.Model):
deleted = models.BooleanField(default=False)
tokens = models.IntegerField(default=0)
created_at = models.DateTimeField(auto_now_add=True)


class SQLQuery(models.Model):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also create a model for the DB connection.
That way a user can configure a new connection from the UI, set up the urls and create a schema.
Then you can save the schema in the db and use it when making the query

chat = models.ForeignKey(Chat, on_delete=models.CASCADE, related_name='sql_queries')
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to link it to the chat.

question = models.TextField()
sql_query = models.TextField()
results = models.JSONField()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we really wanna keep the results in the db... that could end up being very large.
Maybe just some meta of the query instead, result size, execution time...

created_at = models.DateTimeField(auto_now_add=True)
39 changes: 39 additions & 0 deletions chat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,45 @@
import google.generativeai as genai

import markdown
import psycopg2
from psycopg2.extras import RealDictCursor


def natural_language_to_sql(question, db_schema):
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY', "get_one_from_google")
genai.configure(api_key=GOOGLE_API_KEY)
model = genai.GenerativeModel('gemini-1.5-pro')

prompt = f"""
Given the following database schema:
{db_schema}

Convert the following natural language question into a SQL query:
"{question}"

Return only the SQL query, without any additional explanation.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See how I set the model instructions in the other code.
That should work much better for enfocing SQL only.

Prompt needs some work too.

"""

response = model.generate_content(prompt)
return response.text.strip()


def execute_sql_query(query):
# Use your database connection parameters
conn = psycopg2.connect(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be connecting to an external db, not our internal application db.
See connection string for the demo db here: https://colab.research.google.com/drive/1288cz1roJtzWLo2ETeVt5meb0XAtXZnR

dbname=os.getenv('DB_NAME'),
user=os.getenv('DB_USER'),
password=os.getenv('DB_PASSWORD'),
host=os.getenv('DB_HOST'),
port=os.getenv('DB_PORT')
)

with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(query)
results = cur.fetchall()

conn.close()
return results


def run_llm(chat_history, context):
Expand Down
5 changes: 4 additions & 1 deletion chat/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from django.db.models import Q, F


from .models import Chat, Message
from .models import Chat, Message, SQLQuery
from .utils import to_markdown, count_tokens
from django.conf import settings

Expand Down Expand Up @@ -68,4 +68,7 @@ def get_context_data(self, **kwargs):
context["size_count"] = size_count_tokens
context["size_count_percentage"] = (size_count_tokens/1000000)*100

sql_queries = SQLQuery.objects.filter(chat=context["chat"]).order_by("-created_at")
context["sql_queries"] = sql_queries

return context
23 changes: 22 additions & 1 deletion chat/viewsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from chat.models import Chat, Message
from chat.serializers import ChatSerializer
from .utils import run_llm, to_markdown
from .utils import run_llm, to_markdown, natural_language_to_sql, execute_sql_query


class ChatViewSet(viewsets.GenericViewSet):
Expand Down Expand Up @@ -61,6 +61,27 @@ def star_message(self, request, pk=None, msg_id=None):
message.save()
return Response({"message": f"Message starred: {message.starred}"})

@action(detail=False, methods=['POST'])
def sql_query(self, request):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good start, as I said above, this should be a 2 part process.

Step 1. Load in a db and create the schema for it.
Step 2. Use that schema to build the query.

Big thing here is gonna be getting a nice schema that can express the db arch well.

question = request.data.get('question')
db_schema = """
# Insert your database schema here
# For example:
# Table: sales
# Columns: id, date, amount, product_id
# Table: products
# Columns: id, name, price
"""

sql_query = natural_language_to_sql(question, db_schema)
results = execute_sql_query(sql_query)

return Response({
'question': question,
'sql_query': sql_query,
'results': results
})


def _build_summary(context):
mes_history = [{"role": "user", "parts": [context]}]
Expand Down