From a523a3a12ac45ebedd3b5a7072313baa70a1a5dc Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Mon, 26 Feb 2024 08:52:28 -0600 Subject: [PATCH] optimization for bulk update (#58) * configure multiple bgw * optimization for bulk load --- src/guc.rs | 12 ++++ src/transformers/http_handler.rs | 1 - src/workers/mod.rs | 95 ++++++++++++++++++++++++++++++-- src/workers/pg_bgw.rs | 18 ++++-- 4 files changed, 115 insertions(+), 11 deletions(-) diff --git a/src/guc.rs b/src/guc.rs index b4f9e87..8368192 100644 --- a/src/guc.rs +++ b/src/guc.rs @@ -6,6 +6,7 @@ use anyhow::Result; pub static VECTORIZE_HOST: GucSetting> = GucSetting::>::new(None); pub static OPENAI_KEY: GucSetting> = GucSetting::>::new(None); pub static BATCH_SIZE: GucSetting = GucSetting::::new(10000); +pub static NUM_BGW_PROC: GucSetting = GucSetting::::new(1); pub static EMBEDDING_SERVICE_HOST: GucSetting> = GucSetting::>::new(None); @@ -44,6 +45,17 @@ pub fn init_guc() { "Url to a service with request and response schema consistent with OpenAI's embeddings API.", &EMBEDDING_SERVICE_HOST, GucContext::Suset, GucFlags::default()); + + GucRegistry::define_int_guc( + "vectorize.num_bgw_proc", + "Number of bgw processes", + "Number of parallel background worker processes to run. Default is 1.", + &NUM_BGW_PROC, + 1, + 10, + GucContext::Suset, + GucFlags::default(), + ); } // for handling of GUCs that can be error prone diff --git a/src/transformers/http_handler.rs b/src/transformers/http_handler.rs index dacfdce..465b442 100644 --- a/src/transformers/http_handler.rs +++ b/src/transformers/http_handler.rs @@ -41,7 +41,6 @@ pub async fn openai_embedding_request(request: EmbeddingRequest) -> Result(resp, "embeddings").await?; - let embeddings = embedding_resp .data .iter() diff --git a/src/workers/mod.rs b/src/workers/mod.rs index a2eee0a..55753a5 100644 --- a/src/workers/mod.rs +++ b/src/workers/mod.rs @@ -3,10 +3,13 @@ pub mod pg_bgw; use crate::executor::JobMessage; use crate::transformers::{generic, http_handler, openai, types::PairedEmbeddings}; use crate::types; + use anyhow::Result; use pgmq::{Message, PGMQueueExt}; use pgrx::*; +use serde_json::to_string; use sqlx::{Pool, Postgres}; +use std::fmt::Write; pub async fn run_worker( queue: PGMQueueExt, @@ -106,7 +109,90 @@ fn build_upsert_query( (query, bindings) } -use serde_json::to_string; +async fn update_embeddings( + pool: &Pool, + schema: &str, + table: &str, + project: &str, + pkey: &str, + pkey_type: &str, + embeddings: Vec, +) -> anyhow::Result<()> { + if embeddings.len() > 10 { + bulk_update_embeddings(pool, schema, table, project, pkey, pkey_type, embeddings).await + } else { + update_append_table(pool, embeddings, schema, table, project, pkey, pkey_type).await + } +} + +// creates a temporary table, inserts all new values into the temporary table, and then performs an update by join +async fn bulk_update_embeddings( + pool: &Pool, + schema: &str, + table: &str, + project: &str, + pkey: &str, + pkey_type: &str, + embeddings: Vec, +) -> anyhow::Result<()> { + let mut tx = pool.begin().await?; + + let tmp_table = format!("temp_embeddings_{project}"); + + let temp_table_query = format!( + "CREATE TEMP TABLE IF NOT EXISTS {tmp_table} ( + pkey {pkey_type} PRIMARY KEY, + embeddings vector + ) ON COMMIT DROP;", // note, dropping on commit + ); + + sqlx::query(&temp_table_query).execute(&mut *tx).await?; + + // insert all new values into the temporary table + let mut insert_query = format!("INSERT INTO {tmp_table} (pkey, embeddings) VALUES "); + let mut params: Vec<(String, String)> = Vec::new(); + + for embed in &embeddings { + let embedding_json = to_string(&embed.embeddings).expect("failed to serialize embedding"); + params.push((embed.primary_key.to_string(), embedding_json)); + } + + // Constructing query values part and collecting bind parameters + for (i, (_pkey, _embedding)) in params.iter().enumerate() { + if i > 0 { + insert_query.push_str(", "); + } + write!( + &mut insert_query, + "(${}::{}, ${}::vector)", + i * 2 + 1, + pkey_type, + i * 2 + 2 + ) + .expect("Failed to write to query string"); + } + + let mut insert_statement = sqlx::query(&insert_query); + + for (pkey, embedding) in params { + insert_statement = insert_statement.bind(pkey).bind(embedding); + } + // insert to the temp table + insert_statement.execute(&mut *tx).await?; + + let update_query = format!( + "UPDATE {schema}.{table} SET + {project}_embeddings = temp.embeddings, + {project}_updated_at = (NOW()) + FROM {tmp_table} temp + WHERE {schema}.{table}.{pkey}::{pkey_type} = temp.pkey::{pkey_type};" + ); + + sqlx::query(&update_query).execute(&mut *tx).await?; + tx.commit().await?; + + Ok(()) +} async fn update_append_table( pool: &Pool, @@ -127,7 +213,7 @@ async fn update_append_table( UPDATE {schema}.{table} SET {project}_embeddings = $1::vector, - {project}_updated_at = (NOW() at time zone 'utc') + {project}_updated_at = (NOW()) WHERE {pkey} = $2::{pkey_type} " ); @@ -158,17 +244,18 @@ async fn execute_job(dbclient: Pool, msg: Message) -> Resu let paired_embeddings: Vec = http_handler::merge_input_output(msg.message.inputs, embeddings); + log!("pg-vectorize: embeddings size: {}", paired_embeddings.len()); // write embeddings to result table match job_params.clone().table_method { types::TableMethod::append => { - update_append_table( + update_embeddings( &dbclient, - paired_embeddings, &job_params.schema, &job_params.table, &job_meta.clone().name, &job_params.primary_key, &job_params.pkey_type, + paired_embeddings, ) .await?; } diff --git a/src/workers/pg_bgw.rs b/src/workers/pg_bgw.rs index 9e657bf..a78e1a6 100644 --- a/src/workers/pg_bgw.rs +++ b/src/workers/pg_bgw.rs @@ -1,4 +1,4 @@ -use crate::guc::init_guc; +use crate::guc::{init_guc, NUM_BGW_PROC}; use crate::init::VECTORIZE_QUEUE; use crate::util::{get_pg_conn, ready}; use anyhow::Result; @@ -11,11 +11,17 @@ use crate::workers::run_worker; #[pg_guard] pub extern "C" fn _PG_init() { init_guc(); - BackgroundWorkerBuilder::new("PG Vectorize Background Worker") - .set_function("background_worker_main") - .set_library("vectorize") - .enable_spi_access() - .load(); + + let num_bgw = NUM_BGW_PROC.get(); + for i in 0..num_bgw { + log!("pg-vectorize: starting background worker {}", i); + let bginst = format!("pg-vectorize-bgw-{}", i); + BackgroundWorkerBuilder::new(&bginst) + .set_function("background_worker_main") + .set_library("vectorize") + .enable_spi_access() + .load(); + } } #[pg_guard]