diff --git a/src/guc.rs b/src/guc.rs index 8368192..5d24e86 100644 --- a/src/guc.rs +++ b/src/guc.rs @@ -9,6 +9,7 @@ pub static BATCH_SIZE: GucSetting = GucSetting::::new(10000); pub static NUM_BGW_PROC: GucSetting = GucSetting::::new(1); pub static EMBEDDING_SERVICE_HOST: GucSetting> = GucSetting::>::new(None); +pub static EMBEDDING_REQ_TIMEOUT_SEC: GucSetting = GucSetting::::new(120); // initialize GUCs pub fn init_guc() { @@ -56,6 +57,17 @@ pub fn init_guc() { GucContext::Suset, GucFlags::default(), ); + + GucRegistry::define_int_guc( + "vectorize.embedding_req_timeout_sec", + "Timeout, in seconds, for embedding transform requests", + "Number of seconds to wait for an embedding http request to complete. Default is 120 seconds.", + &EMBEDDING_REQ_TIMEOUT_SEC, + 1, + 1800, + GucContext::Suset, + GucFlags::default(), + ); } // for handling of GUCs that can be error prone diff --git a/src/search.rs b/src/search.rs index d00375d..86de9c2 100644 --- a/src/search.rs +++ b/src/search.rs @@ -230,7 +230,14 @@ pub fn search( let schema = proj_params.schema; let table = proj_params.table; - let embeddings = transform(query, &project_meta.transformer, api_key); + let proj_api_key = match api_key { + // if api passed in the function call, use that + Some(k) => Some(k), + // if not, use the one from the project metadata + None => proj_params.api_key, + }; + + let embeddings = transform(query, &project_meta.transformer, proj_api_key); match project_meta.search_alg { types::SimilarityAlg::pgv_cosine_similarity => cosine_similarity_search( diff --git a/src/transformers/generic.rs b/src/transformers/generic.rs index 3c5e1fe..3175b7c 100644 --- a/src/transformers/generic.rs +++ b/src/transformers/generic.rs @@ -6,6 +6,7 @@ use crate::{ executor::VectorizeMeta, guc, transformers::types::{EmbeddingPayload, EmbeddingRequest, Inputs}, + types, }; use super::openai::trim_inputs; @@ -63,12 +64,14 @@ pub fn prepare_generic_embedding_request( model: job_meta.transformer.to_string(), }; + let job_params: types::JobParams = serde_json::from_value(job_meta.params)?; + let svc_host = get_generic_svc_url().context("failed to get embedding service url from GUC")?; Ok(EmbeddingRequest { url: svc_host, payload, - api_key: None, + api_key: job_params.api_key, }) } diff --git a/src/transformers/http_handler.rs b/src/transformers/http_handler.rs index 465b442..067800e 100644 --- a/src/transformers/http_handler.rs +++ b/src/transformers/http_handler.rs @@ -2,6 +2,7 @@ use super::generic::get_generic_svc_url; use super::types::TransformerMetadata; use anyhow::Result; +use crate::guc::EMBEDDING_REQ_TIMEOUT_SEC; use crate::transformers::types::{ EmbeddingPayload, EmbeddingRequest, EmbeddingResponse, Inputs, PairedEmbeddings, }; @@ -32,8 +33,10 @@ pub async fn openai_embedding_request(request: EmbeddingRequest) -> Result(&request.payload) .header("Content-Type", "application/json"); if let Some(key) = request.api_key { @@ -92,7 +95,10 @@ pub async fn get_model_info( let svc_url = get_generic_svc_url()?; let info_url = svc_url.replace("/embeddings", "/info"); let client = reqwest::Client::new(); - let mut req = client.get(info_url).query(&[("model_name", model_name)]); + let mut req = client + .get(info_url) + .query(&[("model_name", model_name)]) + .timeout(std::time::Duration::from_secs(5)); // model info must always be fast if let Some(key) = api_key { req = req.header("Authorization", format!("Bearer {}", key)); } diff --git a/src/transformers/openai.rs b/src/transformers/openai.rs index ad9c0a0..494119f 100644 --- a/src/transformers/openai.rs +++ b/src/transformers/openai.rs @@ -82,6 +82,7 @@ pub fn validate_api_key(key: &str) -> Result<()> { .get("https://api.openai.com/v1/models") .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", key)) + .timeout(std::time::Duration::from_secs(5)) // api validation should be fast .send() .await .unwrap_or_else(|e| error!("failed to make Open AI key validation call: {}", e)); diff --git a/src/workers/mod.rs b/src/workers/mod.rs index 55753a5..f371eed 100644 --- a/src/workers/mod.rs +++ b/src/workers/mod.rs @@ -35,10 +35,15 @@ pub async fn run_worker( msg.message.job_name ); let job_success = execute_job(conn.clone(), msg).await; - let delete_it = if job_success.is_ok() { - true - } else { - read_ct > 2 + let delete_it = match job_success { + Ok(_) => { + info!("pg-vectorize: job success"); + true + } + Err(e) => { + warning!("pg-vectorize: job failed: {:?}", e); + read_ct > 2 + } }; // delete message from queue