Skip to content

Commit

Permalink
trigger initial job on all schedule methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ChuckHend committed Mar 1, 2024
1 parent 073eb62 commit 1bca960
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 69 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ install-pgmq:

test-integration:
echo "\q" | make run
cargo test -- --ignored --test-threads=1
cargo test -- --ignored --test-threads=1 --nocapture

test-unit:
cargo pgrx test
Expand Down
72 changes: 69 additions & 3 deletions src/job.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use anyhow::Result;

use crate::executor::{create_batches, new_rows_query, JobMessage, VectorizeMeta};
use crate::guc::BATCH_SIZE;
use crate::init::VECTORIZE_QUEUE;
use pgrx::prelude::*;

use crate::executor::{JobMessage, VectorizeMeta};
use crate::transformers::types::Inputs;
use crate::types::{self, JobParams, JobType};
use crate::util;

use pgrx::prelude::*;
use tiktoken_rs::cl100k_base;

/// called by the trigger function when a table is updated
Expand Down Expand Up @@ -120,6 +122,70 @@ fn generate_input_concat(inputs: &[String]) -> String {
.join(" || ' ' || ")
}

// creates batches of embedding jobs
// typically used on table init
pub fn initalize_table_job(
job_name: &str,
job_params: &JobParams,
job_type: &JobType,
transformer: &str,
search_alg: types::SimilarityAlg,
) -> Result<()> {
// start with initial batch load
let rows_need_update_query: String = new_rows_query(job_name, job_params);
let mut inputs: Vec<Inputs> = Vec::new();
let bpe = cl100k_base().unwrap();
let _: Result<_, spi::Error> = Spi::connect(|c| {
let rows = c.select(&rows_need_update_query, None, None)?;
for row in rows {
let ipt = row["input_text"]
.value::<String>()?
.expect("input_text is null");
let token_estimate = bpe.encode_with_special_tokens(&ipt).len() as i32;
inputs.push(Inputs {
record_id: row["record_id"]
.value::<String>()?
.expect("record_id is null"),
inputs: ipt,
token_estimate,
});
}
Ok(())
});

let max_batch_size = BATCH_SIZE.get();
let batches = create_batches(inputs, max_batch_size);
let vectorize_meta = VectorizeMeta {
name: job_name.to_string(),
// TODO: in future, lookup job id once this gets put into use
// job_id is currently not used, job_name is unique
job_id: 0,
job_type: job_type.clone(),
params: serde_json::to_value(job_params.clone()).unwrap(),
transformer: transformer.to_string(),
search_alg: search_alg.clone(),
last_completion: None,
};
for b in batches {
let job_message = JobMessage {
job_name: job_name.to_string(),
job_meta: vectorize_meta.clone(),
inputs: b,
};
let query = format!(
"select pgmq.send('{VECTORIZE_QUEUE}', '{}');",
serde_json::to_string(&job_message)
.unwrap()
.replace('\'', "''")
);
let _ran: Result<_, spi::Error> = Spi::connect(|mut c| {
let _r = c.update(&query, None, None)?;
Ok(())
});
}
Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
66 changes: 8 additions & 58 deletions src/search.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
use crate::executor::{create_batches, new_rows_query, JobMessage, VectorizeMeta};
use crate::guc::{self, BATCH_SIZE};
use crate::init::{self, VECTORIZE_QUEUE};
use crate::job::{create_insert_trigger, create_trigger_handler, create_update_trigger};
use crate::executor::VectorizeMeta;
use crate::guc;
use crate::init;
use crate::job::{
create_insert_trigger, create_trigger_handler, create_update_trigger, initalize_table_job,
};
use crate::transformers::http_handler::sync_get_model_info;
use crate::transformers::openai;
use crate::transformers::transform;
use crate::transformers::types::Inputs;
use crate::types;
use crate::util;

use anyhow::Result;
use pgrx::prelude::*;
use tiktoken_rs::cl100k_base;

#[allow(clippy::too_many_arguments)]
pub fn init_table(
Expand Down Expand Up @@ -146,65 +146,15 @@ pub fn init_table(
let _r = c.update(&update_trigger, None, None)?;
Ok(())
});

// start with initial batch load
let rows_need_update_query: String = new_rows_query(job_name, &valid_params);
let mut inputs: Vec<Inputs> = Vec::new();
let bpe = cl100k_base().unwrap();
let _: Result<_, spi::Error> = Spi::connect(|c| {
let rows = c.select(&rows_need_update_query, None, None)?;
for row in rows {
let ipt = row["input_text"]
.value::<String>()?
.expect("input_text is null");
let token_estimate = bpe.encode_with_special_tokens(&ipt).len() as i32;
inputs.push(Inputs {
record_id: row["record_id"]
.value::<String>()?
.expect("record_id is null"),
inputs: ipt,
token_estimate,
});
}
Ok(())
});
let max_batch_size = BATCH_SIZE.get();
let batches = create_batches(inputs, max_batch_size);
let vectorize_meta = VectorizeMeta {
name: job_name.to_string(),
// TODO: in future, lookup job id once this gets put into use
// job_id is currently not used, job_name is unique
job_id: 0,
job_type: job_type.clone(),
params: serde_json::to_value(valid_params.clone()).unwrap(),
transformer: transformer.to_string(),
search_alg: search_alg.clone(),
last_completion: None,
};
for b in batches {
let job_message = JobMessage {
job_name: job_name.to_string(),
job_meta: vectorize_meta.clone(),
inputs: b,
};
let query = format!(
"select pgmq.send('{VECTORIZE_QUEUE}', '{}');",
serde_json::to_string(&job_message)
.unwrap()
.replace('\'', "''")
);
let _ran: Result<_, spi::Error> = Spi::connect(|mut c| {
let _r = c.update(&query, None, None)?;
Ok(())
});
}
}
_ => {
// initialize cron
init::init_cron(schedule, job_name)?;
log!("Initialized cron job");
}
}
// start with initial batch load
initalize_table_job(job_name, &valid_params, &job_type, transformer, search_alg)?;
Ok(format!("Successfully created job: {job_name}"))
}

Expand Down
8 changes: 1 addition & 7 deletions tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,7 @@ async fn test_scheduled_job() {
.await
.expect("failed to init job");

// manually trigger a job
let _ = sqlx::query(&format!("SELECT vectorize.job_execute('{job_name}');"))
.execute(&conn)
.await
.expect("failed to select from test_table");

// should 1 job in the queue
// should be exactly 1 job in the queue
let rowcount = common::row_count(&format!("pgmq.q_vectorize_jobs"), &conn).await;
assert!(rowcount >= 1);

Expand Down

0 comments on commit 1bca960

Please sign in to comment.