Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle portkey virtkey in bgw #138

Merged
merged 8 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions .github/workflows/extension_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ on:
jobs:
dependencies:
name: Install dependencies
runs-on: ubuntu-22.04
runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v2

Expand Down Expand Up @@ -85,7 +85,7 @@ jobs:
test:
name: Run tests
needs: dependencies
runs-on: ubuntu-22.04
runs-on: ubuntu-24.04
services:
# Label used to access the service container
vector-serve:
Expand All @@ -100,10 +100,15 @@ jobs:
toolchain: stable
- uses: Swatinem/rust-cache@v2
with:
prefix-key: "pg-vectorize-extension-test"
workspaces: pg-vectorize
prefix-key: "extension-test"
workspaces: |
vectorize
# Additional directories to cache
cache-directories: /home/runner/.pgrx
cache-directories: |
/home/runner/.pgrx
- name: Install sys dependencies
run: |
sudo apt-get update && sudo apt-get install -y postgresql-server-dev-16 libopenblas-dev libreadline-dev
- uses: ./.github/actions/pgx-init
with:
working-directory: ./extension
Expand All @@ -126,10 +131,7 @@ jobs:
${{ runner.os }}-bins-
- name: setup-tests
run: |
make trunk-dependencies
make setup.urls
make setup.shared_preload_libraries
rm -rf ./target/pgrx-test-data-* || true
make setup
- name: unit-test
run: |
make test-unit
Expand All @@ -146,7 +148,7 @@ jobs:
publish:
if: github.event_name == 'release'
name: trunk publish
runs-on: ubuntu-22.04
runs-on: ubuntu-24.04
strategy:
matrix:
pg-version: [14, 15, 16]
Expand Down
1 change: 1 addition & 0 deletions core/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ pub struct JobParams {
pub api_key: Option<String>,
#[serde(default = "default_schedule")]
pub schedule: String,
pub args: Option<serde_json::Value>,
}

fn default_schedule() -> String {
Expand Down
8 changes: 7 additions & 1 deletion core/src/worker/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,17 @@ async fn execute_job(
let job_meta: VectorizeMeta = msg.message.job_meta;
let job_params: JobParams = serde_json::from_value(job_meta.params.clone())?;

let virtual_key = if let Some(args) = job_params.args.clone() {
args.get("virtual_key").map(|v| v.to_string())
} else {
None
};

let provider = providers::get_provider(
&job_meta.transformer.source,
job_params.api_key.clone(),
None,
None,
virtual_key,
)?;

let embedding_request =
Expand Down
2 changes: 1 addition & 1 deletion extension/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "vectorize"
version = "0.18.0"
version = "0.18.1"
edition = "2021"
publish = false

Expand Down
4 changes: 2 additions & 2 deletions extension/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ install-pgvector:
install-pgmq:
git clone https://github.com/tembo-io/pgmq.git && \
cd pgmq/pgmq-extension && \
PG_CONFIG=${PGRX_PG_CONFIG} make && \
PG_CONFIG=${PGRX_PG_CONFIG} make clean && \
PG_CONFIG=${PGRX_PG_CONFIG} make && \
PG_CONFIG=${PGRX_PG_CONFIG} make install && \
cd .. && rm -rf pgmq
cd ../.. && rm -rf pgmq

install-vectorscale:
@ARCH=$$(uname -m); \
Expand Down
2 changes: 1 addition & 1 deletion extension/Trunk.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ description = "The simplest way to orchestrate vector search on Postgres."
homepage = "https://github.com/tembo-io/pg_vectorize"
documentation = "https://github.com/tembo-io/pg_vectorize"
categories = ["orchestration", "machine_learning"]
version = "0.18.0"
version = "0.18.1"
loadable_libraries = [{ library_name = "vectorize", requires_restart = true }]

[build]
Expand Down
Empty file.
59 changes: 33 additions & 26 deletions extension/src/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,38 +39,19 @@ pub fn init_table(
init::init_pgmq()?;

let guc_configs = get_guc_configs(&transformer.source);
let provider = get_provider(
&transformer.source,
guc_configs.api_key.clone(),
guc_configs.service_url,
None,
)?;

//synchronous
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e));
let model_dim =
match runtime.block_on(async { provider.model_dim(&transformer.api_name()).await }) {
Ok(e) => e,
Err(e) => {
error!("error getting model dim: {}", e);
}
};

// validate API key where necessary
info!("guc_configs: {:?}", guc_configs);
// validate API key where necessary and collect any optional arguments
// certain embedding services require an API key, e.g. openAI
// key can be set in a GUC, so if its required but not provided in args, and not in GUC, error
match transformer.source {
let optional_args = match transformer.source {
ModelSource::OpenAI => {
openai::validate_api_key(
&guc_configs
.api_key
.clone()
.context("OpenAI key is required")?,
)?;
None
}
ModelSource::Tembo => {
error!("Tembo not implemented for search yet");
Expand All @@ -85,15 +66,40 @@ pub fn init_table(
let res = check_model_host(&url);
match res {
Ok(_) => {
info!("Model host active!")
info!("Model host active!");
None
}
Err(e) => {
error!("Error with model host: {:?}", e)
}
}
}
_ => (),
}
ModelSource::Portkey => Some(serde_json::json!({
"virtual_key": guc_configs.virtual_key.clone().expect("Portkey virtual key is required")
})),
_ => None,
};

let provider = get_provider(
&transformer.source,
guc_configs.api_key.clone(),
guc_configs.service_url.clone(),
guc_configs.virtual_key.clone(),
)?;

// synchronous
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e));
let model_dim =
match runtime.block_on(async { provider.model_dim(&transformer.api_name()).await }) {
Ok(e) => e,
Err(e) => {
error!("error getting model dim: {}", e);
}
};

let valid_params = types::JobParams {
schema: schema.to_string(),
Expand All @@ -105,6 +111,7 @@ pub fn init_table(
pkey_type,
api_key: guc_configs.api_key.clone(),
schedule: schedule.to_string(),
args: optional_args,
};
let params =
pgrx::JsonB(serde_json::to_value(valid_params.clone()).expect("error serializing params"));
Expand Down
14 changes: 8 additions & 6 deletions extension/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -788,10 +788,6 @@ async fn test_diskann_cosine() {
common::init_test_table(&test_table_name, &conn).await;
let job_name = format!("job_diskann_{}", test_num);

let _ = sqlx::query("CREATE EXTENSION IF NOT EXISTS vectorscale;")
.execute(&conn)
.await;

common::init_embedding_svc_url(&conn).await;
// initialize a job
let result = sqlx::query(&format!(
Expand All @@ -810,9 +806,15 @@ async fn test_diskann_cosine() {
assert!(result.is_ok());

let search_results: Vec<common::SearchJSON> =
util::common::search_with_retry(&conn, "mobile devices", &job_name, 10, 2, 3, None)
match util::common::search_with_retry(&conn, "mobile devices", &job_name, 10, 2, 3, None)
.await
.unwrap();
{
Ok(results) => results,
Err(e) => {
eprintln!("Error: {:?}", e);
panic!("failed to exec search on diskann");
}
};
assert_eq!(search_results.len(), 3);
}

Expand Down
9 changes: 5 additions & 4 deletions extension/tests/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ pub mod common {
.await
.expect("failed to create extension");

// Optional dependencies
let _ = sqlx::query("CREATE EXTENSION IF NOT EXISTS vectorscale CASCADE")
.execute(&conn)
.await
.expect("failed to create vectorscale extension");
conn
}

Expand All @@ -63,10 +68,6 @@ pub mod common {
28815
} else if cfg!(feature = "pg14") {
28814
} else if cfg!(feature = "pg13") {
28813
} else if cfg!(feature = "pg12") {
28812
} else {
5432
}
Expand Down
Loading