From 6f416d71886b31e13ed939b8e9473840c3374749 Mon Sep 17 00:00:00 2001 From: Itai Smith Date: Sat, 12 Oct 2024 15:11:47 -0700 Subject: [PATCH] xAI EF --- .../xai_embedding_function.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 chromadb/utils/embedding_functions/xai_embedding_function.py diff --git a/chromadb/utils/embedding_functions/xai_embedding_function.py b/chromadb/utils/embedding_functions/xai_embedding_function.py new file mode 100644 index 00000000000..22f87f2c796 --- /dev/null +++ b/chromadb/utils/embedding_functions/xai_embedding_function.py @@ -0,0 +1,37 @@ +import asyncio +import logging +from typing import Optional + +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings + +logger = logging.getLogger(__name__) + +class XAIEmbeddingFunction(EmbeddingFunction[Documents]): + def __init__(self, model_name: str, api_key: Optional[str] = None, host: str = "api.x.ai"): + """ + Initialize the XAIEmbeddingFunction. + Args: + model_name (str): The name of the model to use for embedding. + api_key (str, optional): Your API key for the xai-sdk. If not + provided, it will raise an error to provide an xAI API key. + host (str, optional): Hostname of the xAI API server. + """ + try: + import xai_sdk + except ImportError: + raise ValueError( + "The xai-sdk python package is not installed. Please install it with `pip install xai-sdk`" + ) + + if api_key is None: + raise ValueError("Please provide an OpenAI API key. You can get one at https://developers.x.ai/api/api-key/") + + self._api_key = api_key + self._host = host + self._model_name = model_name + self._client = xai_sdk.Client(api_key=self._api_key, api_host=self._host) + + def __call__(self, input: Documents) -> Embeddings: + # embed() returns a list of tuples, where each contains the embedding and its shape + embeddings = asyncio.run(self._client.embedder.embed(texts=input, model_name=self._model_name)) + return [embedding for embedding, _ in embeddings]