diff --git a/.sqlx/query-2570c83389d229bee10f84a4c566feafb8464aebced1ee62dc5e2b0042f0ab9c.json b/.sqlx/query-406fb387c8a2d04c3a5362ae2da0154ff0295f078171f59b904e97fdd2933989.json similarity index 83% rename from .sqlx/query-2570c83389d229bee10f84a4c566feafb8464aebced1ee62dc5e2b0042f0ab9c.json rename to .sqlx/query-406fb387c8a2d04c3a5362ae2da0154ff0295f078171f59b904e97fdd2933989.json index 8625aa6..fe72344 100644 --- a/.sqlx/query-2570c83389d229bee10f84a4c566feafb8464aebced1ee62dc5e2b0042f0ab9c.json +++ b/.sqlx/query-406fb387c8a2d04c3a5362ae2da0154ff0295f078171f59b904e97fdd2933989.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT *\n FROM vectorize.vectorize_meta\n WHERE name = $1\n ", + "query": "\n SELECT *\n FROM vectorize.job\n WHERE name = $1\n ", "describe": { "columns": [ { @@ -54,5 +54,5 @@ true ] }, - "hash": "2570c83389d229bee10f84a4c566feafb8464aebced1ee62dc5e2b0042f0ab9c" + "hash": "406fb387c8a2d04c3a5362ae2da0154ff0295f078171f59b904e97fdd2933989" } diff --git a/Cargo.toml b/Cargo.toml index abed55e..b6bd47e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "vectorize" -version = "0.5.0" +version = "0.6.0" edition = "2021" publish = false @@ -15,6 +15,7 @@ pg_test = [] [dependencies] anyhow = "1.0.72" chrono = {version = "0.4.26", features = ["serde"] } +lazy_static = "1.4.0" log = "0.4.19" pgmq = "0.24.0" pgrx = "0.11.0" @@ -29,6 +30,7 @@ sqlx = { version = "0.7.2", features = [ "chrono", ] } thiserror = "1.0.44" +tiktoken-rs = "0.5.7" tokio = {version = "1.29.1", features = ["rt-multi-thread"] } url = "2.4.0" diff --git a/Trunk.toml b/Trunk.toml index a8c1c3b..62f3c17 100644 --- a/Trunk.toml +++ b/Trunk.toml @@ -6,7 +6,7 @@ description = "The simplest way to orchestrate vector search on Postgres." homepage = "https://github.com/tembo-io/pg_vectorize" documentation = "https://github.com/tembo-io/pg_vectorize" categories = ["orchestration", "machine_learning"] -version = "0.5.0" +version = "0.6.0" [build] postgres_version = "15" diff --git a/sql/meta.sql b/sql/meta.sql index 823d378..2d206e2 100644 --- a/sql/meta.sql +++ b/sql/meta.sql @@ -1,4 +1,4 @@ -CREATE TABLE vectorize_meta ( +CREATE TABLE vectorize.job ( job_id bigserial, name TEXT NOT NULL UNIQUE, job_type TEXT NOT NULL, @@ -7,5 +7,3 @@ CREATE TABLE vectorize_meta ( params jsonb NOT NULL, last_completion TIMESTAMP WITH TIME ZONE ); - -CREATE EXTENSION IF NOT EXISTS pgmq CASCADE; \ No newline at end of file diff --git a/sql/vectorize--0.5.0--0.6.0.sql b/sql/vectorize--0.5.0--0.6.0.sql new file mode 100644 index 0000000..aeb7d00 --- /dev/null +++ b/sql/vectorize--0.5.0--0.6.0.sql @@ -0,0 +1,21 @@ +DROP function vectorize."table"; + +-- vectorize::api::table +CREATE FUNCTION vectorize."table"( + "table" TEXT, /* &str */ + "columns" TEXT[], /* alloc::vec::Vec */ + "job_name" TEXT, /* alloc::string::String */ + "primary_key" TEXT, /* alloc::string::String */ + "args" json DEFAULT '{}', /* pgrx::datum::json::Json */ + "schema" TEXT DEFAULT 'public', /* alloc::string::String */ + "update_col" TEXT DEFAULT 'last_updated_at', /* alloc::string::String */ + "transformer" vectorize.Transformer DEFAULT 'openai', /* vectorize::types::Transformer */ + "search_alg" vectorize.SimilarityAlg DEFAULT 'pgv_cosine_similarity', /* vectorize::types::SimilarityAlg */ + "table_method" vectorize.TableMethod DEFAULT 'append', /* vectorize::init::TableMethod */ + "schedule" TEXT DEFAULT '* * * * *' /* alloc::string::String */ +) RETURNS TEXT /* core::result::Result */ +STRICT +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'table_wrapper'; + +ALTER TABLE vectorize.vectorize_meta RENAME TO vectorize.job; diff --git a/sqlx-data.json b/sqlx-data.json deleted file mode 100644 index 38d346a..0000000 --- a/sqlx-data.json +++ /dev/null @@ -1,59 +0,0 @@ -{ - "db": "PostgreSQL", - "2570c83389d229bee10f84a4c566feafb8464aebced1ee62dc5e2b0042f0ab9c": { - "describe": { - "columns": [ - { - "name": "job_id", - "ordinal": 0, - "type_info": "Int8" - }, - { - "name": "name", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "job_type", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "transformer", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "search_alg", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "params", - "ordinal": 5, - "type_info": "Jsonb" - }, - { - "name": "last_completion", - "ordinal": 6, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT *\n FROM vectorize.vectorize_meta\n WHERE name = $1\n " - } -} \ No newline at end of file diff --git a/src/api.rs b/src/api.rs index 76756db..5fb5a9c 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,8 +1,9 @@ -use crate::executor::ColumnJobParams; +use crate::guc; use crate::init; -use crate::openai; use crate::search::cosine_similarity_search; +use crate::transformers::openai; use crate::types; +use crate::types::JobParams; use crate::util; use anyhow::Result; use pgrx::prelude::*; @@ -19,11 +20,9 @@ fn table( update_col: default!(String, "'last_updated_at'"), transformer: default!(types::Transformer, "'openai'"), search_alg: default!(types::SimilarityAlg, "'pgv_cosine_similarity'"), - table_method: default!(init::TableMethod, "'append'"), + table_method: default!(types::TableMethod, "'append'"), schedule: default!(String, "'* * * * *'"), ) -> Result { - // initialize pgmq - init::init_pgmq()?; let job_type = types::JobType::Columns; // write job to table @@ -41,11 +40,12 @@ fn table( // certain embedding services require an API key, e.g. openAI // key can be set in a GUC, so if its required but not provided in args, and not in GUC, error + init::init_pgmq(&transformer)?; match transformer { types::Transformer::openai => { let openai_key = match api_key { Some(k) => serde_json::from_value::(k.clone())?, - None => match util::get_guc(util::VectorizeGuc::OpenAIKey) { + None => match guc::get_guc(guc::VectorizeGuc::OpenAIKey) { Some(k) => k, None => { error!("failed to get API key from GUC"); @@ -54,19 +54,22 @@ fn table( }; openai::validate_api_key(&openai_key)?; } + // no-op + types::Transformer::allMiniLML12v2 => (), } - // TODO: implement a struct for these params - let params = pgrx::JsonB(serde_json::json!({ - "schema": schema, - "table": table, - "columns": columns, - "update_time_col": update_col, - "table_method": table_method, - "primary_key": primary_key, - "pkey_type": pkey_type, - "api_key": api_key - })); + let valid_params = types::JobParams { + schema: schema.clone(), + table: table.to_string(), + columns: columns.clone(), + update_time_col: update_col, + table_method: table_method.clone(), + primary_key, + pkey_type, + api_key: api_key + .map(|k| serde_json::from_value::(k.clone()).expect("error parsing api key")), + }; + let params = pgrx::JsonB(serde_json::to_value(valid_params).expect("error serializing params")); // using SPI here because it is unlikely that this code will be run anywhere but inside the extension. // background worker will likely be moved to an external container or service in near future @@ -145,7 +148,7 @@ fn search( } else { error!("failed to get project metadata"); }; - let project_meta: ColumnJobParams = + let project_meta: JobParams = serde_json::from_value(serde_json::to_value(_project_meta).unwrap_or_else(|e| { error!("failed to serialize metadata: {}", e); })) @@ -163,7 +166,7 @@ fn search( let openai_key = match api_key { Some(k) => k, - None => match util::get_guc(util::VectorizeGuc::OpenAIKey) { + None => match guc::get_guc(guc::VectorizeGuc::OpenAIKey) { Some(k) => k, None => { error!("failed to get API key from GUC"); diff --git a/src/executor.rs b/src/executor.rs index d8a37fb..3e5d3c2 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -1,7 +1,8 @@ use pgrx::prelude::*; use crate::errors::DatabaseError; -use crate::init::{TableMethod, PGMQ_QUEUE_NAME}; +use crate::guc::BATCH_SIZE; +use crate::init::QUEUE_MAPPING; use crate::query::check_input; use crate::types; use crate::util::{from_env_default, get_pg_conn}; @@ -12,6 +13,7 @@ use sqlx::error::Error; use sqlx::postgres::PgRow; use sqlx::types::chrono::Utc; use sqlx::{FromRow, PgPool, Pool, Postgres, Row}; +use tiktoken_rs::cl100k_base; // schema for every job // also schema for the vectorize.vectorize_meta table @@ -27,46 +29,6 @@ pub struct VectorizeMeta { pub last_completion: Option>, } -// temporary struct for deserializing from db -// not needed when sqlx 0.7.x -#[derive(Clone, Debug, Deserialize, FromRow, Serialize, PostgresType)] -pub struct _VectorizeMeta { - pub job_id: i64, - pub name: String, - pub job_type: String, - pub transformer: String, - pub search_alg: String, - pub params: serde_json::Value, - #[serde(deserialize_with = "from_tsopt")] - pub last_completion: Option>, -} - -impl From<_VectorizeMeta> for VectorizeMeta { - fn from(val: _VectorizeMeta) -> Self { - VectorizeMeta { - job_id: val.job_id, - name: val.name, - job_type: types::JobType::from(val.job_type), - transformer: types::Transformer::from(val.transformer), - search_alg: types::SimilarityAlg::from(val.search_alg), - params: val.params, - last_completion: val.last_completion, - } - } -} - -#[derive(Clone, Deserialize, Debug, Serialize)] -pub struct ColumnJobParams { - pub schema: String, - pub table: String, - pub columns: Vec, - pub primary_key: String, - pub pkey_type: String, - pub update_time_col: String, - pub api_key: Option, - pub table_method: TableMethod, -} - // creates batches based on total token count // batch_size is the max token count per batch fn create_batches(data: Vec, batch_size: i32) -> Vec> { @@ -112,9 +74,7 @@ fn job_execute(job_name: String) { .build() .unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e)); - // TODO: move into a config - // 100k tokens per batch - let max_batch_size = 100000; + let max_batch_size = BATCH_SIZE.get(); runtime.block_on(async { let conn = get_pg_conn() @@ -126,7 +86,7 @@ fn job_execute(job_name: String) { let meta = get_vectorize_meta(&job_name, &conn) .await .unwrap_or_else(|e| error!("failed to get job metadata: {}", e)); - let job_params = serde_json::from_value::(meta.params.clone()) + let job_params = serde_json::from_value::(meta.params.clone()) .unwrap_or_else(|e| error!("failed to deserialize job params: {}", e)); let _last_completion = match meta.last_completion { Some(t) => t, @@ -151,8 +111,11 @@ fn job_execute(job_name: String) { job_meta: meta.clone(), inputs: b, }; + let queue_name = QUEUE_MAPPING + .get(&meta.transformer) + .expect("invalid transformer"); let msg_id = queue - .send(PGMQ_QUEUE_NAME, &msg) + .send(queue_name, &msg) .await .unwrap_or_else(|e| error!("failed to send message updates: {}", e)); log!("message sent: {}", msg_id); @@ -172,17 +135,17 @@ pub async fn get_vectorize_meta( ) -> Result { log!("fetching job: {}", job_name); let row = sqlx::query_as!( - _VectorizeMeta, + VectorizeMeta, " SELECT * - FROM vectorize.vectorize_meta + FROM vectorize.job WHERE name = $1 ", job_name.to_string(), ) .fetch_one(conn) .await?; - Ok(row.into()) + Ok(row) } #[derive(Clone, Debug, Deserialize, Serialize)] @@ -197,7 +160,7 @@ pub struct Inputs { pub async fn get_new_updates_append( pool: &Pool, job_name: &str, - job_params: ColumnJobParams, + job_params: types::JobParams, ) -> Result>, DatabaseError> { let cols = collapse_to_csv(&job_params.columns); @@ -225,10 +188,11 @@ pub async fn get_new_updates_append( match rows { Ok(rows) => { if !rows.is_empty() { + let bpe = cl100k_base().unwrap(); let mut new_inputs: Vec = Vec::new(); for r in rows { let ipt: String = r.get("input_text"); - let token_estimate = ipt.split_whitespace().count() as i32; + let token_estimate = bpe.encode_with_special_tokens(&ipt).len() as i32; new_inputs.push(Inputs { record_id: r.get("record_id"), inputs: ipt, @@ -249,7 +213,7 @@ pub async fn get_new_updates_append( // queries a table and returns rows that need new embeddings #[allow(dead_code)] pub async fn get_new_updates_shared( - job_params: ColumnJobParams, + job_params: types::JobParams, last_completion: chrono::DateTime, ) -> Result>, DatabaseError> { let pool = PgPool::connect(&from_env_default( @@ -280,9 +244,10 @@ pub async fn get_new_updates_shared( let rows: Result, Error> = sqlx::query(&new_rows_query).fetch_all(&pool).await; match rows { Ok(rows) => { + let bpe = cl100k_base().unwrap(); for r in rows { let ipt: String = r.get("input_text"); - let token_estimate = ipt.split_whitespace().count() as i32; + let token_estimate = bpe.encode_with_special_tokens(&ipt).len() as i32; new_inputs.push(Inputs { record_id: r.get("record_id"), inputs: ipt, diff --git a/src/guc.rs b/src/guc.rs new file mode 100644 index 0000000..31880ee --- /dev/null +++ b/src/guc.rs @@ -0,0 +1,72 @@ +use core::ffi::CStr; +use pgrx::*; + +use anyhow::Result; + +pub static VECTORIZE_HOST: GucSetting> = GucSetting::>::new(None); +pub static OPENAI_KEY: GucSetting> = GucSetting::>::new(None); +pub static BATCH_SIZE: GucSetting = GucSetting::::new(10000); + +// initialize GUCs +pub fn init_guc() { + GucRegistry::define_string_guc( + "vectorize.host", + "unix socket url for Postgres", + "unix socket path to the Postgres instance. Optional. Can also be set in environment variable.", + &VECTORIZE_HOST, + GucContext::Suset, GucFlags::default()); + + GucRegistry::define_string_guc( + "vectorize.openai_key", + "API key from OpenAI", + "API key from OpenAI. Optional. Overridden by any values provided in function calls.", + &OPENAI_KEY, + GucContext::Suset, + GucFlags::SUPERUSER_ONLY, + ); + + GucRegistry::define_int_guc( + "vectorize.batch_size", + "Vectorize job batch size", + "Number of records that can be included in a single vectorize job.", + &BATCH_SIZE, + 1, + 100000, + GucContext::Suset, + GucFlags::default(), + ); +} + +// for handling of GUCs that can be error prone +#[derive(Debug)] +pub enum VectorizeGuc { + Host, + OpenAIKey, +} + +/// a convenience function to get this project's GUCs +pub fn get_guc(guc: VectorizeGuc) -> Option { + let val = match guc { + VectorizeGuc::Host => VECTORIZE_HOST.get(), + VectorizeGuc::OpenAIKey => OPENAI_KEY.get(), + }; + if let Some(cstr) = val { + if let Ok(s) = handle_cstr(cstr) { + Some(s) + } else { + error!("failed to convert CStr to str"); + } + } else { + info!("no value set for GUC: {:?}", guc); + None + } +} + +#[allow(dead_code)] +fn handle_cstr(cstr: &CStr) -> Result { + if let Ok(s) = cstr.to_str() { + Ok(s.to_owned()) + } else { + Err(anyhow::anyhow!("failed to convert CStr to str")) + } +} diff --git a/src/init.rs b/src/init.rs index f320dee..5911c72 100644 --- a/src/init.rs +++ b/src/init.rs @@ -1,31 +1,29 @@ -use crate::{query::check_input, types}; +use crate::{query::check_input, types, types::TableMethod, types::Transformer}; use pgrx::prelude::*; -use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use anyhow::Result; +use lazy_static::lazy_static; -pub const PGMQ_QUEUE_NAME: &str = "vectorize_queue"; - -#[allow(non_camel_case_types)] -#[derive(Clone, Debug, Serialize, Deserialize, PostgresEnum)] -pub enum TableMethod { - // append a new column to the existing table - append, - // join existing table to a new table with embeddings - join, +lazy_static! { + // each model has its own job queue + // maintain the mapping of transformer to queue name here + pub static ref QUEUE_MAPPING: HashMap = { + let mut m = HashMap::new(); + m.insert(Transformer::openai, "v_openai"); + m.insert(Transformer::allMiniLML12v2, "v_all_MiniLM_L12_v2"); + m + }; } -pub fn init_pgmq() -> Result<()> { +pub fn init_pgmq(transformer: &Transformer) -> Result<()> { + let qname = QUEUE_MAPPING.get(transformer).expect("invalid transformer"); let ran: Result<_, spi::Error> = Spi::connect(|mut c| { - let _r = c.update( - &format!("SELECT pgmq.create('{PGMQ_QUEUE_NAME}');"), - None, - None, - )?; + let _r = c.update(&format!("SELECT pgmq.create('{qname}');"), None, None)?; Ok(()) }); if let Err(e) = ran { - error!("error creating embedding table: {}", e); + error!("error creating job queue: {}", e); } Ok(()) } @@ -46,7 +44,7 @@ pub fn init_cron(cron: &str, job_name: &str) -> Result, spi::Error> pub fn init_job_query() -> String { format!( " - INSERT INTO {schema}.vectorize_meta (name, job_type, transformer, search_alg, params) + INSERT INTO {schema}.job (name, job_type, transformer, search_alg, params) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (name) DO UPDATE SET job_type = EXCLUDED.job_type, @@ -78,6 +76,9 @@ pub fn init_embedding_table_query( // currently only supports the text-embedding-ada-002 embedding model - output dim 1536 // https://platform.openai.com/docs/guides/embeddings/what-are-embeddings (types::Transformer::openai, types::SimilarityAlg::pgv_cosine_similarity) => "vector(1536)", + (types::Transformer::allMiniLML12v2, types::SimilarityAlg::pgv_cosine_similarity) => { + "vector(384)" + } }; match transform_method { TableMethod::append => { diff --git a/src/lib.rs b/src/lib.rs index 6a93192..3c8c2fc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,13 +3,14 @@ use pgrx::prelude::*; mod api; mod errors; mod executor; +mod guc; mod init; -mod openai; mod query; mod search; +mod transformers; mod types; mod util; -mod worker; +mod workers; pgrx::pg_module_magic!(); diff --git a/src/transformers/mod.rs b/src/transformers/mod.rs new file mode 100644 index 0000000..d8c3087 --- /dev/null +++ b/src/transformers/mod.rs @@ -0,0 +1 @@ +pub mod openai; diff --git a/src/openai.rs b/src/transformers/openai.rs similarity index 94% rename from src/openai.rs rename to src/transformers/openai.rs index f8ee329..426fec2 100644 --- a/src/openai.rs +++ b/src/transformers/openai.rs @@ -1,19 +1,17 @@ use pgrx::prelude::*; use serde_json::json; -use crate::util::OPENAI_KEY; use anyhow::Result; use crate::{ - executor::{ColumnJobParams, Inputs}, - worker::PairedEmbeddings, + executor::Inputs, + guc::OPENAI_KEY, + types::{JobParams, PairedEmbeddings}, }; // max token length is 8192 // however, depending on content of text, token count can be higher than -// token count returned by split_whitespace() -// TODO: wrap openai toktoken's tokenizer to estimate token count? -pub const MAX_TOKEN_LEN: usize = 7500; +pub const MAX_TOKEN_LEN: usize = 8192; pub const OPENAI_EMBEDDING_RL: &str = "https://api.openai.com/v1/embeddings"; #[derive(serde::Deserialize, Debug)] @@ -36,6 +34,7 @@ pub fn trim_inputs(inputs: &[Inputs]) -> Vec { .iter() .map(|input| { if input.token_estimate as usize > MAX_TOKEN_LEN { + // not example taking tokens, but naive way to trim input let tokens: Vec<&str> = input.inputs.split_whitespace().collect(); tokens .into_iter() @@ -72,10 +71,7 @@ pub async fn openai_embeddings(inputs: &Vec, key: &str) -> Result Result>> { +pub async fn openai_transform(job_params: JobParams, inputs: &[Inputs]) -> Result>> { log!("pg-vectorize: OpenAI transformer"); // handle retrieval of API key. order of precedence: diff --git a/src/types.rs b/src/types.rs index e4daafb..c31ccd0 100644 --- a/src/types.rs +++ b/src/types.rs @@ -6,10 +6,10 @@ use std::str::FromStr; pub const VECTORIZE_SCHEMA: &str = "vectorize"; #[allow(non_camel_case_types)] -#[derive(Clone, Debug, Serialize, Deserialize, PostgresEnum)] +#[derive(Clone, Debug, Serialize, Deserialize, Eq, Hash, PartialEq, PostgresEnum)] pub enum Transformer { openai, - // bert, + allMiniLML12v2, } impl FromStr for Transformer { @@ -27,6 +27,7 @@ impl From for Transformer { fn from(s: String) -> Self { match s.as_str() { "openai" => Transformer::openai, + "all_MiniLM_L12_v2" => Transformer::allMiniLML12v2, _ => panic!("Invalid value for Transformer: {}", s), // or handle this case differently } } @@ -36,6 +37,7 @@ impl Display for Transformer { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { match self { Transformer::openai => write!(f, "openai"), + Transformer::allMiniLML12v2 => write!(f, "all_MiniLM_L12_v2"), } } } @@ -108,3 +110,29 @@ impl Display for JobType { } } } + +pub struct PairedEmbeddings { + pub primary_key: String, + pub embeddings: Vec, +} + +#[allow(non_camel_case_types)] +#[derive(Clone, Debug, Serialize, Deserialize, PostgresEnum)] +pub enum TableMethod { + // append a new column to the existing table + append, + // join existing table to a new table with embeddings + join, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct JobParams { + pub schema: String, + pub table: String, + pub columns: Vec, + pub update_time_col: String, + pub table_method: TableMethod, + pub primary_key: String, + pub pkey_type: String, + pub api_key: Option, +} diff --git a/src/util.rs b/src/util.rs index 6726cd0..215fe27 100644 --- a/src/util.rs +++ b/src/util.rs @@ -5,16 +5,8 @@ use std::env; use url::{ParseError, Url}; use anyhow::Result; -use core::ffi::CStr; -pub static VECTORIZE_HOST: GucSetting> = GucSetting::>::new(None); -pub static OPENAI_KEY: GucSetting> = GucSetting::>::new(None); - -#[derive(Debug)] -pub enum VectorizeGuc { - Host, - OpenAIKey, -} +use crate::guc; #[derive(Clone, Debug)] pub struct Config { @@ -68,37 +60,10 @@ pub fn from_env_default(key: &str, default: &str) -> String { env::var(key).unwrap_or_else(|_| default.to_owned()) } -/// a convenience function to get this project's GUCs -pub fn get_guc(guc: VectorizeGuc) -> Option { - let val = match guc { - VectorizeGuc::Host => VECTORIZE_HOST.get(), - VectorizeGuc::OpenAIKey => OPENAI_KEY.get(), - }; - if let Some(cstr) = val { - if let Ok(s) = handle_cstr(cstr) { - Some(s) - } else { - error!("failed to convert CStr to str"); - } - } else { - warning!("no value set for GU: {:?}", guc); - None - } -} - -#[allow(dead_code)] -fn handle_cstr(cstr: &CStr) -> Result { - if let Ok(s) = cstr.to_str() { - Ok(s.to_owned()) - } else { - Err(anyhow::anyhow!("failed to convert CStr to str")) - } -} - pub fn get_vectorize_meta_spi(job_name: &str) -> Option { let query = " SELECT params::jsonb - FROM vectorize.vectorize_meta + FROM vectorize.job WHERE name = $1 "; let resultset: Result, spi::Error> = Spi::get_one_with_args( @@ -115,7 +80,7 @@ pub fn get_vectorize_meta_spi(job_name: &str) -> Option { pub async fn get_pg_conn() -> Result> { let mut cfg = Config::default(); - if let Some(host) = get_guc(VectorizeGuc::Host) { + if let Some(host) = guc::get_guc(guc::VectorizeGuc::Host) { log!("Using socket url from GUC: {:?}", host); cfg.vectorize_socket_url = Some(host); }; diff --git a/src/worker.rs b/src/worker.rs deleted file mode 100644 index b4fa594..0000000 --- a/src/worker.rs +++ /dev/null @@ -1,243 +0,0 @@ -use crate::executor::{ColumnJobParams, JobMessage}; -use crate::init::{TableMethod, PGMQ_QUEUE_NAME}; -use crate::openai; -use crate::types; -use crate::util::{get_pg_conn, OPENAI_KEY, VECTORIZE_HOST}; -use anyhow::Result; -use pgmq::Message; -use pgrx::bgworkers::*; -use pgrx::*; -use sqlx::{Pool, Postgres}; -use std::time::Duration; - -// initialize GUCs -fn init_guc() { - GucRegistry::define_string_guc( - "vectorize.host", - "unix socket url for Postgres", - "unix socket path to the Postgres instance. Optional. Can also be set in environment variable.", - &VECTORIZE_HOST, - GucContext::Suset, GucFlags::default()); - - GucRegistry::define_string_guc( - "vectorize.openai_key", - "API key from OpenAI", - "API key from OpenAI. Optional. Overridden by any values provided in function calls.", - &OPENAI_KEY, - GucContext::Suset, - GucFlags::SUPERUSER_ONLY, - ); -} - -#[pg_guard] -pub extern "C" fn _PG_init() { - init_guc(); - BackgroundWorkerBuilder::new("PG Vectorize Background Worker") - .set_function("background_worker_main") - .set_library("vectorize") - .enable_spi_access() - .load(); -} - -#[pg_guard] -#[no_mangle] -pub extern "C" fn background_worker_main(_arg: pg_sys::Datum) { - BackgroundWorker::attach_signal_handlers(SignalWakeFlags::SIGHUP | SignalWakeFlags::SIGTERM); - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .unwrap(); - - // specify database - let (conn, queue) = runtime.block_on(async { - let con = get_pg_conn().await.expect("failed to connect to database"); - let queue = pgmq::PGMQueueExt::new_with_pool(con.clone()) - .await - .expect("failed to init db connection"); - (con, queue) - }); - - log!("Starting BG Workers {}", BackgroundWorker::get_name(),); - - while BackgroundWorker::wait_latch(Some(Duration::from_secs(5))) { - if BackgroundWorker::sighup_received() { - // on SIGHUP, you might want to reload configurations and env vars - } - let _: Result<()> = runtime.block_on(async { - let msg: Message = match queue.pop::(PGMQ_QUEUE_NAME).await { - Ok(Some(msg)) => msg, - Ok(None) => { - log!("pg-vectorize: No messages in queue"); - return Ok(()); - } - Err(e) => { - warning!("pg-vectorize: Error reading message: {e}"); - return Ok(()); - } - }; - - let msg_id = msg.msg_id; - let read_ct = msg.read_ct; - log!( - "pg-vectorize: received message for job: {:?}", - msg.message.job_name - ); - let job_success = execute_job(conn.clone(), msg).await; - let delete_it = if job_success.is_ok() { - true - } else { - read_ct > 2 - }; - - // delete message from queue - if delete_it { - match queue.delete(PGMQ_QUEUE_NAME, msg_id).await { - Ok(_) => { - log!("pg-vectorize: deleted message: {}", msg_id); - } - Err(e) => { - warning!("pg-vectorize: Error deleting message: {}", e); - } - } - } - // TODO: update job meta updated_timestamp - Ok(()) - }); - } - log!("pg-vectorize: shutting down"); -} - -pub struct PairedEmbeddings { - pub primary_key: String, - pub embeddings: Vec, -} - -async fn upsert_embedding_table( - conn: &Pool, - schema: &str, - project: &str, - embeddings: Vec, -) -> Result<()> { - let (query, bindings) = build_upsert_query(schema, project, embeddings); - let mut q = sqlx::query(&query); - for (record_id, embeddings) in bindings { - q = q.bind(record_id).bind(embeddings); - } - match q.execute(conn).await { - Ok(_) => Ok(()), - Err(e) => { - log!("Error: {}", e); - Err(anyhow::anyhow!("failed to execute query")) - } - } -} - -// returns query and bindings -// only compatible with pg-vector data types -fn build_upsert_query( - schema: &str, - project: &str, - embeddings: Vec, -) -> (String, Vec<(String, String)>) { - let mut query = format!( - " - INSERT INTO {schema}.{project}_embeddings (record_id, embeddings) VALUES" - ); - let mut bindings: Vec<(String, String)> = Vec::new(); - - for (index, pair) in embeddings.into_iter().enumerate() { - if index > 0 { - query.push(','); - } - query.push_str(&format!( - " (${}, ${}::vector)", - 2 * index + 1, - 2 * index + 2 - )); - - let embedding = - serde_json::to_string(&pair.embeddings).expect("failed to serialize embedding"); - bindings.push((pair.primary_key, embedding)); - } - - query.push_str(" ON CONFLICT (record_id) DO UPDATE SET embeddings = EXCLUDED.embeddings"); - (query, bindings) -} - -use serde_json::to_string; - -async fn update_append_table( - pool: &Pool, - embeddings: Vec, - schema: &str, - table: &str, - project: &str, - pkey: &str, - pkey_type: &str, -) -> anyhow::Result<()> { - for embed in embeddings { - // Serialize the Vec to a JSON string - let embedding = to_string(&embed.embeddings).expect("failed to serialize embedding"); - - // TODO: pkey might not always be integer type - let update_query = format!( - " - UPDATE {schema}.{table} - SET - {project}_embeddings = $1::vector, - {project}_updated_at = (NOW() at time zone 'utc') - WHERE {pkey} = $2::{pkey_type} - " - ); - // Prepare and execute the update statement for this pair within the transaction - sqlx::query(&update_query) - .bind(embedding) - .bind(embed.primary_key) - .execute(pool) - .await?; - } - Ok(()) -} - -async fn execute_job(dbclient: Pool, msg: Message) -> Result<()> { - let job_meta = msg.message.job_meta; - let job_params: ColumnJobParams = serde_json::from_value(job_meta.params)?; - let embeddings: Result> = match job_meta.transformer { - types::Transformer::openai => { - log!("pg-vectorize: OpenAI transformer"); - - let embeddings = - openai::openai_transform(job_params.clone(), &msg.message.inputs).await?; - // TODO: validate returned embeddings order is same as the input order - let emb: Vec = - openai::merge_input_output(msg.message.inputs, embeddings); - Ok(emb) - } - }; - // write embeddings to result table - match job_params.table_method { - TableMethod::append => { - update_append_table( - &dbclient, - embeddings.expect("failed to get embeddings"), - &job_params.schema, - &job_params.table, - &job_meta.name, - &job_params.primary_key, - &job_params.pkey_type, - ) - .await?; - } - TableMethod::join => { - upsert_embedding_table( - &dbclient, - &job_params.schema, - &job_meta.name, - embeddings.expect("failed to get embeddings"), - ) - .await? - } - }; - Ok(()) -} diff --git a/src/workers/mod.rs b/src/workers/mod.rs new file mode 100644 index 0000000..2145e9d --- /dev/null +++ b/src/workers/mod.rs @@ -0,0 +1,182 @@ +pub mod pg_bgw; + +use crate::executor::JobMessage; +use crate::transformers::openai; +use crate::types; +use anyhow::Result; +use pgmq::{Message, PGMQueueExt}; +use pgrx::*; +use sqlx::{Pool, Postgres}; + +pub async fn run_worker(queue: PGMQueueExt, conn: Pool, queue_name: &str) -> Result<()> { + let msg: Message = match queue.pop::(queue_name).await { + Ok(Some(msg)) => msg, + Ok(None) => { + log!("pg-vectorize: No messages in queue"); + return Ok(()); + } + Err(e) => { + warning!("pg-vectorize: Error reading message: {e}"); + return Ok(()); + } + }; + + let msg_id = msg.msg_id; + let read_ct = msg.read_ct; + log!( + "pg-vectorize: received message for job: {:?}", + msg.message.job_name + ); + let job_success = execute_job(conn.clone(), msg).await; + let delete_it = if job_success.is_ok() { + true + } else { + read_ct > 2 + }; + + // delete message from queue + if delete_it { + match queue.delete(queue_name, msg_id).await { + Ok(_) => { + log!("pg-vectorize: deleted message: {}", msg_id); + } + Err(e) => { + warning!("pg-vectorize: Error deleting message: {}", e); + } + } + } + Ok(()) +} + +async fn upsert_embedding_table( + conn: &Pool, + schema: &str, + project: &str, + embeddings: Vec, +) -> Result<()> { + let (query, bindings) = build_upsert_query(schema, project, embeddings); + let mut q = sqlx::query(&query); + for (record_id, embeddings) in bindings { + q = q.bind(record_id).bind(embeddings); + } + match q.execute(conn).await { + Ok(_) => Ok(()), + Err(e) => { + log!("Error: {}", e); + Err(anyhow::anyhow!("failed to execute query")) + } + } +} + +// returns query and bindings +// only compatible with pg-vector data types +fn build_upsert_query( + schema: &str, + project: &str, + embeddings: Vec, +) -> (String, Vec<(String, String)>) { + let mut query = format!( + " + INSERT INTO {schema}.{project}_embeddings (record_id, embeddings) VALUES" + ); + let mut bindings: Vec<(String, String)> = Vec::new(); + + for (index, pair) in embeddings.into_iter().enumerate() { + if index > 0 { + query.push(','); + } + query.push_str(&format!( + " (${}, ${}::vector)", + 2 * index + 1, + 2 * index + 2 + )); + + let embedding = + serde_json::to_string(&pair.embeddings).expect("failed to serialize embedding"); + bindings.push((pair.primary_key, embedding)); + } + + query.push_str(" ON CONFLICT (record_id) DO UPDATE SET embeddings = EXCLUDED.embeddings"); + (query, bindings) +} + +use serde_json::to_string; + +async fn update_append_table( + pool: &Pool, + embeddings: Vec, + schema: &str, + table: &str, + project: &str, + pkey: &str, + pkey_type: &str, +) -> anyhow::Result<()> { + for embed in embeddings { + // Serialize the Vec to a JSON string + let embedding = to_string(&embed.embeddings).expect("failed to serialize embedding"); + + // TODO: pkey might not always be integer type + let update_query = format!( + " + UPDATE {schema}.{table} + SET + {project}_embeddings = $1::vector, + {project}_updated_at = (NOW() at time zone 'utc') + WHERE {pkey} = $2::{pkey_type} + " + ); + // Prepare and execute the update statement for this pair within the transaction + sqlx::query(&update_query) + .bind(embedding) + .bind(embed.primary_key) + .execute(pool) + .await?; + } + Ok(()) +} + +async fn execute_job(dbclient: Pool, msg: Message) -> Result<()> { + let job_meta = msg.message.job_meta; + let job_params: types::JobParams = serde_json::from_value(job_meta.params)?; + let embeddings: Result> = match job_meta.transformer { + types::Transformer::openai => { + log!("pg-vectorize: OpenAI transformer"); + + let embeddings = + openai::openai_transform(job_params.clone(), &msg.message.inputs).await?; + // TODO: validate returned embeddings order is same as the input order + let emb: Vec = + openai::merge_input_output(msg.message.inputs, embeddings); + Ok(emb) + } + types::Transformer::allMiniLML12v2 => { + log!("pg-vectorize: allMiniLML12v2 transformer"); + todo!() + } + }; + // write embeddings to result table + match job_params.table_method { + types::TableMethod::append => { + update_append_table( + &dbclient, + embeddings.expect("failed to get embeddings"), + &job_params.schema, + &job_params.table, + &job_meta.name, + &job_params.primary_key, + &job_params.pkey_type, + ) + .await?; + } + types::TableMethod::join => { + upsert_embedding_table( + &dbclient, + &job_params.schema, + &job_meta.name, + embeddings.expect("failed to get embeddings"), + ) + .await? + } + }; + Ok(()) +} diff --git a/src/workers/pg_bgw.rs b/src/workers/pg_bgw.rs new file mode 100644 index 0000000..a058498 --- /dev/null +++ b/src/workers/pg_bgw.rs @@ -0,0 +1,55 @@ +use crate::guc::init_guc; +use crate::init::QUEUE_MAPPING; +use crate::types::Transformer; +use crate::util::get_pg_conn; +use anyhow::Result; +use pgrx::bgworkers::*; +use pgrx::*; +use std::time::Duration; + +use crate::workers::run_worker; + +#[pg_guard] +pub extern "C" fn _PG_init() { + init_guc(); + BackgroundWorkerBuilder::new("PG Vectorize Background Worker") + .set_function("background_worker_main") + .set_library("vectorize") + .enable_spi_access() + .load(); +} + +#[pg_guard] +#[no_mangle] +pub extern "C" fn background_worker_main(_arg: pg_sys::Datum) { + BackgroundWorker::attach_signal_handlers(SignalWakeFlags::SIGHUP | SignalWakeFlags::SIGTERM); + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .unwrap(); + + // specify database + let (conn, queue) = runtime.block_on(async { + let con = get_pg_conn().await.expect("failed to connect to database"); + let queue = pgmq::PGMQueueExt::new_with_pool(con.clone()) + .await + .expect("failed to init db connection"); + (con, queue) + }); + + log!("Starting BG Workers {}", BackgroundWorker::get_name(),); + + // bgw only supports the OpenAI transformer case + let queue_name = QUEUE_MAPPING + .get(&Transformer::openai) + .expect("invalid transformer"); + while BackgroundWorker::wait_latch(Some(Duration::from_secs(5))) { + if BackgroundWorker::sighup_received() { + // on SIGHUP, you might want to reload configurations and env vars + } + let _: Result<()> = + runtime.block_on(async { run_worker(queue.clone(), conn.clone(), queue_name).await }); + } + log!("pg-vectorize: shutting down"); +}