Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle api key in bgw #59

Merged
merged 8 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/guc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub static BATCH_SIZE: GucSetting<i32> = GucSetting::<i32>::new(10000);
pub static NUM_BGW_PROC: GucSetting<i32> = GucSetting::<i32>::new(1);
pub static EMBEDDING_SERVICE_HOST: GucSetting<Option<&CStr>> =
GucSetting::<Option<&CStr>>::new(None);
pub static EMBEDDING_REQ_TIMEOUT_SEC: GucSetting<i32> = GucSetting::<i32>::new(120);

// initialize GUCs
pub fn init_guc() {
Expand Down Expand Up @@ -56,6 +57,17 @@ pub fn init_guc() {
GucContext::Suset,
GucFlags::default(),
);

GucRegistry::define_int_guc(
"vectorize.embedding_req_timeout_sec",
"Timeout, in seconds, for embedding transform requests",
"Number of seconds to wait for an embedding http request to complete. Default is 120 seconds.",
&EMBEDDING_REQ_TIMEOUT_SEC,
1,
1800,
GucContext::Suset,
GucFlags::default(),
);
}

// for handling of GUCs that can be error prone
Expand Down
9 changes: 8 additions & 1 deletion src/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,14 @@ pub fn search(
let schema = proj_params.schema;
let table = proj_params.table;

let embeddings = transform(query, &project_meta.transformer, api_key);
let proj_api_key = match api_key {
// if api passed in the function call, use that
Some(k) => Some(k),
// if not, use the one from the project metadata
None => proj_params.api_key,
};

let embeddings = transform(query, &project_meta.transformer, proj_api_key);

match project_meta.search_alg {
types::SimilarityAlg::pgv_cosine_similarity => cosine_similarity_search(
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::{
executor::VectorizeMeta,
guc,
transformers::types::{EmbeddingPayload, EmbeddingRequest, Inputs},
types,
};

use super::openai::trim_inputs;
Expand Down Expand Up @@ -63,12 +64,14 @@ pub fn prepare_generic_embedding_request(
model: job_meta.transformer.to_string(),
};

let job_params: types::JobParams = serde_json::from_value(job_meta.params)?;

let svc_host = get_generic_svc_url().context("failed to get embedding service url from GUC")?;

Ok(EmbeddingRequest {
url: svc_host,
payload,
api_key: None,
api_key: job_params.api_key,
})
}

Expand Down
8 changes: 7 additions & 1 deletion src/transformers/http_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::generic::get_generic_svc_url;
use super::types::TransformerMetadata;
use anyhow::Result;

use crate::guc::EMBEDDING_REQ_TIMEOUT_SEC;
use crate::transformers::types::{
EmbeddingPayload, EmbeddingRequest, EmbeddingResponse, Inputs, PairedEmbeddings,
};
Expand Down Expand Up @@ -32,8 +33,10 @@ pub async fn openai_embedding_request(request: EmbeddingRequest) -> Result<Vec<V
request.payload.input.len()
);
let client = reqwest::Client::new();
let timeout = EMBEDDING_REQ_TIMEOUT_SEC.get();
let mut req = client
.post(request.url)
.timeout(std::time::Duration::from_secs(timeout as u64))
.json::<EmbeddingPayload>(&request.payload)
.header("Content-Type", "application/json");
if let Some(key) = request.api_key {
Expand Down Expand Up @@ -92,7 +95,10 @@ pub async fn get_model_info(
let svc_url = get_generic_svc_url()?;
let info_url = svc_url.replace("/embeddings", "/info");
let client = reqwest::Client::new();
let mut req = client.get(info_url).query(&[("model_name", model_name)]);
let mut req = client
.get(info_url)
.query(&[("model_name", model_name)])
.timeout(std::time::Duration::from_secs(5)); // model info must always be fast
if let Some(key) = api_key {
req = req.header("Authorization", format!("Bearer {}", key));
}
Expand Down
1 change: 1 addition & 0 deletions src/transformers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ pub fn validate_api_key(key: &str) -> Result<()> {
.get("https://api.openai.com/v1/models")
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", key))
.timeout(std::time::Duration::from_secs(5)) // api validation should be fast
.send()
.await
.unwrap_or_else(|e| error!("failed to make Open AI key validation call: {}", e));
Expand Down
13 changes: 9 additions & 4 deletions src/workers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,15 @@ pub async fn run_worker(
msg.message.job_name
);
let job_success = execute_job(conn.clone(), msg).await;
let delete_it = if job_success.is_ok() {
true
} else {
read_ct > 2
let delete_it = match job_success {
Ok(_) => {
info!("pg-vectorize: job success");
true
}
Err(e) => {
warning!("pg-vectorize: job failed: {:?}", e);
read_ct > 2
}
};

// delete message from queue
Expand Down
Loading