diff --git a/src/api.rs b/src/api.rs index 2def1aa..39bdf7a 100644 --- a/src/api.rs +++ b/src/api.rs @@ -2,6 +2,7 @@ use crate::executor::VectorizeMeta; use crate::guc; use crate::init; use crate::search::cosine_similarity_search; +use crate::transformers::http_handler::sync_get_model_info; use crate::transformers::{openai, transform}; use crate::types; use crate::types::JobParams; @@ -38,10 +39,10 @@ fn table( // get prim key type let pkey_type = init::get_column_datatype(&schema, table, &primary_key); + init::init_pgmq()?; // certain embedding services require an API key, e.g. openAI // key can be set in a GUC, so if its required but not provided in args, and not in GUC, error - init::init_pgmq()?; match transformer.as_ref() { "text-embedding-ada-002" => { let openai_key = match api_key { @@ -56,7 +57,10 @@ fn table( openai::validate_api_key(&openai_key)?; } // todo: make sure model exists - _ => panic!("check"), + t => { + // TODO: parse svc_url so that we can send GET to /info endpoint here, and in table create + let _ = sync_get_model_info(t).expect("transformer does not exist"); + } } let valid_params = types::JobParams { diff --git a/src/init.rs b/src/init.rs index 7f813aa..bf1bb37 100644 --- a/src/init.rs +++ b/src/init.rs @@ -1,5 +1,4 @@ use crate::{ - guc::{self, VectorizeGuc}, query::check_input, transformers::{http_handler::sync_get_model_info, types::TransformerMetadata}, types, @@ -18,8 +17,10 @@ pub fn init_pgmq() -> Result<()> { ))? .context("error checking if queue exists")?; if queue_exists { + info!("queue already exists"); return Ok(()); } else { + info!("creating queue;"); let ran: Result<_, spi::Error> = Spi::connect(|mut c| { let _r = c.update( &format!("SELECT pgmq.create('{VECTORIZE_QUEUE}');"), @@ -80,11 +81,7 @@ pub fn init_embedding_table_query( // for anything but OpenAI, first call info endpoint to get the embedding dim of the model "text-embedding-ada-002" => "vector(1536)".to_owned(), _ => { - log!("getting model info"); - // all-MiniLM-L12-v2 - let svc_url = guc::get_guc(VectorizeGuc::EmbeddingServiceUrl) - .expect("vectorize.embedding_service_url must be set to a valid service"); - let model_info: TransformerMetadata = sync_get_model_info(transformer, &svc_url) + let model_info: TransformerMetadata = sync_get_model_info(transformer) .expect("failed to call vectorize.embedding_service_url"); let dim = model_info.embedding_dimension; format!("vector({dim})") diff --git a/src/search.rs b/src/search.rs index c50af52..2a0175f 100644 --- a/src/search.rs +++ b/src/search.rs @@ -8,13 +8,12 @@ pub fn cosine_similarity_search( num_results: i32, embeddings: &[f64], ) -> Result, spi::Error> { - let emb = serde_json::to_string(&embeddings).expect("failed to serialize embeddings"); let query = format!( " SELECT to_jsonb(t) as results FROM ( SELECT - 1 - ({project}_embeddings <=> '{emb}'::vector) AS similarity_score, + 1 - ({project}_embeddings <=> $1::vector) AS similarity_score, {cols} FROM {schema}.{table} WHERE {project}_updated_at is NOT NULL @@ -24,10 +23,16 @@ pub fn cosine_similarity_search( ", cols = return_columns.join(", "), ); - log!("query: {}", query); Spi::connect(|client| { let mut results: Vec<(pgrx::JsonB,)> = Vec::new(); - let tup_table = client.select(&query, None, None)?; + let tup_table = client.select( + &query, + None, + Some(vec![( + PgBuiltInOids::FLOAT8ARRAYOID.oid(), + embeddings.into_datum(), + )]), + )?; for row in tup_table { match row["results"].value()? { Some(r) => results.push((r,)), diff --git a/src/transformers/http_handler.rs b/src/transformers/http_handler.rs index 77fff1a..b61f8bb 100644 --- a/src/transformers/http_handler.rs +++ b/src/transformers/http_handler.rs @@ -1,5 +1,6 @@ use anyhow::Result; +use crate::guc; use crate::transformers::types::{ EmbeddingPayload, EmbeddingRequest, EmbeddingResponse, Inputs, PairedEmbeddings, }; @@ -63,18 +64,18 @@ pub fn merge_input_output(inputs: Vec, values: Vec>) -> Vec pgrx::JsonB { - let meta = sync_get_model_info(model_name, url).unwrap(); +pub fn mod_info(model_name: &str) -> pgrx::JsonB { + let meta = sync_get_model_info(model_name).unwrap(); pgrx::JsonB(serde_json::to_value(meta).unwrap()) } -pub fn sync_get_model_info(model_name: &str, url: &str) -> Result { +pub fn sync_get_model_info(model_name: &str) -> Result { let runtime = tokio::runtime::Builder::new_current_thread() .enable_io() .enable_time() .build() .unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e)); - let meta = match runtime.block_on(async { get_model_info(model_name, url).await }) { + let meta = match runtime.block_on(async { get_model_info(model_name).await }) { Ok(e) => e, Err(e) => { error!("error getting embeddings: {}", e); @@ -83,9 +84,12 @@ pub fn sync_get_model_info(model_name: &str, url: &str) -> Result Result { +pub async fn get_model_info(model_name: &str) -> Result { + let svc_url = guc::get_guc(guc::VectorizeGuc::EmbeddingServiceUrl) + .expect("vectorize.embedding_service_url must be set to a valid service"); + let info_url = svc_url.replace("/embeddings", "/info"); let client = reqwest::Client::new(); - let req = client.get(url).query(&[("model_name", model_name)]); + let req = client.get(info_url).query(&[("model_name", model_name)]); let resp = req.send().await?; let meta_response = handle_response::(resp, "info").await?; Ok(meta_response)