Skip to content

Commit 0e9a047

Browse files
authored
Merge pull request #10683 from leiysky/project-set
refactor(planner): Refactor set returning function
2 parents d80b860 + 8af954e commit 0e9a047

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+1324
-573
lines changed

src/query/expression/src/type_check.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ pub fn check_cast<Index: ColumnIndex>(
211211
}
212212
}
213213

214-
fn wrap_nullable_for_try_cast(span: Span, ty: &DataType) -> Result<DataType> {
214+
pub fn wrap_nullable_for_try_cast(span: Span, ty: &DataType) -> Result<DataType> {
215215
match ty {
216216
DataType::Null => Err(ErrorCode::from_string_no_backtrace(
217217
"TRY_CAST() to NULL is not supported".to_string(),

src/query/functions/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use scalars::BUILTIN_FUNCTIONS;
2222

2323
pub mod aggregates;
2424
pub mod scalars;
25+
pub mod srfs;
2526

2627
pub fn is_builtin_function(name: &str) -> bool {
2728
BUILTIN_FUNCTIONS.contains(name) || AggregateFunctionFactory::instance().contains(name)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// Copyright 2022 Datafuse Labs.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use common_exception::ErrorCode;
16+
use common_exception::Result;
17+
use common_expression::type_check::check_cast;
18+
use common_expression::type_check::try_unify_signature;
19+
use common_expression::types::DataType;
20+
use common_expression::ColumnIndex;
21+
use common_expression::Expr;
22+
23+
use crate::scalars::BUILTIN_FUNCTIONS;
24+
25+
#[allow(clippy::type_complexity)]
26+
pub fn try_check_srf<Index: ColumnIndex>(
27+
expected_arg_types: &[DataType],
28+
expected_return_types: &[DataType],
29+
args: &[Expr<Index>],
30+
) -> Result<(Vec<Expr<Index>>, Vec<DataType>, Vec<DataType>)> {
31+
let subst = try_unify_signature(
32+
args.iter().map(Expr::data_type),
33+
expected_arg_types.iter(),
34+
&[],
35+
)?;
36+
37+
let checked_args = args
38+
.iter()
39+
.zip(expected_arg_types.iter())
40+
.map(|(arg, sig_type)| {
41+
let sig_type = subst.apply(sig_type)?;
42+
let is_try = BUILTIN_FUNCTIONS.is_auto_try_cast_rule(arg.data_type(), &sig_type);
43+
check_cast(
44+
arg.span(),
45+
is_try,
46+
arg.clone(),
47+
&sig_type,
48+
&BUILTIN_FUNCTIONS,
49+
)
50+
})
51+
.collect::<Result<Vec<_>>>()?;
52+
53+
let generics = subst
54+
.0
55+
.keys()
56+
.cloned()
57+
.max()
58+
.map(|max_generic_idx| {
59+
(0..max_generic_idx + 1)
60+
.map(|idx| match subst.0.get(&idx) {
61+
Some(ty) => Ok(ty.clone()),
62+
None => Err(ErrorCode::from_string_no_backtrace(format!(
63+
"unable to resolve generic T{idx}"
64+
))),
65+
})
66+
.collect::<Result<Vec<_>>>()
67+
})
68+
.unwrap_or_else(|| Ok(vec![]))?;
69+
70+
// TODO: we only support one return type for now, so we can reuse
71+
// the current implementation of `Substitution` to get the return type.
72+
let return_type = subst.apply(&expected_return_types[0])?;
73+
74+
Ok((checked_args, vec![return_type], generics))
75+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Copyright 2022 Datafuse Labs.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use common_exception::Result;
16+
use common_expression::types::AnyType;
17+
use common_expression::DataBlock;
18+
use common_expression::Evaluator;
19+
use common_expression::FunctionContext;
20+
use common_expression::FunctionRegistry;
21+
use common_expression::Value;
22+
23+
use crate::srfs::SrfExpr;
24+
25+
pub struct SrfEvaluator<'a> {
26+
input_columns: &'a DataBlock,
27+
func_ctx: FunctionContext,
28+
fn_registry: &'a FunctionRegistry,
29+
}
30+
31+
impl<'a> SrfEvaluator<'a> {
32+
pub fn new(
33+
input_columns: &'a DataBlock,
34+
func_ctx: FunctionContext,
35+
fn_registry: &'a FunctionRegistry,
36+
) -> Self {
37+
SrfEvaluator {
38+
input_columns,
39+
func_ctx,
40+
fn_registry,
41+
}
42+
}
43+
44+
/// Evaluate an SRF. This will essentially returns a `DataBlock` for each row in the input `DataBlock`.
45+
pub fn run(&self, srf_expr: &SrfExpr) -> Result<Vec<(Vec<Value<AnyType>>, usize)>> {
46+
let evaluator = Evaluator::new(self.input_columns, self.func_ctx, self.fn_registry);
47+
let arg_results = srf_expr
48+
.args
49+
.iter()
50+
.map(|arg| evaluator.run(arg))
51+
.collect::<Result<Vec<_>>>()?;
52+
let arg_result_refs = arg_results
53+
.iter()
54+
.map(|arg| arg.as_ref())
55+
.collect::<Vec<_>>();
56+
57+
let results = (*srf_expr.srf.eval)(&arg_result_refs, self.input_columns.num_rows());
58+
59+
Ok(results)
60+
}
61+
}
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
// Copyright 2022 Datafuse Labs.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
mod check;
16+
mod evaluator;
17+
mod register;
18+
19+
use std::collections::HashMap;
20+
use std::fmt::Debug;
21+
use std::fmt::Formatter;
22+
use std::sync::Arc;
23+
24+
use common_exception::ErrorCode;
25+
use common_exception::Result;
26+
use common_expression::types::AnyType;
27+
use common_expression::types::DataType;
28+
use common_expression::ColumnIndex;
29+
use common_expression::Expr;
30+
use common_expression::RemoteExpr;
31+
use common_expression::Value;
32+
use common_expression::ValueRef;
33+
use ctor::ctor;
34+
pub use evaluator::SrfEvaluator;
35+
use serde::Deserialize;
36+
use serde::Serialize;
37+
38+
use crate::scalars::BUILTIN_FUNCTIONS;
39+
use crate::srfs::check::try_check_srf;
40+
use crate::srfs::register::builtin_set_returning_functions;
41+
42+
#[ctor]
43+
pub static BUILTIN_SET_RETURNING_FUNCTIONS: SetReturningFunctionRegistry =
44+
builtin_set_returning_functions();
45+
46+
pub type SetReturningFunctionID = (String, usize);
47+
48+
#[derive(Clone)]
49+
pub struct SrfExpr<Index: ColumnIndex = usize> {
50+
pub id: SetReturningFunctionID,
51+
pub srf: Arc<SetReturningFunction>,
52+
pub args: Vec<Expr<Index>>,
53+
pub generic_types: Vec<DataType>,
54+
pub return_types: Vec<DataType>,
55+
}
56+
57+
impl SrfExpr {
58+
pub fn sql_display(&self) -> String {
59+
let mut args = vec![];
60+
for arg in &self.args {
61+
args.push(arg.sql_display());
62+
}
63+
format!("{}({})", self.id.0, args.join(", "))
64+
}
65+
}
66+
67+
impl Debug for SrfExpr {
68+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
69+
f.debug_struct("SrfExpr")
70+
.field("id", &self.id)
71+
.field("args", &self.args)
72+
.field("generic_types", &self.generic_types)
73+
.field("return_types", &self.return_types)
74+
.finish()
75+
}
76+
}
77+
78+
impl SrfExpr {
79+
pub fn into_remote_srf_expr(self) -> RemoteSrfExpr {
80+
let mut args = vec![];
81+
for arg in self.args {
82+
args.push(arg.as_remote_expr());
83+
}
84+
RemoteSrfExpr {
85+
id: self.id,
86+
args,
87+
generic_types: self.generic_types,
88+
return_types: self.return_types,
89+
}
90+
}
91+
}
92+
93+
#[derive(Clone, Debug, Serialize, Deserialize)]
94+
pub struct RemoteSrfExpr<Index: ColumnIndex = usize> {
95+
pub id: SetReturningFunctionID,
96+
pub args: Vec<RemoteExpr<Index>>,
97+
pub generic_types: Vec<DataType>,
98+
pub return_types: Vec<DataType>,
99+
}
100+
101+
impl RemoteSrfExpr {
102+
pub fn into_srf_expr(self) -> SrfExpr {
103+
let mut args = vec![];
104+
for arg in self.args {
105+
args.push(arg.as_expr(&BUILTIN_FUNCTIONS));
106+
}
107+
SrfExpr {
108+
id: self.id.clone(),
109+
srf: BUILTIN_SET_RETURNING_FUNCTIONS.get(&self.id).unwrap(),
110+
args,
111+
generic_types: self.generic_types,
112+
return_types: self.return_types,
113+
}
114+
}
115+
}
116+
117+
#[allow(clippy::type_complexity)]
118+
pub struct SetReturningFunction {
119+
pub name: String,
120+
pub arg_types: Vec<DataType>,
121+
pub return_types: Vec<DataType>,
122+
123+
pub eval:
124+
Box<dyn Fn(&[ValueRef<AnyType>], usize) -> Vec<(Vec<Value<AnyType>>, usize)> + Send + Sync>,
125+
}
126+
127+
#[derive(Default)]
128+
pub struct SetReturningFunctionRegistry {
129+
pub srfs: HashMap<String, Vec<Arc<SetReturningFunction>>>,
130+
}
131+
132+
impl SetReturningFunctionRegistry {
133+
pub fn register<F>(
134+
&mut self,
135+
name: &str,
136+
arg_types: &[DataType],
137+
return_types: &[DataType],
138+
eval: F,
139+
) where
140+
F: Fn(&[ValueRef<AnyType>], usize) -> Vec<(Vec<Value<AnyType>>, usize)>
141+
+ Send
142+
+ Sync
143+
+ 'static,
144+
{
145+
if return_types.iter().any(|t| !t.is_nullable()) {
146+
panic!("return type of srf must be nullable");
147+
}
148+
149+
let name = name.to_string();
150+
let srf = Arc::new(SetReturningFunction {
151+
name: name.to_string(),
152+
arg_types: arg_types.to_vec(),
153+
return_types: return_types.to_vec(),
154+
eval: Box::new(eval),
155+
});
156+
self.srfs.entry(name).or_insert_with(Vec::new).push(srf);
157+
}
158+
159+
pub fn contains(&self, name: &str) -> bool {
160+
self.srfs.contains_key(name)
161+
}
162+
163+
pub fn search_candidates<Index: ColumnIndex>(
164+
&self,
165+
name: &str,
166+
args: &[Expr<Index>],
167+
) -> Vec<(Arc<SetReturningFunction>, usize)> {
168+
let mut candidates = vec![];
169+
self.srfs
170+
.get(name)
171+
.unwrap_or(&vec![])
172+
.iter()
173+
.enumerate()
174+
.for_each(|(index, srf)| {
175+
if srf.name == name && srf.arg_types.len() == args.len() {
176+
candidates.push((srf.clone(), index));
177+
}
178+
});
179+
180+
candidates
181+
}
182+
183+
pub fn get(&self, id: &SetReturningFunctionID) -> Option<Arc<SetReturningFunction>> {
184+
self.srfs
185+
.get(&id.0)
186+
.and_then(|srfs| srfs.get(id.1))
187+
.cloned()
188+
}
189+
}
190+
191+
pub fn check_srf<Index: ColumnIndex>(
192+
name: &str,
193+
args: &[Expr<Index>],
194+
registry: &SetReturningFunctionRegistry,
195+
) -> Result<SrfExpr<Index>> {
196+
let candidates = registry.search_candidates(name, args);
197+
198+
if candidates.is_empty() && !registry.contains(name) {
199+
return Err(ErrorCode::UnknownFunction(format!(
200+
"srf `{name}` does not exist"
201+
)));
202+
}
203+
204+
let mut fail_reasons = Vec::with_capacity(candidates.len());
205+
for (srf, index) in candidates.iter() {
206+
match try_check_srf(&srf.arg_types, &srf.return_types, args) {
207+
Ok((checked_args, return_types, generic_types)) => {
208+
return Ok(SrfExpr {
209+
id: (srf.name.clone(), *index),
210+
srf: srf.clone(),
211+
args: checked_args,
212+
generic_types,
213+
return_types,
214+
});
215+
}
216+
Err(err) => fail_reasons.push(err),
217+
}
218+
}
219+
220+
Err(ErrorCode::SemanticError(format!(
221+
"no overload satisfies {name}({}): {}",
222+
args.iter()
223+
.map(|x| x.data_type().to_string())
224+
.collect::<Vec<_>>()
225+
.join(", "),
226+
fail_reasons
227+
.into_iter()
228+
.map(|e| e.message())
229+
.collect::<Vec<_>>()
230+
.join(", ")
231+
)))
232+
}

0 commit comments

Comments
 (0)