Skip to content

Commit

Permalink
implement python api
Browse files Browse the repository at this point in the history
  • Loading branch information
goldmedal committed Oct 16, 2024
1 parent d42ba46 commit b9dd669
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 13 deletions.
4 changes: 4 additions & 0 deletions ibis-server/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Config:
def __init__(self):
load_dotenv(override=True)
self.wren_engine_endpoint = os.getenv("WREN_ENGINE_ENDPOINT")
self.remote_function_list_path = os.getenv("REMOTE_FUNCTION_LIST_PATH")
self.validate_wren_engine_endpoint(self.wren_engine_endpoint)
self.diagnose = False
self.init_logger()
Expand Down Expand Up @@ -57,6 +58,9 @@ def update(self, diagnose: bool):
else:
self.init_logger()

def set_remote_function_list_path(self, path: str):
self.remote_function_list_path = path


config = Config()

Expand Down
12 changes: 8 additions & 4 deletions ibis-server/app/mdl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def __init__(
self.manifest_str = manifest_str
self.data_source = data_source
if experiment:
self._rewriter = EmbeddedEngineRewriter(manifest_str)
config = get_config()
function_path = config.remote_function_list_path
self._rewriter = EmbeddedEngineRewriter(manifest_str, function_path)
else:
self._rewriter = ExternalEngineRewriter(manifest_str)

Expand Down Expand Up @@ -68,14 +70,16 @@ def rewrite(self, sql: str) -> str:


class EmbeddedEngineRewriter:
def __init__(self, manifest_str: str):
def __init__(self, manifest_str: str, function_path: str):
self.manifest_str = manifest_str
self.function_path = function_path

def rewrite(self, sql: str) -> str:
from wren_core import transform_sql
from wren_core import read_remote_function_list, transform_sql

try:
return transform_sql(self.manifest_str, sql)
functions = read_remote_function_list(self.function_path)
return transform_sql(self.manifest_str, functions, sql)
except Exception as e:
raise RewriteError(str(e))

Expand Down
2 changes: 2 additions & 0 deletions ibis-server/app/routers/v2/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from fastapi import APIRouter, Depends, Query, Response
from fastapi.responses import JSONResponse

from app.config import get_config
from app.dependencies import verify_query_dto
from app.mdl.rewriter import Rewriter
from app.model import (
Expand All @@ -18,6 +19,7 @@
from app.util import to_json

router = APIRouter(prefix="/connector")
config = get_config()


@router.post("/{data_source}/query", dependencies=[Depends(verify_query_dto)])
Expand Down
26 changes: 26 additions & 0 deletions ibis-server/tests/routers/v3/connector/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastapi.testclient import TestClient
from testcontainers.postgres import PostgresContainer

from app.config import get_config
from app.main import app
from app.model.validator import rules
from tests.confest import file_path
Expand Down Expand Up @@ -369,6 +370,31 @@ def test_dry_plan():
assert response.text is not None


def test_query_with_remote_function(postgres: PostgresContainer):
config = get_config()
config.set_remote_function_list_path(file_path("resource/functions.csv"))

connection_info = to_connection_info(postgres)
response = client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT unistr(o_orderstatus) FROM wren.public.orders LIMIT 1",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["columns"]) == 1
assert len(result["data"]) == 1
assert result["data"][0] == [
"O",
]
assert result["dtypes"] == {
"unistr": "object",
}


def to_connection_info(pg: PostgresContainer):
return {
"host": pg.get_container_host_ip(),
Expand Down
96 changes: 96 additions & 0 deletions wren-modeling-py/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions wren-modeling-py/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ wren-core = { path = "../wren-modeling-rs/core" }
base64 = "0.22.1"
serde_json = "1.0.117"
thiserror = "1.0"
csv = "1.3.0"
serde = { version = "1.0.210", features = ["derive"] }
env_logger = "0.11.5"
log = "0.4.22"

[build-dependencies]
pyo3-build-config = "0.21.2"
58 changes: 50 additions & 8 deletions wren-modeling-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,64 @@ use std::sync::Arc;
use base64::prelude::*;
use pyo3::prelude::*;

use crate::errors::CoreError;
use crate::remote_functions::RemoteFunction;
use log::debug;
use wren_core::mdl;
use wren_core::mdl::manifest::Manifest;
use wren_core::mdl::AnalyzedWrenMDL;

use crate::errors::CoreError;

mod errors;
mod remote_functions;

#[pyfunction]
fn transform_sql(mdl_base64: &str, sql: &str) -> Result<String, CoreError> {
fn transform_sql(
mdl_base64: &str,
remote_functions: Vec<RemoteFunction>,
sql: &str,
) -> Result<String, CoreError> {
let mdl_json_bytes = BASE64_STANDARD
.decode(mdl_base64)
.map_err(CoreError::from)?;
let mdl_json = String::from_utf8(mdl_json_bytes).map_err(CoreError::from)?;
let manifest = serde_json::from_str::<Manifest>(&mdl_json)?;
let remote_functions: Vec<mdl::function::RemoteFunction> = remote_functions
.into_iter()
.map(|f| f.into())
.collect::<Vec<_>>();

let Ok(analyzed_mdl) = AnalyzedWrenMDL::analyze(manifest) else {
return Err(CoreError::new("Failed to analyze manifest"));
};
match mdl::transform_sql(Arc::new(analyzed_mdl), sql) {
match mdl::transform_sql(Arc::new(analyzed_mdl), &remote_functions, sql) {
Ok(transformed_sql) => Ok(transformed_sql),
Err(e) => Err(CoreError::new(&e.to_string())),
}
}

#[pyfunction]
fn read_remote_function_list(path: Option<&str>) -> Vec<RemoteFunction> {
debug!(
"Reading remote function list from {}",
path.unwrap_or("path is not provided")
);
if let Some(path) = path {
csv::Reader::from_path(path)
.unwrap()
.into_deserialize::<RemoteFunction>()
.filter_map(Result::ok)
.collect::<Vec<_>>()
} else {
vec![]
}
}

#[pymodule]
#[pyo3(name = "wren_core")]
fn wren_core_wrapper(m: &Bound<'_, PyModule>) -> PyResult<()> {
env_logger::init();
m.add_function(wrap_pyfunction!(transform_sql, m)?)?;
m.add_function(wrap_pyfunction!(read_remote_function_list, m)?)?;
Ok(())
}

Expand All @@ -41,7 +70,7 @@ mod tests {
use base64::Engine;
use serde_json::Value;

use crate::transform_sql;
use crate::{read_remote_function_list, transform_sql};

#[test]
fn test_transform_sql() {
Expand All @@ -66,13 +95,26 @@ mod tests {
}"#;
let v: Value = serde_json::from_str(data).unwrap();
let mdl_base64: String = BASE64_STANDARD.encode(v.to_string().as_bytes());
let transformed_sql =
transform_sql(&mdl_base64, "SELECT * FROM my_catalog.my_schema.customer")
.unwrap();
let transformed_sql = transform_sql(
&mdl_base64,
vec![],
"SELECT * FROM my_catalog.my_schema.customer",
)
.unwrap();
assert_eq!(
transformed_sql,
"SELECT customer.c_custkey, customer.c_name FROM \
(SELECT main.customer.c_custkey AS c_custkey, main.customer.c_name AS c_name FROM main.customer) AS customer"
);
}

#[test]
fn test_read_remote_function_list() {
let path = "tests/functions.csv";
let remote_functions = read_remote_function_list(Some(path));
assert_eq!(remote_functions.len(), 3);

let remote_function = read_remote_function_list(None);
assert_eq!(remote_function.len(), 0);
}
}
13 changes: 12 additions & 1 deletion wren-modeling-py/tests/test_modeling_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,19 @@

def test_transform_sql():
sql = "SELECT * FROM my_catalog.my_schema.customer"
rewritten_sql = wren_core.transform_sql(manifest_str, sql)
rewritten_sql = wren_core.transform_sql(manifest_str, [], sql)
assert (
rewritten_sql
== 'SELECT customer.c_custkey, customer.c_name FROM (SELECT main.customer.c_custkey AS c_custkey, main.customer.c_name AS c_name FROM main.customer) AS customer'
)

def test_read_function_list():
path = "tests/functions.csv"
functions = wren_core.read_remote_function_list(path)
assert len(functions) == 3

rewritten_sql = wren_core.transform_sql(manifest_str, functions, "SELECT add_two(c_custkey) FROM my_catalog.my_schema.customer")
assert rewritten_sql == 'SELECT add_two(customer.c_custkey) FROM (SELECT customer.c_custkey FROM (SELECT main.customer.c_custkey AS c_custkey FROM main.customer) AS customer) AS customer'

functions = wren_core.read_remote_function_list(None)
assert len(functions) == 0

0 comments on commit b9dd669

Please sign in to comment.