-
Notifications
You must be signed in to change notification settings - Fork 0
(DRAFT) Text to sql #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,3 +21,11 @@ class Message(models.Model): | |
| deleted = models.BooleanField(default=False) | ||
| tokens = models.IntegerField(default=0) | ||
| created_at = models.DateTimeField(auto_now_add=True) | ||
|
|
||
|
|
||
| class SQLQuery(models.Model): | ||
| chat = models.ForeignKey(Chat, on_delete=models.CASCADE, related_name='sql_queries') | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to link it to the chat. |
||
| question = models.TextField() | ||
| sql_query = models.TextField() | ||
| results = models.JSONField() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if we really wanna keep the results in the db... that could end up being very large. |
||
| created_at = models.DateTimeField(auto_now_add=True) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,45 @@ | |
| import google.generativeai as genai | ||
|
|
||
| import markdown | ||
| import psycopg2 | ||
| from psycopg2.extras import RealDictCursor | ||
|
|
||
|
|
||
| def natural_language_to_sql(question, db_schema): | ||
| GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY', "get_one_from_google") | ||
| genai.configure(api_key=GOOGLE_API_KEY) | ||
| model = genai.GenerativeModel('gemini-1.5-pro') | ||
|
|
||
| prompt = f""" | ||
| Given the following database schema: | ||
| {db_schema} | ||
|
|
||
| Convert the following natural language question into a SQL query: | ||
| "{question}" | ||
|
|
||
| Return only the SQL query, without any additional explanation. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See how I set the model instructions in the other code. Prompt needs some work too. |
||
| """ | ||
|
|
||
| response = model.generate_content(prompt) | ||
| return response.text.strip() | ||
|
|
||
|
|
||
| def execute_sql_query(query): | ||
| # Use your database connection parameters | ||
| conn = psycopg2.connect( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be connecting to an external db, not our internal application db. |
||
| dbname=os.getenv('DB_NAME'), | ||
| user=os.getenv('DB_USER'), | ||
| password=os.getenv('DB_PASSWORD'), | ||
| host=os.getenv('DB_HOST'), | ||
| port=os.getenv('DB_PORT') | ||
| ) | ||
|
|
||
| with conn.cursor(cursor_factory=RealDictCursor) as cur: | ||
| cur.execute(query) | ||
| results = cur.fetchall() | ||
|
|
||
| conn.close() | ||
| return results | ||
|
|
||
|
|
||
| def run_llm(chat_history, context): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,7 @@ | |
|
|
||
| from chat.models import Chat, Message | ||
| from chat.serializers import ChatSerializer | ||
| from .utils import run_llm, to_markdown | ||
| from .utils import run_llm, to_markdown, natural_language_to_sql, execute_sql_query | ||
|
|
||
|
|
||
| class ChatViewSet(viewsets.GenericViewSet): | ||
|
|
@@ -61,6 +61,27 @@ def star_message(self, request, pk=None, msg_id=None): | |
| message.save() | ||
| return Response({"message": f"Message starred: {message.starred}"}) | ||
|
|
||
| @action(detail=False, methods=['POST']) | ||
| def sql_query(self, request): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good start, as I said above, this should be a 2 part process. Step 1. Load in a db and create the schema for it. Big thing here is gonna be getting a nice schema that can express the db arch well. |
||
| question = request.data.get('question') | ||
| db_schema = """ | ||
| # Insert your database schema here | ||
| # For example: | ||
| # Table: sales | ||
| # Columns: id, date, amount, product_id | ||
| # Table: products | ||
| # Columns: id, name, price | ||
| """ | ||
|
|
||
| sql_query = natural_language_to_sql(question, db_schema) | ||
| results = execute_sql_query(sql_query) | ||
|
|
||
| return Response({ | ||
| 'question': question, | ||
| 'sql_query': sql_query, | ||
| 'results': results | ||
| }) | ||
|
|
||
|
|
||
| def _build_summary(context): | ||
| mes_history = [{"role": "user", "parts": [context]}] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also create a model for the DB connection.
That way a user can configure a new connection from the UI, set up the urls and create a schema.
Then you can save the schema in the db and use it when making the query