Skip to content

Commit

Permalink
any SentenceTransformer
Browse files Browse the repository at this point in the history
  • Loading branch information
ChuckHend committed Jan 24, 2024
1 parent a5fbc6c commit 493d147
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 18 deletions.
8 changes: 6 additions & 2 deletions src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::executor::VectorizeMeta;
use crate::guc;
use crate::init;
use crate::search::cosine_similarity_search;
use crate::transformers::http_handler::sync_get_model_info;
use crate::transformers::{openai, transform};
use crate::types;
use crate::types::JobParams;
Expand Down Expand Up @@ -38,10 +39,10 @@ fn table(

// get prim key type
let pkey_type = init::get_column_datatype(&schema, table, &primary_key);
init::init_pgmq()?;

// certain embedding services require an API key, e.g. openAI
// key can be set in a GUC, so if its required but not provided in args, and not in GUC, error
init::init_pgmq()?;
match transformer.as_ref() {
"text-embedding-ada-002" => {
let openai_key = match api_key {
Expand All @@ -56,7 +57,10 @@ fn table(
openai::validate_api_key(&openai_key)?;
}
// todo: make sure model exists
_ => panic!("check"),
t => {
// TODO: parse svc_url so that we can send GET to /info endpoint here, and in table create
let _ = sync_get_model_info(t).expect("transformer does not exist");
}
}

let valid_params = types::JobParams {
Expand Down
9 changes: 3 additions & 6 deletions src/init.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::{
guc::{self, VectorizeGuc},
query::check_input,
transformers::{http_handler::sync_get_model_info, types::TransformerMetadata},
types,
Expand All @@ -18,8 +17,10 @@ pub fn init_pgmq() -> Result<()> {
))?
.context("error checking if queue exists")?;
if queue_exists {
info!("queue already exists");
return Ok(());
} else {
info!("creating queue;");
let ran: Result<_, spi::Error> = Spi::connect(|mut c| {
let _r = c.update(
&format!("SELECT pgmq.create('{VECTORIZE_QUEUE}');"),
Expand Down Expand Up @@ -80,11 +81,7 @@ pub fn init_embedding_table_query(
// for anything but OpenAI, first call info endpoint to get the embedding dim of the model
"text-embedding-ada-002" => "vector(1536)".to_owned(),
_ => {
log!("getting model info");
// all-MiniLM-L12-v2
let svc_url = guc::get_guc(VectorizeGuc::EmbeddingServiceUrl)
.expect("vectorize.embedding_service_url must be set to a valid service");
let model_info: TransformerMetadata = sync_get_model_info(transformer, &svc_url)
let model_info: TransformerMetadata = sync_get_model_info(transformer)
.expect("failed to call vectorize.embedding_service_url");
let dim = model_info.embedding_dimension;
format!("vector({dim})")
Expand Down
13 changes: 9 additions & 4 deletions src/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@ pub fn cosine_similarity_search(
num_results: i32,
embeddings: &[f64],
) -> Result<Vec<(pgrx::JsonB,)>, spi::Error> {
let emb = serde_json::to_string(&embeddings).expect("failed to serialize embeddings");
let query = format!(
"
SELECT to_jsonb(t)
as results FROM (
SELECT
1 - ({project}_embeddings <=> '{emb}'::vector) AS similarity_score,
1 - ({project}_embeddings <=> $1::vector) AS similarity_score,
{cols}
FROM {schema}.{table}
WHERE {project}_updated_at is NOT NULL
Expand All @@ -24,10 +23,16 @@ pub fn cosine_similarity_search(
",
cols = return_columns.join(", "),
);
log!("query: {}", query);
Spi::connect(|client| {
let mut results: Vec<(pgrx::JsonB,)> = Vec::new();
let tup_table = client.select(&query, None, None)?;
let tup_table = client.select(
&query,
None,
Some(vec![(
PgBuiltInOids::FLOAT8ARRAYOID.oid(),
embeddings.into_datum(),
)]),
)?;
for row in tup_table {
match row["results"].value()? {
Some(r) => results.push((r,)),
Expand Down
16 changes: 10 additions & 6 deletions src/transformers/http_handler.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use anyhow::Result;

use crate::guc;
use crate::transformers::types::{
EmbeddingPayload, EmbeddingRequest, EmbeddingResponse, Inputs, PairedEmbeddings,
};
Expand Down Expand Up @@ -63,18 +64,18 @@ pub fn merge_input_output(inputs: Vec<Inputs>, values: Vec<Vec<f64>>) -> Vec<Pai
}

#[pg_extern]
pub fn mod_info(model_name: &str, url: &str) -> pgrx::JsonB {
let meta = sync_get_model_info(model_name, url).unwrap();
pub fn mod_info(model_name: &str) -> pgrx::JsonB {
let meta = sync_get_model_info(model_name).unwrap();
pgrx::JsonB(serde_json::to_value(meta).unwrap())
}

pub fn sync_get_model_info(model_name: &str, url: &str) -> Result<TransformerMetadata> {
pub fn sync_get_model_info(model_name: &str) -> Result<TransformerMetadata> {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e));
let meta = match runtime.block_on(async { get_model_info(model_name, url).await }) {
let meta = match runtime.block_on(async { get_model_info(model_name).await }) {
Ok(e) => e,
Err(e) => {
error!("error getting embeddings: {}", e);
Expand All @@ -83,9 +84,12 @@ pub fn sync_get_model_info(model_name: &str, url: &str) -> Result<TransformerMet
Ok(meta)
}

pub async fn get_model_info(model_name: &str, url: &str) -> Result<TransformerMetadata> {
pub async fn get_model_info(model_name: &str) -> Result<TransformerMetadata> {
let svc_url = guc::get_guc(guc::VectorizeGuc::EmbeddingServiceUrl)
.expect("vectorize.embedding_service_url must be set to a valid service");
let info_url = svc_url.replace("/embeddings", "/info");
let client = reqwest::Client::new();
let req = client.get(url).query(&[("model_name", model_name)]);
let req = client.get(info_url).query(&[("model_name", model_name)]);
let resp = req.send().await?;
let meta_response = handle_response::<TransformerMetadata>(resp, "info").await?;
Ok(meta_response)
Expand Down

0 comments on commit 493d147

Please sign in to comment.