Skip to content

Commit

Permalink
implement remote function register
Browse files Browse the repository at this point in the history
  • Loading branch information
goldmedal committed Oct 16, 2024
1 parent afead04 commit d42ba46
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 33 deletions.
2 changes: 1 addition & 1 deletion wren-modeling-rs/benchmarks/src/tpch/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl RunOpt {
let start = Instant::now();
let sql = &get_query_sql(query_id)?;
for query in sql {
transform_sql_with_ctx(&ctx, Arc::clone(&mdl), query).await?;
transform_sql_with_ctx(&ctx, Arc::clone(&mdl), &[], query).await?;
}

let elapsed = start.elapsed(); //.as_secs_f64() * 1000.0;
Expand Down
2 changes: 2 additions & 0 deletions wren-modeling-rs/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ path = "src/lib.rs"

[dependencies]
async-trait = { workspace = true }
csv = "1.3.0"
datafusion = { workspace = true, default-features = false, features = [
"nested_expressions",
"crypto_expressions",
Expand All @@ -22,6 +23,7 @@ datafusion = { workspace = true, default-features = false, features = [
"regex_expressions",
"unicode_expressions",
] }
env_logger = { workspace = true }
log = { workspace = true }
parking_lot = "0.12.3"
petgraph = "0.6.5"
Expand Down
43 changes: 43 additions & 0 deletions wren-modeling-rs/core/src/mdl/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,50 @@ use datafusion::logical_expr::{
Accumulator, AggregateUDFImpl, ColumnarValue, PartitionEvaluator, ScalarUDFImpl,
Signature, TypeSignature, Volatility, WindowUDFImpl,
};
use serde::{Deserialize, Serialize};
use std::any::Any;
use std::fmt::Display;
use std::str::FromStr;

#[derive(Serialize, Deserialize, Debug)]
pub struct RemoteFunction {
pub function_type: FunctionType,
pub name: String,
pub return_type: String,
pub description: Option<String>,
}

#[derive(Serialize, Deserialize, Debug)]
#[serde(rename_all = "lowercase")]
pub enum FunctionType {
Scalar,
Aggregate,
Window,
}

impl Display for FunctionType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let str = match self {
FunctionType::Scalar => "scalar".to_string(),
FunctionType::Aggregate => "aggregate".to_string(),
FunctionType::Window => "window".to_string(),
};
write!(f, "{}", str)
}
}

impl FromStr for FunctionType {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"scalar" => Ok(FunctionType::Scalar),
"aggregate" => Ok(FunctionType::Aggregate),
"window" => Ok(FunctionType::Window),
_ => Err(format!("Unknown function type: {}", s)),
}
}
}

/// A scalar UDF that will be bypassed when planning logical plan.
/// This is used to register the remote function to the context. The function should not be
Expand Down
136 changes: 109 additions & 27 deletions wren-modeling-rs/core/src/mdl/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
use crate::logical_plan::analyze::expand_view::ExpandWrenViewRule;
use crate::logical_plan::analyze::model_anlayze::ModelAnalyzeRule;
use crate::logical_plan::analyze::model_generation::ModelGenerationRule;
use crate::logical_plan::utils::from_qualified_name_str;
use crate::logical_plan::utils::{from_qualified_name_str, map_data_type};
use crate::mdl::context::{create_ctx_with_mdl, register_table_with_mdl, WrenDataSource};
use crate::mdl::function::{
ByPassAggregateUDF, ByPassScalarUDF, ByPassWindowFunction, FunctionType,
RemoteFunction,
};
use crate::mdl::manifest::{Column, Manifest, Model, View};
use datafusion::datasource::TableProvider;
use datafusion::error::Result;
use datafusion::execution::context::SessionState;
use datafusion::logical_expr::sqlparser::keywords::ALL_KEYWORDS;
use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
use datafusion::prelude::SessionContext;
use datafusion::sql::unparser::dialect::{Dialect, IntervalStyle};
use datafusion::sql::unparser::Unparser;
Expand Down Expand Up @@ -227,24 +232,34 @@ impl WrenMDL {
}

/// Transform the SQL based on the MDL
pub fn transform_sql(analyzed_mdl: Arc<AnalyzedWrenMDL>, sql: &str) -> Result<String> {
pub fn transform_sql(
analyzed_mdl: Arc<AnalyzedWrenMDL>,
remote_functions: &[RemoteFunction],
sql: &str,
) -> Result<String> {
let runtime = tokio::runtime::Runtime::new().unwrap();
runtime.block_on(transform_sql_with_ctx(
&SessionContext::new(),
analyzed_mdl,
remote_functions,
sql,
))
}

/// Transform the SQL based on the MDL with the SessionContext
/// Wren engine will normalize the SQL to the lower case to solve the case sensitive
/// Wren engine will normalize the SQL to the lower case to solve the case-sensitive
/// issue for the Wren view
pub async fn transform_sql_with_ctx(
ctx: &SessionContext,
analyzed_mdl: Arc<AnalyzedWrenMDL>,
remote_functions: &[RemoteFunction],
sql: &str,
) -> Result<String> {
info!("wren-core received SQL: {}", sql);
remote_functions.iter().for_each(|remote_function| {
debug!("Registering remote function: {:?}", remote_function);
register_remote_function(ctx, remote_function);
});
let ctx = create_ctx_with_mdl(ctx, Arc::clone(&analyzed_mdl), false).await?;
let plan = ctx.state().create_logical_plan(sql).await?;
debug!("wren-core original plan:\n {plan}");
Expand All @@ -266,6 +281,29 @@ pub async fn transform_sql_with_ctx(
}
}

fn register_remote_function(ctx: &SessionContext, remote_function: &RemoteFunction) {
match &remote_function.function_type {
FunctionType::Scalar => {
ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new(
&remote_function.name,
map_data_type(&remote_function.return_type),
)))
}
FunctionType::Aggregate => {
ctx.register_udaf(AggregateUDF::new_from_impl(ByPassAggregateUDF::new(
&remote_function.name,
map_data_type(&remote_function.return_type),
)))
}
FunctionType::Window => {
ctx.register_udwf(WindowUDF::new_from_impl(ByPassWindowFunction::new(
&remote_function.name,
map_data_type(&remote_function.return_type),
)))
}
}
}

/// WrenDialect is a dialect for Wren engine. Handle the identifier quote style based on the
/// original Datafusion Dialect implementation but with more strict rules.
/// If the identifier isn't lowercase, it will be quoted.
Expand Down Expand Up @@ -294,29 +332,6 @@ fn non_lowercase(sql: &str) -> bool {
lowercase != sql
}

/// Apply Wren Rules to a given session context with a WrenMDL
///
/// TODO: There're some issue about apply the rule with the native optimize rules of datafusion
/// Recommend to use [transform_sql_with_ctx] generated the SQL text instead.
pub async fn apply_wren_rules(
ctx: &SessionContext,
analyzed_wren_mdl: Arc<AnalyzedWrenMDL>,
) -> Result<()> {
// expand the view should be the first rule
ctx.add_analyzer_rule(Arc::new(ExpandWrenViewRule::new(
Arc::clone(&analyzed_wren_mdl),
ctx.state_ref(),
)));
ctx.add_analyzer_rule(Arc::new(ModelAnalyzeRule::new(
Arc::clone(&analyzed_wren_mdl),
ctx.state_ref(),
)));
ctx.add_analyzer_rule(Arc::new(ModelGenerationRule::new(Arc::clone(
&analyzed_wren_mdl,
))));
register_table_with_mdl(ctx, analyzed_wren_mdl.wren_mdl()).await
}

/// Analyze the decision point. It's same as the /v1/analysis/sql API in wren engine
pub fn decision_point_analyze(_wren_mdl: Arc<WrenMDL>, _sql: &str) {}

Expand Down Expand Up @@ -344,8 +359,9 @@ mod test {
use std::sync::Arc;

use crate::mdl::builder::{ColumnBuilder, ManifestBuilder, ModelBuilder};
use crate::mdl::function::RemoteFunction;
use crate::mdl::manifest::Manifest;
use crate::mdl::{self, AnalyzedWrenMDL};
use crate::mdl::{self, transform_sql_with_ctx, AnalyzedWrenMDL};
use datafusion::arrow::array::{ArrayRef, Int64Array, RecordBatch, StringArray};
use datafusion::common::not_impl_err;
use datafusion::common::Result;
Expand All @@ -366,6 +382,7 @@ mod test {
let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)?);
let _ = mdl::transform_sql(
Arc::clone(&analyzed_mdl),
&vec![],
"select o_orderkey + o_orderkey from test.test.orders",
)?;
Ok(())
Expand Down Expand Up @@ -401,6 +418,7 @@ mod test {
let actual = mdl::transform_sql_with_ctx(
&SessionContext::new(),
Arc::clone(&analyzed_mdl),
&vec![],
sql,
)
.await?;
Expand Down Expand Up @@ -428,6 +446,7 @@ mod test {
let actual = mdl::transform_sql_with_ctx(
&SessionContext::new(),
Arc::clone(&analyzed_mdl),
&vec![],
sql,
)
.await?;
Expand Down Expand Up @@ -455,6 +474,7 @@ mod test {
let actual = mdl::transform_sql_with_ctx(
&SessionContext::new(),
Arc::clone(&analyzed_mdl),
&vec![],
sql,
)
.await?;
Expand All @@ -465,6 +485,68 @@ mod test {
Ok(())
}

#[tokio::test]
async fn test_remote_function() -> Result<()> {
env_logger::init();
let test_data: PathBuf =
[env!("CARGO_MANIFEST_DIR"), "tests", "data", "functions.csv"]
.iter()
.collect();
let ctx = SessionContext::new();
let functions = csv::Reader::from_path(test_data)
.unwrap()
.into_deserialize::<RemoteFunction>()
.filter_map(Result::ok)
.collect::<Vec<_>>();
dbg!(&functions);
let manifest = ManifestBuilder::new()
.catalog("CTest")
.schema("STest")
.model(
ModelBuilder::new("Customer")
.table_reference("datafusion.public.customer")
.column(ColumnBuilder::new("Custkey", "int").build())
.column(ColumnBuilder::new("Name", "string").build())
.build(),
)
.build();
let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?);
let actual = transform_sql_with_ctx(
&ctx,
Arc::clone(&analyzed_mdl),
&functions,
r#"select add_two("Custkey") from "Customer""#,
)
.await?;
assert_eq!(actual, "SELECT add_two(\"Customer\".\"Custkey\") FROM \
(SELECT datafusion.public.customer.\"Custkey\" AS \"Custkey\" FROM \
(SELECT datafusion.public.customer.\"Custkey\" FROM datafusion.public.customer)) AS \"Customer\"");

let actual = transform_sql_with_ctx(
&ctx,
Arc::clone(&analyzed_mdl),
&functions,
r#"select median("Custkey") from "CTest"."STest"."Customer" group by "Name""#,
)
.await?;
assert_eq!(actual, "SELECT median(\"Customer\".\"Custkey\") FROM \
(SELECT datafusion.public.customer.\"Custkey\" AS \"Custkey\", \
datafusion.public.customer.\"Name\" AS \"Name\" FROM \
(SELECT datafusion.public.customer.\"Custkey\", datafusion.public.customer.\"Name\" \
FROM datafusion.public.customer)) AS \"Customer\" GROUP BY \"Customer\".\"Name\"");

// TODO: support window functions analysis
// let actual = transform_sql_with_ctx(
// &ctx,
// Arc::clone(&analyzed_mdl),
// &functions,
// r#"select max_if("Custkey") OVER (PARTITION BY "Name") from "Customer""#,
// ).await?;
// assert_eq!(actual, "");

Ok(())
}

async fn assert_sql_valid_executable(sql: &str) -> Result<()> {
let ctx = SessionContext::new();
// To roundtrip testing, we should register the mock table for the planned sql.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ async fn main() -> Result<()> {
let transformed = match transform_sql_with_ctx(
&ctx,
Arc::clone(&analyzed_mdl),
vec![],
"select totalprice from wrenai.public.customers",
)
.await
Expand All @@ -104,6 +105,7 @@ async fn main() -> Result<()> {
let transformed = match transform_sql_with_ctx(
&ctx,
Arc::clone(&analyzed_mdl),
&vec![],
"select customer_state_cf from wrenai.public.order_items",
)
.await
Expand Down
2 changes: 1 addition & 1 deletion wren-modeling-rs/wren-example/examples/datafusion-apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async fn main() -> Result<()> {
// TODO: there're some issue for optimize rules
// let ctx = create_ctx_with_mdl(&ctx, analyzed_mdl).await?;
let sql = "select * from wrenai.public.order_items";
let sql = transform_sql_with_ctx(&ctx, analyzed_mdl, sql).await?;
let sql = transform_sql_with_ctx(&ctx, analyzed_mdl, &vec![], sql).await?;
println!("Wren engine generated SQL: \n{}", sql);
// create a plan to run a SQL query
let df = match ctx.sql(&sql).await {
Expand Down
3 changes: 2 additions & 1 deletion wren-modeling-rs/wren-example/examples/plan-sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ async fn main() -> datafusion::common::Result<()> {

let sql = "select customer_state from wrenai.public.orders_model";
println!("Original SQL: \n{}", sql);
let sql = transform_sql_with_ctx(&SessionContext::new(), analyzed_mdl, sql).await?;
let sql =
transform_sql_with_ctx(&SessionContext::new(), analyzed_mdl, vec![], sql).await?;
println!("Wren engine generated SQL: \n{}", sql);
Ok(())
}
Expand Down
5 changes: 3 additions & 2 deletions wren-modeling-rs/wren-example/examples/to-many-calculation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ use datafusion::prelude::{CsvReadOptions, SessionContext};
use wren_core::mdl::builder::{
ColumnBuilder, ManifestBuilder, ModelBuilder, RelationshipBuilder,
};
use wren_core::mdl::context::create_ctx_with_mdl;
use wren_core::mdl::manifest::{JoinType, Manifest};
use wren_core::mdl::{apply_wren_rules, AnalyzedWrenMDL};
use wren_core::mdl::AnalyzedWrenMDL;

#[tokio::main]
async fn main() -> Result<()> {
Expand Down Expand Up @@ -75,7 +76,7 @@ async fn main() -> Result<()> {
]);
let analyzed_mdl =
Arc::new(AnalyzedWrenMDL::analyze_with_tables(manifest, register)?);
apply_wren_rules(&ctx, analyzed_mdl).await?;
let ctx = create_ctx_with_mdl(&ctx, analyzed_mdl).await?;
let df = match ctx
.sql("select totalprice from wrenai.public.customers")
.await
Expand Down
3 changes: 2 additions & 1 deletion wren-modeling-rs/wren-example/examples/view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ async fn main() -> datafusion::common::Result<()> {

let sql = "select * from wrenai.public.customers_view";
println!("Original SQL: \n{}", sql);
let sql = transform_sql_with_ctx(&SessionContext::new(), analyzed_mdl, sql).await?;
let sql =
transform_sql_with_ctx(&SessionContext::new(), analyzed_mdl, vec![], sql).await?;
println!("Wren engine generated SQL: \n{}", sql);
Ok(())
}
Expand Down

0 comments on commit d42ba46

Please sign in to comment.