From 9976ea722036617e95316252387e60df5f4ff0d8 Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Thu, 22 Aug 2024 14:14:17 -0500 Subject: [PATCH 01/11] move ollama interface --- core/src/transformers/ollama.rs | 99 ----------------------- core/src/transformers/providers/mod.rs | 8 +- core/src/transformers/providers/ollama.rs | 46 ++++++++--- core/src/transformers/providers/openai.rs | 7 +- extension/src/chat/ops.rs | 11 +-- extension/src/search.rs | 2 +- 6 files changed, 51 insertions(+), 122 deletions(-) delete mode 100644 core/src/transformers/ollama.rs diff --git a/core/src/transformers/ollama.rs b/core/src/transformers/ollama.rs deleted file mode 100644 index 1a27b22..0000000 --- a/core/src/transformers/ollama.rs +++ /dev/null @@ -1,99 +0,0 @@ -use anyhow::Result; -use ollama_rs::{generation::completion::request::GenerationRequest, Ollama}; -use url::Url; - -use super::types::EmbeddingRequest; - -pub struct OllamaInstance { - pub model_name: String, - pub instance: Ollama, -} - -pub trait LLMFunctions { - fn new(model_name: String, url: String) -> Self; - #[allow(async_fn_in_trait)] - async fn generate_reponse(&self, prompt_text: String) -> Result; - #[allow(async_fn_in_trait)] - async fn generate_embedding(&self, inputs: String) -> Result, String>; -} - -impl LLMFunctions for OllamaInstance { - fn new(model_name: String, url: String) -> Self { - let parsed_url = Url::parse(&url).unwrap_or_else(|_| panic!("invalid url: {}", url)); - let instance = Ollama::new( - format!( - "{}://{}", - parsed_url.scheme(), - parsed_url.host_str().expect("parsed url missing") - ), - parsed_url.port().expect("parsed port missing"), - ); - OllamaInstance { - model_name, - instance, - } - } - async fn generate_reponse(&self, prompt_text: String) -> Result { - let req = GenerationRequest::new(self.model_name.clone(), prompt_text); - let res = self.instance.generate(req).await; - match res { - Ok(res) => Ok(res.response), - Err(e) => Err(e.to_string()), - } - } - async fn generate_embedding(&self, input: String) -> Result, String> { - let embed = self - .instance - .generate_embeddings(self.model_name.clone(), input, None) - .await; - match embed { - Ok(res) => Ok(res.embeddings), - Err(e) => Err(e.to_string()), - } - } -} - -pub fn ollama_embedding_dim(model_name: &str) -> i32 { - match model_name { - "llama2" => 5192, - _ => 1536, - } -} - -pub fn check_model_host(url: &str) -> Result { - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .unwrap_or_else(|e| panic!("failed to initialize tokio runtime: {}", e)); - - runtime.block_on(async { - let response = reqwest::get(url).await.unwrap(); - match response.status() { - reqwest::StatusCode::OK => Ok(format!("Success! {:?}", response)), - _ => Err(format!("Error! {:?}", response)), - } - }) -} - -pub fn generate_embeddings(request: EmbeddingRequest) -> Result>> { - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .unwrap_or_else(|e| panic!("failed to initialize tokio runtime: {}", e)); - - runtime.block_on(async { - let instance = OllamaInstance::new(request.payload.model, request.url); - let mut embeddings: Vec> = vec![]; - for input in request.payload.input { - let response = instance.generate_embedding(input).await; - let embedding = match response { - Ok(embed) => embed, - Err(e) => panic!("Unable to generate embeddings.\nError: {:?}", e), - }; - embeddings.push(embedding); - } - Ok(embeddings) - }) -} diff --git a/core/src/transformers/providers/mod.rs b/core/src/transformers/providers/mod.rs index 98cd7c4..3c9cc40 100644 --- a/core/src/transformers/providers/mod.rs +++ b/core/src/transformers/providers/mod.rs @@ -1,6 +1,7 @@ pub mod cohere; pub mod ollama; pub mod openai; +pub mod portkey; pub mod vector_serve; use anyhow::Result; @@ -62,8 +63,13 @@ pub fn get_provider( ModelSource::SentenceTransformers => Ok(Box::new( providers::vector_serve::VectorServeProvider::new(url, api_key), )), - ModelSource::Ollama | ModelSource::Tembo => Err(anyhow::anyhow!( + ModelSource::Ollama => Ok(Box::new(providers::ollama::OllamaProvider::new(url))), + ModelSource::Tembo => Err(anyhow::anyhow!( "Ollama/Tembo transformer not implemented yet" ))?, } } + +fn split_vector(vec: Vec, chunk_size: usize) -> Vec> { + vec.chunks(chunk_size).map(|chunk| chunk.to_vec()).collect() +} diff --git a/core/src/transformers/providers/ollama.rs b/core/src/transformers/providers/ollama.rs index a12d3e1..f1bc632 100644 --- a/core/src/transformers/providers/ollama.rs +++ b/core/src/transformers/providers/ollama.rs @@ -8,20 +8,19 @@ use url::Url; pub const OLLAMA_BASE_URL: &str = "http://localhost:3001"; pub struct OllamaProvider { - pub model_name: String, pub instance: Ollama, } #[derive(Clone, Debug, Serialize, Deserialize)] struct ModelInfo { - model: String, embedding_dimension: u32, max_seq_len: u32, } impl OllamaProvider { - pub fn new(model_name: String, url: String) -> Self { - let parsed_url = Url::parse(&url).unwrap_or_else(|_| panic!("invalid url: {}", url)); + pub fn new(url: Option) -> Self { + let url_in = url.unwrap_or_else(|| OLLAMA_BASE_URL.to_string()); + let parsed_url = Url::parse(&url_in).unwrap_or_else(|_| panic!("invalid url: {}", url_in)); let instance = Ollama::new( format!( "{}://{}", @@ -30,10 +29,7 @@ impl OllamaProvider { ), parsed_url.port().expect("parsed port missing"), ); - OllamaProvider { - model_name, - instance, - } + OllamaProvider { instance } } } @@ -44,10 +40,11 @@ impl EmbeddingProvider for OllamaProvider { request: &'a GenericEmbeddingRequest, ) -> Result { let mut all_embeddings: Vec> = Vec::with_capacity(request.input.len()); + let model_name = request.model.clone(); for ipt in request.input.iter() { let embed = self .instance - .generate_embeddings(self.model_name.clone(), ipt.clone(), None) + .generate_embeddings(model_name.clone(), ipt.clone(), None) .await?; all_embeddings.push(embed.embeddings); } @@ -66,9 +63,36 @@ impl EmbeddingProvider for OllamaProvider { } impl OllamaProvider { - pub async fn generate_response(&self, prompt_text: String) -> Result { - let req = GenerationRequest::new(self.model_name.clone(), prompt_text); + pub async fn generate_response( + &self, + model_name: String, + prompt_text: &str, + ) -> Result { + let req = GenerationRequest::new(model_name, prompt_text.to_owned()); let res = self.instance.generate(req).await?; Ok(res.response) } } + +pub fn check_model_host(url: &str) -> Result { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .unwrap_or_else(|e| panic!("failed to initialize tokio runtime: {}", e)); + + runtime.block_on(async { + let response = reqwest::get(url).await.unwrap(); + match response.status() { + reqwest::StatusCode::OK => Ok(format!("Success! {:?}", response)), + _ => Err(format!("Error! {:?}", response)), + } + }) +} + +pub fn ollama_embedding_dim(model_name: &str) -> i32 { + match model_name { + "llama2" => 5192, + _ => 1536, + } +} diff --git a/core/src/transformers/providers/openai.rs b/core/src/transformers/providers/openai.rs index 82a8f63..cd9339d 100644 --- a/core/src/transformers/providers/openai.rs +++ b/core/src/transformers/providers/openai.rs @@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize}; use super::{EmbeddingProvider, GenericEmbeddingRequest, GenericEmbeddingResponse}; use crate::errors::VectorizeError; use crate::transformers::http_handler::handle_response; +use crate::transformers::providers; use crate::transformers::types::Inputs; use async_trait::async_trait; use std::env; @@ -79,7 +80,7 @@ impl EmbeddingProvider for OpenAIProvider { let req = OpenAIEmbeddingBody::from(request.clone()); let num_inputs = request.input.len(); let todo_requests: Vec = if num_inputs > 2048 { - split_vector(req.input, 2048) + providers::split_vector(req.input, 2048) .iter() .map(|chunk| OpenAIEmbeddingBody { input: chunk.clone(), @@ -128,10 +129,6 @@ pub fn openai_embedding_dim(model_name: &str) -> i32 { } } -fn split_vector(vec: Vec, chunk_size: usize) -> Vec> { - vec.chunks(chunk_size).map(|chunk| chunk.to_vec()).collect() -} - #[cfg(test)] mod integration_tests { use super::*; diff --git a/extension/src/chat/ops.rs b/extension/src/chat/ops.rs index 0c3bb57..306bf3f 100644 --- a/extension/src/chat/ops.rs +++ b/extension/src/chat/ops.rs @@ -10,8 +10,7 @@ use handlebars::Handlebars; use openai_api_rs::v1::api::Client; use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; use pgrx::prelude::*; -use vectorize_core::transformers::ollama::LLMFunctions; -use vectorize_core::transformers::ollama::OllamaInstance; +use vectorize_core::transformers::providers::ollama::OllamaProvider; use vectorize_core::types::Model; use vectorize_core::types::ModelSource; @@ -231,11 +230,13 @@ fn call_ollama_chat_completions(prompts: RenderedPrompt, model: &str) -> Result< .build() .unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e)); - let instance = OllamaInstance::new(model.to_string(), url.to_string()); + let ollama_provider = OllamaProvider::new(Some(url)); + + let prompt = prompts.sys_rendered + "\n" + &prompts.user_rendered; let response = runtime.block_on(async { - instance - .generate_reponse(prompts.sys_rendered + "\n" + &prompts.user_rendered) + ollama_provider + .generate_response(model.to_owned(), &prompt) .await }); diff --git a/extension/src/search.rs b/extension/src/search.rs index ebe4508..ab86790 100644 --- a/extension/src/search.rs +++ b/extension/src/search.rs @@ -8,8 +8,8 @@ use crate::util; use anyhow::{Context, Result}; use pgrx::prelude::*; -use vectorize_core::transformers::ollama::check_model_host; use vectorize_core::transformers::providers::get_provider; +use vectorize_core::transformers::providers::ollama::check_model_host; use vectorize_core::types::{self, Model, ModelSource, TableMethod, VectorizeMeta}; #[allow(clippy::too_many_arguments)] From ae4dd5b55e5c90afe3dae606b1494847b7e91f68 Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Thu, 22 Aug 2024 14:14:29 -0500 Subject: [PATCH 02/11] add portkey embedding source --- core/src/transformers/mod.rs | 1 - core/src/transformers/providers/portkey.rs | 216 +++++++++++++++++++++ 2 files changed, 216 insertions(+), 1 deletion(-) create mode 100644 core/src/transformers/providers/portkey.rs diff --git a/core/src/transformers/mod.rs b/core/src/transformers/mod.rs index 670f2d2..cb739e0 100644 --- a/core/src/transformers/mod.rs +++ b/core/src/transformers/mod.rs @@ -1,5 +1,4 @@ pub mod generic; pub mod http_handler; -pub mod ollama; pub mod providers; pub mod types; diff --git a/core/src/transformers/providers/portkey.rs b/core/src/transformers/providers/portkey.rs new file mode 100644 index 0000000..691f5a2 --- /dev/null +++ b/core/src/transformers/providers/portkey.rs @@ -0,0 +1,216 @@ +use reqwest::Client; +use serde::{Deserialize, Serialize}; + +use super::{EmbeddingProvider, GenericEmbeddingRequest, GenericEmbeddingResponse}; +use crate::errors::VectorizeError; +use crate::transformers::http_handler::handle_response; +use crate::transformers::providers; +use crate::transformers::providers::openai; +use crate::transformers::types::Inputs; +use async_trait::async_trait; +use std::env; + +pub const PORTKEY_BASE_URL: &str = "https://api.portkey.ai/v1"; +pub const MAX_TOKEN_LEN: usize = 8192; + +pub struct PortkeyProvider { + pub url: String, + pub api_key: String, +} + +impl PortkeyProvider { + pub fn new(url: Option, api_key: Option) -> Self { + let final_url = match url { + Some(url) => url, + None => PORTKEY_BASE_URL.to_string(), + }; + let final_api_key = match api_key { + Some(api_key) => api_key, + None => env::var("PORTKEY_API_KEY").expect("PORTKEY_API_KEY not set"), + }; + PortkeyProvider { + url: final_url, + api_key: final_api_key, + } + } +} + +#[async_trait] +impl EmbeddingProvider for PortkeyProvider { + async fn generate_embedding<'a>( + &self, + request: &'a GenericEmbeddingRequest, + ) -> Result { + let client = Client::new(); + + let req = openai::OpenAIEmbeddingBody::from(request.clone()); + let num_inputs = request.input.len(); + let todo_requests: Vec = if num_inputs > 2048 { + providers::split_vector(req.input, 2048) + .iter() + .map(|chunk| openai::OpenAIEmbeddingBody { + input: chunk.clone(), + model: request.model.clone(), + }) + .collect() + } else { + vec![req] + }; + + let mut all_embeddings: Vec> = Vec::with_capacity(num_inputs); + + for request_payload in todo_requests.iter() { + let payload_val = serde_json::to_value(request_payload)?; + let embeddings_url = format!("{}/embeddings", self.url); + let response = client + .post(&embeddings_url) + .timeout(std::time::Duration::from_secs(120_u64)) + .header("Accept", "application/json") + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&payload_val) + .send() + .await?; + + let embeddings = + handle_response::(response, "embeddings").await?; + all_embeddings.extend(embeddings.data.iter().map(|x| x.embedding.clone())); + } + Ok(GenericEmbeddingResponse { + embeddings: all_embeddings, + }) + } + + async fn model_dim(&self, model_name: &str) -> Result { + // determine embedding dim by generating an embedding and getting length of array + let req = GenericEmbeddingRequest { + input: vec!["hello world".to_string()], + model: model_name.to_string(), + }; + let embedding = self.generate_embedding(&req).await?; + let dim = embedding.embeddings[0].len(); + Ok(dim as u32) + } +} + +#[cfg(test)] +mod integration_tests { + use super::*; + use tokio::test as async_test; + + #[async_test] + async fn test_generate_embedding() { + let provider = PortkeyProvider::new(Some(PORTKEY_BASE_URL.to_string()), None); + let request = GenericEmbeddingRequest { + model: "text-embedding-ada-002".to_string(), + input: vec!["hello world".to_string()], + }; + + let embeddings = provider.generate_embedding(&request).await.unwrap(); + assert!( + !embeddings.embeddings.is_empty(), + "Embeddings should not be empty" + ); + assert!( + embeddings.embeddings.len() == 1, + "Embeddings should have length 1" + ); + assert!( + embeddings.embeddings[0].len() == 1536, + "Embeddings should have length 1536" + ); + } +} + +// OpenAI embedding model has a limit of 8192 tokens per input +// there can be a number of ways condense the inputs +pub fn trim_inputs(inputs: &[Inputs]) -> Vec { + inputs + .iter() + .map(|input| { + if input.token_estimate as usize > MAX_TOKEN_LEN { + // not example taking tokens, but naive way to trim input + let tokens: Vec<&str> = input.inputs.split_whitespace().collect(); + tokens + .into_iter() + .take(MAX_TOKEN_LEN) + .collect::>() + .join(" ") + } else { + input.inputs.clone() + } + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_trim_inputs_no_trimming_required() { + let data = vec![ + Inputs { + record_id: "1".to_string(), + inputs: "token1 token2".to_string(), + token_estimate: 2, + }, + Inputs { + record_id: "2".to_string(), + inputs: "token3 token4".to_string(), + token_estimate: 2, + }, + ]; + + let trimmed = trim_inputs(&data); + assert_eq!(trimmed, vec!["token1 token2", "token3 token4"]); + } + + #[test] + fn test_trim_inputs_trimming_required() { + let token_len = 1000000; + let long_input = (0..token_len) + .map(|i| format!("token{}", i)) + .collect::>() + .join(" "); + + let num_tokens = long_input.split_whitespace().count(); + assert_eq!(num_tokens, token_len); + + let data = vec![Inputs { + record_id: "1".to_string(), + inputs: long_input.clone(), + token_estimate: token_len as i32, + }]; + + let trimmed = trim_inputs(&data); + let trimmed_input = trimmed[0].clone(); + let trimmed_length = trimmed_input.split_whitespace().count(); + assert_eq!(trimmed_length, MAX_TOKEN_LEN); + } + + #[test] + fn test_trim_inputs_mixed_cases() { + let num_tokens_in = 1000000; + let long_input = (0..num_tokens_in) + .map(|i| format!("token{}", i)) + .collect::>() + .join(" "); + let data = vec![ + Inputs { + record_id: "1".to_string(), + inputs: "token1 token2".to_string(), + token_estimate: 2, + }, + Inputs { + record_id: "2".to_string(), + inputs: long_input.clone(), + token_estimate: num_tokens_in, + }, + ]; + + let trimmed = trim_inputs(&data); + assert_eq!(trimmed[0].split_whitespace().count(), 2); + assert_eq!(trimmed[1].split_whitespace().count(), MAX_TOKEN_LEN); + } +} From 6f737259e3a8fce0664ce8c1dc48bdbded165d42 Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Thu, 22 Aug 2024 17:04:19 -0500 Subject: [PATCH 03/11] Add portkey provider --- core/src/transformers/providers/mod.rs | 3 + core/src/transformers/providers/openai.rs | 1 - core/src/transformers/providers/portkey.rs | 168 +++++++++------------ core/src/types.rs | 14 ++ 4 files changed, 85 insertions(+), 101 deletions(-) diff --git a/core/src/transformers/providers/mod.rs b/core/src/transformers/providers/mod.rs index 3c9cc40..e4475bf 100644 --- a/core/src/transformers/providers/mod.rs +++ b/core/src/transformers/providers/mod.rs @@ -60,6 +60,9 @@ pub fn get_provider( ModelSource::Cohere => Ok(Box::new(providers::cohere::CohereProvider::new( url, api_key, ))), + ModelSource::Portkey => Ok(Box::new(providers::portkey::PortkeyProvider::new( + url, api_key, None, + ))), ModelSource::SentenceTransformers => Ok(Box::new( providers::vector_serve::VectorServeProvider::new(url, api_key), )), diff --git a/core/src/transformers/providers/openai.rs b/core/src/transformers/providers/openai.rs index cd9339d..44e2df4 100644 --- a/core/src/transformers/providers/openai.rs +++ b/core/src/transformers/providers/openai.rs @@ -76,7 +76,6 @@ impl EmbeddingProvider for OpenAIProvider { request: &'a GenericEmbeddingRequest, ) -> Result { let client = Client::new(); - let req = OpenAIEmbeddingBody::from(request.clone()); let num_inputs = request.input.len(); let todo_requests: Vec = if num_inputs > 2048 { diff --git a/core/src/transformers/providers/portkey.rs b/core/src/transformers/providers/portkey.rs index 691f5a2..be5a054 100644 --- a/core/src/transformers/providers/portkey.rs +++ b/core/src/transformers/providers/portkey.rs @@ -1,13 +1,12 @@ use reqwest::Client; -use serde::{Deserialize, Serialize}; use super::{EmbeddingProvider, GenericEmbeddingRequest, GenericEmbeddingResponse}; use crate::errors::VectorizeError; use crate::transformers::http_handler::handle_response; use crate::transformers::providers; use crate::transformers::providers::openai; -use crate::transformers::types::Inputs; use async_trait::async_trait; +use serde::Deserialize; use std::env; pub const PORTKEY_BASE_URL: &str = "https://api.portkey.ai/v1"; @@ -16,10 +15,11 @@ pub const MAX_TOKEN_LEN: usize = 8192; pub struct PortkeyProvider { pub url: String, pub api_key: String, + pub virtual_key: String, } impl PortkeyProvider { - pub fn new(url: Option, api_key: Option) -> Self { + pub fn new(url: Option, api_key: Option, virtual_key: Option) -> Self { let final_url = match url { Some(url) => url, None => PORTKEY_BASE_URL.to_string(), @@ -28,9 +28,14 @@ impl PortkeyProvider { Some(api_key) => api_key, None => env::var("PORTKEY_API_KEY").expect("PORTKEY_API_KEY not set"), }; + let final_virtual_key = match virtual_key { + Some(vkey) => vkey, + None => env::var("PORTKEY_VIRTUAL_KEY").expect("PORTKEY_VIRTUAL_KEY not set"), + }; PortkeyProvider { url: final_url, api_key: final_api_key, + virtual_key: final_virtual_key, } } } @@ -56,18 +61,19 @@ impl EmbeddingProvider for PortkeyProvider { } else { vec![req] }; + let embeddings_url = format!("{}/embeddings", self.url); let mut all_embeddings: Vec> = Vec::with_capacity(num_inputs); - for request_payload in todo_requests.iter() { let payload_val = serde_json::to_value(request_payload)?; - let embeddings_url = format!("{}/embeddings", self.url); let response = client .post(&embeddings_url) .timeout(std::time::Duration::from_secs(120_u64)) .header("Accept", "application/json") .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", self.api_key)) + .header("x-portkey-virtual-key", self.virtual_key.clone()) + .header("x-portkey-api-key", &self.api_key) + // .header("x-portkey-provider", portkey_provider) .json(&payload_val) .send() .await?; @@ -93,14 +99,60 @@ impl EmbeddingProvider for PortkeyProvider { } } +#[derive(Deserialize, Debug)] +struct ChatResponse { + choices: Vec, +} + +#[derive(Deserialize, Debug)] +struct Choice { + message: ChatMessage, +} + +#[derive(Deserialize, Debug)] +struct ChatMessage { + content: String, +} + +impl PortkeyProvider { + pub async fn generate_response( + &self, + model_name: String, + prompt_text: &str, + ) -> Result { + let client = Client::new(); + let message = serde_json::json!({ + "model": model_name, + "messages": [{"role": "user", "content": prompt_text}], + }); + let chat_url = format!("{}/chat/completions", self.url); + let response = client + .post(&chat_url) + .timeout(std::time::Duration::from_secs(120_u64)) + .header("Accept", "application/json") + .header("Content-Type", "application/json") + .header("x-portkey-virtual-key", self.virtual_key.clone()) + .header("x-portkey-api-key", &self.api_key) + // .header("x-portkey-provider", portkey_provider) + .json(&message) + .send() + .await?; + let chat_response = handle_response::(response, "embeddings").await?; + Ok(chat_response.choices[0].message.content.clone()) + } +} + #[cfg(test)] -mod integration_tests { +mod portkey_integration_tests { use super::*; use tokio::test as async_test; #[async_test] - async fn test_generate_embedding() { - let provider = PortkeyProvider::new(Some(PORTKEY_BASE_URL.to_string()), None); + async fn test_portkey_openai() { + let portkey_api_key = env::var("PORTKEY_API_KEY").expect("PORTKEY_API_KEY not set"); + let portkey_virtual_key = + env::var("PORTKEY_VIRTUAL_KEY_OPENAI").expect("PORTKEY_VIRTUAL_KEY_OPENAI not set"); + let provider = PortkeyProvider::new(None, Some(portkey_api_key), Some(portkey_virtual_key)); let request = GenericEmbeddingRequest { model: "text-embedding-ada-002".to_string(), input: vec!["hello world".to_string()], @@ -119,98 +171,14 @@ mod integration_tests { embeddings.embeddings[0].len() == 1536, "Embeddings should have length 1536" ); - } -} - -// OpenAI embedding model has a limit of 8192 tokens per input -// there can be a number of ways condense the inputs -pub fn trim_inputs(inputs: &[Inputs]) -> Vec { - inputs - .iter() - .map(|input| { - if input.token_estimate as usize > MAX_TOKEN_LEN { - // not example taking tokens, but naive way to trim input - let tokens: Vec<&str> = input.inputs.split_whitespace().collect(); - tokens - .into_iter() - .take(MAX_TOKEN_LEN) - .collect::>() - .join(" ") - } else { - input.inputs.clone() - } - }) - .collect() -} -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_trim_inputs_no_trimming_required() { - let data = vec![ - Inputs { - record_id: "1".to_string(), - inputs: "token1 token2".to_string(), - token_estimate: 2, - }, - Inputs { - record_id: "2".to_string(), - inputs: "token3 token4".to_string(), - token_estimate: 2, - }, - ]; - - let trimmed = trim_inputs(&data); - assert_eq!(trimmed, vec!["token1 token2", "token3 token4"]); - } - - #[test] - fn test_trim_inputs_trimming_required() { - let token_len = 1000000; - let long_input = (0..token_len) - .map(|i| format!("token{}", i)) - .collect::>() - .join(" "); - - let num_tokens = long_input.split_whitespace().count(); - assert_eq!(num_tokens, token_len); - - let data = vec![Inputs { - record_id: "1".to_string(), - inputs: long_input.clone(), - token_estimate: token_len as i32, - }]; - - let trimmed = trim_inputs(&data); - let trimmed_input = trimmed[0].clone(); - let trimmed_length = trimmed_input.split_whitespace().count(); - assert_eq!(trimmed_length, MAX_TOKEN_LEN); - } + let dim = provider.model_dim("text-embedding-ada-002").await.unwrap(); + assert_eq!(dim, 1536); - #[test] - fn test_trim_inputs_mixed_cases() { - let num_tokens_in = 1000000; - let long_input = (0..num_tokens_in) - .map(|i| format!("token{}", i)) - .collect::>() - .join(" "); - let data = vec![ - Inputs { - record_id: "1".to_string(), - inputs: "token1 token2".to_string(), - token_estimate: 2, - }, - Inputs { - record_id: "2".to_string(), - inputs: long_input.clone(), - token_estimate: num_tokens_in, - }, - ]; - - let trimmed = trim_inputs(&data); - assert_eq!(trimmed[0].split_whitespace().count(), 2); - assert_eq!(trimmed[1].split_whitespace().count(), MAX_TOKEN_LEN); + let response = provider + .generate_response("gpt-3.5-turbo".to_string(), "Hello-World") + .await + .unwrap(); + assert!(!response.is_empty(), "Response should not be empty"); } } diff --git a/core/src/types.rs b/core/src/types.rs index 4254f0f..67419d1 100644 --- a/core/src/types.rs +++ b/core/src/types.rs @@ -160,6 +160,7 @@ impl Model { ModelSource::Ollama => self.name.clone(), ModelSource::Tembo => self.name.clone(), ModelSource::Cohere => self.name.clone(), + ModelSource::Portkey => self.name.clone(), } } } @@ -236,6 +237,7 @@ pub enum ModelSource { Ollama, Tembo, Cohere, + Portkey, } impl FromStr for ModelSource { @@ -248,6 +250,7 @@ impl FromStr for ModelSource { "sentence-transformers" => Ok(ModelSource::SentenceTransformers), "tembo" => Ok(ModelSource::Tembo), "cohere" => Ok(ModelSource::Cohere), + "portkey" => Ok(ModelSource::Portkey), _ => Ok(ModelSource::SentenceTransformers), } } @@ -261,6 +264,7 @@ impl Display for ModelSource { ModelSource::SentenceTransformers => write!(f, "sentence-transformers"), ModelSource::Tembo => write!(f, "tembo"), ModelSource::Cohere => write!(f, "cohere"), + ModelSource::Portkey => write!(f, "portkey"), } } } @@ -273,6 +277,7 @@ impl From for ModelSource { "sentence-transformers" => ModelSource::SentenceTransformers, "tembo" => ModelSource::Tembo, "cohere" => ModelSource::Cohere, + "portkey" => ModelSource::Portkey, // other cases are assumed to be private sentence-transformer compatible model // and can be hot-loaded _ => ModelSource::SentenceTransformers, @@ -285,6 +290,15 @@ impl From for ModelSource { mod model_tests { use super::*; + #[test] + fn test_portkey_parsing() { + let model = Model::new("portkey/openai/text-embedding-ada-002").unwrap(); + assert_eq!(model.source, ModelSource::Portkey); + assert_eq!(model.fullname, "portkey/openai/text-embedding-ada-002"); + assert_eq!(model.name, "text-embedding-ada-002"); + assert_eq!(model.api_name(), "text-embedding-ada-002"); + } + #[test] fn test_tembo_parsing() { let model = Model::new("tembo/meta-llama/Meta-Llama-3-8B-Instruct").unwrap(); From cf05bf1c594cac7445fc242da5e9c800d607d6e8 Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Thu, 22 Aug 2024 17:05:18 -0500 Subject: [PATCH 04/11] sql api --- extension/src/api.rs | 7 +++++- extension/src/chat/ops.rs | 51 +++++++++++++++------------------------ extension/src/guc.rs | 40 +++++++++++++++++++++++++++++- 3 files changed, 65 insertions(+), 33 deletions(-) diff --git a/extension/src/api.rs b/extension/src/api.rs index 982ba12..0718fff 100644 --- a/extension/src/api.rs +++ b/extension/src/api.rs @@ -1,5 +1,6 @@ use crate::chat::ops::{call_chat, get_chat_response}; use crate::chat::types::RenderedPrompt; +use crate::guc::get_guc_configs; use crate::search::{self, init_table}; use crate::transformers::generic::env_interpolate_string; use crate::transformers::transform; @@ -162,7 +163,11 @@ fn generate( sys_rendered: "".to_string(), user_rendered: input.to_string(), }; - get_chat_response(prompt, &model, api_key) + let mut guc_configs = get_guc_configs(&model.source); + if let Some(api_key) = api_key { + guc_configs.api_key = Some(api_key); + } + get_chat_response(prompt, &model, &guc_configs) } #[pg_extern] diff --git a/extension/src/chat/ops.rs b/extension/src/chat/ops.rs index 306bf3f..1ef5ad0 100644 --- a/extension/src/chat/ops.rs +++ b/extension/src/chat/ops.rs @@ -5,6 +5,7 @@ use crate::search; use crate::transformers::generic::get_env_interpolated_guc; use crate::util::get_vectorize_meta_spi; +use anyhow::Context; use anyhow::{anyhow, Result}; use handlebars::Handlebars; use openai_api_rs::v1::api::Client; @@ -49,6 +50,9 @@ pub fn call_chat( ModelSource::SentenceTransformers | ModelSource::Cohere => { error!("SentenceTransformers and Cohere not yet supported for chat completions") } + ModelSource::Portkey => { + get_bpe_from_model(&chat_model.name).expect("failed to get BPE from model") + } }; // can only be 1 column in a chat job, for now, so safe to grab first element @@ -122,7 +126,8 @@ pub fn call_chat( )?; // http request to chat completions - let chat_response = get_chat_response(rendered_prompt, chat_model, api_key)?; + let guc_configs = guc::get_guc_configs(&chat_model.source); + let chat_response = get_chat_response(rendered_prompt, chat_model, &guc_configs)?; Ok(ChatResponse { context: search_results, @@ -133,10 +138,12 @@ pub fn call_chat( pub fn get_chat_response( prompt: RenderedPrompt, model: &Model, - api_key: Option, + guc_configs: &guc::ModelGucConfig, ) -> Result { match model.source { - ModelSource::OpenAI | ModelSource::Tembo => call_chat_completions(prompt, model, api_key), + ModelSource::OpenAI | ModelSource::Tembo | ModelSource::Portkey => { + call_chat_completions(prompt, model, &guc_configs) + } ModelSource::Ollama => call_ollama_chat_completions(prompt, &model.name), ModelSource::SentenceTransformers | ModelSource::Cohere => { error!("SentenceTransformers and Cohere not yet supported for chat completions"); @@ -157,38 +164,20 @@ fn render_user_message(user_prompt_template: &str, context: &str, query: &str) - fn call_chat_completions( prompts: RenderedPrompt, model: &Model, - api_key: Option, + guc_configs: &guc::ModelGucConfig, ) -> Result { - let api_key = match api_key { - Some(k) => k.to_string(), - None => { - let this_guc = match model.source { - ModelSource::Tembo => guc::VectorizeGuc::TemboAIKey, - ModelSource::OpenAI => guc::VectorizeGuc::OpenAIKey, - _ => { - error!("API key not found for model source"); - } - }; - match guc::get_guc(this_guc) { - Some(k) => k, - None => { - error!("failed to get API key from GUC"); - } - } - } - }; + let base_url = guc_configs + .service_url + .clone() + .context("service url request for chat completions")?; - let base_url = match model.source { - ModelSource::Tembo => get_env_interpolated_guc(guc::VectorizeGuc::TemboServiceUrl) - .expect("vectorize.tembo_service_url must be set"), - ModelSource::OpenAI => get_env_interpolated_guc(guc::VectorizeGuc::OpenAIServiceUrl) - .expect("vectorize.openai_service_url must be set"), - _ => { - error!("API key not found for model source"); - } - }; // set the url for openai client env::set_var("OPENAI_API_BASE", base_url); + let api_key = if let Some(k) = guc_configs.api_key.clone() { + k + } else { + "None".to_owned() + }; let client = Client::new(api_key); let sys_msg = chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::system, diff --git a/extension/src/guc.rs b/extension/src/guc.rs index bbd2d73..30dc9ba 100644 --- a/extension/src/guc.rs +++ b/extension/src/guc.rs @@ -4,6 +4,8 @@ use pgrx::*; use anyhow::Result; use vectorize_core::types::ModelSource; +use crate::transformers::generic::env_interpolate_string; + pub static VECTORIZE_HOST: GucSetting> = GucSetting::>::new(None); pub static VECTORIZE_DATABASE_NAME: GucSetting> = GucSetting::>::new(None); @@ -23,6 +25,8 @@ pub static OLLAMA_SERVICE_HOST: GucSetting> = GucSetting::> = GucSetting::>::new(None); pub static TEMBO_API_KEY: GucSetting> = GucSetting::>::new(None); pub static COHERE_API_KEY: GucSetting> = GucSetting::>::new(None); +pub static PORTKEY_API_KEY: GucSetting> = GucSetting::>::new(None); +pub static PORTKEY_VIRTUAL_KEY: GucSetting> = GucSetting::>::new(None); // initialize GUCs pub fn init_guc() { @@ -143,6 +147,24 @@ pub fn init_guc() { GucContext::Suset, GucFlags::default(), ); + + GucRegistry::define_string_guc( + "vectorize.portkey_api_key", + "API Key for the Portkey platform", + "API Key for the Portkey platform", + &PORTKEY_API_KEY, + GucContext::Suset, + GucFlags::default(), + ); + + GucRegistry::define_string_guc( + "vectorize.portkey_virtual_key", + "Virtual Key for the Portkey platform", + "Virtual Key for the Portkey platform", + &PORTKEY_API_KEY, + GucContext::Suset, + GucFlags::default(), + ); } // for handling of GUCs that can be error prone @@ -158,6 +180,8 @@ pub enum VectorizeGuc { OllamaServiceUrl, TemboServiceUrl, CohereApiKey, + PortkeyApiKey, + PortkeyVirtualKey, } /// a convenience function to get this project's GUCs @@ -173,10 +197,13 @@ pub fn get_guc(guc: VectorizeGuc) -> Option { VectorizeGuc::OpenAIServiceUrl => OPENAI_BASE_URL.get(), VectorizeGuc::EmbeddingServiceApiKey => EMBEDDING_SERVICE_API_KEY.get(), VectorizeGuc::CohereApiKey => COHERE_API_KEY.get(), + VectorizeGuc::PortkeyApiKey => PORTKEY_API_KEY.get(), + VectorizeGuc::PortkeyVirtualKey => PORTKEY_VIRTUAL_KEY.get(), }; if let Some(cstr) = val { if let Ok(s) = handle_cstr(cstr) { - Some(s) + let interpolated = env_interpolate_string(&s).unwrap(); + Some(interpolated) } else { error!("failed to convert CStr to str"); } @@ -199,6 +226,7 @@ fn handle_cstr(cstr: &CStr) -> Result { pub struct ModelGucConfig { pub api_key: Option, pub service_url: Option, + pub virtual_key: Option, } pub fn get_guc_configs(model_source: &ModelSource) -> ModelGucConfig { @@ -206,22 +234,32 @@ pub fn get_guc_configs(model_source: &ModelSource) -> ModelGucConfig { ModelSource::OpenAI => ModelGucConfig { api_key: get_guc(VectorizeGuc::OpenAIKey), service_url: get_guc(VectorizeGuc::OpenAIServiceUrl), + virtual_key: None, }, ModelSource::Tembo => ModelGucConfig { api_key: get_guc(VectorizeGuc::TemboAIKey), service_url: get_guc(VectorizeGuc::TemboServiceUrl), + virtual_key: None, }, ModelSource::SentenceTransformers => ModelGucConfig { api_key: get_guc(VectorizeGuc::EmbeddingServiceApiKey), service_url: get_guc(VectorizeGuc::EmbeddingServiceUrl), + virtual_key: None, }, ModelSource::Cohere => ModelGucConfig { api_key: get_guc(VectorizeGuc::CohereApiKey), service_url: None, + virtual_key: None, }, ModelSource::Ollama => ModelGucConfig { api_key: None, service_url: get_guc(VectorizeGuc::OllamaServiceUrl), + virtual_key: None, + }, + ModelSource::Portkey => ModelGucConfig { + api_key: get_guc(VectorizeGuc::PortkeyApiKey), + service_url: None, + virtual_key: get_guc(VectorizeGuc::PortkeyVirtualKey), }, } } From 3b2e2762d20caace037f476b658b7b6ff3429bac Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Thu, 22 Aug 2024 22:11:19 -0500 Subject: [PATCH 05/11] add portkey --- core/src/transformers/providers/mod.rs | 26 +++++++- core/src/transformers/providers/ollama.rs | 11 +++- core/src/transformers/providers/openai.rs | 73 +++++++++++++++------- core/src/transformers/providers/portkey.rs | 32 ++++------ core/src/worker/base.rs | 1 + 5 files changed, 96 insertions(+), 47 deletions(-) diff --git a/core/src/transformers/providers/mod.rs b/core/src/transformers/providers/mod.rs index e4475bf..b743c2a 100644 --- a/core/src/transformers/providers/mod.rs +++ b/core/src/transformers/providers/mod.rs @@ -52,6 +52,7 @@ pub fn get_provider( model_source: &ModelSource, api_key: Option, url: Option, + virtual_key: Option, ) -> Result, VectorizeError> { match model_source { ModelSource::OpenAI => Ok(Box::new(providers::openai::OpenAIProvider::new( @@ -61,7 +62,9 @@ pub fn get_provider( url, api_key, ))), ModelSource::Portkey => Ok(Box::new(providers::portkey::PortkeyProvider::new( - url, api_key, None, + url, + api_key, + virtual_key, ))), ModelSource::SentenceTransformers => Ok(Box::new( providers::vector_serve::VectorServeProvider::new(url, api_key), @@ -76,3 +79,24 @@ pub fn get_provider( fn split_vector(vec: Vec, chunk_size: usize) -> Vec> { vec.chunks(chunk_size).map(|chunk| chunk.to_vec()).collect() } + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ChatMessageRequest { + pub role: String, + pub content: String, +} + +#[derive(Deserialize, Debug)] +struct ChatResponse { + choices: Vec, +} + +#[derive(Deserialize, Debug)] +struct Choice { + message: ResponseMessage, +} + +#[derive(Deserialize, Debug)] +struct ResponseMessage { + content: String, +} diff --git a/core/src/transformers/providers/ollama.rs b/core/src/transformers/providers/ollama.rs index f1bc632..9180580 100644 --- a/core/src/transformers/providers/ollama.rs +++ b/core/src/transformers/providers/ollama.rs @@ -1,4 +1,4 @@ -use super::{EmbeddingProvider, GenericEmbeddingRequest, GenericEmbeddingResponse}; +use super::{ChatMessageRequest, EmbeddingProvider, GenericEmbeddingRequest, GenericEmbeddingResponse}; use crate::errors::VectorizeError; use async_trait::async_trait; use ollama_rs::{generation::completion::request::GenerationRequest, Ollama}; @@ -66,9 +66,14 @@ impl OllamaProvider { pub async fn generate_response( &self, model_name: String, - prompt_text: &str, + prompt_text: &[ChatMessageRequest], ) -> Result { - let req = GenerationRequest::new(model_name, prompt_text.to_owned()); + let single_prompt: String = prompt_text + .iter() + .map(|x| x.content.clone()) + .collect::>() + .join("\n\n"); + let req = GenerationRequest::new(model_name, single_prompt.to_owned()); let res = self.instance.generate(req).await?; Ok(res.response) } diff --git a/core/src/transformers/providers/openai.rs b/core/src/transformers/providers/openai.rs index 44e2df4..9e1c1b9 100644 --- a/core/src/transformers/providers/openai.rs +++ b/core/src/transformers/providers/openai.rs @@ -1,7 +1,10 @@ use reqwest::Client; use serde::{Deserialize, Serialize}; -use super::{EmbeddingProvider, GenericEmbeddingRequest, GenericEmbeddingResponse}; +use super::{ + ChatMessageRequest, ChatResponse, EmbeddingProvider, GenericEmbeddingRequest, + GenericEmbeddingResponse, +}; use crate::errors::VectorizeError; use crate::transformers::http_handler::handle_response; use crate::transformers::providers; @@ -128,6 +131,53 @@ pub fn openai_embedding_dim(model_name: &str) -> i32 { } } +impl OpenAIProvider { + pub async fn generate_response( + &self, + model_name: String, + messages: &[ChatMessageRequest], + ) -> Result { + let client = Client::new(); + let chat_url = format!("{}/chat/completions", self.url); + let message = serde_json::json!({ + "model": model_name, + "messages": messages, + }); + let response = client + .post(&chat_url) + .timeout(std::time::Duration::from_secs(120_u64)) + .header("Accept", "application/json") + .header("Content-Type", "application/json") + .header("Authorization", &format!("Bearer {}", self.api_key)) + .json(&message) + .send() + .await?; + let chat_response = handle_response::(response, "embeddings").await?; + Ok(chat_response.choices[0].message.content.clone()) + } +} + +// OpenAI embedding model has a limit of 8192 tokens per input +// there can be a number of ways condense the inputs +pub fn trim_inputs(inputs: &[Inputs]) -> Vec { + inputs + .iter() + .map(|input| { + if input.token_estimate as usize > MAX_TOKEN_LEN { + // not example taking tokens, but naive way to trim input + let tokens: Vec<&str> = input.inputs.split_whitespace().collect(); + tokens + .into_iter() + .take(MAX_TOKEN_LEN) + .collect::>() + .join(" ") + } else { + input.inputs.clone() + } + }) + .collect() +} + #[cfg(test)] mod integration_tests { use super::*; @@ -157,27 +207,6 @@ mod integration_tests { } } -// OpenAI embedding model has a limit of 8192 tokens per input -// there can be a number of ways condense the inputs -pub fn trim_inputs(inputs: &[Inputs]) -> Vec { - inputs - .iter() - .map(|input| { - if input.token_estimate as usize > MAX_TOKEN_LEN { - // not example taking tokens, but naive way to trim input - let tokens: Vec<&str> = input.inputs.split_whitespace().collect(); - tokens - .into_iter() - .take(MAX_TOKEN_LEN) - .collect::>() - .join(" ") - } else { - input.inputs.clone() - } - }) - .collect() -} - #[cfg(test)] mod tests { use super::*; diff --git a/core/src/transformers/providers/portkey.rs b/core/src/transformers/providers/portkey.rs index be5a054..84c5e51 100644 --- a/core/src/transformers/providers/portkey.rs +++ b/core/src/transformers/providers/portkey.rs @@ -1,12 +1,14 @@ use reqwest::Client; -use super::{EmbeddingProvider, GenericEmbeddingRequest, GenericEmbeddingResponse}; +use super::{ + ChatMessageRequest, ChatResponse, EmbeddingProvider, GenericEmbeddingRequest, + GenericEmbeddingResponse, +}; use crate::errors::VectorizeError; use crate::transformers::http_handler::handle_response; use crate::transformers::providers; use crate::transformers::providers::openai; use async_trait::async_trait; -use serde::Deserialize; use std::env; pub const PORTKEY_BASE_URL: &str = "https://api.portkey.ai/v1"; @@ -73,7 +75,6 @@ impl EmbeddingProvider for PortkeyProvider { .header("Content-Type", "application/json") .header("x-portkey-virtual-key", self.virtual_key.clone()) .header("x-portkey-api-key", &self.api_key) - // .header("x-portkey-provider", portkey_provider) .json(&payload_val) .send() .await?; @@ -99,31 +100,16 @@ impl EmbeddingProvider for PortkeyProvider { } } -#[derive(Deserialize, Debug)] -struct ChatResponse { - choices: Vec, -} - -#[derive(Deserialize, Debug)] -struct Choice { - message: ChatMessage, -} - -#[derive(Deserialize, Debug)] -struct ChatMessage { - content: String, -} - impl PortkeyProvider { pub async fn generate_response( &self, model_name: String, - prompt_text: &str, + messages: &[ChatMessageRequest], ) -> Result { let client = Client::new(); let message = serde_json::json!({ "model": model_name, - "messages": [{"role": "user", "content": prompt_text}], + "messages": messages, }); let chat_url = format!("{}/chat/completions", self.url); let response = client @@ -175,8 +161,12 @@ mod portkey_integration_tests { let dim = provider.model_dim("text-embedding-ada-002").await.unwrap(); assert_eq!(dim, 1536); + let chatmessage = ChatMessageRequest { + role: "user".to_string(), + content: "hello world".to_string(), + }; let response = provider - .generate_response("gpt-3.5-turbo".to_string(), "Hello-World") + .generate_response("gpt-3.5-turbo".to_string(), &[chatmessage]) .await .unwrap(); assert!(!response.is_empty(), "Response should not be empty"); diff --git a/core/src/worker/base.rs b/core/src/worker/base.rs index 7e2ba68..2908e1a 100644 --- a/core/src/worker/base.rs +++ b/core/src/worker/base.rs @@ -98,6 +98,7 @@ async fn execute_job( &job_meta.transformer.source, job_params.api_key.clone(), None, + None, )?; let embedding_request = From e63f4f977fedc835105a265bc02c20d2c3769a4b Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Thu, 22 Aug 2024 22:12:53 -0500 Subject: [PATCH 06/11] extension impl --- extension/Cargo.toml | 1 - extension/src/api.rs | 4 +- extension/src/chat/ops.rs | 133 +++++++++++------------------- extension/src/guc.rs | 16 +++- extension/src/search.rs | 1 + extension/src/transformers/mod.rs | 9 +- extension/src/workers/mod.rs | 1 + 7 files changed, 72 insertions(+), 93 deletions(-) diff --git a/extension/Cargo.toml b/extension/Cargo.toml index 2dc5cda..2b3be00 100644 --- a/extension/Cargo.toml +++ b/extension/Cargo.toml @@ -20,7 +20,6 @@ chrono = {version = "0.4.26", features = ["serde"] } handlebars = "5.1.0" lazy_static = "1.4.0" log = "0.4.21" -openai-api-rs = "4.0.6" pgmq = "0.26.1" pgrx = "0.11.4" postgres-types = "0.2.5" diff --git a/extension/src/api.rs b/extension/src/api.rs index 0718fff..d7575a8 100644 --- a/extension/src/api.rs +++ b/extension/src/api.rs @@ -1,4 +1,4 @@ -use crate::chat::ops::{call_chat, get_chat_response}; +use crate::chat::ops::{call_chat, call_chat_completions}; use crate::chat::types::RenderedPrompt; use crate::guc::get_guc_configs; use crate::search::{self, init_table}; @@ -167,7 +167,7 @@ fn generate( if let Some(api_key) = api_key { guc_configs.api_key = Some(api_key); } - get_chat_response(prompt, &model, &guc_configs) + call_chat_completions(prompt, &model, &guc_configs) } #[pg_extern] diff --git a/extension/src/chat/ops.rs b/extension/src/chat/ops.rs index 1ef5ad0..cb52ad2 100644 --- a/extension/src/chat/ops.rs +++ b/extension/src/chat/ops.rs @@ -1,17 +1,14 @@ -use std::env; - use crate::guc; use crate::search; -use crate::transformers::generic::get_env_interpolated_guc; use crate::util::get_vectorize_meta_spi; -use anyhow::Context; use anyhow::{anyhow, Result}; use handlebars::Handlebars; -use openai_api_rs::v1::api::Client; -use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; use pgrx::prelude::*; use vectorize_core::transformers::providers::ollama::OllamaProvider; +use vectorize_core::transformers::providers::openai::OpenAIProvider; +use vectorize_core::transformers::providers::portkey::PortkeyProvider; +use vectorize_core::transformers::providers::ChatMessageRequest; use vectorize_core::types::Model; use vectorize_core::types::ModelSource; @@ -127,7 +124,7 @@ pub fn call_chat( // http request to chat completions let guc_configs = guc::get_guc_configs(&chat_model.source); - let chat_response = get_chat_response(rendered_prompt, chat_model, &guc_configs)?; + let chat_response = call_chat_completions(rendered_prompt, chat_model, &guc_configs)?; Ok(ChatResponse { context: search_results, @@ -135,22 +132,6 @@ pub fn call_chat( }) } -pub fn get_chat_response( - prompt: RenderedPrompt, - model: &Model, - guc_configs: &guc::ModelGucConfig, -) -> Result { - match model.source { - ModelSource::OpenAI | ModelSource::Tembo | ModelSource::Portkey => { - call_chat_completions(prompt, model, &guc_configs) - } - ModelSource::Ollama => call_ollama_chat_completions(prompt, &model.name), - ModelSource::SentenceTransformers | ModelSource::Cohere => { - error!("SentenceTransformers and Cohere not yet supported for chat completions"); - } - } -} - fn render_user_message(user_prompt_template: &str, context: &str, query: &str) -> Result { let handlebars = Handlebars::new(); let render_vals = serde_json::json!({ @@ -161,80 +142,60 @@ fn render_user_message(user_prompt_template: &str, context: &str, query: &str) - Ok(user_rendered) } -fn call_chat_completions( +pub fn call_chat_completions( prompts: RenderedPrompt, model: &Model, guc_configs: &guc::ModelGucConfig, ) -> Result { - let base_url = guc_configs - .service_url - .clone() - .context("service url request for chat completions")?; - - // set the url for openai client - env::set_var("OPENAI_API_BASE", base_url); - let api_key = if let Some(k) = guc_configs.api_key.clone() { - k - } else { - "None".to_owned() - }; - let client = Client::new(api_key); - let sys_msg = chat_completion::ChatCompletionMessage { - role: chat_completion::MessageRole::system, - content: chat_completion::Content::Text(prompts.sys_rendered), - name: None, - }; - let usr_msg = chat_completion::ChatCompletionMessage { - role: chat_completion::MessageRole::user, - content: chat_completion::Content::Text(prompts.user_rendered), - name: None, - }; - - let req = ChatCompletionRequest::new(model.name.clone(), vec![sys_msg, usr_msg]); - let result = client.chat_completion(req)?; - // currently we only support single query, and not a conversation - // so we can safely select the first response for now - let responses = &result.choices[0]; - let chat_response: String = responses - .message - .content - .clone() - .expect("no response from chat model"); - Ok(chat_response) -} - -fn call_ollama_chat_completions(prompts: RenderedPrompt, model: &str) -> Result { - // get url from guc - - let url = match guc::get_guc(guc::VectorizeGuc::OllamaServiceUrl) { - Some(k) => k, - None => { - error!("failed to get Ollama url from GUC"); - } - }; - + let messages = vec![ + ChatMessageRequest { + role: "system".to_owned(), + content: prompts.sys_rendered.clone(), + }, + ChatMessageRequest { + role: "user".to_owned(), + content: prompts.user_rendered.clone(), + }, + ]; let runtime = tokio::runtime::Builder::new_current_thread() .enable_io() .enable_time() .build() .unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e)); - let ollama_provider = OllamaProvider::new(Some(url)); - - let prompt = prompts.sys_rendered + "\n" + &prompts.user_rendered; - - let response = runtime.block_on(async { - ollama_provider - .generate_response(model.to_owned(), &prompt) - .await - }); - - match response { - Ok(k) => Ok(k), - Err(k) => { - error!("Unable to generate response. Error: {k}"); + let chat_response: String = runtime.block_on(async { + match model.source { + ModelSource::OpenAI | ModelSource::Tembo => { + let provider = OpenAIProvider::new( + guc_configs.service_url.clone(), + guc_configs.api_key.clone(), + ); + provider + .generate_response(model.api_name(), &messages) + .await + } + ModelSource::Portkey => { + let provider = PortkeyProvider::new( + guc_configs.service_url.clone(), + guc_configs.api_key.clone(), + guc_configs.virtual_key.clone(), + ); + provider + .generate_response(model.api_name(), &messages) + .await + } + ModelSource::Ollama => { + let provider = OllamaProvider::new(guc_configs.service_url.clone()); + provider + .generate_response(model.api_name(), &messages) + .await + } + ModelSource::SentenceTransformers | ModelSource::Cohere => { + error!("SentenceTransformers and Cohere not yet supported for chat completions") + } } - } + })?; + Ok(chat_response) } // Trims the context to fit within the token limit when force_trim = True diff --git a/extension/src/guc.rs b/extension/src/guc.rs index 30dc9ba..fea0470 100644 --- a/extension/src/guc.rs +++ b/extension/src/guc.rs @@ -27,6 +27,7 @@ pub static TEMBO_API_KEY: GucSetting> = GucSetting:: pub static COHERE_API_KEY: GucSetting> = GucSetting::>::new(None); pub static PORTKEY_API_KEY: GucSetting> = GucSetting::>::new(None); pub static PORTKEY_VIRTUAL_KEY: GucSetting> = GucSetting::>::new(None); +pub static PORTKEY_SERVICE_URL: GucSetting> = GucSetting::>::new(None); // initialize GUCs pub fn init_guc() { @@ -148,6 +149,15 @@ pub fn init_guc() { GucFlags::default(), ); + GucRegistry::define_string_guc( + "vectorize.portkey_service_url", + "Base url for the Portkey platform", + "Base url for the Portkey platform", + &PORTKEY_SERVICE_URL, + GucContext::Suset, + GucFlags::default(), + ); + GucRegistry::define_string_guc( "vectorize.portkey_api_key", "API Key for the Portkey platform", @@ -161,7 +171,7 @@ pub fn init_guc() { "vectorize.portkey_virtual_key", "Virtual Key for the Portkey platform", "Virtual Key for the Portkey platform", - &PORTKEY_API_KEY, + &PORTKEY_VIRTUAL_KEY, GucContext::Suset, GucFlags::default(), ); @@ -182,6 +192,7 @@ pub enum VectorizeGuc { CohereApiKey, PortkeyApiKey, PortkeyVirtualKey, + PortkeyServiceUrl, } /// a convenience function to get this project's GUCs @@ -199,6 +210,7 @@ pub fn get_guc(guc: VectorizeGuc) -> Option { VectorizeGuc::CohereApiKey => COHERE_API_KEY.get(), VectorizeGuc::PortkeyApiKey => PORTKEY_API_KEY.get(), VectorizeGuc::PortkeyVirtualKey => PORTKEY_VIRTUAL_KEY.get(), + VectorizeGuc::PortkeyServiceUrl => PORTKEY_SERVICE_URL.get(), }; if let Some(cstr) = val { if let Ok(s) = handle_cstr(cstr) { @@ -258,7 +270,7 @@ pub fn get_guc_configs(model_source: &ModelSource) -> ModelGucConfig { }, ModelSource::Portkey => ModelGucConfig { api_key: get_guc(VectorizeGuc::PortkeyApiKey), - service_url: None, + service_url: get_guc(VectorizeGuc::PortkeyServiceUrl), virtual_key: get_guc(VectorizeGuc::PortkeyVirtualKey), }, } diff --git a/extension/src/search.rs b/extension/src/search.rs index ab86790..c901f13 100644 --- a/extension/src/search.rs +++ b/extension/src/search.rs @@ -43,6 +43,7 @@ pub fn init_table( &transformer.source, guc_configs.api_key.clone(), guc_configs.service_url, + None, )?; //synchronous diff --git a/extension/src/transformers/mod.rs b/extension/src/transformers/mod.rs index e1fe033..a589333 100644 --- a/extension/src/transformers/mod.rs +++ b/extension/src/transformers/mod.rs @@ -23,8 +23,13 @@ pub fn transform(input: &str, transformer: &Model, api_key: Option) -> V guc_configs.api_key }; - let provider = providers::get_provider(&transformer.source, api_key, guc_configs.service_url) - .expect("failed to get provider"); + let provider = providers::get_provider( + &transformer.source, + api_key, + guc_configs.service_url, + guc_configs.virtual_key, + ) + .expect("failed to get provider"); let input = Inputs { record_id: "".to_string(), inputs: input.to_string(), diff --git a/extension/src/workers/mod.rs b/extension/src/workers/mod.rs index d846e64..16df1c3 100644 --- a/extension/src/workers/mod.rs +++ b/extension/src/workers/mod.rs @@ -82,6 +82,7 @@ async fn execute_job(dbclient: Pool, msg: Message) &job_meta.transformer.source, job_params.api_key.clone(), guc_configs.service_url, + guc_configs.virtual_key, )?; let embedding_response = provider.generate_embedding(&embedding_request).await?; From 1f15c1fed85bb2e982caccae0a1e43e857d230a4 Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Thu, 22 Aug 2024 22:15:58 -0500 Subject: [PATCH 07/11] remove unused --- core/src/transformers/providers/portkey.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/transformers/providers/portkey.rs b/core/src/transformers/providers/portkey.rs index 84c5e51..ea3b2cf 100644 --- a/core/src/transformers/providers/portkey.rs +++ b/core/src/transformers/providers/portkey.rs @@ -119,7 +119,6 @@ impl PortkeyProvider { .header("Content-Type", "application/json") .header("x-portkey-virtual-key", self.virtual_key.clone()) .header("x-portkey-api-key", &self.api_key) - // .header("x-portkey-provider", portkey_provider) .json(&message) .send() .await?; From 5ea47c3442b4a1a4ef9e3e213d3b47cd64e9fd43 Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Thu, 22 Aug 2024 22:25:23 -0500 Subject: [PATCH 08/11] fmt --- core/src/transformers/providers/ollama.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/transformers/providers/ollama.rs b/core/src/transformers/providers/ollama.rs index 9180580..349508b 100644 --- a/core/src/transformers/providers/ollama.rs +++ b/core/src/transformers/providers/ollama.rs @@ -1,4 +1,6 @@ -use super::{ChatMessageRequest, EmbeddingProvider, GenericEmbeddingRequest, GenericEmbeddingResponse}; +use super::{ + ChatMessageRequest, EmbeddingProvider, GenericEmbeddingRequest, GenericEmbeddingResponse, +}; use crate::errors::VectorizeError; use async_trait::async_trait; use ollama_rs::{generation::completion::request::GenerationRequest, Ollama}; From 5e4808c6212d1f0b70e4d972cddc402558a06568 Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Fri, 23 Aug 2024 07:42:35 -0500 Subject: [PATCH 09/11] env --- .github/workflows/extension_ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/extension_ci.yml b/.github/workflows/extension_ci.yml index 5650114..3ad4ccd 100644 --- a/.github/workflows/extension_ci.yml +++ b/.github/workflows/extension_ci.yml @@ -135,6 +135,8 @@ jobs: env: HF_API_KEY: ${{ secrets.HF_API_KEY }} CO_API_KEY: ${{ secrets.CO_API_KEY }} + PORTKEY_API_KEY: ${{ secrets.PORTKEY_API_KEY }} + PORTKEY_VIRTUAL_KEY: ${{ secrets.PORTKEY_VIRTUAL_KEY }} run: | echo "\q" | make run make test-integration From 09793b17294abc4fbcc5703c64fb423a71648f28 Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Fri, 23 Aug 2024 08:15:47 -0500 Subject: [PATCH 10/11] env for core --- .github/workflows/extension_ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/extension_ci.yml b/.github/workflows/extension_ci.yml index 3ad4ccd..7123e37 100644 --- a/.github/workflows/extension_ci.yml +++ b/.github/workflows/extension_ci.yml @@ -111,6 +111,8 @@ jobs: env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} CO_API_KEY: ${{ secrets.CO_API_KEY }} + PORTKEY_API_KEY: ${{ secrets.PORTKEY_API_KEY }} + PORTKEY_VIRTUAL_KEY: ${{ secrets.PORTKEY_VIRTUAL_KEY }} run: | cd ../core && cargo test - name: Restore cached binaries From 74a6294531384e4615fa11adaf409b1a568f4168 Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Fri, 23 Aug 2024 08:34:25 -0500 Subject: [PATCH 11/11] fix env var --- .github/workflows/extension_ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/extension_ci.yml b/.github/workflows/extension_ci.yml index 7123e37..053ba13 100644 --- a/.github/workflows/extension_ci.yml +++ b/.github/workflows/extension_ci.yml @@ -112,7 +112,7 @@ jobs: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} CO_API_KEY: ${{ secrets.CO_API_KEY }} PORTKEY_API_KEY: ${{ secrets.PORTKEY_API_KEY }} - PORTKEY_VIRTUAL_KEY: ${{ secrets.PORTKEY_VIRTUAL_KEY }} + PORTKEY_VIRTUAL_KEY_OPENAI: ${{ secrets.PORTKEY_VIRTUAL_KEY_OPENAI }} run: | cd ../core && cargo test - name: Restore cached binaries @@ -138,7 +138,7 @@ jobs: HF_API_KEY: ${{ secrets.HF_API_KEY }} CO_API_KEY: ${{ secrets.CO_API_KEY }} PORTKEY_API_KEY: ${{ secrets.PORTKEY_API_KEY }} - PORTKEY_VIRTUAL_KEY: ${{ secrets.PORTKEY_VIRTUAL_KEY }} + PORTKEY_VIRTUAL_KEY_OPENAI: ${{ secrets.PORTKEY_VIRTUAL_KEY_OPENAI }} run: | echo "\q" | make run make test-integration