Skip to content

Commit 3b44e83

Browse files
authored
Merge pull request #10645 from BohuTANG/dev-refine-ai
chore(openai): add openai.rs to prepare more AI functions
2 parents 0f7e24b + 6015090 commit 3b44e83

File tree

3 files changed

+86
-30
lines changed

3 files changed

+86
-30
lines changed

src/query/service/src/table_functions/openai/gpt_to_sql.rs

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414

1515
use std::any::Any;
1616
use std::sync::Arc;
17-
use std::time::Duration;
1817

19-
use async_openai::types::CreateCompletionRequestArgs;
20-
use async_openai::Client;
2118
use chrono::NaiveDateTime;
2219
use chrono::TimeZone;
2320
use chrono::Utc;
@@ -50,6 +47,9 @@ use common_storages_fuse::TableContext;
5047
use common_storages_view::view_table::VIEW_ENGINE;
5148
use tracing::info;
5249

50+
use crate::table_functions::openai::AIModel;
51+
use crate::table_functions::openai::OpenAI;
52+
5353
pub struct GPT2SQLTable {
5454
prompt: String,
5555
api_key: String,
@@ -233,37 +233,16 @@ impl AsyncSource for GPT2SQLSource {
233233
template.push("#".to_string());
234234
template.push("SELECT".to_string());
235235

236-
let model = "code-davinci-002";
237-
let timeout = Duration::from_secs(30);
238236
let prompt = template.join("");
239-
let api_key = self.api_key.clone();
240237
info!("openai request prompt: {}", prompt);
241238

242-
// Client
243-
let http_client = reqwest::ClientBuilder::new()
244-
.user_agent("databend")
245-
.timeout(timeout)
246-
.build()
247-
.map_err(|e| ErrorCode::Internal(format!("openai http error: {:?}", e)))?;
248-
let client = Client::new()
249-
.with_api_key(api_key)
250-
.with_http_client(http_client);
251-
252-
// Request
253-
let request = CreateCompletionRequestArgs::default()
254-
.model(model)
255-
.prompt(prompt)
256-
.temperature(0.0)
257-
.max_tokens(150_u16)
258-
.top_p(1.0)
259-
.frequency_penalty(0.0)
260-
.presence_penalty(0.0)
261-
.stop(["#", ";"])
262-
.build()
263-
.map_err(|e| ErrorCode::Internal(format!("openai request error: {:?}", e)))?;
264-
265239
// Response.
266-
let response = client
240+
let api_key = self.api_key.clone();
241+
let openai = OpenAI::create(api_key, AIModel::CodeDavinci002);
242+
let request = openai.completion_request(prompt)?;
243+
244+
let response = openai
245+
.client()?
267246
.completions()
268247
.create(request)
269248
.await

src/query/service/src/table_functions/openai/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,9 @@
1313
// limitations under the License.
1414

1515
mod gpt_to_sql;
16+
#[allow(clippy::module_inception)]
17+
mod openai;
1618

1719
pub use gpt_to_sql::GPT2SQLTable;
20+
pub use openai::AIModel;
21+
pub use openai::OpenAI;
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// Copyright 2023 Datafuse Labs.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use std::time::Duration;
16+
17+
use async_openai::types::CreateCompletionRequest;
18+
use async_openai::types::CreateCompletionRequestArgs;
19+
use async_openai::Client;
20+
use common_exception::ErrorCode;
21+
use common_exception::Result;
22+
23+
pub enum AIModel {
24+
CodeDavinci002,
25+
}
26+
27+
// https://platform.openai.com/examples
28+
impl ToString for AIModel {
29+
fn to_string(&self) -> String {
30+
match self {
31+
AIModel::CodeDavinci002 => "code-davinci-002".to_string(),
32+
}
33+
}
34+
}
35+
36+
pub struct OpenAI {
37+
api_key: String,
38+
model: AIModel,
39+
}
40+
41+
impl OpenAI {
42+
pub fn create(api_key: String, model: AIModel) -> Self {
43+
OpenAI { api_key, model }
44+
}
45+
46+
pub fn client(&self) -> Result<Client> {
47+
let timeout = Duration::from_secs(30);
48+
// Client
49+
let http_client = reqwest::ClientBuilder::new()
50+
.user_agent("databend")
51+
.timeout(timeout)
52+
.build()
53+
.map_err(|e| ErrorCode::Internal(format!("openai http error: {:?}", e)))?;
54+
55+
Ok(Client::new()
56+
.with_api_key(self.api_key.clone())
57+
.with_http_client(http_client))
58+
}
59+
60+
pub fn completion_request(&self, prompt: String) -> Result<CreateCompletionRequest> {
61+
CreateCompletionRequestArgs::default()
62+
.model(self.model.to_string())
63+
.prompt(prompt)
64+
.temperature(0.0)
65+
.max_tokens(150_u16)
66+
.top_p(1.0)
67+
.frequency_penalty(0.0)
68+
.presence_penalty(0.0)
69+
.stop(["#", ";"])
70+
.build()
71+
.map_err(|e| ErrorCode::Internal(format!("openai completion request error: {:?}", e)))
72+
}
73+
}

0 commit comments

Comments
 (0)