Skip to content

Commit

Permalink
Refactor the model planner and fix view planning (#727)
Browse files Browse the repository at this point in the history
* fix Wren view

* enable and enhance tests

* fix clippy

* cargo fmt

* fix ibis test
  • Loading branch information
goldmedal authored Aug 2, 2024
1 parent 9111d51 commit 39e2025
Show file tree
Hide file tree
Showing 8 changed files with 315 additions and 232 deletions.
2 changes: 1 addition & 1 deletion ibis-server/tests/routers/v3/connector/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def test_dry_plan():
assert response.status_code == 200
assert (
response.text
== "\"SELECT orders.orderkey, orders.order_cust_key FROM (SELECT orders.order_cust_key, orders.orderkey FROM (SELECT CONCAT(orders.o_orderkey, '_', orders.o_custkey) AS order_cust_key, orders.o_orderkey AS orderkey FROM public.orders) AS orders) AS orders LIMIT 1\""
== "\"SELECT orders.orderkey, orders.order_cust_key FROM (SELECT orders.order_cust_key, orders.orderkey FROM (SELECT CONCAT(public.orders.o_orderkey, '_', public.orders.o_custkey) AS order_cust_key, public.orders.o_orderkey AS orderkey FROM public.orders) AS orders) AS orders LIMIT 1\""
)


Expand Down
4 changes: 4 additions & 0 deletions wren-modeling-rs/core/src/logical_plan/analyze/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ impl ModelPlanNode {
original_table_scan,
)
}

pub fn plan_name(&self) -> &str {
&self.plan_name
}
}

/// The builder of [ModelPlanNode] to build the plan for the model.
Expand Down
436 changes: 247 additions & 189 deletions wren-modeling-rs/core/src/logical_plan/analyze/rule.rs

Large diffs are not rendered by default.

5 changes: 1 addition & 4 deletions wren-modeling-rs/core/src/mdl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ use datafusion::logical_expr::Expr;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionContext;

use crate::logical_plan::analyze::rule::{
ModelAnalyzeRule, ModelGenerationRule, RemoveWrenPrefixRule,
};
use crate::logical_plan::analyze::rule::{ModelAnalyzeRule, ModelGenerationRule};
use crate::logical_plan::utils::create_schema;
use crate::mdl::manifest::Model;
use crate::mdl::{AnalyzedWrenMDL, WrenMDL};
Expand All @@ -34,7 +32,6 @@ pub async fn create_ctx_with_mdl(
ctx.state_ref(),
)),
Arc::new(ModelGenerationRule::new(Arc::clone(&analyzed_mdl))),
Arc::new(RemoveWrenPrefixRule::new(Arc::clone(&analyzed_mdl))),
])
// TODO: there're some issues for the optimize rule.
.with_optimizer_rules(vec![]);
Expand Down
71 changes: 41 additions & 30 deletions wren-modeling-rs/core/src/mdl/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
use crate::logical_plan::utils::from_qualified_name_str;
use crate::mdl::context::{create_ctx_with_mdl, register_table_with_mdl};
use crate::mdl::manifest::{Column, Manifest, Model};
use std::{collections::HashMap, sync::Arc};

use datafusion::execution::context::SessionState;
use datafusion::prelude::SessionContext;
use datafusion::{error::Result, sql::unparser::plan_to_sql};
use log::{debug, info};
use manifest::Relationship;
use parking_lot::RwLock;
use std::{collections::HashMap, sync::Arc};

pub use dataset::Dataset;
use manifest::Relationship;

use crate::logical_plan::analyze::rule::{ModelAnalyzeRule, ModelGenerationRule};
use crate::logical_plan::utils::from_qualified_name_str;
use crate::mdl::context::{create_ctx_with_mdl, register_table_with_mdl};
use crate::mdl::manifest::{Column, Manifest, Model};

pub mod builder;
pub mod context;
Expand All @@ -16,11 +21,6 @@ pub mod lineage;
pub mod manifest;
pub mod utils;

use crate::logical_plan::analyze::rule::{
ModelAnalyzeRule, ModelGenerationRule, RemoveWrenPrefixRule,
};
pub use dataset::Dataset;

pub type SessionStateRef = Arc<RwLock<SessionState>>;

pub struct AnalyzedWrenMDL {
Expand Down Expand Up @@ -195,6 +195,11 @@ pub async fn transform_sql_with_ctx(
analyzed_mdl: Arc<AnalyzedWrenMDL>,
sql: &str,
) -> Result<String> {
let catalog_schema = format!(
"{}.{}.",
analyzed_mdl.wren_mdl().catalog(),
analyzed_mdl.wren_mdl().schema()
);
info!("wren-core received SQL: {}", sql);
let ctx = create_ctx_with_mdl(ctx, analyzed_mdl).await?;
let plan = ctx.state().create_logical_plan(sql).await?;
Expand All @@ -205,8 +210,10 @@ pub async fn transform_sql_with_ctx(
// show the planned sql
match plan_to_sql(&analyzed) {
Ok(sql) => {
info!("wren-core planned SQL: {}", sql.to_string());
Ok(sql.to_string())
// TODO: workaround to remove unnecessary catalog and schema of mdl
let replaced = sql.to_string().replace(&catalog_schema, "");
info!("wren-core planned SQL: {}", replaced);
Ok(replaced)
}
Err(e) => Err(e),
}
Expand All @@ -224,9 +231,6 @@ pub async fn apply_wren_rules(
ctx.add_analyzer_rule(Arc::new(ModelGenerationRule::new(Arc::clone(
&analyzed_wren_mdl,
))));
ctx.add_analyzer_rule(Arc::new(RemoveWrenPrefixRule::new(Arc::clone(
&analyzed_wren_mdl,
))));
register_table_with_mdl(ctx, analyzed_wren_mdl.wren_mdl()).await
}

Expand Down Expand Up @@ -256,13 +260,12 @@ mod test {
use std::path::PathBuf;
use std::sync::Arc;

use datafusion::arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray};
use datafusion::arrow::array::{ArrayRef, Int64Array, RecordBatch, StringArray};
use datafusion::common::not_impl_err;
use datafusion::common::Result;
use datafusion::prelude::SessionContext;
use datafusion::sql::unparser::plan_to_sql;

use crate::mdl::context::create_ctx_with_mdl;
use crate::mdl::manifest::Manifest;
use crate::mdl::{self, AnalyzedWrenMDL};

Expand Down Expand Up @@ -317,8 +320,8 @@ mod test {
sql,
)
.await?;
let after_roundtrip = plan_sql(&actual, Arc::clone(&analyzed_mdl)).await?;
println!("After roundtrip: {}", after_roundtrip);
println!("After transform: {}", actual);
assert_sql_valid_executable(&actual).await?;
}

Ok(())
Expand All @@ -344,25 +347,33 @@ mod test {
sql,
)
.await?;
let after_roundtrip = plan_sql(&actual, Arc::clone(&analyzed_mdl)).await?;
println!("After roundtrip: {}", after_roundtrip);
assert_sql_valid_executable(&actual).await?;
Ok(())
}

async fn plan_sql(sql: &str, analyzed_mdl: Arc<AnalyzedWrenMDL>) -> Result<String> {
let ctx = create_ctx_with_mdl(&SessionContext::new(), analyzed_mdl).await?;
async fn assert_sql_valid_executable(sql: &str) -> Result<()> {
let ctx = SessionContext::new();
// To roundtrip testing, we should register the mock table for the planned sql.
ctx.register_batch("orders", orders())?;
ctx.register_batch("customer", customer())?;
ctx.register_batch("profile", profile())?;
let plan = ctx.state().create_logical_plan(sql).await?;
let df = ctx.sql(sql).await?;
let plan = df.into_optimized_plan()?;
// show the planned sql
plan_to_sql(&plan).map(|sql| sql.to_string())
let after_roundtrip = plan_to_sql(&plan).map(|sql| sql.to_string())?;
println!("After roundtrip: {}", after_roundtrip);
match ctx.sql(sql).await?.collect().await {
Ok(_) => Ok(()),
Err(e) => {
eprintln!("Error: {e}");
Err(e)
}
}
}

/// Return a RecordBatch with made up data about customer
fn customer() -> RecordBatch {
let custkey: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3]));
let custkey: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3]));
let name: ArrayRef =
Arc::new(StringArray::from_iter_values(["Gura", "Azki", "Ina"]));
RecordBatch::try_from_iter(vec![("c_custkey", custkey), ("c_name", name)])
Expand All @@ -371,7 +382,7 @@ mod test {

/// Return a RecordBatch with made up data about profile
fn profile() -> RecordBatch {
let custkey: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3]));
let custkey: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3]));
let phone: ArrayRef = Arc::new(StringArray::from_iter_values([
"123456", "234567", "345678",
]));
Expand All @@ -386,9 +397,9 @@ mod test {

/// Return a RecordBatch with made up data about orders
fn orders() -> RecordBatch {
let orderkey: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3]));
let custkey: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3]));
let totalprice: ArrayRef = Arc::new(Int32Array::from(vec![100, 200, 300]));
let orderkey: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3]));
let custkey: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3]));
let totalprice: ArrayRef = Arc::new(Int64Array::from(vec![100, 200, 300]));
RecordBatch::try_from_iter(vec![
("o_orderkey", orderkey),
("o_custkey", custkey),
Expand Down
12 changes: 5 additions & 7 deletions wren-modeling-rs/sqllogictest/src/test_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,11 @@ async fn register_ecommerce_mdl(
.statement("select * from wrenai.public.customers")
.build(),
)
// TODO: support expression without alias inside view
// .view(ViewBuilder::new("revenue_orders").statement("select order_id, sum(price) from wrenai.public.order_items group by order_id").build())
// TODO: fix view with calculation
// .view(
// ViewBuilder::new("revenue_orders")
// .statement("select order_id, sum(price) as totalprice from wrenai.public.order_items group by order_id")
// .build())
.view(ViewBuilder::new("revenue_orders").statement("select order_id, sum(price) from wrenai.public.order_items group by order_id").build())
.view(
ViewBuilder::new("revenue_orders_alias")
.statement("select order_id as order_id, sum(price) as totalprice from wrenai.public.order_items group by order_id")
.build())
.build();
let mut register_tables = HashMap::new();
register_tables.insert(
Expand Down
11 changes: 10 additions & 1 deletion wren-modeling-rs/sqllogictest/test_sql_files/model.slt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@ statement ok
select cast(freight_value as int) + cast(price as int) from wrenai.public.order_items

statement ok
select product_id from wrenai.public.order_items where cast(freight_value as double) > 10
select product_id from wrenai.public.order_items where cast(freight_value as double) > 10.0

statement ok
select product_id from wrenai.public.order_items where freight_value > 10.0

statement ok
select product_id, min(price) from wrenai.public.order_items group by product_id

statement ok
select product_id, min(price) from wrenai.public.order_items where freight_value > 10.0 group by product_id

statement ok
select * from wrenai.public.order_items;
Expand Down
6 changes: 6 additions & 0 deletions wren-modeling-rs/sqllogictest/test_sql_files/view.slt
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
statement ok
SELECT * FROM wrenai.public.customer_view

statement ok
SELECT * FROM wrenai.public.revenue_orders

statement ok
SELECT * FROM wrenai.public.revenue_orders_alias

#query TR
#SELECT totalprice FROM wrenai.public.revenue_orders where order_id = '76754c0e642c8f99a8c3fcb8a14ac700'
#----
Expand Down

0 comments on commit 39e2025

Please sign in to comment.