diff --git a/Makefile b/Makefile index 9a99f75..a91d8bc 100644 --- a/Makefile +++ b/Makefile @@ -64,7 +64,7 @@ install-pgmq: test-integration: echo "\q" | make run - cargo test -- --ignored --test-threads=1 + cargo test -- --ignored --test-threads=1 --nocapture test-unit: cargo pgrx test diff --git a/src/job.rs b/src/job.rs index 99b711d..4c2e238 100644 --- a/src/job.rs +++ b/src/job.rs @@ -1,11 +1,13 @@ use anyhow::Result; +use crate::executor::{create_batches, new_rows_query, JobMessage, VectorizeMeta}; +use crate::guc::BATCH_SIZE; use crate::init::VECTORIZE_QUEUE; -use pgrx::prelude::*; - -use crate::executor::{JobMessage, VectorizeMeta}; use crate::transformers::types::Inputs; +use crate::types::{self, JobParams, JobType}; use crate::util; + +use pgrx::prelude::*; use tiktoken_rs::cl100k_base; /// called by the trigger function when a table is updated @@ -120,6 +122,70 @@ fn generate_input_concat(inputs: &[String]) -> String { .join(" || ' ' || ") } +// creates batches of embedding jobs +// typically used on table init +pub fn initalize_table_job( + job_name: &str, + job_params: &JobParams, + job_type: &JobType, + transformer: &str, + search_alg: types::SimilarityAlg, +) -> Result<()> { + // start with initial batch load + let rows_need_update_query: String = new_rows_query(job_name, job_params); + let mut inputs: Vec = Vec::new(); + let bpe = cl100k_base().unwrap(); + let _: Result<_, spi::Error> = Spi::connect(|c| { + let rows = c.select(&rows_need_update_query, None, None)?; + for row in rows { + let ipt = row["input_text"] + .value::()? + .expect("input_text is null"); + let token_estimate = bpe.encode_with_special_tokens(&ipt).len() as i32; + inputs.push(Inputs { + record_id: row["record_id"] + .value::()? + .expect("record_id is null"), + inputs: ipt, + token_estimate, + }); + } + Ok(()) + }); + + let max_batch_size = BATCH_SIZE.get(); + let batches = create_batches(inputs, max_batch_size); + let vectorize_meta = VectorizeMeta { + name: job_name.to_string(), + // TODO: in future, lookup job id once this gets put into use + // job_id is currently not used, job_name is unique + job_id: 0, + job_type: job_type.clone(), + params: serde_json::to_value(job_params.clone()).unwrap(), + transformer: transformer.to_string(), + search_alg: search_alg.clone(), + last_completion: None, + }; + for b in batches { + let job_message = JobMessage { + job_name: job_name.to_string(), + job_meta: vectorize_meta.clone(), + inputs: b, + }; + let query = format!( + "select pgmq.send('{VECTORIZE_QUEUE}', '{}');", + serde_json::to_string(&job_message) + .unwrap() + .replace('\'', "''") + ); + let _ran: Result<_, spi::Error> = Spi::connect(|mut c| { + let _r = c.update(&query, None, None)?; + Ok(()) + }); + } + Ok(()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/search.rs b/src/search.rs index 86de9c2..6fddffc 100644 --- a/src/search.rs +++ b/src/search.rs @@ -1,17 +1,17 @@ -use crate::executor::{create_batches, new_rows_query, JobMessage, VectorizeMeta}; -use crate::guc::{self, BATCH_SIZE}; -use crate::init::{self, VECTORIZE_QUEUE}; -use crate::job::{create_insert_trigger, create_trigger_handler, create_update_trigger}; +use crate::executor::VectorizeMeta; +use crate::guc; +use crate::init; +use crate::job::{ + create_insert_trigger, create_trigger_handler, create_update_trigger, initalize_table_job, +}; use crate::transformers::http_handler::sync_get_model_info; use crate::transformers::openai; use crate::transformers::transform; -use crate::transformers::types::Inputs; use crate::types; use crate::util; use anyhow::Result; use pgrx::prelude::*; -use tiktoken_rs::cl100k_base; #[allow(clippy::too_many_arguments)] pub fn init_table( @@ -146,58 +146,6 @@ pub fn init_table( let _r = c.update(&update_trigger, None, None)?; Ok(()) }); - - // start with initial batch load - let rows_need_update_query: String = new_rows_query(job_name, &valid_params); - let mut inputs: Vec = Vec::new(); - let bpe = cl100k_base().unwrap(); - let _: Result<_, spi::Error> = Spi::connect(|c| { - let rows = c.select(&rows_need_update_query, None, None)?; - for row in rows { - let ipt = row["input_text"] - .value::()? - .expect("input_text is null"); - let token_estimate = bpe.encode_with_special_tokens(&ipt).len() as i32; - inputs.push(Inputs { - record_id: row["record_id"] - .value::()? - .expect("record_id is null"), - inputs: ipt, - token_estimate, - }); - } - Ok(()) - }); - let max_batch_size = BATCH_SIZE.get(); - let batches = create_batches(inputs, max_batch_size); - let vectorize_meta = VectorizeMeta { - name: job_name.to_string(), - // TODO: in future, lookup job id once this gets put into use - // job_id is currently not used, job_name is unique - job_id: 0, - job_type: job_type.clone(), - params: serde_json::to_value(valid_params.clone()).unwrap(), - transformer: transformer.to_string(), - search_alg: search_alg.clone(), - last_completion: None, - }; - for b in batches { - let job_message = JobMessage { - job_name: job_name.to_string(), - job_meta: vectorize_meta.clone(), - inputs: b, - }; - let query = format!( - "select pgmq.send('{VECTORIZE_QUEUE}', '{}');", - serde_json::to_string(&job_message) - .unwrap() - .replace('\'', "''") - ); - let _ran: Result<_, spi::Error> = Spi::connect(|mut c| { - let _r = c.update(&query, None, None)?; - Ok(()) - }); - } } _ => { // initialize cron @@ -205,6 +153,8 @@ pub fn init_table( log!("Initialized cron job"); } } + // start with initial batch load + initalize_table_job(job_name, &valid_params, &job_type, transformer, search_alg)?; Ok(format!("Successfully created job: {job_name}")) } diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 0a057df..e3f6cbf 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -30,13 +30,7 @@ async fn test_scheduled_job() { .await .expect("failed to init job"); - // manually trigger a job - let _ = sqlx::query(&format!("SELECT vectorize.job_execute('{job_name}');")) - .execute(&conn) - .await - .expect("failed to select from test_table"); - - // should 1 job in the queue + // should be exactly 1 job in the queue let rowcount = common::row_count(&format!("pgmq.q_vectorize_jobs"), &conn).await; assert!(rowcount >= 1);