From 978b73f44bbd4e261a3bf23b901e4be00a6303fe Mon Sep 17 00:00:00 2001 From: analyzer1 Date: Fri, 27 Sep 2024 10:28:55 -0400 Subject: [PATCH] Feature/PADW 65 GUC Based Transformer Server Type (#11) * Integrating transformer_client to enable GUC-based switching between different transformer server types (Ollama or OpenAI), enhancing flexibility for transformer server configuration. * Refactor TransformerServerType with FromStr implementation - Implemented FromStr for TransformerServerType to simplify string parsing. - Improved error handling by returning &'static str for invalid server types. * Logging Cleanup --- .../src/controller/bgw_transformer_client.rs | 14 +++----- extension/src/controller/dv_builder.rs | 2 -- extension/src/controller/dv_loader.rs | 3 +- extension/src/utility/guc.rs | 18 ++++++++-- extension/src/utility/mod.rs | 5 +-- extension/src/utility/ollama_client.rs | 4 --- extension/src/utility/openai_client.rs | 11 ------ extension/src/utility/transformer_client.rs | 35 +++++++++++++++++++ 8 files changed, 59 insertions(+), 33 deletions(-) create mode 100644 extension/src/utility/transformer_client.rs diff --git a/extension/src/controller/bgw_transformer_client.rs b/extension/src/controller/bgw_transformer_client.rs index 339c61e..8e2568e 100644 --- a/extension/src/controller/bgw_transformer_client.rs +++ b/extension/src/controller/bgw_transformer_client.rs @@ -7,8 +7,7 @@ use tokio::runtime::Runtime; use serde::Deserialize; use crate::model::*; -// use crate::utility::ollama_client; -use crate::utility::openai_client; +use crate::utility::transformer_client; use crate::utility::guc; use regex::Regex; @@ -74,13 +73,8 @@ pub extern "C" fn background_worker_transformer_client(_arg: pg_sys::Datum) { while retries < MAX_TRANSFORMER_RETRIES { runtime.block_on(async { // Get Generation - generation_json_bk_identification = match openai_client::send_request(table_details_json_str.as_str(), prompt_template::PromptTemplate::BKIdentification, &0, &hints).await { + generation_json_bk_identification = match transformer_client::send_request(table_details_json_str.as_str(), prompt_template::PromptTemplate::BKIdentification, &0, &hints).await { Ok(response_json) => { - - // TODO: Add a function to enable logging. - let response_json_pretty = serde_json::to_string_pretty(&response_json) - .expect("Failed to convert Response JSON to Pretty String."); - log!("Response: {}", response_json_pretty); Some(response_json) }, Err(e) => { @@ -122,7 +116,7 @@ pub extern "C" fn background_worker_transformer_client(_arg: pg_sys::Datum) { while retries < MAX_TRANSFORMER_RETRIES { runtime.block_on(async { // Get Generation - generation_json_bk_name = match openai_client::send_request(table_details_json_str.as_str(), prompt_template::PromptTemplate::BKName, &0, &hints).await { + generation_json_bk_name = match transformer_client::send_request(table_details_json_str.as_str(), prompt_template::PromptTemplate::BKName, &0, &hints).await { Ok(response_json) => { // let response_json_pretty = serde_json::to_string_pretty(&response_json) @@ -171,7 +165,7 @@ pub extern "C" fn background_worker_transformer_client(_arg: pg_sys::Datum) { runtime.block_on(async { // Get Generation generation_json_descriptor_sensitive = - match openai_client::send_request( + match transformer_client::send_request( table_details_json_str.as_str(), prompt_template::PromptTemplate::DescriptorSensitive, column, diff --git a/extension/src/controller/dv_builder.rs b/extension/src/controller/dv_builder.rs index 042eef4..9bc6a87 100644 --- a/extension/src/controller/dv_builder.rs +++ b/extension/src/controller/dv_builder.rs @@ -159,8 +159,6 @@ pub fn build_dv(build_id: &String, dv_objects_query: &str) { dv_ddl_sql.push_str(&dv_business_key_ddl_sql); } - // log!("DDL Full: {}", &dv_ddl_sql); - // Build Tables using DDL Spi::connect( |mut client| { _ = client.update(&dv_ddl_sql, None, None); diff --git a/extension/src/controller/dv_loader.rs b/extension/src/controller/dv_loader.rs index c2a8edc..498321f 100644 --- a/extension/src/controller/dv_loader.rs +++ b/extension/src/controller/dv_loader.rs @@ -26,7 +26,6 @@ pub fn dv_load_schema_from_build_id(build_id: &String) -> Option { let deserialized_schema: Result = serde_json::from_value(schema_json.0); match deserialized_schema { Ok(deserialized_schema) => { - // log!("Schema deserialized correctly: JSON{:?}", &deserialized_schema); schema_result = Some(deserialized_schema); }, Err(_) => { @@ -55,7 +54,7 @@ pub fn dv_data_loader(dv_schema: &DVSchema) { // Run SQL let dv_dml = hub_dml + &sat_dml; - // log!("DV DML: {}", &dv_dml); + // Build Tables using DDL Spi::connect( |mut client| { // client.select(dv_objects_query, None, None); diff --git a/extension/src/utility/guc.rs b/extension/src/utility/guc.rs index 2bd3838..75d7c79 100644 --- a/extension/src/utility/guc.rs +++ b/extension/src/utility/guc.rs @@ -12,6 +12,11 @@ pub static PG_AUTO_DW_DATABASE_NAME: GucSetting> = GucSetting::> = GucSetting::>::new(None); +// Default set to Ollama +pub static PG_AUTO_DW_TRANSFORMER_SERVER_TYPE: GucSetting> = GucSetting::>::new(Some(unsafe { + CStr::from_bytes_with_nul_unchecked(b"ollama\0") +})); + // Default Transformer Server URL pub static PG_AUTO_DW_TRANSFORMER_SERVER_URL: GucSetting> = GucSetting::>::new(Some(unsafe { CStr::from_bytes_with_nul_unchecked(b"http://localhost:11434/api/generate\0") @@ -25,8 +30,6 @@ pub static PG_AUTO_DW_MODEL: GucSetting> = GucSetting:: = GucSetting::::new(0.8); @@ -51,6 +54,15 @@ pub fn init_guc() { GucFlags::default(), ); + GucRegistry::define_string_guc( + "pg_auto_dw.transformer_server_type", + "Transformer server type for the pg_auto_dw extension.", + "Specifies the server type used by the pg_auto_dw extension. Current available server types include, ollama and openai.", + &PG_AUTO_DW_TRANSFORMER_SERVER_TYPE, + GucContext::Suset, + GucFlags::default(), + ); + GucRegistry::define_string_guc( "pg_auto_dw.transformer_server_url", "Transformer URL for the pg_auto_dw extension.", @@ -96,6 +108,7 @@ pub fn init_guc() { pub enum PgAutoDWGuc { DatabaseName, DwSchema, + TransformerServerType, TransformerServerUrl, TransformerServerToken, Model, @@ -108,6 +121,7 @@ pub fn get_guc(guc: PgAutoDWGuc) -> Option { let val = match guc { PgAutoDWGuc::DatabaseName => PG_AUTO_DW_DATABASE_NAME.get(), PgAutoDWGuc::DwSchema => PG_AUTO_DW_DW_SCHEMA.get(), + PgAutoDWGuc::TransformerServerType => PG_AUTO_DW_TRANSFORMER_SERVER_TYPE.get(), PgAutoDWGuc::TransformerServerUrl => PG_AUTO_DW_TRANSFORMER_SERVER_URL.get(), PgAutoDWGuc::TransformerServerToken => PG_AUTO_DW_TRANSFORMER_SERVER_TOKEN.get(), PgAutoDWGuc::Model => PG_AUTO_DW_MODEL.get(), diff --git a/extension/src/utility/mod.rs b/extension/src/utility/mod.rs index c089672..46f088e 100644 --- a/extension/src/utility/mod.rs +++ b/extension/src/utility/mod.rs @@ -1,4 +1,5 @@ -pub mod ollama_client; -pub mod openai_client; +pub mod transformer_client; +mod ollama_client; +mod openai_client; pub mod setup; pub mod guc; \ No newline at end of file diff --git a/extension/src/utility/ollama_client.rs b/extension/src/utility/ollama_client.rs index e5694a4..75af8bb 100644 --- a/extension/src/utility/ollama_client.rs +++ b/extension/src/utility/ollama_client.rs @@ -5,8 +5,6 @@ use std::time::Duration; use crate::utility::guc; use crate::model::prompt_template::PromptTemplate; -use pgrx::prelude::*; - #[derive(Serialize, Debug)] pub struct GenerateRequest { pub model: String, @@ -43,8 +41,6 @@ pub async fn send_request(new_json: &str, template_type: PromptTemplate, col: &u .replace("{column_no}", &column_number) .replace("{hints}", &hints); - log!("Prompt: {prompt}"); - // GUC Values for the transformer server let transformer_server_url = guc::get_guc(guc::PgAutoDWGuc::TransformerServerUrl).ok_or("GUC: Transformer Server URL is not set")?; let model = guc::get_guc(guc::PgAutoDWGuc::Model).ok_or("MODEL GUC is not set.")?; diff --git a/extension/src/utility/openai_client.rs b/extension/src/utility/openai_client.rs index 09f9083..bc1f41b 100644 --- a/extension/src/utility/openai_client.rs +++ b/extension/src/utility/openai_client.rs @@ -4,7 +4,6 @@ use std::time::Duration; use crate::utility::guc; use crate::model::prompt_template::PromptTemplate; -use pgrx::prelude::*; #[derive(Serialize, Debug)] pub struct Request { @@ -93,12 +92,6 @@ pub async fn send_request(new_json: &str, template_type: PromptTemplate, col: &u response_format, }; - log!("Request URL: {}", transformer_server_url); - log!("Request Headers:"); - // log!(" Authorization: Bearer {}", transformer_server_token); - log!(" Content-Type: application/json"); - log!("Request Body: {}", serde_json::to_string(&request).unwrap()); - let response = client .post(&transformer_server_url) // Ensure this is updated to OpenAI's URL .header("Authorization", format!("Bearer {}", transformer_server_token)) // Add Bearer token here @@ -109,10 +102,6 @@ pub async fn send_request(new_json: &str, template_type: PromptTemplate, col: &u .json::() // Await the response and parse it as JSON .await?; - log!("Response: {}", serde_json::to_string(&response).unwrap()); - - // let response_json: serde_json::Value = serde_json::to_value(&response)?; - // Extract the content string let content_str = &response .choices diff --git a/extension/src/utility/transformer_client.rs b/extension/src/utility/transformer_client.rs new file mode 100644 index 0000000..2b54ebd --- /dev/null +++ b/extension/src/utility/transformer_client.rs @@ -0,0 +1,35 @@ +use crate::model::prompt_template::PromptTemplate; +use super::{guc, openai_client, ollama_client}; +use TransformerServerType::{OpenAI, Ollama}; +use std::str::FromStr; + +pub enum TransformerServerType { + OpenAI, + Ollama +} + +impl FromStr for TransformerServerType { + type Err = &'static str; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "openai" => Ok(OpenAI), + "ollama" => Ok(Ollama), + _ => Err("Invalid Transformer Server Type"), + } + } +} + +pub async fn send_request(new_json: &str, template_type: PromptTemplate, col: &u32, hints: &str) -> Result> { + + let transformer_server_type_str = guc::get_guc(guc::PgAutoDWGuc::TransformerServerType).ok_or("GUC: Transformer Server Type is not set.")?; + + let transformer_server_type = transformer_server_type_str.parse::() + .map_err(|e| format!("Error parsing Transformer Server Type: {}", e))?; + + match transformer_server_type { + OpenAI => openai_client::send_request(new_json, template_type, col, hints).await, + Ollama => ollama_client::send_request(new_json, template_type, col, hints).await, + } +} +