Skip to content

Commit

Permalink
Arbitrary transformer backends (#26)
Browse files Browse the repository at this point in the history
* add test

* pluggable http handler

* switch on search

* handle worker

* add generic adapter

* handle a generic webserver

* handle response

* update migration script
  • Loading branch information
ChuckHend authored Nov 20, 2023
1 parent 589bcca commit 650c56f
Show file tree
Hide file tree
Showing 19 changed files with 490 additions and 196 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ url = "2.4.0"

[dev-dependencies]
pgrx-tests = "0.11.0"
rand = "0.8.5"
whoami = "1.4.1"

[profile.dev]
panic = "unwind"
Expand Down
2 changes: 1 addition & 1 deletion sql/vectorize--0.5.0--0.6.0.sql
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ CREATE FUNCTION vectorize."table"(
"update_col" TEXT DEFAULT 'last_updated_at', /* alloc::string::String */
"transformer" vectorize.Transformer DEFAULT 'openai', /* vectorize::types::Transformer */
"search_alg" vectorize.SimilarityAlg DEFAULT 'pgv_cosine_similarity', /* vectorize::types::SimilarityAlg */
"table_method" vectorize.TableMethod DEFAULT 'append', /* vectorize::init::TableMethod */
"table_method" vectorize.TableMethod DEFAULT 'append', /* vectorize::types::TableMethod */
"schedule" TEXT DEFAULT '* * * * *' /* alloc::string::String */
) RETURNS TEXT /* core::result::Result<alloc::string::String, anyhow::Error> */
STRICT
Expand Down
105 changes: 68 additions & 37 deletions src/api.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use crate::executor::VectorizeMeta;
use crate::guc;
use crate::guc::get_guc;
use crate::init;
use crate::search::cosine_similarity_search;
use crate::transformers::openai;
use crate::transformers::{
http_handler::openai_embedding_request, openai, openai::OPENAI_EMBEDDING_MODEL,
openai::OPENAI_EMBEDDING_URL, types::EmbeddingPayload, types::EmbeddingRequest,
};
use crate::types;
use crate::types::JobParams;
use crate::util;
Expand Down Expand Up @@ -55,7 +60,7 @@ fn table(
openai::validate_api_key(&openai_key)?;
}
// no-op
types::Transformer::allMiniLML12v2 => (),
types::Transformer::all_MiniLM_L12_v2 => (),
}

let valid_params = types::JobParams {
Expand Down Expand Up @@ -138,57 +143,83 @@ fn search(
return_columns: default!(Vec<String>, "ARRAY['*']::text[]"),
num_results: default!(i32, 10),
) -> Result<TableIterator<'static, (name!(search_results, pgrx::JsonB),)>, spi::Error> {
// note: this is not the most performant implementation
// this requires a query to metadata table to get the projects schema and table, which has a cost
// this does ensure consistency between the model used to generate the stored embeddings and the query embeddings, which is crucial

// get project metadata
let _project_meta = if let Some(js) = util::get_vectorize_meta_spi(job_name) {
let project_meta: VectorizeMeta = if let Ok(Some(js)) = util::get_vectorize_meta_spi(job_name) {
js
} else {
error!("failed to get project metadata");
};
let project_meta: JobParams =
serde_json::from_value(serde_json::to_value(_project_meta).unwrap_or_else(|e| {
let proj_params: JobParams = serde_json::from_value(
serde_json::to_value(project_meta.params).unwrap_or_else(|e| {
error!("failed to serialize metadata: {}", e);
}))
.unwrap_or_else(|e| error!("failed to serialize metadata: {}", e));
// assuming default openai API for now
}),
)
.unwrap_or_else(|e| error!("failed to deserialize metadata: {}", e));

// get embeddings
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e));

let schema = project_meta.schema;
let table = project_meta.table;
let schema = proj_params.schema;
let table = proj_params.table;

let embedding_request = match project_meta.transformer {
types::Transformer::openai => {
let openai_key = match api_key {
Some(k) => k,
None => match guc::get_guc(guc::VectorizeGuc::OpenAIKey) {
Some(k) => k,
None => {
error!("failed to get API key from GUC");
}
},
};

let openai_key = match api_key {
Some(k) => k,
None => match guc::get_guc(guc::VectorizeGuc::OpenAIKey) {
Some(k) => k,
None => {
error!("failed to get API key from GUC");
let embedding_request = EmbeddingPayload {
input: vec![query.to_string()],
model: OPENAI_EMBEDDING_MODEL.to_string(),
};
EmbeddingRequest {
url: OPENAI_EMBEDDING_URL.to_owned(),
payload: embedding_request,
api_key: Some(openai_key),
}
}
types::Transformer::all_MiniLM_L12_v2 => {
let url: String = get_guc(guc::VectorizeGuc::EmbeddingServiceUrl)
.expect("failed to get embedding service url from GUC");
let embedding_request = EmbeddingPayload {
input: vec![query.to_string()],
model: project_meta.transformer.to_string(),
};
EmbeddingRequest {
url,
payload: embedding_request,
api_key: None,
}
},
}
};

let embeddings = match runtime
.block_on(async { openai::openai_embeddings(&vec![query.to_string()], &openai_key).await })
{
Ok(e) => e,
Err(e) => {
error!("error getting embeddings: {}", e);
}
let embeddings =
match runtime.block_on(async { openai_embedding_request(embedding_request).await }) {
Ok(e) => e,
Err(e) => {
error!("error getting embeddings: {}", e);
}
};

let search_results = match project_meta.search_alg {
types::SimilarityAlg::pgv_cosine_similarity => cosine_similarity_search(
job_name,
&schema,
&table,
&return_columns,
num_results,
&embeddings[0],
)?,
};
let search_results = cosine_similarity_search(
job_name,
&schema,
&table,
&return_columns,
num_results,
&embeddings[0],
)?;

Ok(TableIterator::new(search_results))
}
10 changes: 2 additions & 8 deletions src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::errors::DatabaseError;
use crate::guc::BATCH_SIZE;
use crate::init::QUEUE_MAPPING;
use crate::query::check_input;
use crate::transformers::types::Inputs;
use crate::types;
use crate::util::{from_env_default, get_pg_conn};
use chrono::serde::ts_seconds_option::deserialize as from_tsopt;
Expand All @@ -17,7 +18,7 @@ use tiktoken_rs::cl100k_base;

// schema for every job
// also schema for the vectorize.vectorize_meta table
#[derive(Clone, Debug, Deserialize, FromRow, Serialize, PostgresType)]
#[derive(Clone, Debug, Deserialize, FromRow, Serialize)]
pub struct VectorizeMeta {
pub job_id: i64,
pub name: String,
Expand Down Expand Up @@ -148,13 +149,6 @@ pub async fn get_vectorize_meta(
Ok(row)
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Inputs {
pub record_id: String, // the value to join the record
pub inputs: String, // concatenation of input columns
pub token_estimate: i32, // estimated token count
}

// queries a table and returns rows that need new embeddings
// used for the TableMethod::append, which has source and embedding on the same table
pub async fn get_new_updates_append(
Expand Down
11 changes: 11 additions & 0 deletions src/guc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use anyhow::Result;
pub static VECTORIZE_HOST: GucSetting<Option<&CStr>> = GucSetting::<Option<&CStr>>::new(None);
pub static OPENAI_KEY: GucSetting<Option<&CStr>> = GucSetting::<Option<&CStr>>::new(None);
pub static BATCH_SIZE: GucSetting<i32> = GucSetting::<i32>::new(10000);
pub static EMBEDDING_SERVICE_HOST: GucSetting<Option<&CStr>> =
GucSetting::<Option<&CStr>>::new(None);

// initialize GUCs
pub fn init_guc() {
Expand Down Expand Up @@ -35,20 +37,29 @@ pub fn init_guc() {
GucContext::Suset,
GucFlags::default(),
);

GucRegistry::define_string_guc(
"vectorize.embedding_service_url",
"Url for an OpenAI compatible embedding service",
"Url to a service with request and response schema consistent with OpenAI's embeddings API.",
&EMBEDDING_SERVICE_HOST,
GucContext::Suset, GucFlags::default());
}

// for handling of GUCs that can be error prone
#[derive(Debug)]
pub enum VectorizeGuc {
Host,
OpenAIKey,
EmbeddingServiceUrl,
}

/// a convenience function to get this project's GUCs
pub fn get_guc(guc: VectorizeGuc) -> Option<String> {
let val = match guc {
VectorizeGuc::Host => VECTORIZE_HOST.get(),
VectorizeGuc::OpenAIKey => OPENAI_KEY.get(),
VectorizeGuc::EmbeddingServiceUrl => EMBEDDING_SERVICE_HOST.get(),
};
if let Some(cstr) = val {
if let Ok(s) = handle_cstr(cstr) {
Expand Down
4 changes: 2 additions & 2 deletions src/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ lazy_static! {
pub static ref QUEUE_MAPPING: HashMap<Transformer, &'static str> = {
let mut m = HashMap::new();
m.insert(Transformer::openai, "v_openai");
m.insert(Transformer::allMiniLML12v2, "v_all_MiniLM_L12_v2");
m.insert(Transformer::all_MiniLM_L12_v2, "v_all_MiniLM_L12_v2");
m
};
}
Expand Down Expand Up @@ -76,7 +76,7 @@ pub fn init_embedding_table_query(
// currently only supports the text-embedding-ada-002 embedding model - output dim 1536
// https://platform.openai.com/docs/guides/embeddings/what-are-embeddings
(types::Transformer::openai, types::SimilarityAlg::pgv_cosine_similarity) => "vector(1536)",
(types::Transformer::allMiniLML12v2, types::SimilarityAlg::pgv_cosine_similarity) => {
(types::Transformer::all_MiniLM_L12_v2, types::SimilarityAlg::pgv_cosine_similarity) => {
"vector(384)"
}
};
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@ pub mod pg_test {

pub fn postgresql_conf_options() -> Vec<&'static str> {
// return any postgresql.conf settings that are required for your tests
vec![]
vec!["cron.database_name='vectorize_test'"]
}
}
29 changes: 29 additions & 0 deletions src/transformers/generic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use anyhow::{Context, Result};

use crate::{
executor::VectorizeMeta,
guc,
transformers::types::{EmbeddingPayload, EmbeddingRequest, Inputs},
};

use super::openai::trim_inputs;

pub fn prepare_generic_embedding_request(
job_meta: VectorizeMeta,
inputs: &[Inputs],
) -> Result<EmbeddingRequest> {
let text_inputs = trim_inputs(inputs);
let payload = EmbeddingPayload {
input: text_inputs,
model: job_meta.transformer.to_string(),
};

let svc_host = guc::get_guc(guc::VectorizeGuc::EmbeddingServiceUrl)
.context("vectorize.embedding_Service_url is not set")?;

Ok(EmbeddingRequest {
url: svc_host,
payload,
api_key: None,
})
}
61 changes: 61 additions & 0 deletions src/transformers/http_handler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use anyhow::Result;

use crate::transformers::types::{
EmbeddingPayload, EmbeddingRequest, EmbeddingResponse, Inputs, PairedEmbeddings,
};
use pgrx::prelude::*;

pub async fn handle_response<T: for<'de> serde::Deserialize<'de>>(
resp: reqwest::Response,
method: &'static str,
) -> Result<T> {
if !resp.status().is_success() {
let errmsg = format!(
"Failed to call method '{}', received response with status code:{} and body: {}",
method,
resp.status(),
resp.text().await?
);
warning!("pg-vectorize: error handling response: {}", errmsg);
return Err(anyhow::anyhow!(errmsg));
}
let value = resp.json::<T>().await?;
Ok(value)
}

// handle an OpenAI compatible embedding transform request
pub async fn openai_embedding_request(request: EmbeddingRequest) -> Result<Vec<Vec<f64>>> {
log!(
"pg-vectorize: openai request size: {}",
request.payload.input.len()
);
let client = reqwest::Client::new();
let mut req = client
.post(request.url)
.json::<EmbeddingPayload>(&request.payload)
.header("Content-Type", "application/json");
if let Some(key) = request.api_key {
req = req.header("Authorization", format!("Bearer {}", key));
}
let resp = req.send().await?;
let embedding_resp = handle_response::<EmbeddingResponse>(resp, "embeddings").await?;

let embeddings = embedding_resp
.data
.iter()
.map(|d| d.embedding.clone())
.collect();
Ok(embeddings)
}

// merges the vec of inputs with the embedding responses
pub fn merge_input_output(inputs: Vec<Inputs>, values: Vec<Vec<f64>>) -> Vec<PairedEmbeddings> {
inputs
.into_iter()
.zip(values)
.map(|(input, value)| PairedEmbeddings {
primary_key: input.record_id,
embeddings: value,
})
.collect()
}
4 changes: 4 additions & 0 deletions src/transformers/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
pub mod generic;
pub mod http_handler;
pub mod openai;
pub mod tembo;
pub mod types;
Loading

0 comments on commit 650c56f

Please sign in to comment.