diff --git a/python/starpoint/_utils.py b/python/starpoint/_utils.py index 8747cf1..5592754 100644 --- a/python/starpoint/_utils.py +++ b/python/starpoint/_utils.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, Optional +from typing import Dict, Optional, List from uuid import UUID import requests @@ -102,3 +102,13 @@ def _check_collection_identifier_collision( raise ValueError(NO_COLLECTION_VALUE_ERROR) elif collection_id and collection_name: raise ValueError(MULTI_COLLECTION_VALUE_ERROR) + + +def _ensure_embedding_dict(embeddings: List[float] | Dict[str, List[float] | int] | None): + if isinstance(embeddings, list): + dict_embeddings = { + "values": embeddings, + "dimensionality": len(embeddings) + } + return dict_embeddings + return embeddings diff --git a/python/starpoint/db.py b/python/starpoint/db.py index 5cddffb..38ab888 100644 --- a/python/starpoint/db.py +++ b/python/starpoint/db.py @@ -8,7 +8,6 @@ import validators from starpoint import reader, writer, _utils -from starpoint.embedding import Embedding LOGGER = logging.getLogger(__name__) @@ -88,7 +87,7 @@ def insert( def column_insert( self, - embeddings: List[Embedding], + embeddings: List[Dict[str, List[float] | int]], document_metadatas: List[Dict[Any, Any]], collection_id: Optional[str] = None, collection_name: Optional[str] = None, @@ -127,7 +126,7 @@ def query( sql: Optional[str] = None, collection_id: Optional[str] = None, collection_name: Optional[str] = None, - query_embedding: Optional[List[float] | Embedding] = None, + query_embedding: Optional[List[float] | Dict[str, List[float] | int]] = None, params: Optional[List[Any]] = None, text_search_query: Optional[List[str]] = None, text_search_weight: Optional[float] = None, @@ -143,6 +142,7 @@ def query( collection_name: The collection's name where the query will happen. This or the `collection_id` needs to be provided. query_embedding: An embedding to query against the collection using similarity search. + This is of the shape {"values": List[float], "dimensionality": int} params: values for parameterized sql Returns: @@ -154,12 +154,6 @@ def query( requests.exceptions.SSLError: Failure likely due to network issues. """ - # check if query embedding is a float, if it is, convert to a embedding object - if isinstance(query_embedding, list): - query_embedding = Embedding( - vectors=query_embedding, - dim=len(query_embedding)) - return self.reader.query( sql=sql, collection_id=collection_id, @@ -231,7 +225,7 @@ def update( def column_update( self, - embeddings: List[Embedding], + embeddings: List[Dict[str, List[float] | int]], document_metadatas: List[Dict[Any, Any]], collection_id: Optional[str] = None, collection_name: Optional[str] = None, diff --git a/python/starpoint/embedding.py b/python/starpoint/embedding.py index ff6be66..75c68c9 100644 --- a/python/starpoint/embedding.py +++ b/python/starpoint/embedding.py @@ -30,15 +30,6 @@ ) -class Embedding(object): - values: List[float] - dimensionality: int - - def __init__(self, values: List[float], dimensionality: Optional[int] = None): - self.values = values - self.dimensionality = len(values) if dimensionality is None else dimensionality - - class EmbeddingModel(Enum): MINILM = "MINI_LM" diff --git a/python/starpoint/reader.py b/python/starpoint/reader.py index 459faa6..9a0fed3 100644 --- a/python/starpoint/reader.py +++ b/python/starpoint/reader.py @@ -9,9 +9,9 @@ _build_header, _check_collection_identifier_collision, _validate_host, + _ensure_embedding_dict ) -from starpoint.embedding import Embedding LOGGER = logging.getLogger(__name__) @@ -49,7 +49,7 @@ def query( sql: Optional[str] = None, collection_id: Optional[str] = None, collection_name: Optional[str] = None, - query_embeddings: Optional[Embedding] = None, + query_embeddings: Optional[Dict[str, int | List[float]] | List[float]] = None, params: Optional[List[Any]] = None, text_search_query: Optional[List[str]] = None, text_search_weight: Optional[float] = None, @@ -89,10 +89,13 @@ def query( ) """ + # check if type of query embeddings is list of float, if so convert to a dict + query_embeddings = _ensure_embedding_dict(query_embeddings) + request_data = dict( collection_id=collection_id, collection_name=collection_name, - query_embeddings=query_embeddings, + query_embedding=query_embeddings, sql=sql, params=params, text_search_query=text_search_query, diff --git a/python/starpoint/writer.py b/python/starpoint/writer.py index af6dd5b..2b45259 100644 --- a/python/starpoint/writer.py +++ b/python/starpoint/writer.py @@ -10,7 +10,6 @@ _validate_host, ) -from starpoint.embedding import Embedding LOGGER = logging.getLogger(__name__) @@ -171,7 +170,7 @@ def insert( def column_insert( self, - embeddings: List[Embedding], + embeddings: List[Dict[str, List[float] | int]], document_metadatas: List[Dict[Any, Any]], collection_id: Optional[str] = None, collection_name: Optional[str] = None, @@ -282,7 +281,7 @@ def update( def column_update( self, - embeddings: List[Embedding], + embeddings: List[Dict[str, List[float] | int]], document_metadatas: List[Dict[Any, Any]], collection_id: Optional[str] = None, collection_name: Optional[str] = None,