diff --git a/Cargo.lock b/Cargo.lock index 10d9cc725..09a6a746c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4523,6 +4523,24 @@ dependencies = [ "substring", ] +[[package]] +name = "sparrow-logical" +version = "0.11.0" +dependencies = [ + "arrow-schema", + "derive_more", + "error-stack", + "hashbrown 0.14.0", + "insta", + "itertools 0.11.0", + "serde", + "serde_yaml", + "sparrow-arrow", + "sparrow-types", + "static_init", + "uuid 1.4.1", +] + [[package]] name = "sparrow-main" version = "0.11.0" @@ -4877,6 +4895,7 @@ dependencies = [ "hashbrown 0.14.0", "itertools 0.11.0", "regex", + "serde", "static_init", ] diff --git a/crates/sparrow-logical/Cargo.toml b/crates/sparrow-logical/Cargo.toml new file mode 100644 index 000000000..60ee6f093 --- /dev/null +++ b/crates/sparrow-logical/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "sparrow-logical" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +publish = false +description = """ +Logical representation of Kaskada queries. +""" + +[dependencies] +arrow-schema.workspace = true +derive_more.workspace = true +error-stack.workspace = true +hashbrown.workspace = true +itertools.workspace = true +serde.workspace = true +sparrow-arrow = { path = "../sparrow-arrow" } +sparrow-types = { path = "../sparrow-types" } +static_init.workspace = true +serde_yaml.workspace = true +uuid.workspace = true + +[dev-dependencies] +insta.workspace = true + +[lib] +doctest = false diff --git a/crates/sparrow-logical/src/error.rs b/crates/sparrow-logical/src/error.rs new file mode 100644 index 000000000..08f54f18e --- /dev/null +++ b/crates/sparrow-logical/src/error.rs @@ -0,0 +1,25 @@ +use crate::{ExprRef, Grouping}; +use arrow_schema::DataType; + +use sparrow_types::DisplayFenlType; + +#[derive(derive_more::Display, Debug)] +pub enum Error { + #[display(fmt = "internal error: {_0}")] + Internal(&'static str), + #[display(fmt = "invalid non-struct type: {}", "_0.display()")] + InvalidNonStructType(DataType), + #[display(fmt = "invalid non-string literal: {_0:?}")] + InvalidNonStringLiteral(ExprRef), + // TODO: Include nearest matches? + #[display(fmt = "invalid field name '{name}'")] + InvalidFieldName { name: String }, + #[display(fmt = "invalid types")] + InvalidTypes, + #[display(fmt = "incompatible groupings {_0:?}")] + IncompatibleGroupings(Vec), + #[display(fmt = "invalid function: '{_0}'")] + InvalidFunction(String), +} + +impl error_stack::Context for Error {} diff --git a/crates/sparrow-logical/src/expr.rs b/crates/sparrow-logical/src/expr.rs new file mode 100644 index 000000000..4a688b072 --- /dev/null +++ b/crates/sparrow-logical/src/expr.rs @@ -0,0 +1,223 @@ +use crate::{Error, Grouping}; +use arrow_schema::{DataType, TimeUnit}; +use sparrow_types::Types; +use std::borrow::Cow; +use std::sync::Arc; +use uuid::Uuid; + +/// Represents an operation applied to 0 or more arguments. +#[derive(Debug)] +pub struct Expr { + /// The instruction being applied by this expression. + pub name: Cow<'static, str>, + /// Zero or more literal-valued arguments. + pub literal_args: Vec, + /// Arguments to the expression. + pub args: Vec, + /// The type produced by the expression. + pub result_type: DataType, + /// The grouping associated with the expression. + pub grouping: Grouping, +} + +#[derive(Debug)] +pub enum Literal { + Null, + Bool(bool), + String(String), + Int64(i64), + UInt64(u64), + Float64(f64), + Timedelta { seconds: i64, nanos: i64 }, + Uuid(Uuid), +} + +impl Expr { + pub fn try_new( + name: Cow<'static, str>, + args: Vec, + ) -> error_stack::Result { + let Types { + arguments: arg_types, + result: result_type, + } = crate::typecheck::typecheck(name.as_ref(), &args)?; + + // If any of the types are different, we'll need to create new arguments. + let args = args + .into_iter() + .zip(arg_types) + .map(|(arg, arg_type)| arg.cast(arg_type)) + .collect::, _>>()?; + + let grouping = Grouping::from_args(&args)?; + + Ok(Self { + name, + literal_args: vec![], + args, + result_type, + grouping, + }) + } + + /// Create a new literal node referencing a UUID. + /// + /// This can be used for sources, UDFs, etc. + /// + /// Generally, the `name` should identify the kind of thing being referenced (source, UDF, etc.) + /// and the `uuid` should identify the specific thing being referenced. + pub fn new_uuid( + name: &'static str, + uuid: Uuid, + result_type: DataType, + grouping: Grouping, + ) -> Self { + Self { + name: Cow::Borrowed(name), + literal_args: vec![Literal::Uuid(uuid)], + args: vec![], + result_type, + grouping, + } + } + + pub fn new_literal(literal: Literal) -> Self { + let result_type = match literal { + Literal::Null => DataType::Null, + Literal::Bool(_) => DataType::Boolean, + Literal::String(_) => DataType::Utf8, + Literal::Int64(_) => DataType::Int64, + Literal::UInt64(_) => DataType::UInt64, + Literal::Float64(_) => DataType::Float64, + Literal::Timedelta { .. } => DataType::Duration(TimeUnit::Nanosecond), + Literal::Uuid(_) => DataType::FixedSizeBinary(BYTES_IN_UUID), + }; + Self { + name: Cow::Borrowed("literal"), + literal_args: vec![literal], + args: vec![], + result_type, + grouping: Grouping::Literal, + } + } + + /// Create a new cast expression to the given type. + pub fn cast(self: Arc, data_type: DataType) -> error_stack::Result, Error> { + if self.result_type == data_type { + Ok(self) + } else { + let grouping = self.grouping.clone(); + Ok(Arc::new(Expr { + name: Cow::Borrowed("cast"), + literal_args: vec![], + args: vec![self], + result_type: data_type, + grouping, + })) + } + } + + /// If this expression is a literal, return the corresponding scalar value. + pub fn literal_opt(&self) -> Option<&Literal> { + if self.name == "literal" { + debug_assert_eq!(self.literal_args.len(), 1); + Some(&self.literal_args[0]) + } else { + None + } + } + + /// If this expression is a literal string, return it. + /// + /// This returns `None` if: + /// 1. This expression is not a literal. + /// 2. This expression is not a string literal. + pub fn literal_str_opt(&self) -> Option<&str> { + self.literal_opt().and_then(|scalar| match scalar { + Literal::String(str) => Some(str.as_str()), + _ => None, + }) + } +} + +const BYTES_IN_UUID: i32 = (std::mem::size_of::() / std::mem::size_of::()) as i32; + +/// Reference counted expression. +pub type ExprRef = Arc; + +#[cfg(test)] +mod tests { + use arrow_schema::Field; + + use super::*; + + #[test] + fn test_literal() { + let literal = Expr::new_literal(Literal::String("hello".to_owned())); + insta::assert_debug_snapshot!(literal); + } + + #[test] + fn test_fieldref() { + let uuid = Uuid::from_u64_pair(42, 84); + let source = Arc::new(Expr::new_uuid( + "source", + uuid, + DataType::Struct( + vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Float64, false), + ] + .into(), + ), + Grouping::Literal, + )); + let field = Expr::try_new( + "fieldref".into(), + vec![ + source, + Arc::new(Expr::new_literal(Literal::String("a".to_owned()))), + ], + ) + .unwrap(); + insta::assert_debug_snapshot!(field); + } + + #[test] + fn test_function() { + let uuid = Uuid::from_u64_pair(42, 84); + let source = Arc::new(Expr::new_uuid( + "source", + uuid, + DataType::Struct( + vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Float64, false), + ] + .into(), + ), + Grouping::Literal, + )); + let a_i32 = Arc::new( + Expr::try_new( + "fieldref".into(), + vec![ + source, + Arc::new(Expr::new_literal(Literal::String("a".to_owned()))), + ], + ) + .unwrap(), + ); + + // i32 + f64 literal => f64 + let a_i32_plus_1 = Expr::try_new( + "add".into(), + vec![ + a_i32.clone(), + Arc::new(Expr::new_literal(Literal::Float64(1.0))), + ], + ) + .unwrap(); + insta::assert_debug_snapshot!(a_i32_plus_1); + } +} diff --git a/crates/sparrow-logical/src/functions.yml b/crates/sparrow-logical/src/functions.yml new file mode 100644 index 000000000..953608b18 --- /dev/null +++ b/crates/sparrow-logical/src/functions.yml @@ -0,0 +1,6 @@ +- name: add + signature: "(x: T, y: T) -> T" +- name: sub + signature: "(x: T, y: T) -> T" +- name: eq + signature: "(x: T, y: T) -> bool" \ No newline at end of file diff --git a/crates/sparrow-logical/src/grouping.rs b/crates/sparrow-logical/src/grouping.rs new file mode 100644 index 000000000..d8a8f9ec5 --- /dev/null +++ b/crates/sparrow-logical/src/grouping.rs @@ -0,0 +1,35 @@ +use itertools::Itertools; + +use crate::{Error, ExprRef}; + +/// A wrapper around a u32 identifying a distinct grouping. +#[derive(Debug, Eq, PartialEq, Clone, Copy, Hash, Ord, PartialOrd)] +#[repr(transparent)] +pub struct GroupId(u32); + +/// The grouping associated with an expression. +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +pub enum Grouping { + Literal, + Group(GroupId), +} + +impl Grouping { + pub fn from_args(args: &[ExprRef]) -> error_stack::Result { + let groupings = args + .iter() + .map(|arg| &arg.grouping) + .unique() + .filter(|g| **g == Grouping::Literal) + .cloned(); + + match groupings.at_most_one() { + Ok(None) => Ok(Grouping::Literal), + Ok(Some(grouping)) => Ok(grouping), + Err(groupings) => { + let groupings: Vec<_> = groupings.collect(); + error_stack::bail!(Error::IncompatibleGroupings(groupings)) + } + } + } +} diff --git a/crates/sparrow-logical/src/lib.rs b/crates/sparrow-logical/src/lib.rs new file mode 100644 index 000000000..eec2ef5cb --- /dev/null +++ b/crates/sparrow-logical/src/lib.rs @@ -0,0 +1,19 @@ +#![warn( + rust_2018_idioms, + nonstandard_style, + future_incompatible, + clippy::mod_module_files, + clippy::print_stdout, + clippy::print_stderr, + clippy::undocumented_unsafe_blocks +)] + +//! Logical execution plans for Kaskada queries. +mod error; +mod expr; +mod grouping; +mod typecheck; + +pub use error::*; +pub use expr::*; +pub use grouping::*; diff --git a/crates/sparrow-logical/src/snapshots/sparrow_logical__expr__tests__fieldref.snap b/crates/sparrow-logical/src/snapshots/sparrow_logical__expr__tests__fieldref.snap new file mode 100644 index 000000000..735a24d5d --- /dev/null +++ b/crates/sparrow-logical/src/snapshots/sparrow_logical__expr__tests__fieldref.snap @@ -0,0 +1,53 @@ +--- +source: crates/sparrow-logical/src/expr.rs +expression: field +--- +Expr { + name: "fieldref", + literal_args: [], + args: [ + Expr { + name: "source", + literal_args: [ + Uuid( + 00000000-0000-002a-0000-000000000054, + ), + ], + args: [], + result_type: Struct( + [ + Field { + name: "a", + data_type: Int64, + nullable: true, + dict_id: 0, + dict_is_ordered: false, + metadata: {}, + }, + Field { + name: "b", + data_type: Float64, + nullable: false, + dict_id: 0, + dict_is_ordered: false, + metadata: {}, + }, + ], + ), + grouping: Literal, + }, + Expr { + name: "literal", + literal_args: [ + String( + "a", + ), + ], + args: [], + result_type: Utf8, + grouping: Literal, + }, + ], + result_type: Int64, + grouping: Literal, +} diff --git a/crates/sparrow-logical/src/snapshots/sparrow_logical__expr__tests__function.snap b/crates/sparrow-logical/src/snapshots/sparrow_logical__expr__tests__function.snap new file mode 100644 index 000000000..b32019de4 --- /dev/null +++ b/crates/sparrow-logical/src/snapshots/sparrow_logical__expr__tests__function.snap @@ -0,0 +1,80 @@ +--- +source: crates/sparrow-logical/src/expr.rs +expression: a_i32_plus_1 +--- +Expr { + name: "add", + literal_args: [], + args: [ + Expr { + name: "cast", + literal_args: [], + args: [ + Expr { + name: "fieldref", + literal_args: [], + args: [ + Expr { + name: "source", + literal_args: [ + Uuid( + 00000000-0000-002a-0000-000000000054, + ), + ], + args: [], + result_type: Struct( + [ + Field { + name: "a", + data_type: Int32, + nullable: true, + dict_id: 0, + dict_is_ordered: false, + metadata: {}, + }, + Field { + name: "b", + data_type: Float64, + nullable: false, + dict_id: 0, + dict_is_ordered: false, + metadata: {}, + }, + ], + ), + grouping: Literal, + }, + Expr { + name: "literal", + literal_args: [ + String( + "a", + ), + ], + args: [], + result_type: Utf8, + grouping: Literal, + }, + ], + result_type: Int32, + grouping: Literal, + }, + ], + result_type: Float64, + grouping: Literal, + }, + Expr { + name: "literal", + literal_args: [ + Float64( + 1.0, + ), + ], + args: [], + result_type: Float64, + grouping: Literal, + }, + ], + result_type: Float64, + grouping: Literal, +} diff --git a/crates/sparrow-logical/src/snapshots/sparrow_logical__expr__tests__literal.snap b/crates/sparrow-logical/src/snapshots/sparrow_logical__expr__tests__literal.snap new file mode 100644 index 000000000..d75de463b --- /dev/null +++ b/crates/sparrow-logical/src/snapshots/sparrow_logical__expr__tests__literal.snap @@ -0,0 +1,15 @@ +--- +source: crates/sparrow-logical/src/expr.rs +expression: literal +--- +Expr { + name: "literal", + literal_args: [ + String( + "hello", + ), + ], + args: [], + result_type: Utf8, + grouping: Literal, +} diff --git a/crates/sparrow-logical/src/typecheck.rs b/crates/sparrow-logical/src/typecheck.rs new file mode 100644 index 000000000..bfaf1d119 --- /dev/null +++ b/crates/sparrow-logical/src/typecheck.rs @@ -0,0 +1,71 @@ +use crate::{Error, ExprRef}; +use arrow_schema::DataType; +use error_stack::ResultExt; +use hashbrown::HashMap; +use sparrow_types::{Signature, Types}; + +/// Type-check the given function name. +pub(crate) fn typecheck(name: &str, args: &[ExprRef]) -> error_stack::Result { + match name { + "fieldref" => { + error_stack::ensure!( + args.len() == 2, + Error::Internal("invalid arguments for fieldref") + ); + + let DataType::Struct(fields) = &args[0].result_type else { + error_stack::bail!(Error::InvalidNonStructType(args[0].result_type.clone())) + }; + let Some(name) = args[1].literal_str_opt() else { + error_stack::bail!(Error::InvalidNonStringLiteral(args[0].clone())) + }; + + let Some((_, field)) = fields.find(name) else { + error_stack::bail!(Error::InvalidFieldName { + name: name.to_owned(), + }) + }; + let types = Types { + arguments: vec![args[0].result_type.clone(), DataType::Utf8], + result: field.data_type().clone(), + }; + Ok(types) + } + _ => { + let signature = get_signature(name)?; + // TODO: Ideally, instantiate would accept references so we didn't need to clone. + let arguments = args.iter().map(|arg| arg.result_type.clone()).collect(); + let result_type = signature + .instantiate(arguments) + .change_context(Error::InvalidTypes)?; + Ok(result_type) + } + } +} + +#[derive(serde::Deserialize, Debug)] +struct Function { + name: &'static str, + signature: Signature, +} + +#[static_init::dynamic] +static FUNCTION_SIGNATURES: HashMap<&'static str, &'static Function> = { + let functions: Vec = + serde_yaml::from_str(include_str!("functions.yml")).expect("failed to parse functions.yml"); + + functions + .into_iter() + .map(|function| { + let function: &'static Function = Box::leak(Box::new(function)); + (function.name, function) + }) + .collect() +}; + +fn get_signature(name: &str) -> error_stack::Result<&'static Signature, Error> { + let function = FUNCTION_SIGNATURES + .get(name) + .ok_or_else(|| Error::InvalidFunction(name.to_owned()))?; + Ok(&function.signature) +} diff --git a/crates/sparrow-types/Cargo.toml b/crates/sparrow-types/Cargo.toml index c1b5440dc..43124bd1a 100644 --- a/crates/sparrow-types/Cargo.toml +++ b/crates/sparrow-types/Cargo.toml @@ -18,6 +18,7 @@ error-stack.workspace = true hashbrown.workspace = true itertools.workspace = true regex.workspace = true +serde.workspace = true static_init.workspace = true [dev-dependencies] diff --git a/crates/sparrow-types/src/signature.rs b/crates/sparrow-types/src/signature.rs index ba5a1cfeb..60beac932 100644 --- a/crates/sparrow-types/src/signature.rs +++ b/crates/sparrow-types/src/signature.rs @@ -18,6 +18,32 @@ pub struct Signature { pub(super) variadic: bool, } +impl<'a> serde::de::Deserialize<'a> for Signature { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'a>, + { + deserializer.deserialize_str(SignatureDeserializer) + } +} + +struct SignatureDeserializer; + +impl<'a> serde::de::Visitor<'a> for SignatureDeserializer { + type Value = Signature; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("a function signature") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + Signature::parse(v).map_err(|e| E::custom(format!("{:?}", e))) + } +} + /// A type-parameter within a signature. #[derive(Debug, PartialEq, Eq)] pub(super) struct TypeParameter {