Skip to content

Commit

Permalink
Merge branch 'main' into TEM-3229-2
Browse files Browse the repository at this point in the history
  • Loading branch information
ChuckHend committed Feb 26, 2024
2 parents 669d021 + bf68a89 commit 1c85f76
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 8 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/extension_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ jobs:
if: github.event_name == 'release'
name: trunk publish
runs-on: ubuntu-22.04
strategy:
matrix:
pg-version: [14, 15, 16]
steps:
- uses: actions/checkout@v2
- name: Install Rust stable toolchain
Expand All @@ -158,7 +161,7 @@ jobs:
cargo install pg-trunk
- name: trunk build
working-directory: ./
run: ~/.cargo/bin/trunk build
run: ~/.cargo/bin/trunk build --pg-version ${{ matrix.pg-version }}
- name: trunk publish
working-directory: ./
env:
Expand Down
12 changes: 12 additions & 0 deletions src/guc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub static BATCH_SIZE: GucSetting<i32> = GucSetting::<i32>::new(10000);
pub static NUM_BGW_PROC: GucSetting<i32> = GucSetting::<i32>::new(1);
pub static EMBEDDING_SERVICE_HOST: GucSetting<Option<&CStr>> =
GucSetting::<Option<&CStr>>::new(None);
pub static EMBEDDING_REQ_TIMEOUT_SEC: GucSetting<i32> = GucSetting::<i32>::new(120);

// initialize GUCs
pub fn init_guc() {
Expand Down Expand Up @@ -56,6 +57,17 @@ pub fn init_guc() {
GucContext::Suset,
GucFlags::default(),
);

GucRegistry::define_int_guc(
"vectorize.embedding_req_timeout_sec",
"Timeout, in seconds, for embedding transform requests",
"Number of seconds to wait for an embedding http request to complete. Default is 120 seconds.",
&EMBEDDING_REQ_TIMEOUT_SEC,
1,
1800,
GucContext::Suset,
GucFlags::default(),
);
}

// for handling of GUCs that can be error prone
Expand Down
9 changes: 8 additions & 1 deletion src/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,14 @@ pub fn search(
let schema = proj_params.schema;
let table = proj_params.table;

let embeddings = transform(query, &project_meta.transformer, api_key);
let proj_api_key = match api_key {
// if api passed in the function call, use that
Some(k) => Some(k),
// if not, use the one from the project metadata
None => proj_params.api_key,
};

let embeddings = transform(query, &project_meta.transformer, proj_api_key);

match project_meta.search_alg {
types::SimilarityAlg::pgv_cosine_similarity => cosine_similarity_search(
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::{
executor::VectorizeMeta,
guc,
transformers::types::{EmbeddingPayload, EmbeddingRequest, Inputs},
types,
};

use super::openai::trim_inputs;
Expand Down Expand Up @@ -63,12 +64,14 @@ pub fn prepare_generic_embedding_request(
model: job_meta.transformer.to_string(),
};

let job_params: types::JobParams = serde_json::from_value(job_meta.params)?;

let svc_host = get_generic_svc_url().context("failed to get embedding service url from GUC")?;

Ok(EmbeddingRequest {
url: svc_host,
payload,
api_key: None,
api_key: job_params.api_key,
})
}

Expand Down
8 changes: 7 additions & 1 deletion src/transformers/http_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::generic::get_generic_svc_url;
use super::types::TransformerMetadata;
use anyhow::Result;

use crate::guc::EMBEDDING_REQ_TIMEOUT_SEC;
use crate::transformers::types::{
EmbeddingPayload, EmbeddingRequest, EmbeddingResponse, Inputs, PairedEmbeddings,
};
Expand Down Expand Up @@ -32,8 +33,10 @@ pub async fn openai_embedding_request(request: EmbeddingRequest) -> Result<Vec<V
request.payload.input.len()
);
let client = reqwest::Client::new();
let timeout = EMBEDDING_REQ_TIMEOUT_SEC.get();
let mut req = client
.post(request.url)
.timeout(std::time::Duration::from_secs(timeout as u64))
.json::<EmbeddingPayload>(&request.payload)
.header("Content-Type", "application/json");
if let Some(key) = request.api_key {
Expand Down Expand Up @@ -92,7 +95,10 @@ pub async fn get_model_info(
let svc_url = get_generic_svc_url()?;
let info_url = svc_url.replace("/embeddings", "/info");
let client = reqwest::Client::new();
let mut req = client.get(info_url).query(&[("model_name", model_name)]);
let mut req = client
.get(info_url)
.query(&[("model_name", model_name)])
.timeout(std::time::Duration::from_secs(5)); // model info must always be fast
if let Some(key) = api_key {
req = req.header("Authorization", format!("Bearer {}", key));
}
Expand Down
1 change: 1 addition & 0 deletions src/transformers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ pub fn validate_api_key(key: &str) -> Result<()> {
.get("https://api.openai.com/v1/models")
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", key))
.timeout(std::time::Duration::from_secs(5)) // api validation should be fast
.send()
.await
.unwrap_or_else(|e| error!("failed to make Open AI key validation call: {}", e));
Expand Down
13 changes: 9 additions & 4 deletions src/workers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,15 @@ pub async fn run_worker(
msg.message.job_name
);
let job_success = execute_job(conn.clone(), msg).await;
let delete_it = if job_success.is_ok() {
true
} else {
read_ct > 2
let delete_it = match job_success {
Ok(_) => {
info!("pg-vectorize: job success");
true
}
Err(e) => {
warning!("pg-vectorize: job failed: {:?}", e);
read_ct > 2
}
};

// delete message from queue
Expand Down

0 comments on commit 1c85f76

Please sign in to comment.