Skip to content

Commit

Permalink
pass through transformer name
Browse files Browse the repository at this point in the history
  • Loading branch information
ChuckHend committed Jan 24, 2024
1 parent 0178728 commit 3c23663
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 128 deletions.
8 changes: 5 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@ crate-type = ["cdylib"]

[features]
default = ["pg15"]
pg14 = ["pgrx/pg14", "pgrx-tests/pg14"]
pg15 = ["pgrx/pg15", "pgrx-tests/pg15"]
pg16 = ["pgrx/pg16", "pgrx-tests/pg16"]
pg_test = []

[dependencies]
anyhow = "1.0.72"
chrono = {version = "0.4.26", features = ["serde"] }
lazy_static = "1.4.0"
log = "0.4.19"
pgmq = "0.24.0"
pgrx = "0.11.0"
pgmq = "0.26.0"
pgrx = "0.11.2"
postgres-types = "0.2.5"
regex = "1.9.2"
reqwest = {version = "0.11.18", features = ["json"] }
Expand All @@ -35,7 +37,7 @@ tokio = {version = "1.29.1", features = ["rt-multi-thread"] }
url = "2.4.0"

[dev-dependencies]
pgrx-tests = "0.11.0"
pgrx-tests = "0.11.2"
rand = "0.8.5"
whoami = "1.4.1"

Expand Down
28 changes: 11 additions & 17 deletions src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ fn table(
args: default!(pgrx::Json, "'{}'"),
schema: default!(String, "'public'"),
update_col: default!(String, "'last_updated_at'"),
transformer: default!(types::Transformer, "'text_embedding_ada_002'"),
transformer: default!(String, "'text_embedding_ada_002'"),
search_alg: default!(types::SimilarityAlg, "'pgv_cosine_similarity'"),
table_method: default!(types::TableMethod, "'append'"),
schedule: default!(String, "'* * * * *'"),
Expand All @@ -41,9 +41,9 @@ fn table(

// certain embedding services require an API key, e.g. openAI
// key can be set in a GUC, so if its required but not provided in args, and not in GUC, error
init::init_pgmq(&transformer)?;
match transformer {
types::Transformer::text_embedding_ada_002 => {
init::init_pgmq()?;
match transformer.as_ref() {
"text-embedding-ada-002" => {
let openai_key = match api_key {
Some(k) => serde_json::from_value::<String>(k.clone())?,
None => match guc::get_guc(guc::VectorizeGuc::OpenAIKey) {
Expand All @@ -55,8 +55,8 @@ fn table(
};
openai::validate_api_key(&openai_key)?;
}
// no-op
types::Transformer::all_MiniLM_L12_v2 => (),
// todo: make sure model exists
_ => panic!("check"),
}

let valid_params = types::JobParams {
Expand Down Expand Up @@ -105,14 +105,8 @@ fn table(
if ran.is_err() {
error!("error creating job");
}
let init_embed_q = init::init_embedding_table_query(
&job_name,
&schema,
table,
&transformer,
&search_alg,
&table_method,
);
let init_embed_q =
init::init_embedding_table_query(&job_name, &schema, table, &transformer, &table_method);

let ran: Result<_, spi::Error> = Spi::connect(|mut c| {
for q in init_embed_q {
Expand Down Expand Up @@ -152,7 +146,7 @@ fn search(
let schema = proj_params.schema;
let table = proj_params.table;

let embeddings = transform(query, project_meta.transformer, api_key);
let embeddings = transform(query, &project_meta.transformer, api_key);

let search_results = match project_meta.search_alg {
types::SimilarityAlg::pgv_cosine_similarity => cosine_similarity_search(
Expand All @@ -171,8 +165,8 @@ fn search(
#[pg_extern]
fn transform_embeddings(
input: &str,
model_name: default!(types::Transformer, "'text_embedding_ada_002'"),
model_name: default!(String, "'text_embedding_ada_002'"),
api_key: default!(Option<String>, "NULL"),
) -> Result<Vec<f64>, spi::Error> {
Ok(transform(input, model_name, api_key).remove(0))
Ok(transform(input, &model_name, api_key).remove(0))
}
9 changes: 3 additions & 6 deletions src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use pgrx::prelude::*;

use crate::errors::DatabaseError;
use crate::guc::BATCH_SIZE;
use crate::init::QUEUE_MAPPING;
use crate::init::VECTORIZE_QUEUE;
use crate::query::check_input;
use crate::transformers::types::Inputs;
use crate::types;
Expand All @@ -23,7 +23,7 @@ pub struct VectorizeMeta {
pub job_id: i64,
pub name: String,
pub job_type: types::JobType,
pub transformer: types::Transformer,
pub transformer: String,
pub search_alg: types::SimilarityAlg,
pub params: serde_json::Value,
#[serde(deserialize_with = "from_tsopt")]
Expand Down Expand Up @@ -112,11 +112,8 @@ fn job_execute(job_name: String) {
job_meta: meta.clone(),
inputs: b,
};
let queue_name = QUEUE_MAPPING
.get(&meta.transformer)
.expect("invalid transformer");
let msg_id = queue
.send(queue_name, &msg)
.send(VECTORIZE_QUEUE, &msg)
.await
.unwrap_or_else(|e| error!("failed to send message updates: {}", e));
log!("message sent: {}", msg_id);
Expand Down
69 changes: 31 additions & 38 deletions src/init.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,31 @@
use crate::{query::check_input, types, types::TableMethod, types::Transformer};
use crate::{
guc::{self, VectorizeGuc},
query::check_input,
transformers::{http_handler::sync_get_model_info, types::TransformerMetadata},
types,
types::TableMethod,
};
use pgrx::prelude::*;
use std::collections::HashMap;

use anyhow::{Context, Result};
use lazy_static::lazy_static;

lazy_static! {
// each model has its own job queue
// maintain the mapping of transformer to queue name here
pub static ref QUEUE_MAPPING: HashMap<Transformer, &'static str> = {
let mut m = HashMap::new();
m.insert(Transformer::text_embedding_ada_002, "v_openai");
m.insert(Transformer::all_MiniLM_L12_v2, "v_all_MiniLM_L12_v2");
m
};
}
pub static VECTORIZE_QUEUE: &str = "vectorize_jobs";

pub fn init_pgmq(transformer: &Transformer) -> Result<()> {
let qname = QUEUE_MAPPING.get(transformer).expect("invalid transformer");
pub fn init_pgmq() -> Result<()> {
// check if queue already created:
let queue_exists: bool = Spi::get_one(&format!(
"SELECT EXISTS (SELECT 1 FROM pgmq.meta WHERE queue_name = '{qname}');",
"SELECT EXISTS (SELECT 1 FROM pgmq.meta WHERE queue_name = '{VECTORIZE_QUEUE}');",
))?
.context("error checking if queue exists")?;
if queue_exists {
return Ok(());
} else {
let ran: Result<_, spi::Error> = Spi::connect(|mut c| {
let _r = c.update(&format!("SELECT pgmq.create('{qname}');"), None, None)?;
let _r = c.update(
&format!("SELECT pgmq.create('{VECTORIZE_QUEUE}');"),
None,
None,
)?;
Ok(())
});
if let Err(e) = ran {
Expand Down Expand Up @@ -69,38 +67,38 @@ pub fn init_embedding_table_query(
job_name: &str,
schema: &str,
table: &str,
transformer: &types::Transformer,
search_alg: &types::SimilarityAlg,
transformer: &str,
transform_method: &TableMethod,
) -> Vec<String> {
// TODO: when adding support for other models, add the output dimension to the transformer attributes
// so that they can be read here, not hard-coded here below
// currently only supports the text-embedding-ada-002 embedding model - output dim 1536
// https://platform.openai.com/docs/guides/embeddings/what-are-embeddings

check_input(job_name).expect("invalid job name");
let col_type = match (transformer, search_alg) {
let col_type = match transformer {
// TODO: when adding support for other models, add the output dimension to the transformer attributes
// so that they can be read here, not hard-coded here below
// currently only supports the text-embedding-ada-002 embedding model - output dim 1536
// https://platform.openai.com/docs/guides/embeddings/what-are-embeddings
(
types::Transformer::text_embedding_ada_002,
types::SimilarityAlg::pgv_cosine_similarity,
) => "vector(1536)",
(types::Transformer::all_MiniLM_L12_v2, types::SimilarityAlg::pgv_cosine_similarity) => {
"vector(384)"

// for anything but OpenAI, first call info endpoint to get the embedding dim of the model
"text-embedding-ada-002" => "vector(1536)".to_owned(),
_ => {
log!("getting model info");
// all-MiniLM-L12-v2
let svc_url = guc::get_guc(VectorizeGuc::EmbeddingServiceUrl)
.expect("vectorize.embedding_service_url must be set to a valid service");
let model_info: TransformerMetadata = sync_get_model_info(transformer, &svc_url)
.expect("failed to call vectorize.embedding_service_url");
let dim = model_info.embedding_dimension;
format!("vector({dim})")
}
};
match transform_method {
TableMethod::append => {
vec![
append_embedding_column(job_name, schema, table, col_type),
append_embedding_column(job_name, schema, table, &col_type),
create_hnsw_cosine_index(job_name, schema, table),
]
}
TableMethod::join => {
vec![create_embedding_table(job_name, col_type)]
vec![create_embedding_table(job_name, &col_type)]
}
}
}
Expand All @@ -125,11 +123,6 @@ fn create_hnsw_cosine_index(job_name: &str, schema: &str, table: &str) -> String
}

fn append_embedding_column(job_name: &str, schema: &str, table: &str, col_type: &str) -> String {
// TODO: when adding support for other models, add the output dimension to the transformer attributes
// so that they can be read here, not hard-coded here below
// currently only supports the text-embedding-ada-002 embedding model - output dim 1536
// https://platform.openai.com/docs/guides/embeddings/what-are-embeddings

check_input(job_name).expect("invalid job name");
format!(
"
Expand Down
31 changes: 31 additions & 0 deletions src/transformers/http_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use crate::transformers::types::{
};
use pgrx::prelude::*;

use super::types::TransformerMetadata;

pub async fn handle_response<T: for<'de> serde::Deserialize<'de>>(
resp: reqwest::Response,
method: &'static str,
Expand Down Expand Up @@ -59,3 +61,32 @@ pub fn merge_input_output(inputs: Vec<Inputs>, values: Vec<Vec<f64>>) -> Vec<Pai
})
.collect()
}

#[pg_extern]
pub fn mod_info(model_name: &str, url: &str) -> pgrx::JsonB {
let meta = sync_get_model_info(model_name, url).unwrap();
pgrx::JsonB(serde_json::to_value(meta).unwrap())
}

pub fn sync_get_model_info(model_name: &str, url: &str) -> Result<TransformerMetadata> {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e));
let meta = match runtime.block_on(async { get_model_info(model_name, url).await }) {
Ok(e) => e,
Err(e) => {
error!("error getting embeddings: {}", e);
}
};
Ok(meta)
}

pub async fn get_model_info(model_name: &str, url: &str) -> Result<TransformerMetadata> {
let client = reqwest::Client::new();
let req = client.get(url).query(&[("model_name", model_name)]);
let resp = req.send().await?;
let meta_response = handle_response::<TransformerMetadata>(resp, "info").await?;
Ok(meta_response)
}
7 changes: 3 additions & 4 deletions src/transformers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,21 @@ pub mod tembo;
pub mod types;

use crate::guc;
use crate::types::Transformer;
use generic::get_generic_svc_url;
use http_handler::openai_embedding_request;
use openai::{OPENAI_EMBEDDING_MODEL, OPENAI_EMBEDDING_URL};
use pgrx::prelude::*;
use types::{EmbeddingPayload, EmbeddingRequest};

pub fn transform(input: &str, transformer: Transformer, api_key: Option<String>) -> Vec<Vec<f64>> {
pub fn transform(input: &str, transformer: &str, api_key: Option<String>) -> Vec<Vec<f64>> {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e));

let embedding_request = match transformer {
Transformer::text_embedding_ada_002 => {
"text-embedding-ada-002" => {
let openai_key = match api_key {
Some(k) => k,
None => match guc::get_guc(guc::VectorizeGuc::OpenAIKey) {
Expand All @@ -41,7 +40,7 @@ pub fn transform(input: &str, transformer: Transformer, api_key: Option<String>)
api_key: Some(openai_key),
}
}
Transformer::all_MiniLM_L12_v2 => {
_ => {
let url = get_generic_svc_url().expect("failed to get embedding service url from GUC");
let embedding_request = EmbeddingPayload {
input: vec![input.to_string()],
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,10 @@ pub struct PairedEmbeddings {
pub primary_key: String,
pub embeddings: Vec<f64>,
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct TransformerMetadata {
pub model: String,
pub max_seq_len: i32,
pub embedding_dimension: i32,
}
37 changes: 0 additions & 37 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,6 @@ use std::str::FromStr;

pub const VECTORIZE_SCHEMA: &str = "vectorize";

#[allow(non_camel_case_types)]
#[derive(Clone, Debug, Serialize, Deserialize, Eq, Hash, PartialEq, PostgresEnum)]
pub enum Transformer {
text_embedding_ada_002,
all_MiniLM_L12_v2,
}

impl FromStr for Transformer {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"text_embedding_ada_002" => Ok(Transformer::text_embedding_ada_002),
"all_MiniLM_L12_v2" => Ok(Transformer::all_MiniLM_L12_v2),
_ => Err(format!("Invalid value: {}", s)),
}
}
}

impl From<String> for Transformer {
fn from(s: String) -> Self {
match s.as_str() {
"text_embedding_ada_002" => Transformer::text_embedding_ada_002,
"all_MiniLM_L12_v2" => Transformer::all_MiniLM_L12_v2,
_ => panic!("Invalid value for Transformer: {}", s), // or handle this case differently
}
}
}

impl Display for Transformer {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
match self {
Transformer::text_embedding_ada_002 => write!(f, "text_embedding_ada_002"),
Transformer::all_MiniLM_L12_v2 => write!(f, "all_MiniLM_L12_v2"),
}
}
}

#[allow(non_camel_case_types)]
#[derive(Clone, Debug, Serialize, Deserialize, PostgresEnum)]
Expand Down
Loading

0 comments on commit 3c23663

Please sign in to comment.