diff --git a/.github/workflows/extension_ci.yml b/.github/workflows/extension_ci.yml index 5650114..053ba13 100644 --- a/.github/workflows/extension_ci.yml +++ b/.github/workflows/extension_ci.yml @@ -111,6 +111,8 @@ jobs: env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} CO_API_KEY: ${{ secrets.CO_API_KEY }} + PORTKEY_API_KEY: ${{ secrets.PORTKEY_API_KEY }} + PORTKEY_VIRTUAL_KEY_OPENAI: ${{ secrets.PORTKEY_VIRTUAL_KEY_OPENAI }} run: | cd ../core && cargo test - name: Restore cached binaries @@ -135,6 +137,8 @@ jobs: env: HF_API_KEY: ${{ secrets.HF_API_KEY }} CO_API_KEY: ${{ secrets.CO_API_KEY }} + PORTKEY_API_KEY: ${{ secrets.PORTKEY_API_KEY }} + PORTKEY_VIRTUAL_KEY_OPENAI: ${{ secrets.PORTKEY_VIRTUAL_KEY_OPENAI }} run: | echo "\q" | make run make test-integration diff --git a/core/src/transformers/mod.rs b/core/src/transformers/mod.rs index 670f2d2..cb739e0 100644 --- a/core/src/transformers/mod.rs +++ b/core/src/transformers/mod.rs @@ -1,5 +1,4 @@ pub mod generic; pub mod http_handler; -pub mod ollama; pub mod providers; pub mod types; diff --git a/core/src/transformers/ollama.rs b/core/src/transformers/ollama.rs deleted file mode 100644 index 1a27b22..0000000 --- a/core/src/transformers/ollama.rs +++ /dev/null @@ -1,99 +0,0 @@ -use anyhow::Result; -use ollama_rs::{generation::completion::request::GenerationRequest, Ollama}; -use url::Url; - -use super::types::EmbeddingRequest; - -pub struct OllamaInstance { - pub model_name: String, - pub instance: Ollama, -} - -pub trait LLMFunctions { - fn new(model_name: String, url: String) -> Self; - #[allow(async_fn_in_trait)] - async fn generate_reponse(&self, prompt_text: String) -> Result; - #[allow(async_fn_in_trait)] - async fn generate_embedding(&self, inputs: String) -> Result, String>; -} - -impl LLMFunctions for OllamaInstance { - fn new(model_name: String, url: String) -> Self { - let parsed_url = Url::parse(&url).unwrap_or_else(|_| panic!("invalid url: {}", url)); - let instance = Ollama::new( - format!( - "{}://{}", - parsed_url.scheme(), - parsed_url.host_str().expect("parsed url missing") - ), - parsed_url.port().expect("parsed port missing"), - ); - OllamaInstance { - model_name, - instance, - } - } - async fn generate_reponse(&self, prompt_text: String) -> Result { - let req = GenerationRequest::new(self.model_name.clone(), prompt_text); - let res = self.instance.generate(req).await; - match res { - Ok(res) => Ok(res.response), - Err(e) => Err(e.to_string()), - } - } - async fn generate_embedding(&self, input: String) -> Result, String> { - let embed = self - .instance - .generate_embeddings(self.model_name.clone(), input, None) - .await; - match embed { - Ok(res) => Ok(res.embeddings), - Err(e) => Err(e.to_string()), - } - } -} - -pub fn ollama_embedding_dim(model_name: &str) -> i32 { - match model_name { - "llama2" => 5192, - _ => 1536, - } -} - -pub fn check_model_host(url: &str) -> Result { - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .unwrap_or_else(|e| panic!("failed to initialize tokio runtime: {}", e)); - - runtime.block_on(async { - let response = reqwest::get(url).await.unwrap(); - match response.status() { - reqwest::StatusCode::OK => Ok(format!("Success! {:?}", response)), - _ => Err(format!("Error! {:?}", response)), - } - }) -} - -pub fn generate_embeddings(request: EmbeddingRequest) -> Result>> { - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .unwrap_or_else(|e| panic!("failed to initialize tokio runtime: {}", e)); - - runtime.block_on(async { - let instance = OllamaInstance::new(request.payload.model, request.url); - let mut embeddings: Vec> = vec![]; - for input in request.payload.input { - let response = instance.generate_embedding(input).await; - let embedding = match response { - Ok(embed) => embed, - Err(e) => panic!("Unable to generate embeddings.\nError: {:?}", e), - }; - embeddings.push(embedding); - } - Ok(embeddings) - }) -} diff --git a/core/src/transformers/providers/mod.rs b/core/src/transformers/providers/mod.rs index 98cd7c4..b743c2a 100644 --- a/core/src/transformers/providers/mod.rs +++ b/core/src/transformers/providers/mod.rs @@ -1,6 +1,7 @@ pub mod cohere; pub mod ollama; pub mod openai; +pub mod portkey; pub mod vector_serve; use anyhow::Result; @@ -51,6 +52,7 @@ pub fn get_provider( model_source: &ModelSource, api_key: Option, url: Option, + virtual_key: Option, ) -> Result, VectorizeError> { match model_source { ModelSource::OpenAI => Ok(Box::new(providers::openai::OpenAIProvider::new( @@ -59,11 +61,42 @@ pub fn get_provider( ModelSource::Cohere => Ok(Box::new(providers::cohere::CohereProvider::new( url, api_key, ))), + ModelSource::Portkey => Ok(Box::new(providers::portkey::PortkeyProvider::new( + url, + api_key, + virtual_key, + ))), ModelSource::SentenceTransformers => Ok(Box::new( providers::vector_serve::VectorServeProvider::new(url, api_key), )), - ModelSource::Ollama | ModelSource::Tembo => Err(anyhow::anyhow!( + ModelSource::Ollama => Ok(Box::new(providers::ollama::OllamaProvider::new(url))), + ModelSource::Tembo => Err(anyhow::anyhow!( "Ollama/Tembo transformer not implemented yet" ))?, } } + +fn split_vector(vec: Vec, chunk_size: usize) -> Vec> { + vec.chunks(chunk_size).map(|chunk| chunk.to_vec()).collect() +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ChatMessageRequest { + pub role: String, + pub content: String, +} + +#[derive(Deserialize, Debug)] +struct ChatResponse { + choices: Vec, +} + +#[derive(Deserialize, Debug)] +struct Choice { + message: ResponseMessage, +} + +#[derive(Deserialize, Debug)] +struct ResponseMessage { + content: String, +} diff --git a/core/src/transformers/providers/ollama.rs b/core/src/transformers/providers/ollama.rs index a12d3e1..349508b 100644 --- a/core/src/transformers/providers/ollama.rs +++ b/core/src/transformers/providers/ollama.rs @@ -1,4 +1,6 @@ -use super::{EmbeddingProvider, GenericEmbeddingRequest, GenericEmbeddingResponse}; +use super::{ + ChatMessageRequest, EmbeddingProvider, GenericEmbeddingRequest, GenericEmbeddingResponse, +}; use crate::errors::VectorizeError; use async_trait::async_trait; use ollama_rs::{generation::completion::request::GenerationRequest, Ollama}; @@ -8,20 +10,19 @@ use url::Url; pub const OLLAMA_BASE_URL: &str = "http://localhost:3001"; pub struct OllamaProvider { - pub model_name: String, pub instance: Ollama, } #[derive(Clone, Debug, Serialize, Deserialize)] struct ModelInfo { - model: String, embedding_dimension: u32, max_seq_len: u32, } impl OllamaProvider { - pub fn new(model_name: String, url: String) -> Self { - let parsed_url = Url::parse(&url).unwrap_or_else(|_| panic!("invalid url: {}", url)); + pub fn new(url: Option) -> Self { + let url_in = url.unwrap_or_else(|| OLLAMA_BASE_URL.to_string()); + let parsed_url = Url::parse(&url_in).unwrap_or_else(|_| panic!("invalid url: {}", url_in)); let instance = Ollama::new( format!( "{}://{}", @@ -30,10 +31,7 @@ impl OllamaProvider { ), parsed_url.port().expect("parsed port missing"), ); - OllamaProvider { - model_name, - instance, - } + OllamaProvider { instance } } } @@ -44,10 +42,11 @@ impl EmbeddingProvider for OllamaProvider { request: &'a GenericEmbeddingRequest, ) -> Result { let mut all_embeddings: Vec> = Vec::with_capacity(request.input.len()); + let model_name = request.model.clone(); for ipt in request.input.iter() { let embed = self .instance - .generate_embeddings(self.model_name.clone(), ipt.clone(), None) + .generate_embeddings(model_name.clone(), ipt.clone(), None) .await?; all_embeddings.push(embed.embeddings); } @@ -66,9 +65,41 @@ impl EmbeddingProvider for OllamaProvider { } impl OllamaProvider { - pub async fn generate_response(&self, prompt_text: String) -> Result { - let req = GenerationRequest::new(self.model_name.clone(), prompt_text); + pub async fn generate_response( + &self, + model_name: String, + prompt_text: &[ChatMessageRequest], + ) -> Result { + let single_prompt: String = prompt_text + .iter() + .map(|x| x.content.clone()) + .collect::>() + .join("\n\n"); + let req = GenerationRequest::new(model_name, single_prompt.to_owned()); let res = self.instance.generate(req).await?; Ok(res.response) } } + +pub fn check_model_host(url: &str) -> Result { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .unwrap_or_else(|e| panic!("failed to initialize tokio runtime: {}", e)); + + runtime.block_on(async { + let response = reqwest::get(url).await.unwrap(); + match response.status() { + reqwest::StatusCode::OK => Ok(format!("Success! {:?}", response)), + _ => Err(format!("Error! {:?}", response)), + } + }) +} + +pub fn ollama_embedding_dim(model_name: &str) -> i32 { + match model_name { + "llama2" => 5192, + _ => 1536, + } +} diff --git a/core/src/transformers/providers/openai.rs b/core/src/transformers/providers/openai.rs index 82a8f63..9e1c1b9 100644 --- a/core/src/transformers/providers/openai.rs +++ b/core/src/transformers/providers/openai.rs @@ -1,9 +1,13 @@ use reqwest::Client; use serde::{Deserialize, Serialize}; -use super::{EmbeddingProvider, GenericEmbeddingRequest, GenericEmbeddingResponse}; +use super::{ + ChatMessageRequest, ChatResponse, EmbeddingProvider, GenericEmbeddingRequest, + GenericEmbeddingResponse, +}; use crate::errors::VectorizeError; use crate::transformers::http_handler::handle_response; +use crate::transformers::providers; use crate::transformers::types::Inputs; use async_trait::async_trait; use std::env; @@ -75,11 +79,10 @@ impl EmbeddingProvider for OpenAIProvider { request: &'a GenericEmbeddingRequest, ) -> Result { let client = Client::new(); - let req = OpenAIEmbeddingBody::from(request.clone()); let num_inputs = request.input.len(); let todo_requests: Vec = if num_inputs > 2048 { - split_vector(req.input, 2048) + providers::split_vector(req.input, 2048) .iter() .map(|chunk| OpenAIEmbeddingBody { input: chunk.clone(), @@ -128,8 +131,51 @@ pub fn openai_embedding_dim(model_name: &str) -> i32 { } } -fn split_vector(vec: Vec, chunk_size: usize) -> Vec> { - vec.chunks(chunk_size).map(|chunk| chunk.to_vec()).collect() +impl OpenAIProvider { + pub async fn generate_response( + &self, + model_name: String, + messages: &[ChatMessageRequest], + ) -> Result { + let client = Client::new(); + let chat_url = format!("{}/chat/completions", self.url); + let message = serde_json::json!({ + "model": model_name, + "messages": messages, + }); + let response = client + .post(&chat_url) + .timeout(std::time::Duration::from_secs(120_u64)) + .header("Accept", "application/json") + .header("Content-Type", "application/json") + .header("Authorization", &format!("Bearer {}", self.api_key)) + .json(&message) + .send() + .await?; + let chat_response = handle_response::(response, "embeddings").await?; + Ok(chat_response.choices[0].message.content.clone()) + } +} + +// OpenAI embedding model has a limit of 8192 tokens per input +// there can be a number of ways condense the inputs +pub fn trim_inputs(inputs: &[Inputs]) -> Vec { + inputs + .iter() + .map(|input| { + if input.token_estimate as usize > MAX_TOKEN_LEN { + // not example taking tokens, but naive way to trim input + let tokens: Vec<&str> = input.inputs.split_whitespace().collect(); + tokens + .into_iter() + .take(MAX_TOKEN_LEN) + .collect::>() + .join(" ") + } else { + input.inputs.clone() + } + }) + .collect() } #[cfg(test)] @@ -161,27 +207,6 @@ mod integration_tests { } } -// OpenAI embedding model has a limit of 8192 tokens per input -// there can be a number of ways condense the inputs -pub fn trim_inputs(inputs: &[Inputs]) -> Vec { - inputs - .iter() - .map(|input| { - if input.token_estimate as usize > MAX_TOKEN_LEN { - // not example taking tokens, but naive way to trim input - let tokens: Vec<&str> = input.inputs.split_whitespace().collect(); - tokens - .into_iter() - .take(MAX_TOKEN_LEN) - .collect::>() - .join(" ") - } else { - input.inputs.clone() - } - }) - .collect() -} - #[cfg(test)] mod tests { use super::*; diff --git a/core/src/transformers/providers/portkey.rs b/core/src/transformers/providers/portkey.rs new file mode 100644 index 0000000..ea3b2cf --- /dev/null +++ b/core/src/transformers/providers/portkey.rs @@ -0,0 +1,173 @@ +use reqwest::Client; + +use super::{ + ChatMessageRequest, ChatResponse, EmbeddingProvider, GenericEmbeddingRequest, + GenericEmbeddingResponse, +}; +use crate::errors::VectorizeError; +use crate::transformers::http_handler::handle_response; +use crate::transformers::providers; +use crate::transformers::providers::openai; +use async_trait::async_trait; +use std::env; + +pub const PORTKEY_BASE_URL: &str = "https://api.portkey.ai/v1"; +pub const MAX_TOKEN_LEN: usize = 8192; + +pub struct PortkeyProvider { + pub url: String, + pub api_key: String, + pub virtual_key: String, +} + +impl PortkeyProvider { + pub fn new(url: Option, api_key: Option, virtual_key: Option) -> Self { + let final_url = match url { + Some(url) => url, + None => PORTKEY_BASE_URL.to_string(), + }; + let final_api_key = match api_key { + Some(api_key) => api_key, + None => env::var("PORTKEY_API_KEY").expect("PORTKEY_API_KEY not set"), + }; + let final_virtual_key = match virtual_key { + Some(vkey) => vkey, + None => env::var("PORTKEY_VIRTUAL_KEY").expect("PORTKEY_VIRTUAL_KEY not set"), + }; + PortkeyProvider { + url: final_url, + api_key: final_api_key, + virtual_key: final_virtual_key, + } + } +} + +#[async_trait] +impl EmbeddingProvider for PortkeyProvider { + async fn generate_embedding<'a>( + &self, + request: &'a GenericEmbeddingRequest, + ) -> Result { + let client = Client::new(); + + let req = openai::OpenAIEmbeddingBody::from(request.clone()); + let num_inputs = request.input.len(); + let todo_requests: Vec = if num_inputs > 2048 { + providers::split_vector(req.input, 2048) + .iter() + .map(|chunk| openai::OpenAIEmbeddingBody { + input: chunk.clone(), + model: request.model.clone(), + }) + .collect() + } else { + vec![req] + }; + let embeddings_url = format!("{}/embeddings", self.url); + + let mut all_embeddings: Vec> = Vec::with_capacity(num_inputs); + for request_payload in todo_requests.iter() { + let payload_val = serde_json::to_value(request_payload)?; + let response = client + .post(&embeddings_url) + .timeout(std::time::Duration::from_secs(120_u64)) + .header("Accept", "application/json") + .header("Content-Type", "application/json") + .header("x-portkey-virtual-key", self.virtual_key.clone()) + .header("x-portkey-api-key", &self.api_key) + .json(&payload_val) + .send() + .await?; + + let embeddings = + handle_response::(response, "embeddings").await?; + all_embeddings.extend(embeddings.data.iter().map(|x| x.embedding.clone())); + } + Ok(GenericEmbeddingResponse { + embeddings: all_embeddings, + }) + } + + async fn model_dim(&self, model_name: &str) -> Result { + // determine embedding dim by generating an embedding and getting length of array + let req = GenericEmbeddingRequest { + input: vec!["hello world".to_string()], + model: model_name.to_string(), + }; + let embedding = self.generate_embedding(&req).await?; + let dim = embedding.embeddings[0].len(); + Ok(dim as u32) + } +} + +impl PortkeyProvider { + pub async fn generate_response( + &self, + model_name: String, + messages: &[ChatMessageRequest], + ) -> Result { + let client = Client::new(); + let message = serde_json::json!({ + "model": model_name, + "messages": messages, + }); + let chat_url = format!("{}/chat/completions", self.url); + let response = client + .post(&chat_url) + .timeout(std::time::Duration::from_secs(120_u64)) + .header("Accept", "application/json") + .header("Content-Type", "application/json") + .header("x-portkey-virtual-key", self.virtual_key.clone()) + .header("x-portkey-api-key", &self.api_key) + .json(&message) + .send() + .await?; + let chat_response = handle_response::(response, "embeddings").await?; + Ok(chat_response.choices[0].message.content.clone()) + } +} + +#[cfg(test)] +mod portkey_integration_tests { + use super::*; + use tokio::test as async_test; + + #[async_test] + async fn test_portkey_openai() { + let portkey_api_key = env::var("PORTKEY_API_KEY").expect("PORTKEY_API_KEY not set"); + let portkey_virtual_key = + env::var("PORTKEY_VIRTUAL_KEY_OPENAI").expect("PORTKEY_VIRTUAL_KEY_OPENAI not set"); + let provider = PortkeyProvider::new(None, Some(portkey_api_key), Some(portkey_virtual_key)); + let request = GenericEmbeddingRequest { + model: "text-embedding-ada-002".to_string(), + input: vec!["hello world".to_string()], + }; + + let embeddings = provider.generate_embedding(&request).await.unwrap(); + assert!( + !embeddings.embeddings.is_empty(), + "Embeddings should not be empty" + ); + assert!( + embeddings.embeddings.len() == 1, + "Embeddings should have length 1" + ); + assert!( + embeddings.embeddings[0].len() == 1536, + "Embeddings should have length 1536" + ); + + let dim = provider.model_dim("text-embedding-ada-002").await.unwrap(); + assert_eq!(dim, 1536); + + let chatmessage = ChatMessageRequest { + role: "user".to_string(), + content: "hello world".to_string(), + }; + let response = provider + .generate_response("gpt-3.5-turbo".to_string(), &[chatmessage]) + .await + .unwrap(); + assert!(!response.is_empty(), "Response should not be empty"); + } +} diff --git a/core/src/types.rs b/core/src/types.rs index 4254f0f..67419d1 100644 --- a/core/src/types.rs +++ b/core/src/types.rs @@ -160,6 +160,7 @@ impl Model { ModelSource::Ollama => self.name.clone(), ModelSource::Tembo => self.name.clone(), ModelSource::Cohere => self.name.clone(), + ModelSource::Portkey => self.name.clone(), } } } @@ -236,6 +237,7 @@ pub enum ModelSource { Ollama, Tembo, Cohere, + Portkey, } impl FromStr for ModelSource { @@ -248,6 +250,7 @@ impl FromStr for ModelSource { "sentence-transformers" => Ok(ModelSource::SentenceTransformers), "tembo" => Ok(ModelSource::Tembo), "cohere" => Ok(ModelSource::Cohere), + "portkey" => Ok(ModelSource::Portkey), _ => Ok(ModelSource::SentenceTransformers), } } @@ -261,6 +264,7 @@ impl Display for ModelSource { ModelSource::SentenceTransformers => write!(f, "sentence-transformers"), ModelSource::Tembo => write!(f, "tembo"), ModelSource::Cohere => write!(f, "cohere"), + ModelSource::Portkey => write!(f, "portkey"), } } } @@ -273,6 +277,7 @@ impl From for ModelSource { "sentence-transformers" => ModelSource::SentenceTransformers, "tembo" => ModelSource::Tembo, "cohere" => ModelSource::Cohere, + "portkey" => ModelSource::Portkey, // other cases are assumed to be private sentence-transformer compatible model // and can be hot-loaded _ => ModelSource::SentenceTransformers, @@ -285,6 +290,15 @@ impl From for ModelSource { mod model_tests { use super::*; + #[test] + fn test_portkey_parsing() { + let model = Model::new("portkey/openai/text-embedding-ada-002").unwrap(); + assert_eq!(model.source, ModelSource::Portkey); + assert_eq!(model.fullname, "portkey/openai/text-embedding-ada-002"); + assert_eq!(model.name, "text-embedding-ada-002"); + assert_eq!(model.api_name(), "text-embedding-ada-002"); + } + #[test] fn test_tembo_parsing() { let model = Model::new("tembo/meta-llama/Meta-Llama-3-8B-Instruct").unwrap(); diff --git a/core/src/worker/base.rs b/core/src/worker/base.rs index 7e2ba68..2908e1a 100644 --- a/core/src/worker/base.rs +++ b/core/src/worker/base.rs @@ -98,6 +98,7 @@ async fn execute_job( &job_meta.transformer.source, job_params.api_key.clone(), None, + None, )?; let embedding_request = diff --git a/extension/Cargo.toml b/extension/Cargo.toml index 2dc5cda..2b3be00 100644 --- a/extension/Cargo.toml +++ b/extension/Cargo.toml @@ -20,7 +20,6 @@ chrono = {version = "0.4.26", features = ["serde"] } handlebars = "5.1.0" lazy_static = "1.4.0" log = "0.4.21" -openai-api-rs = "4.0.6" pgmq = "0.26.1" pgrx = "0.11.4" postgres-types = "0.2.5" diff --git a/extension/src/api.rs b/extension/src/api.rs index 982ba12..d7575a8 100644 --- a/extension/src/api.rs +++ b/extension/src/api.rs @@ -1,5 +1,6 @@ -use crate::chat::ops::{call_chat, get_chat_response}; +use crate::chat::ops::{call_chat, call_chat_completions}; use crate::chat::types::RenderedPrompt; +use crate::guc::get_guc_configs; use crate::search::{self, init_table}; use crate::transformers::generic::env_interpolate_string; use crate::transformers::transform; @@ -162,7 +163,11 @@ fn generate( sys_rendered: "".to_string(), user_rendered: input.to_string(), }; - get_chat_response(prompt, &model, api_key) + let mut guc_configs = get_guc_configs(&model.source); + if let Some(api_key) = api_key { + guc_configs.api_key = Some(api_key); + } + call_chat_completions(prompt, &model, &guc_configs) } #[pg_extern] diff --git a/extension/src/chat/ops.rs b/extension/src/chat/ops.rs index 0c3bb57..cb52ad2 100644 --- a/extension/src/chat/ops.rs +++ b/extension/src/chat/ops.rs @@ -1,17 +1,14 @@ -use std::env; - use crate::guc; use crate::search; -use crate::transformers::generic::get_env_interpolated_guc; use crate::util::get_vectorize_meta_spi; use anyhow::{anyhow, Result}; use handlebars::Handlebars; -use openai_api_rs::v1::api::Client; -use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; use pgrx::prelude::*; -use vectorize_core::transformers::ollama::LLMFunctions; -use vectorize_core::transformers::ollama::OllamaInstance; +use vectorize_core::transformers::providers::ollama::OllamaProvider; +use vectorize_core::transformers::providers::openai::OpenAIProvider; +use vectorize_core::transformers::providers::portkey::PortkeyProvider; +use vectorize_core::transformers::providers::ChatMessageRequest; use vectorize_core::types::Model; use vectorize_core::types::ModelSource; @@ -50,6 +47,9 @@ pub fn call_chat( ModelSource::SentenceTransformers | ModelSource::Cohere => { error!("SentenceTransformers and Cohere not yet supported for chat completions") } + ModelSource::Portkey => { + get_bpe_from_model(&chat_model.name).expect("failed to get BPE from model") + } }; // can only be 1 column in a chat job, for now, so safe to grab first element @@ -123,7 +123,8 @@ pub fn call_chat( )?; // http request to chat completions - let chat_response = get_chat_response(rendered_prompt, chat_model, api_key)?; + let guc_configs = guc::get_guc_configs(&chat_model.source); + let chat_response = call_chat_completions(rendered_prompt, chat_model, &guc_configs)?; Ok(ChatResponse { context: search_results, @@ -131,20 +132,6 @@ pub fn call_chat( }) } -pub fn get_chat_response( - prompt: RenderedPrompt, - model: &Model, - api_key: Option, -) -> Result { - match model.source { - ModelSource::OpenAI | ModelSource::Tembo => call_chat_completions(prompt, model, api_key), - ModelSource::Ollama => call_ollama_chat_completions(prompt, &model.name), - ModelSource::SentenceTransformers | ModelSource::Cohere => { - error!("SentenceTransformers and Cohere not yet supported for chat completions"); - } - } -} - fn render_user_message(user_prompt_template: &str, context: &str, query: &str) -> Result { let handlebars = Handlebars::new(); let render_vals = serde_json::json!({ @@ -155,96 +142,60 @@ fn render_user_message(user_prompt_template: &str, context: &str, query: &str) - Ok(user_rendered) } -fn call_chat_completions( +pub fn call_chat_completions( prompts: RenderedPrompt, model: &Model, - api_key: Option, + guc_configs: &guc::ModelGucConfig, ) -> Result { - let api_key = match api_key { - Some(k) => k.to_string(), - None => { - let this_guc = match model.source { - ModelSource::Tembo => guc::VectorizeGuc::TemboAIKey, - ModelSource::OpenAI => guc::VectorizeGuc::OpenAIKey, - _ => { - error!("API key not found for model source"); - } - }; - match guc::get_guc(this_guc) { - Some(k) => k, - None => { - error!("failed to get API key from GUC"); - } - } - } - }; - - let base_url = match model.source { - ModelSource::Tembo => get_env_interpolated_guc(guc::VectorizeGuc::TemboServiceUrl) - .expect("vectorize.tembo_service_url must be set"), - ModelSource::OpenAI => get_env_interpolated_guc(guc::VectorizeGuc::OpenAIServiceUrl) - .expect("vectorize.openai_service_url must be set"), - _ => { - error!("API key not found for model source"); - } - }; - // set the url for openai client - env::set_var("OPENAI_API_BASE", base_url); - let client = Client::new(api_key); - let sys_msg = chat_completion::ChatCompletionMessage { - role: chat_completion::MessageRole::system, - content: chat_completion::Content::Text(prompts.sys_rendered), - name: None, - }; - let usr_msg = chat_completion::ChatCompletionMessage { - role: chat_completion::MessageRole::user, - content: chat_completion::Content::Text(prompts.user_rendered), - name: None, - }; - - let req = ChatCompletionRequest::new(model.name.clone(), vec![sys_msg, usr_msg]); - let result = client.chat_completion(req)?; - // currently we only support single query, and not a conversation - // so we can safely select the first response for now - let responses = &result.choices[0]; - let chat_response: String = responses - .message - .content - .clone() - .expect("no response from chat model"); - Ok(chat_response) -} - -fn call_ollama_chat_completions(prompts: RenderedPrompt, model: &str) -> Result { - // get url from guc - - let url = match guc::get_guc(guc::VectorizeGuc::OllamaServiceUrl) { - Some(k) => k, - None => { - error!("failed to get Ollama url from GUC"); - } - }; - + let messages = vec![ + ChatMessageRequest { + role: "system".to_owned(), + content: prompts.sys_rendered.clone(), + }, + ChatMessageRequest { + role: "user".to_owned(), + content: prompts.user_rendered.clone(), + }, + ]; let runtime = tokio::runtime::Builder::new_current_thread() .enable_io() .enable_time() .build() .unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e)); - let instance = OllamaInstance::new(model.to_string(), url.to_string()); - - let response = runtime.block_on(async { - instance - .generate_reponse(prompts.sys_rendered + "\n" + &prompts.user_rendered) - .await - }); - - match response { - Ok(k) => Ok(k), - Err(k) => { - error!("Unable to generate response. Error: {k}"); + let chat_response: String = runtime.block_on(async { + match model.source { + ModelSource::OpenAI | ModelSource::Tembo => { + let provider = OpenAIProvider::new( + guc_configs.service_url.clone(), + guc_configs.api_key.clone(), + ); + provider + .generate_response(model.api_name(), &messages) + .await + } + ModelSource::Portkey => { + let provider = PortkeyProvider::new( + guc_configs.service_url.clone(), + guc_configs.api_key.clone(), + guc_configs.virtual_key.clone(), + ); + provider + .generate_response(model.api_name(), &messages) + .await + } + ModelSource::Ollama => { + let provider = OllamaProvider::new(guc_configs.service_url.clone()); + provider + .generate_response(model.api_name(), &messages) + .await + } + ModelSource::SentenceTransformers | ModelSource::Cohere => { + error!("SentenceTransformers and Cohere not yet supported for chat completions") + } } - } + })?; + Ok(chat_response) } // Trims the context to fit within the token limit when force_trim = True diff --git a/extension/src/guc.rs b/extension/src/guc.rs index bbd2d73..fea0470 100644 --- a/extension/src/guc.rs +++ b/extension/src/guc.rs @@ -4,6 +4,8 @@ use pgrx::*; use anyhow::Result; use vectorize_core::types::ModelSource; +use crate::transformers::generic::env_interpolate_string; + pub static VECTORIZE_HOST: GucSetting> = GucSetting::>::new(None); pub static VECTORIZE_DATABASE_NAME: GucSetting> = GucSetting::>::new(None); @@ -23,6 +25,9 @@ pub static OLLAMA_SERVICE_HOST: GucSetting> = GucSetting::> = GucSetting::>::new(None); pub static TEMBO_API_KEY: GucSetting> = GucSetting::>::new(None); pub static COHERE_API_KEY: GucSetting> = GucSetting::>::new(None); +pub static PORTKEY_API_KEY: GucSetting> = GucSetting::>::new(None); +pub static PORTKEY_VIRTUAL_KEY: GucSetting> = GucSetting::>::new(None); +pub static PORTKEY_SERVICE_URL: GucSetting> = GucSetting::>::new(None); // initialize GUCs pub fn init_guc() { @@ -143,6 +148,33 @@ pub fn init_guc() { GucContext::Suset, GucFlags::default(), ); + + GucRegistry::define_string_guc( + "vectorize.portkey_service_url", + "Base url for the Portkey platform", + "Base url for the Portkey platform", + &PORTKEY_SERVICE_URL, + GucContext::Suset, + GucFlags::default(), + ); + + GucRegistry::define_string_guc( + "vectorize.portkey_api_key", + "API Key for the Portkey platform", + "API Key for the Portkey platform", + &PORTKEY_API_KEY, + GucContext::Suset, + GucFlags::default(), + ); + + GucRegistry::define_string_guc( + "vectorize.portkey_virtual_key", + "Virtual Key for the Portkey platform", + "Virtual Key for the Portkey platform", + &PORTKEY_VIRTUAL_KEY, + GucContext::Suset, + GucFlags::default(), + ); } // for handling of GUCs that can be error prone @@ -158,6 +190,9 @@ pub enum VectorizeGuc { OllamaServiceUrl, TemboServiceUrl, CohereApiKey, + PortkeyApiKey, + PortkeyVirtualKey, + PortkeyServiceUrl, } /// a convenience function to get this project's GUCs @@ -173,10 +208,14 @@ pub fn get_guc(guc: VectorizeGuc) -> Option { VectorizeGuc::OpenAIServiceUrl => OPENAI_BASE_URL.get(), VectorizeGuc::EmbeddingServiceApiKey => EMBEDDING_SERVICE_API_KEY.get(), VectorizeGuc::CohereApiKey => COHERE_API_KEY.get(), + VectorizeGuc::PortkeyApiKey => PORTKEY_API_KEY.get(), + VectorizeGuc::PortkeyVirtualKey => PORTKEY_VIRTUAL_KEY.get(), + VectorizeGuc::PortkeyServiceUrl => PORTKEY_SERVICE_URL.get(), }; if let Some(cstr) = val { if let Ok(s) = handle_cstr(cstr) { - Some(s) + let interpolated = env_interpolate_string(&s).unwrap(); + Some(interpolated) } else { error!("failed to convert CStr to str"); } @@ -199,6 +238,7 @@ fn handle_cstr(cstr: &CStr) -> Result { pub struct ModelGucConfig { pub api_key: Option, pub service_url: Option, + pub virtual_key: Option, } pub fn get_guc_configs(model_source: &ModelSource) -> ModelGucConfig { @@ -206,22 +246,32 @@ pub fn get_guc_configs(model_source: &ModelSource) -> ModelGucConfig { ModelSource::OpenAI => ModelGucConfig { api_key: get_guc(VectorizeGuc::OpenAIKey), service_url: get_guc(VectorizeGuc::OpenAIServiceUrl), + virtual_key: None, }, ModelSource::Tembo => ModelGucConfig { api_key: get_guc(VectorizeGuc::TemboAIKey), service_url: get_guc(VectorizeGuc::TemboServiceUrl), + virtual_key: None, }, ModelSource::SentenceTransformers => ModelGucConfig { api_key: get_guc(VectorizeGuc::EmbeddingServiceApiKey), service_url: get_guc(VectorizeGuc::EmbeddingServiceUrl), + virtual_key: None, }, ModelSource::Cohere => ModelGucConfig { api_key: get_guc(VectorizeGuc::CohereApiKey), service_url: None, + virtual_key: None, }, ModelSource::Ollama => ModelGucConfig { api_key: None, service_url: get_guc(VectorizeGuc::OllamaServiceUrl), + virtual_key: None, + }, + ModelSource::Portkey => ModelGucConfig { + api_key: get_guc(VectorizeGuc::PortkeyApiKey), + service_url: get_guc(VectorizeGuc::PortkeyServiceUrl), + virtual_key: get_guc(VectorizeGuc::PortkeyVirtualKey), }, } } diff --git a/extension/src/search.rs b/extension/src/search.rs index ebe4508..c901f13 100644 --- a/extension/src/search.rs +++ b/extension/src/search.rs @@ -8,8 +8,8 @@ use crate::util; use anyhow::{Context, Result}; use pgrx::prelude::*; -use vectorize_core::transformers::ollama::check_model_host; use vectorize_core::transformers::providers::get_provider; +use vectorize_core::transformers::providers::ollama::check_model_host; use vectorize_core::types::{self, Model, ModelSource, TableMethod, VectorizeMeta}; #[allow(clippy::too_many_arguments)] @@ -43,6 +43,7 @@ pub fn init_table( &transformer.source, guc_configs.api_key.clone(), guc_configs.service_url, + None, )?; //synchronous diff --git a/extension/src/transformers/mod.rs b/extension/src/transformers/mod.rs index e1fe033..a589333 100644 --- a/extension/src/transformers/mod.rs +++ b/extension/src/transformers/mod.rs @@ -23,8 +23,13 @@ pub fn transform(input: &str, transformer: &Model, api_key: Option) -> V guc_configs.api_key }; - let provider = providers::get_provider(&transformer.source, api_key, guc_configs.service_url) - .expect("failed to get provider"); + let provider = providers::get_provider( + &transformer.source, + api_key, + guc_configs.service_url, + guc_configs.virtual_key, + ) + .expect("failed to get provider"); let input = Inputs { record_id: "".to_string(), inputs: input.to_string(), diff --git a/extension/src/workers/mod.rs b/extension/src/workers/mod.rs index d846e64..16df1c3 100644 --- a/extension/src/workers/mod.rs +++ b/extension/src/workers/mod.rs @@ -82,6 +82,7 @@ async fn execute_job(dbclient: Pool, msg: Message) &job_meta.transformer.source, job_params.api_key.clone(), guc_configs.service_url, + guc_configs.virtual_key, )?; let embedding_response = provider.generate_embedding(&embedding_request).await?;