Skip to content

Commit

Permalink
Feature: PADW-53 Accept New Tables (#6)
Browse files Browse the repository at this point in the history
* Accepting new tables: Incremental LLM requests focusing independently on business key attributes, naming, descriptor categorization (e.g., PII), and providing reasons with confidence scores to mitigate hallucinations and improve control.

* Remove excessive logging

* Enable transformer retries with hints if the appropriate JSON structure is not returned.
  • Loading branch information
analyzer1 authored Sep 10, 2024
1 parent 22eeaab commit 0907b7c
Show file tree
Hide file tree
Showing 5 changed files with 635 additions and 137 deletions.
1 change: 1 addition & 0 deletions extension/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ tokio = { version = "1", features = ["full"] }
uuid = { version = "1.1", features = ["v4", "v5", "serde"] }
chrono = { version = "0.4", features = ["serde"] }
anyhow = "1.0"
regex = "1.7"

[dev-dependencies]
pgrx-tests = "=0.11.4"
Expand Down
311 changes: 257 additions & 54 deletions extension/src/controller/bgw_transformer_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@ use pgrx::bgworkers::*;
use pgrx::prelude::*;

use std::time::Duration;
use std::collections::HashMap;
use tokio::runtime::Runtime;
use serde::de::DeserializeOwned;
use serde_json::from_value;
use serde::Deserialize;

use crate::queries;
use crate::model::source_objects;
use crate::utility::ollama_client;
use crate::utility::guc;
use regex::Regex;

const MAX_TRANSFORMER_RETRIES: u8 = 3; // TODO: Set in GUC

#[pg_guard]
#[no_mangle]
Expand All @@ -20,10 +24,8 @@ pub extern "C" fn background_worker_transformer_client(_arg: pg_sys::Datum) {
// Initialize Tokio runtime
let runtime = Runtime::new().expect("Failed to create Tokio runtime");


while BackgroundWorker::wait_latch(Some(Duration::from_secs(10))) {


// Load Prompts into Results
let result: Result<Vec<source_objects::SourceTablePrompt>, pgrx::spi::Error> = BackgroundWorker::transaction(|| {
Spi::connect(|client| {
Expand All @@ -40,7 +42,6 @@ pub extern "C" fn background_worker_transformer_client(_arg: pg_sys::Datum) {
table_column_links: table_column_links,
table_details: table_details
};

v_source_table_prompts.push(source_table_prompt)
}
Ok(v_source_table_prompts)
Expand All @@ -57,57 +58,209 @@ pub extern "C" fn background_worker_transformer_client(_arg: pg_sys::Datum) {

let table_column_link_json_str = serde_json::to_string_pretty(&source_table_prompt.table_column_links).expect("Failed to convert JSON Column Links to pretty string");
let table_column_links_o: Option<source_objects::TableLinks> = serde_json::from_str(&table_column_link_json_str).ok();

let columns = extract_column_numbers(&table_details_json_str);


// Define generation_json_o outside the runtime.block_on block
let mut generation_json_o: Option<serde_json::Value> = None;

// Run the async block
runtime.block_on(async {
// Get Generation
generation_json_o = match ollama_client::send_request(table_details_json_str.as_str()).await {
Ok(response_json) => {
// log!("Transformer client request successful. {:?}", response_json);
Some(response_json)
},

// Identity BK Ordinal Location
let mut generation_json_bk_identification: Option<serde_json::Value> = None;
let mut identified_business_key_opt: Option<IdentifiedBusinessKey> = None;
let mut retries = 0;
let mut hints = String::new();
while retries < MAX_TRANSFORMER_RETRIES {
runtime.block_on(async {
// Get Generation
generation_json_bk_identification = match ollama_client::send_request(table_details_json_str.as_str(), ollama_client::PromptTemplate::BKIdentification, &0, &hints).await {
Ok(mut response_json) => {

// TODO: Add a function to enable logging.
// let response_json_pretty = serde_json::to_string_pretty(&response_json)
// .expect("Failed to convert Response JSON to Pretty String.");
Some(response_json)
},
Err(e) => {
log!("Error in Ollama client request: {}", e);
hints = format!("Hint: Please ensure you provide a JSON response only. This is your {} attempt.", retries + 1);
None
}
};
});
// let identified_business_key: IdentifiedBusinessKey = serde_json::from_value(generation_json_bk_identification.unwrap()).expect("Not valid JSON");

match serde_json::from_value::<IdentifiedBusinessKey>(generation_json_bk_identification.clone().unwrap()) {
Ok(bk) => {
identified_business_key_opt = Some(bk);
break; // Successfully Decoded
}
Err(e) => {
log!("Error in Ollama client request: {}", e);
None
log!("Error JSON JSON Structure not of type IdentifiedBusinessKey: {}", e);
hints = format!("Hint: Please ensure the correct JSON key pair structure is given. Previously you gave a response but it errored. Error: {e}. Please try again.");
}
};
});
}
retries += 1;
}

let identified_business_key = match identified_business_key_opt {
Some(bk) => bk,
None => panic!("Failed to identify business key after {} retries", retries),
};

// Identity BK Name
let mut generation_json_bk_name: Option<serde_json::Value> = None;
let mut business_key_name_opt: Option<BusinessKeyName> = None;
let mut retries = 0;
let mut hints = String::new();
while retries < MAX_TRANSFORMER_RETRIES {
runtime.block_on(async {
// Get Generation
generation_json_bk_name = match ollama_client::send_request(table_details_json_str.as_str(), ollama_client::PromptTemplate::BKName, &0, &hints).await {
Ok(mut response_json) => {

// let response_json_pretty = serde_json::to_string_pretty(&response_json)
// .expect("Failed to convert Response JSON to Pretty String.");
Some(response_json)
},
Err(e) => {
log!("Error in Ollama client request: {}", e);
hints = format!("Hint: Please ensure you provide a JSON response only. This is your {} attempt.", retries + 1);
None
}
};
});

match serde_json::from_value::<BusinessKeyName>(generation_json_bk_name.clone().unwrap()) {
Ok(bk) => {
business_key_name_opt = Some(bk);
break; // Successfully Decoded
}
Err(e) => {
log!("Error JSON JSON Structure not of type BusinessKeyName: {}", e);
}
}
retries += 1;
}

let business_key_name = match business_key_name_opt {
Some(bk) => bk,
None => panic!("Failed to identify business key name after {} retries", retries),
};

let generation_table_detail_o: Option<source_objects::GenerationTableDetail> = deserialize_option(generation_json_o);
// Identity Descriptor - Sensitive
// let mut generation_json_descriptors_sensitive: HashMap<&u32, Option<serde_json::Value>> = HashMap::new();
let mut descriptors_sensitive: HashMap<&u32, DescriptorSensitive> = HashMap::new();
let mut generation_json_descriptor_sensitive: Option<serde_json::Value> = None;
for column in &columns {
let mut retries = 0;
let mut hints = String::new();
while retries < MAX_TRANSFORMER_RETRIES {
// Run the async block
runtime.block_on(async {
// Get Generation
generation_json_descriptor_sensitive =
match ollama_client::send_request(
table_details_json_str.as_str(),
ollama_client::PromptTemplate::DescriptorSensitive,
column,
&hints).await {
Ok(mut response_json) => {

// let response_json_pretty = serde_json::to_string_pretty(&response_json)
// .expect("Failed to convert Response JSON to Pretty String.");

Some(response_json)
},
Err(e) => {
log!("Error in Ollama client request: {}", e);
hints = format!("Hint: Please ensure you provide a JSON response only. This is your {} attempt.", retries + 1);
None
}
};
// generation_json_descriptors_sensitive.insert(column, generation_json_descriptor_sensitive);
});

match serde_json::from_value::<DescriptorSensitive>(generation_json_descriptor_sensitive.clone().unwrap()) {
Ok(des) => {
// business_key_name_opt = Some(des);
descriptors_sensitive.insert(column, des);
break; // Successfully Decoded
}
Err(e) => {
log!("Error JSON JSON Structure not of type DescriptorSensitive: {}", e);
}
}

retries += 1;
}
}

let table_column_links = table_column_links_o.unwrap();
let generation_table_detail = generation_table_detail_o.unwrap();

// Build the SQL INSERT statement
// Build the SQL INSERT statement
let mut insert_sql = String::from("INSERT INTO auto_dw.transformer_responses (fk_source_objects, model_name, category, business_key_name, confidence_score, reason) VALUES ");

for (index, column_link) in table_column_links.column_links.iter().enumerate() {

let not_last = index != table_column_links.column_links.len() - 1;

let index_o = generation_table_detail.response_column_details.iter().position(|r| r.column_no == column_link.column_ordinal_position);
match index_o {
Some(index) => {
let column_detail = &generation_table_detail.response_column_details[index];

let category = &column_detail.category.replace("'", "''");
let business_key_name = &column_detail.business_key_name.replace("'", "''");
let confidence_score = &column_detail.confidence;
let reason = &column_detail.reason.replace("'", "''");
let pk_source_objects = column_link.pk_source_objects;

let model_name = "Mixtral";

if not_last {
insert_sql.push_str(&format!("({}, '{}', '{}', '{}', {}, '{}'),", pk_source_objects, model_name, category, business_key_name, confidence_score, reason));
} else {
insert_sql.push_str(&format!("({}, '{}', '{}', '{}', {}, '{}');", pk_source_objects, model_name, category, business_key_name, confidence_score, reason));
for (index, column) in columns.iter().enumerate() {

let last = {index == table_column_links.column_links.len() - 1};

if column == &identified_business_key.identified_business_key_values.column_no {

let category = "Business Key Part";
let confidence_score = identified_business_key.identified_business_key_values.confidence_value * business_key_name.business_key_name_values.confidence_value;
let bk_name = &business_key_name.business_key_name_values.name;
let bk_identified_reason = &identified_business_key.identified_business_key_values.reason;
let bk_name_reason = &business_key_name.business_key_name_values.reason;
let reason = format!("BK Identified Reason: {}, BK Naming Reason: {}", bk_identified_reason, bk_name_reason);
let model_name_owned = guc::get_guc(guc::PgAutoDWGuc::Model).expect("MODEL GUC is not set.");
let model_name = model_name_owned.as_str();

let pk_source_objects: i32;
if let Some(pk_source_objects_temp) = table_column_links.find_pk_source_objects(column.clone() as i32) {
pk_source_objects = pk_source_objects_temp;
} else {
println!("No match found for column_ordinal_position: {}", column);
panic!()
}

if !last {
insert_sql.push_str(&format!("({}, '{}', '{}', '{}', {}, '{}'),", pk_source_objects, model_name, category, bk_name.replace(" ", "_"), confidence_score, reason.replace("'", "''")));
} else {
insert_sql.push_str(&format!("({}, '{}', '{}', '{}', {}, '{}');", pk_source_objects, model_name, category, bk_name.replace(" ", "_"), confidence_score, reason.replace("'", "''")));
}

} else {

let pk_source_objects: i32;
let mut category = "Descriptor";
let mut confidence_score: f64 = 1.0;
let bk_name = "NA";
let mut reason = "Defaulted of category 'Descriptor' maintained.".to_string();
let model_name_owned = guc::get_guc(guc::PgAutoDWGuc::Model).expect("MODEL GUC is not set.");
let model_name = model_name_owned.as_str();


if let Some(pk_source_objects_temp) = table_column_links.find_pk_source_objects(column.clone() as i32) {
pk_source_objects = pk_source_objects_temp;
} else {
println!("No match found for column_ordinal_position: {}", column);
panic!()
}

if let Some(descriptor_sensitive) = descriptors_sensitive.get(&column) {
if descriptor_sensitive.descriptor_sensitive_values.is_pii && (descriptor_sensitive.descriptor_sensitive_values.confidence_value > 0.5) {
category = "Descriptor - Sensitive";
confidence_score = descriptor_sensitive.descriptor_sensitive_values.confidence_value;
reason = descriptor_sensitive.descriptor_sensitive_values.reason.clone();
}
} else {
log!("Teseting Can't find a response for {} in Descriptors Sensitive Hashmap.", column);
}

if !last {
insert_sql.push_str(&format!("({}, '{}', '{}', '{}', {}, '{}'),", pk_source_objects, model_name, category, bk_name.replace(" ", "_"), confidence_score, reason.replace("'", "''")));
} else {
insert_sql.push_str(&format!("({}, '{}', '{}', '{}', {}, '{}');", pk_source_objects, model_name, category, bk_name.replace(" ", "_"), confidence_score, reason.replace("'", "''")));
}
None => {break;}
}
}

Expand All @@ -117,16 +270,66 @@ pub extern "C" fn background_worker_transformer_client(_arg: pg_sys::Datum) {
_ = client.update(insert_sql.as_str(), None, None);
})
});
}
}

}
}

fn deserialize_option<T>(json_option: Option<serde_json::Value>) -> Option<T>
where
T: DeserializeOwned
{
json_option.and_then(|json| {
from_value::<T>(json).ok()
})
}
fn extract_column_numbers(json_str: &str) -> Vec<u32> {
// Define a regex to capture the column numbers
let re = Regex::new(r"Column No: (\d+)").expect("Invalid regex");

// Find all matches and collect the column numbers
re.captures_iter(json_str)
.filter_map(|caps| caps.get(1).map(|m| m.as_str().parse::<u32>().unwrap()))
.collect()
}

#[derive(Deserialize, Debug)]
struct IdentifiedBusinessKey {
#[serde(rename = "Identified Business Key")]
identified_business_key_values: IdentifiedBusinessKeyValues,
}

#[derive(Deserialize, Debug)]
struct IdentifiedBusinessKeyValues {
#[serde(rename = "Column No")]
column_no: u32,
#[serde(rename = "Confidence Value")]
confidence_value: f64,
#[serde(rename = "Reason")]
reason: String,
}

#[derive(Deserialize, Debug)]
struct BusinessKeyName {
#[serde(rename = "Business Key Name")]
business_key_name_values: BusinessKeyNameValues,
}

#[derive(Deserialize, Debug)]
struct BusinessKeyNameValues {
#[serde(rename = "Name")]
name: String,
#[serde(rename = "Confidence Value")]
confidence_value: f64,
#[serde(rename = "Reason")]
reason: String,
}

#[derive(Deserialize, Debug)]
struct DescriptorSensitive {
#[serde(rename = "Descriptor - Sensitive")]
descriptor_sensitive_values: DescriptorSensitiveValues,
}

#[derive(Deserialize, Debug)]
struct DescriptorSensitiveValues {
#[serde(rename = "Is PII")]
is_pii: bool,
#[serde(rename = "Confidence Value")]
confidence_value: f64,
#[serde(rename = "Reason")]
reason: String,
}

2 changes: 1 addition & 1 deletion extension/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ fn source_include( schema_pattern_include: &str,
}

#[pg_extern]
fn source_exlude( schema_pattern_exclude: &str,
fn source_exclude( schema_pattern_exclude: &str,
table_pattern_exclude: default!(Option<&str>, "NULL"),
column_pattern_exclude: default!(Option<&str>, "NULL")) -> &'static str {
let schema_pattern_include: &str = "a^";
Expand Down
Loading

0 comments on commit 0907b7c

Please sign in to comment.