diff --git a/core/Cargo.toml b/core/Cargo.toml index a941f7a..347e7da 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -21,6 +21,7 @@ lazy_static = "1.4.0" log = "0.4.21" ollama-rs = "=0.2.1" pgmq = "0.29" +pgrx = "=0.12.5" regex = "1.9.2" reqwest = {version = "0.11.18", features = ["json"] } serde = { version = "1.0.173", features = ["derive"] } diff --git a/core/src/types.rs b/core/src/types.rs index 3828a1a..8b3f86f 100644 --- a/core/src/types.rs +++ b/core/src/types.rs @@ -1,5 +1,6 @@ use chrono::serde::ts_seconds_option::deserialize as from_tsopt; +use pgrx::pg_sys::Oid; use serde::{Deserialize, Serialize}; use sqlx::types::chrono::Utc; use sqlx::FromRow; @@ -103,8 +104,7 @@ pub enum TableMethod { #[derive(Clone, Debug, Default, Serialize, Deserialize, FromRow)] pub struct JobParams { - pub schema: String, - pub table: String, + pub table: PgOid, pub columns: Vec, pub update_time_col: Option, pub table_method: TableMethod, diff --git a/extension/src/api.rs b/extension/src/api.rs index 7a35754..2dbd246 100644 --- a/extension/src/api.rs +++ b/extension/src/api.rs @@ -5,7 +5,6 @@ use crate::search::{self, init_table}; use crate::transformers::generic::env_interpolate_string; use crate::transformers::transform; use crate::types; -use crate::util::pg_oid_to_table_name; use anyhow::Result; use pgrx::prelude::*; @@ -28,10 +27,10 @@ fn table( schedule: default!(&str, "'* * * * *'"), ) -> Result { let model = Model::new(transformer)?; - let table_name_str = pg_oid_to_table_name(table_name); + init_table( job_name, - &table_name_str, + table_name, columns, primary_key, Some(update_col), @@ -107,7 +106,6 @@ fn init_rag( let transformer_model = Model::new(transformer)?; init_table( agent_name, - schema, table_name, columns, unique_record_id, diff --git a/extension/src/init.rs b/extension/src/init.rs index 78a8ee2..3ef91e7 100644 --- a/extension/src/init.rs +++ b/extension/src/init.rs @@ -256,14 +256,14 @@ pub fn get_column_datatype(table: &str, column: &str) -> Result { ) .map_err(|_| { anyhow!( - "One of schema:`{}`, table:`{}`, column:`{}` does not exist.", + "One of table:`{}`, column:`{}` does not exist.", table, column ) })? .ok_or_else(|| { anyhow!( - "An unknown error occurred while fetching the data type for column `{}` in `{}.{}`.", + "An unknown error occurred while fetching the data type for column `{}` in `{}`.", table, column ) diff --git a/extension/src/search.rs b/extension/src/search.rs index 9dba1d2..6cf7a55 100644 --- a/extension/src/search.rs +++ b/extension/src/search.rs @@ -36,7 +36,7 @@ pub fn init_table( } // get prim key type - let pkey_type = init::get_column_datatype(&table_name_str, primary_key)?; + let pkey_type = init::get_column_datatype(table_name, primary_key)?; init::init_pgmq()?; let guc_configs = get_guc_configs(&transformer.source); @@ -102,8 +102,7 @@ pub fn init_table( }; let valid_params = types::JobParams { - schema: schema.to_string(), - table: table.to_string(), + table: table_name_str.clone(), columns: columns.clone(), update_time_col: update_col, table_method: table_method.clone(), @@ -168,8 +167,8 @@ pub fn init_table( // setup triggers // create the trigger if not exists let trigger_handler = create_trigger_handler(job_name, &columns, primary_key); - let insert_trigger = create_event_trigger(job_name, schema, table, "INSERT"); - let update_trigger = create_event_trigger(job_name, schema, table, "UPDATE"); + let insert_trigger = create_event_trigger(job_name, table_name_str.clone(), "INSERT"); + let update_trigger = create_event_trigger(job_name, table_name_str.clone(), "UPDATE"); let _: Result<_, spi::Error> = Spi::connect(|mut c| { let _r = c.update(&trigger_handler, None, None)?; let _r = c.update(&insert_trigger, None, None)?; diff --git a/extension/src/util.rs b/extension/src/util.rs index b030991..25e8f16 100644 --- a/extension/src/util.rs +++ b/extension/src/util.rs @@ -5,7 +5,6 @@ use pgrx::*; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use sqlx::{Pool, Postgres}; use std::env; -use std::ffi::CStr; use url::{ParseError, Url}; use crate::guc; @@ -207,12 +206,14 @@ pub fn get_pg_options(cfg: Config) -> Result { } pub fn pg_oid_to_table_name(oid: PgOid) -> String { - unsafe { - let regclass_cstring = regclassout(oid.value() as Oid); - CStr::from_ptr(regclass_cstring) - .to_string_lossy() - .into_owned() - } + let query = "SELECT relname FROM pg_class WHERE oid = $1"; + let table_name: String = Spi::get_one_with_args( + query, + vec![(PgBuiltInOids::REGCLASSOID.oid(), oid.into_datum())] + ) + .expect("Failed to fetch table name from oid") + .unwrap_or_else(|| panic!("Table name not found for oid: {}", oid.value())); + table_name } pub async fn ready(conn: &Pool) -> bool {