Skip to content

Commit

Permalink
Feature/PADW 65 GUC Based Transformer Server Type (#11)
Browse files Browse the repository at this point in the history
* Integrating transformer_client to enable GUC-based switching between different transformer server types (Ollama or OpenAI), enhancing flexibility for transformer server configuration.

* Refactor TransformerServerType with FromStr implementation
- Implemented FromStr for TransformerServerType to simplify string parsing.
- Improved error handling by returning &'static str for invalid server types.

* Logging Cleanup
  • Loading branch information
analyzer1 authored Sep 27, 2024
1 parent 5d9de8b commit 978b73f
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 33 deletions.
14 changes: 4 additions & 10 deletions extension/src/controller/bgw_transformer_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ use tokio::runtime::Runtime;
use serde::Deserialize;

use crate::model::*;
// use crate::utility::ollama_client;
use crate::utility::openai_client;
use crate::utility::transformer_client;
use crate::utility::guc;
use regex::Regex;

Expand Down Expand Up @@ -74,13 +73,8 @@ pub extern "C" fn background_worker_transformer_client(_arg: pg_sys::Datum) {
while retries < MAX_TRANSFORMER_RETRIES {
runtime.block_on(async {
// Get Generation
generation_json_bk_identification = match openai_client::send_request(table_details_json_str.as_str(), prompt_template::PromptTemplate::BKIdentification, &0, &hints).await {
generation_json_bk_identification = match transformer_client::send_request(table_details_json_str.as_str(), prompt_template::PromptTemplate::BKIdentification, &0, &hints).await {
Ok(response_json) => {

// TODO: Add a function to enable logging.
let response_json_pretty = serde_json::to_string_pretty(&response_json)
.expect("Failed to convert Response JSON to Pretty String.");
log!("Response: {}", response_json_pretty);
Some(response_json)
},
Err(e) => {
Expand Down Expand Up @@ -122,7 +116,7 @@ pub extern "C" fn background_worker_transformer_client(_arg: pg_sys::Datum) {
while retries < MAX_TRANSFORMER_RETRIES {
runtime.block_on(async {
// Get Generation
generation_json_bk_name = match openai_client::send_request(table_details_json_str.as_str(), prompt_template::PromptTemplate::BKName, &0, &hints).await {
generation_json_bk_name = match transformer_client::send_request(table_details_json_str.as_str(), prompt_template::PromptTemplate::BKName, &0, &hints).await {
Ok(response_json) => {

// let response_json_pretty = serde_json::to_string_pretty(&response_json)
Expand Down Expand Up @@ -171,7 +165,7 @@ pub extern "C" fn background_worker_transformer_client(_arg: pg_sys::Datum) {
runtime.block_on(async {
// Get Generation
generation_json_descriptor_sensitive =
match openai_client::send_request(
match transformer_client::send_request(
table_details_json_str.as_str(),
prompt_template::PromptTemplate::DescriptorSensitive,
column,
Expand Down
2 changes: 0 additions & 2 deletions extension/src/controller/dv_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,6 @@ pub fn build_dv(build_id: &String, dv_objects_query: &str) {
dv_ddl_sql.push_str(&dv_business_key_ddl_sql);
}

// log!("DDL Full: {}", &dv_ddl_sql);

// Build Tables using DDL
Spi::connect( |mut client| {
_ = client.update(&dv_ddl_sql, None, None);
Expand Down
3 changes: 1 addition & 2 deletions extension/src/controller/dv_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ pub fn dv_load_schema_from_build_id(build_id: &String) -> Option<DVSchema> {
let deserialized_schema: Result<DVSchema, serde_json::Error> = serde_json::from_value(schema_json.0);
match deserialized_schema {
Ok(deserialized_schema) => {
// log!("Schema deserialized correctly: JSON{:?}", &deserialized_schema);
schema_result = Some(deserialized_schema);
},
Err(_) => {
Expand Down Expand Up @@ -55,7 +54,7 @@ pub fn dv_data_loader(dv_schema: &DVSchema) {

// Run SQL
let dv_dml = hub_dml + &sat_dml;
// log!("DV DML: {}", &dv_dml);

// Build Tables using DDL
Spi::connect( |mut client| {
// client.select(dv_objects_query, None, None);
Expand Down
18 changes: 16 additions & 2 deletions extension/src/utility/guc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ pub static PG_AUTO_DW_DATABASE_NAME: GucSetting<Option<&CStr>> = GucSetting::<Op
// Default not set, as this will make direct changes to the database
pub static PG_AUTO_DW_DW_SCHEMA: GucSetting<Option<&CStr>> = GucSetting::<Option<&CStr>>::new(None);

// Default set to Ollama
pub static PG_AUTO_DW_TRANSFORMER_SERVER_TYPE: GucSetting<Option<&CStr>> = GucSetting::<Option<&CStr>>::new(Some(unsafe {
CStr::from_bytes_with_nul_unchecked(b"ollama\0")
}));

// Default Transformer Server URL
pub static PG_AUTO_DW_TRANSFORMER_SERVER_URL: GucSetting<Option<&CStr>> = GucSetting::<Option<&CStr>>::new(Some(unsafe {
CStr::from_bytes_with_nul_unchecked(b"http://localhost:11434/api/generate\0")
Expand All @@ -25,8 +30,6 @@ pub static PG_AUTO_DW_MODEL: GucSetting<Option<&CStr>> = GucSetting::<Option<&CS
CStr::from_bytes_with_nul_unchecked(b"mistral\0")
}));



// Default confidence level value is 0.8
// pub static PG_AUTO_DW_CONFIDENCE_LEVEL: GucSetting<f64> = GucSetting::<f64>::new(0.8);

Expand All @@ -51,6 +54,15 @@ pub fn init_guc() {
GucFlags::default(),
);

GucRegistry::define_string_guc(
"pg_auto_dw.transformer_server_type",
"Transformer server type for the pg_auto_dw extension.",
"Specifies the server type used by the pg_auto_dw extension. Current available server types include, ollama and openai.",
&PG_AUTO_DW_TRANSFORMER_SERVER_TYPE,
GucContext::Suset,
GucFlags::default(),
);

GucRegistry::define_string_guc(
"pg_auto_dw.transformer_server_url",
"Transformer URL for the pg_auto_dw extension.",
Expand Down Expand Up @@ -96,6 +108,7 @@ pub fn init_guc() {
pub enum PgAutoDWGuc {
DatabaseName,
DwSchema,
TransformerServerType,
TransformerServerUrl,
TransformerServerToken,
Model,
Expand All @@ -108,6 +121,7 @@ pub fn get_guc(guc: PgAutoDWGuc) -> Option<String> {
let val = match guc {
PgAutoDWGuc::DatabaseName => PG_AUTO_DW_DATABASE_NAME.get(),
PgAutoDWGuc::DwSchema => PG_AUTO_DW_DW_SCHEMA.get(),
PgAutoDWGuc::TransformerServerType => PG_AUTO_DW_TRANSFORMER_SERVER_TYPE.get(),
PgAutoDWGuc::TransformerServerUrl => PG_AUTO_DW_TRANSFORMER_SERVER_URL.get(),
PgAutoDWGuc::TransformerServerToken => PG_AUTO_DW_TRANSFORMER_SERVER_TOKEN.get(),
PgAutoDWGuc::Model => PG_AUTO_DW_MODEL.get(),
Expand Down
5 changes: 3 additions & 2 deletions extension/src/utility/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod ollama_client;
pub mod openai_client;
pub mod transformer_client;
mod ollama_client;
mod openai_client;
pub mod setup;
pub mod guc;
4 changes: 0 additions & 4 deletions extension/src/utility/ollama_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ use std::time::Duration;
use crate::utility::guc;
use crate::model::prompt_template::PromptTemplate;

use pgrx::prelude::*;

#[derive(Serialize, Debug)]
pub struct GenerateRequest {
pub model: String,
Expand Down Expand Up @@ -43,8 +41,6 @@ pub async fn send_request(new_json: &str, template_type: PromptTemplate, col: &u
.replace("{column_no}", &column_number)
.replace("{hints}", &hints);

log!("Prompt: {prompt}");

// GUC Values for the transformer server
let transformer_server_url = guc::get_guc(guc::PgAutoDWGuc::TransformerServerUrl).ok_or("GUC: Transformer Server URL is not set")?;
let model = guc::get_guc(guc::PgAutoDWGuc::Model).ok_or("MODEL GUC is not set.")?;
Expand Down
11 changes: 0 additions & 11 deletions extension/src/utility/openai_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::time::Duration;

use crate::utility::guc;
use crate::model::prompt_template::PromptTemplate;
use pgrx::prelude::*;

#[derive(Serialize, Debug)]
pub struct Request {
Expand Down Expand Up @@ -93,12 +92,6 @@ pub async fn send_request(new_json: &str, template_type: PromptTemplate, col: &u
response_format,
};

log!("Request URL: {}", transformer_server_url);
log!("Request Headers:");
// log!(" Authorization: Bearer {}", transformer_server_token);
log!(" Content-Type: application/json");
log!("Request Body: {}", serde_json::to_string(&request).unwrap());

let response = client
.post(&transformer_server_url) // Ensure this is updated to OpenAI's URL
.header("Authorization", format!("Bearer {}", transformer_server_token)) // Add Bearer token here
Expand All @@ -109,10 +102,6 @@ pub async fn send_request(new_json: &str, template_type: PromptTemplate, col: &u
.json::<Response>() // Await the response and parse it as JSON
.await?;

log!("Response: {}", serde_json::to_string(&response).unwrap());

// let response_json: serde_json::Value = serde_json::to_value(&response)?;

// Extract the content string
let content_str = &response
.choices
Expand Down
35 changes: 35 additions & 0 deletions extension/src/utility/transformer_client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use crate::model::prompt_template::PromptTemplate;
use super::{guc, openai_client, ollama_client};
use TransformerServerType::{OpenAI, Ollama};
use std::str::FromStr;

pub enum TransformerServerType {
OpenAI,
Ollama
}

impl FromStr for TransformerServerType {
type Err = &'static str;

fn from_str(s: &str) -> Result<TransformerServerType, Self::Err> {
match s.to_lowercase().as_str() {
"openai" => Ok(OpenAI),
"ollama" => Ok(Ollama),
_ => Err("Invalid Transformer Server Type"),
}
}
}

pub async fn send_request(new_json: &str, template_type: PromptTemplate, col: &u32, hints: &str) -> Result<serde_json::Value, Box<dyn std::error::Error>> {

let transformer_server_type_str = guc::get_guc(guc::PgAutoDWGuc::TransformerServerType).ok_or("GUC: Transformer Server Type is not set.")?;

let transformer_server_type = transformer_server_type_str.parse::<TransformerServerType>()
.map_err(|e| format!("Error parsing Transformer Server Type: {}", e))?;

match transformer_server_type {
OpenAI => openai_client::send_request(new_json, template_type, col, hints).await,
Ollama => ollama_client::send_request(new_json, template_type, col, hints).await,
}
}

0 comments on commit 978b73f

Please sign in to comment.