From 705900fc52a8dfa3fbb1d23d21f66d420800a402 Mon Sep 17 00:00:00 2001 From: scottwey Date: Wed, 25 Oct 2023 11:58:23 -0700 Subject: [PATCH 1/7] multivector support --- python/starpoint/db.py | 2 +- python/starpoint/reader.py | 4 +-- python/starpoint/writer.py | 51 +++----------------------------------- 3 files changed, 6 insertions(+), 51 deletions(-) diff --git a/python/starpoint/db.py b/python/starpoint/db.py index c535ac7..801bdb7 100644 --- a/python/starpoint/db.py +++ b/python/starpoint/db.py @@ -192,7 +192,7 @@ def query( sql=sql, collection_id=collection_id, collection_name=collection_name, - query_embedding=query_embedding, + query_embeddings=query_embedding, params=params, text_search_query=text_search_query, text_search_weight=text_search_weight, diff --git a/python/starpoint/reader.py b/python/starpoint/reader.py index b6841dd..a161e63 100644 --- a/python/starpoint/reader.py +++ b/python/starpoint/reader.py @@ -48,7 +48,7 @@ def query( sql: Optional[str] = None, collection_id: Optional[str] = None, collection_name: Optional[str] = None, - query_embedding: Optional[List[float]] = None, + query_embeddings: Optional[Dict[Any, Any]] = None, params: Optional[List[Any]] = None, text_search_query: Optional[List[str]] = None, text_search_weight: Optional[float] = None, @@ -91,7 +91,7 @@ def query( request_data = dict( collection_id=collection_id, collection_name=collection_name, - query_embedding=query_embedding, + query_embeddings=query_embeddings, sql=sql, params=params, text_search_query=text_search_query, diff --git a/python/starpoint/writer.py b/python/starpoint/writer.py index 76b009f..2aaec8d 100644 --- a/python/starpoint/writer.py +++ b/python/starpoint/writer.py @@ -100,51 +100,6 @@ def delete( return {} return response.json() - def column_delete( - self, - embeddings: List[List[float]], - document_metadatas: List[Dict[Any, Any]], - collection_id: Optional[str] = None, - collection_name: Optional[str] = None, - ) -> Dict[Any, Any]: - """Deletes documents from an existing collection by embedding and document metadata arrays. - The arrays are zipped together and updates the document in the order of the two arrays. - - Args: - embeddings: A list of embeddings. - Order of the embeddings should match the document_metadatas. - document_metadatas: A list of metadata to be associated with embeddings. - Order of these metadatas should match the embeddings. - collection_id: The collection's id where the documents will be deleted. - This or the `collection_name` needs to be provided. - collection_name: The collection's name where the documents will be deleted. - This or the `collection_id` needs to be provided. - - Returns: - dict: delete response json - - Raises: - ValueError: If neither collection id and collection name are provided. - ValueError: If both collection id and collection name are provided. - requests.exceptions.SSLError: Failure likely due to network issues. - """ - if len(embeddings) != len(document_metadatas): - LOGGER.warning(EMBEDDING_METADATA_LENGTH_MISMATCH_WARNING) - - documents = [ - { - "embedding": embedding, - "metadata": document_metadata, - } - for embedding, document_metadata in zip(embeddings, document_metadatas) - ] - - return self.delete( - documents=documents, - collection_id=collection_id, - collection_name=collection_name, - ) - def insert( self, documents: List[Dict[Any, Any]], @@ -214,7 +169,7 @@ def insert( def column_insert( self, - embeddings: List[List[float]], + embeddings: List[Dict[Any, Any]], document_metadatas: List[Dict[Any, Any]], collection_id: Optional[str] = None, collection_name: Optional[str] = None, @@ -245,7 +200,7 @@ def column_insert( documents = [ { - "embedding": embedding, + "embeddings": embedding, "metadata": document_metadata, } for embedding, document_metadata in zip(embeddings, document_metadatas) @@ -325,7 +280,7 @@ def update( def column_update( self, - embeddings: List[List[float]], + embeddings: List[Dict[Any, Any]], document_metadatas: List[Dict[Any, Any]], collection_id: Optional[str] = None, collection_name: Optional[str] = None, From f6bfe555cc3f7b78a28b5e231cb718cf3812a57b Mon Sep 17 00:00:00 2001 From: scottwey Date: Wed, 25 Oct 2023 11:59:56 -0700 Subject: [PATCH 2/7] fix update --- python/starpoint/writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/starpoint/writer.py b/python/starpoint/writer.py index 2aaec8d..7885269 100644 --- a/python/starpoint/writer.py +++ b/python/starpoint/writer.py @@ -311,7 +311,7 @@ def column_update( documents = [ { - "embedding": embedding, + "embeddings": embedding, "metadata": document_metadata, } for embedding, document_metadata in zip(embeddings, document_metadatas) From b73499b4d0c40777ea769f5bd2fcedc99f39bdc9 Mon Sep 17 00:00:00 2001 From: Tyler Duong Date: Wed, 25 Oct 2023 23:46:27 -0700 Subject: [PATCH 3/7] add embedding type and fix tests --- python/starpoint/db.py | 50 +++++------------------ python/starpoint/embedding.py | 9 +++++ python/starpoint/reader.py | 3 +- python/starpoint/writer.py | 6 ++- python/tests/test_db.py | 16 ++------ python/tests/test_writer.py | 76 +++++------------------------------ 6 files changed, 40 insertions(+), 120 deletions(-) diff --git a/python/starpoint/db.py b/python/starpoint/db.py index 801bdb7..5cddffb 100644 --- a/python/starpoint/db.py +++ b/python/starpoint/db.py @@ -8,6 +8,7 @@ import validators from starpoint import reader, writer, _utils +from starpoint.embedding import Embedding LOGGER = logging.getLogger(__name__) @@ -56,42 +57,6 @@ def delete( collection_name=collection_name, ) - def column_delete( - self, - embeddings: List[List[float]], - document_metadatas: List[Dict[Any, Any]], - collection_id: Optional[str] = None, - collection_name: Optional[str] = None, - ) -> Dict[Any, Any]: - """Deletes documents from an existing collection by embedding and document metadata arrays. - The arrays are zipped together and updates the document in the order of the two arrays. - `column_delete()` method from [`Writer`](#writer-objects). - - Args: - embeddings: A list of embeddings. - Order of the embeddings should match the document_metadatas. - document_metadatas: A list of metadata to be associated with embeddings. - Order of these metadatas should match the embeddings. - collection_id: The collection's id where the documents will be deleted. - This or the `collection_name` needs to be provided. - collection_name: The collection's name where the documents will be deleted. - This or the `collection_id` needs to be provided. - - Returns: - dict: delete response json - - Raises: - ValueError: If neither collection id and collection name are provided. - ValueError: If both collection id and collection name are provided. - requests.exceptions.SSLError: Failure likely due to network issues. - """ - return self.writer.column_delete( - embeddings=embeddings, - document_metadatas=document_metadatas, - collection_id=collection_id, - collection_name=collection_name, - ) - def insert( self, documents: List[Dict[Any, Any]], @@ -123,7 +88,7 @@ def insert( def column_insert( self, - embeddings: List[List[float]], + embeddings: List[Embedding], document_metadatas: List[Dict[Any, Any]], collection_id: Optional[str] = None, collection_name: Optional[str] = None, @@ -162,7 +127,7 @@ def query( sql: Optional[str] = None, collection_id: Optional[str] = None, collection_name: Optional[str] = None, - query_embedding: Optional[List[float]] = None, + query_embedding: Optional[List[float] | Embedding] = None, params: Optional[List[Any]] = None, text_search_query: Optional[List[str]] = None, text_search_weight: Optional[float] = None, @@ -188,6 +153,13 @@ def query( ValueError: If both collection id and collection name are provided. requests.exceptions.SSLError: Failure likely due to network issues. """ + + # check if query embedding is a float, if it is, convert to a embedding object + if isinstance(query_embedding, list): + query_embedding = Embedding( + vectors=query_embedding, + dim=len(query_embedding)) + return self.reader.query( sql=sql, collection_id=collection_id, @@ -259,7 +231,7 @@ def update( def column_update( self, - embeddings: List[List[float]], + embeddings: List[Embedding], document_metadatas: List[Dict[Any, Any]], collection_id: Optional[str] = None, collection_name: Optional[str] = None, diff --git a/python/starpoint/embedding.py b/python/starpoint/embedding.py index 75c68c9..ec77fef 100644 --- a/python/starpoint/embedding.py +++ b/python/starpoint/embedding.py @@ -30,6 +30,15 @@ ) +class Embedding(object): + vectors: List[float] + dim: int + + def __init__(self, vectors: List[float], dim: Optional[int] = None): + self.vectors = vectors + self.dim = len(vectors) if dim is None else dim + + class EmbeddingModel(Enum): MINILM = "MINI_LM" diff --git a/python/starpoint/reader.py b/python/starpoint/reader.py index a161e63..459faa6 100644 --- a/python/starpoint/reader.py +++ b/python/starpoint/reader.py @@ -11,6 +11,7 @@ _validate_host, ) +from starpoint.embedding import Embedding LOGGER = logging.getLogger(__name__) @@ -48,7 +49,7 @@ def query( sql: Optional[str] = None, collection_id: Optional[str] = None, collection_name: Optional[str] = None, - query_embeddings: Optional[Dict[Any, Any]] = None, + query_embeddings: Optional[Embedding] = None, params: Optional[List[Any]] = None, text_search_query: Optional[List[str]] = None, text_search_weight: Optional[float] = None, diff --git a/python/starpoint/writer.py b/python/starpoint/writer.py index 7885269..af6dd5b 100644 --- a/python/starpoint/writer.py +++ b/python/starpoint/writer.py @@ -10,6 +10,8 @@ _validate_host, ) +from starpoint.embedding import Embedding + LOGGER = logging.getLogger(__name__) # Host @@ -169,7 +171,7 @@ def insert( def column_insert( self, - embeddings: List[Dict[Any, Any]], + embeddings: List[Embedding], document_metadatas: List[Dict[Any, Any]], collection_id: Optional[str] = None, collection_name: Optional[str] = None, @@ -280,7 +282,7 @@ def update( def column_update( self, - embeddings: List[Dict[Any, Any]], + embeddings: List[Embedding], document_metadatas: List[Dict[Any, Any]], collection_id: Optional[str] = None, collection_name: Optional[str] = None, diff --git a/python/tests/test_db.py b/python/tests/test_db.py index 4ac953b..1dfc0c6 100644 --- a/python/tests/test_db.py +++ b/python/tests/test_db.py @@ -1,6 +1,7 @@ from tempfile import NamedTemporaryFile from uuid import uuid4 from unittest.mock import MagicMock, patch +from starpoint.embedding import Embedding import pytest from _pytest.monkeypatch import MonkeyPatch @@ -30,17 +31,6 @@ def test_client_delete(mock_writer: MagicMock, mock_reader: MagicMock): mock_writer().delete.assert_called_once() -@patch("starpoint.reader.Reader") -@patch("starpoint.writer.Writer") -def test_client_column_delete(mock_writer: MagicMock, mock_reader: MagicMock): - client = db.Client(api_key=uuid4()) - - client.column_delete(embeddings=[1.1], document_metadatas={"mock": "value"}) - - mock_reader.assert_called_once() # Only called during init - mock_writer().column_delete.assert_called_once() - - @patch("starpoint.reader.Reader") @patch("starpoint.writer.Writer") def test_client_insert(mock_writer: MagicMock, mock_reader: MagicMock): @@ -57,7 +47,7 @@ def test_client_insert(mock_writer: MagicMock, mock_reader: MagicMock): def test_client_column_insert(mock_writer: MagicMock, mock_reader: MagicMock): client = db.Client(api_key=uuid4()) - client.column_insert(embeddings=[1.1], document_metadatas={"mock": "value"}) + client.column_insert(embeddings=[Embedding([1.1])], document_metadatas=[{"mock": "value"}]) mock_reader.assert_called_once() # Only called during init mock_writer().column_insert.assert_called_once() @@ -101,7 +91,7 @@ def test_client_update(mock_writer: MagicMock, mock_reader: MagicMock): def test_client_column_update(mock_writer: MagicMock, mock_reader: MagicMock): client = db.Client(api_key=uuid4()) - client.column_update(embeddings=[1.1], document_metadatas={"mock": "value"}) + client.column_update(embeddings=[Embedding([1.1])], document_metadatas=[{"mock": "value"}]) mock_reader.assert_called_once() # Only called during init mock_writer().column_update.assert_called_once() diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index c99b06c..1e19d60 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -6,6 +6,7 @@ from requests.exceptions import SSLError from starpoint import writer +from starpoint.embedding import Embedding @pytest.fixture(scope="session") @@ -160,7 +161,7 @@ def test_writer_insert_SSLError( @patch("starpoint.writer.Writer.insert") def test_writer_column_insert(insert_mock: MagicMock, mock_writer: writer.Writer): - test_embeddings = [0.88, 0.71] + test_embeddings = [Embedding([0.88]), Embedding([0.71])] test_document_metadatas = [{"mock": "metadata"}, {"mock2": "metadata2"}] expected_insert_document = [ { @@ -188,7 +189,7 @@ def test_writer_column_insert(insert_mock: MagicMock, mock_writer: writer.Writer def test_writer_column_insert_collection_id_collection_name_passed_through( insert_mock: MagicMock, mock_writer: writer.Writer ): - test_embeddings = [0.88] + test_embeddings = [Embedding([0.88])] test_document_metadatas = [{"mock": "metadata"}] expected_insert_document = [ { @@ -217,7 +218,7 @@ def test_writer_column_insert_collection_id_collection_name_passed_through( def test_writer_column_insert_shorter_metadatas_length( insert_mock: MagicMock, mock_writer: writer.Writer, monkeypatch: MonkeyPatch ): - test_embeddings = [0.88, 0.71] + test_embeddings = [Embedding([0.88]), Embedding([0.71])] test_document_metadatas = [{"mock": "metadata"}] expected_insert_document = [ { @@ -247,7 +248,7 @@ def test_writer_column_insert_shorter_metadatas_length( def test_writer_column_insert_shorter_embeddings_length( insert_mock: MagicMock, mock_writer: writer.Writer, monkeypatch: MonkeyPatch ): - test_embeddings = [0.88] + test_embeddings = [Embedding([0.88])] test_document_metadatas = [{"mock": "metadata"}, {"mock2": "metadata2"}] expected_insert_document = [ { @@ -337,7 +338,7 @@ def test_writer_update_SSLError( @patch("starpoint.writer.Writer.update") def test_writer_column_update(update_mock: MagicMock, mock_writer: writer.Writer): - test_embeddings = [0.88, 0.71] + test_embeddings = [Embedding([0.88]), Embedding([0.71])] test_document_metadatas = [{"mock": "metadata"}, {"mock2": "metadata2"}] expected_update_document = [ { @@ -365,7 +366,7 @@ def test_writer_column_update(update_mock: MagicMock, mock_writer: writer.Writer def test_writer_column_update_collection_id_collection_name_passed_through( update_mock: MagicMock, mock_writer: writer.Writer ): - test_embeddings = [0.88] + test_embeddings = [Embedding([0.88])] test_document_metadatas = [{"mock": "metadata"}] expected_update_document = [ { @@ -394,7 +395,7 @@ def test_writer_column_update_collection_id_collection_name_passed_through( def test_writer_column_insert_shorter_metadatas_length( update_mock: MagicMock, mock_writer: writer.Writer, monkeypatch: MonkeyPatch ): - test_embeddings = [0.88, 0.71] + test_embeddings = [Embedding([0.88]), Embedding([0.71])] test_document_metadatas = [{"mock": "metadata"}] expected_update_document = [ { @@ -424,7 +425,7 @@ def test_writer_column_insert_shorter_metadatas_length( def test_writer_column_update_shorter_embeddings_length( update_mock: MagicMock, mock_writer: writer.Writer, monkeypatch: MonkeyPatch ): - test_embeddings = [0.88] + test_embeddings = [Embedding([0.88])] test_document_metadatas = [{"mock": "metadata"}, {"mock2": "metadata2"}] expected_update_document = [ { @@ -450,66 +451,11 @@ def test_writer_column_update_shorter_embeddings_length( ) -@patch("starpoint.writer.Writer.delete") -def test_writer_column_delete(delete_mock: MagicMock, mock_writer: writer.Writer): - test_embeddings = [0.88, 0.71] - test_document_metadatas = [{"mock": "metadata"}, {"mock2": "metadata2"}] - expected_delete_document = [ - { - "embedding": test_embeddings[0], - "metadata": test_document_metadatas[0], - }, - { - "embedding": test_embeddings[1], - "metadata": test_document_metadatas[1], - }, - ] - - mock_writer.column_delete( - embeddings=test_embeddings, document_metadatas=test_document_metadatas - ) - - delete_mock.assert_called_once_with( - documents=expected_delete_document, - collection_id=None, - collection_name=None, - ) - - -@patch("starpoint.writer.Writer.delete") -def test_writer_column_delete_collection_id_collection_name_passed_through( - delete_mock: MagicMock, mock_writer: writer.Writer -): - test_embeddings = [0.88] - test_document_metadatas = [{"mock": "metadata"}] - expected_delete_document = [ - { - "embedding": test_embeddings[0], - "metadata": test_document_metadatas[0], - }, - ] - expected_collection_id = "mock_id" - expected_collection_name = "mock_name" - - mock_writer.column_delete( - embeddings=test_embeddings, - document_metadatas=test_document_metadatas, - collection_id=expected_collection_id, - collection_name=expected_collection_name, - ) - - delete_mock.assert_called_once_with( - documents=expected_delete_document, - collection_id=expected_collection_id, - collection_name=expected_collection_name, - ) - - @patch("starpoint.writer.Writer.delete") def test_writer_column_insert_shorter_metadatas_length( delete_mock: MagicMock, mock_writer: writer.Writer, monkeypatch: MonkeyPatch ): - test_embeddings = [0.88, 0.71] + test_embeddings = [Embedding([0.88]), Embedding([0.71])] test_document_metadatas = [{"mock": "metadata"}] expected_delete_document = [ { @@ -539,7 +485,7 @@ def test_writer_column_insert_shorter_metadatas_length( def test_writer_column_delete_shorter_embeddings_length( delete_mock: MagicMock, mock_writer: writer.Writer, monkeypatch: MonkeyPatch ): - test_embeddings = [0.88] + test_embeddings = [Embedding([0.88])] test_document_metadatas = [{"mock": "metadata"}, {"mock2": "metadata2"}] expected_delete_document = [ { From 890a1c12dca07658e665b733f27547e294781a55 Mon Sep 17 00:00:00 2001 From: scottwey Date: Thu, 26 Oct 2023 08:19:57 -0700 Subject: [PATCH 4/7] fix up tests --- python/tests/test_pandas.py | 40 ------------------------- python/tests/test_writer.py | 60 ------------------------------------- 2 files changed, 100 deletions(-) diff --git a/python/tests/test_pandas.py b/python/tests/test_pandas.py index 1ac5278..31e034f 100644 --- a/python/tests/test_pandas.py +++ b/python/tests/test_pandas.py @@ -132,43 +132,3 @@ def test_update_by_dataframe_missing_embedding_column(): pandas.MISSING_COLUMN.substitute(column_name="embedding") in excinfo.value.__notes__ ) - - -@patch("starpoint.pandas._check_column_length") -def test_delete_by_dataframe_success(check_column_mock: MagicMock): - """Tests a successful delete operation.""" - mock_startpoint_client = MagicMock() - pandas_client = pandas.PandasClient(mock_startpoint_client) - - test_dataframe = pd.DataFrame( - [[1, 2]], - columns=["embedding", "metadata"], - ) - - pandas_client.delete_by_dataframe(test_dataframe) - - check_column_mock.assert_called_once_with(test_dataframe) - mock_startpoint_client.column_delete.assert_called_once_with( - embeddings=[1], - document_metadatas=[{"metadata": 2}], - collection_id=None, - collection_name=None, - ) - - -def test_delete_by_dataframe_missing_embedding_column(): - mock_startpoint_client = MagicMock() - pandas_client = pandas.PandasClient(mock_startpoint_client) - - missing_embedding_column_dataframe = pd.DataFrame( - [[1, 2]], - columns=["metadata", "extra"], - ) - - with pytest.raises(KeyError) as excinfo: - pandas_client.delete_by_dataframe(missing_embedding_column_dataframe) - - assert ( - pandas.MISSING_COLUMN.substitute(column_name="embedding") - in excinfo.value.__notes__ - ) diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 1e19d60..a5dc417 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -449,63 +449,3 @@ def test_writer_column_update_shorter_embeddings_length( collection_id=None, collection_name=None, ) - - -@patch("starpoint.writer.Writer.delete") -def test_writer_column_insert_shorter_metadatas_length( - delete_mock: MagicMock, mock_writer: writer.Writer, monkeypatch: MonkeyPatch -): - test_embeddings = [Embedding([0.88]), Embedding([0.71])] - test_document_metadatas = [{"mock": "metadata"}] - expected_delete_document = [ - { - "embedding": test_embeddings[0], - "metadata": test_document_metadatas[0], - }, - ] - - logger_mock = MagicMock() - monkeypatch.setattr(writer, "LOGGER", logger_mock) - - mock_writer.column_delete( - embeddings=test_embeddings, document_metadatas=test_document_metadatas - ) - - logger_mock.warning.assert_called_once_with( - writer.EMBEDDING_METADATA_LENGTH_MISMATCH_WARNING - ) - delete_mock.assert_called_once_with( - documents=expected_delete_document, - collection_id=None, - collection_name=None, - ) - - -@patch("starpoint.writer.Writer.delete") -def test_writer_column_delete_shorter_embeddings_length( - delete_mock: MagicMock, mock_writer: writer.Writer, monkeypatch: MonkeyPatch -): - test_embeddings = [Embedding([0.88])] - test_document_metadatas = [{"mock": "metadata"}, {"mock2": "metadata2"}] - expected_delete_document = [ - { - "embedding": test_embeddings[0], - "metadata": test_document_metadatas[0], - }, - ] - - logger_mock = MagicMock() - monkeypatch.setattr(writer, "LOGGER", logger_mock) - - mock_writer.column_delete( - embeddings=test_embeddings, document_metadatas=test_document_metadatas - ) - - logger_mock.warning.assert_called_once_with( - writer.EMBEDDING_METADATA_LENGTH_MISMATCH_WARNING - ) - delete_mock.assert_called_once_with( - documents=expected_delete_document, - collection_id=None, - collection_name=None, - ) From 4ce49458e0cd361a6bc2067da7e4427cd39407f2 Mon Sep 17 00:00:00 2001 From: scottwey Date: Thu, 26 Oct 2023 08:21:06 -0700 Subject: [PATCH 5/7] remove delete_by_datafame --- python/starpoint/pandas.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/python/starpoint/pandas.py b/python/starpoint/pandas.py index 9f030cf..0fb6056 100644 --- a/python/starpoint/pandas.py +++ b/python/starpoint/pandas.py @@ -112,27 +112,3 @@ def update_by_dataframe( collection_name=collection_name, ) - def delete_by_dataframe( - self, - dataframe: pd.DataFrame, - collection_id: Optional[str] = None, - collection_name: Optional[str] = None, - embedding_column_name: str = EMBEDDING_COLUMN_NAME, - ) -> Dict[Any, Any]: - _check_column_length(dataframe) - embedding_column_values = _get_column_value_from_dataframe( - dataframe, - embedding_column_name, - ) - - metadata_column_values = _get_aggregate_column_values_from_dataframe( - dataframe, - [embedding_column_name], - ) - - self.starpoint.column_delete( - embeddings=embedding_column_values, - document_metadatas=metadata_column_values, - collection_id=collection_id, - collection_name=collection_name, - ) From 557e342ddb27cc868acb772fcc12ab1f69a6228b Mon Sep 17 00:00:00 2001 From: scottwey Date: Thu, 26 Oct 2023 08:24:21 -0700 Subject: [PATCH 6/7] embedding -> embeddings --- python/README.md | 4 +++- python/starpoint/openai.py | 2 +- python/starpoint/pandas.py | 6 +++--- python/tests/test_openai.py | 8 ++++---- python/tests/test_pandas.py | 8 ++++---- python/tests/test_writer.py | 20 ++++++++++---------- 6 files changed, 25 insertions(+), 23 deletions(-) diff --git a/python/README.md b/python/README.md index 679f09c..b5cca1e 100644 --- a/python/README.md +++ b/python/README.md @@ -15,7 +15,7 @@ client = Client(api_key="YOUR_API_KEY_HERE") documents = [ { - "embedding": [0.1, 0.2, 0.3, 0.4, 0.5], + "embeddings": [0.1, 0.2, 0.3, 0.4, 0.5], "metadata": { "label1": "0", "label2": "1", @@ -31,11 +31,13 @@ client.insert(documents=documents, collection_name="COLLECTION_NAME") ## Contributing Make sure you have installed dev requirements + ``` pip install -r dev-requirements.txt ``` Unit tests should be passing. You can run them via + ``` pytest ./tests ``` diff --git a/python/starpoint/openai.py b/python/starpoint/openai.py index 1999fda..0bcbef9 100644 --- a/python/starpoint/openai.py +++ b/python/starpoint/openai.py @@ -114,7 +114,7 @@ def build_and_insert_embeddings( # Return the embedding response no matter what issues/bugs we might run into in the sdk try: sorted_embedding_data = sorted(embedding_data, key=lambda x: x["index"]) - embeddings = map(lambda x: x.get("embedding"), sorted_embedding_data) + embeddings = map(lambda x: x.get("embeddings"), sorted_embedding_data) starpoint_response = self.starpoint.column_insert( embeddings=embeddings, document_metadatas=document_metadatas, diff --git a/python/starpoint/pandas.py b/python/starpoint/pandas.py index 0fb6056..1ab33f7 100644 --- a/python/starpoint/pandas.py +++ b/python/starpoint/pandas.py @@ -8,7 +8,7 @@ LOGGER = logging.getLogger(__name__) -EMBEDDING_COLUMN_NAME = "embedding" +EMBEDDING_COLUMN_NAME = "embeddings" TOO_FEW_COLUMN_ERROR = """Not enough columns in dataframe provided. Please make sure to provide a column for at least embeddings. For examples of what this should look like visit: @@ -29,10 +29,10 @@ def _check_column_length(dataframe: pd.DataFrame): def _get_aggregate_column_values_from_dataframe( dataframe: pd.DataFrame, exclude_column_names: List[str] ) -> List[Dict]: - """Gets a dataframe of everything except for the "embedding" column then produce + """Gets a dataframe of everything except for the "embeddings" column then produce a list of row-wise dicts that will be loaded as the metadata. For example: - df = DataFrame([[1,2,3], [4,5,6]], columns=["embedding","b","c"] + df = DataFrame([[1,2,3], [4,5,6]], columns=["embeddings","b","c"] metadata_column_values will be [{'b': 2, 'c': 3}, {'b': 5, 'c': 6}] """ if not all((True if name in dataframe else False for name in exclude_column_names)): diff --git a/python/tests/test_openai.py b/python/tests/test_openai.py index 931a527..688b4fc 100644 --- a/python/tests/test_openai.py +++ b/python/tests/test_openai.py @@ -80,7 +80,7 @@ def test_client_build_and_insert_embeddings_input_string_success( expected_embedding_response = { "data": [ { - "embedding": mock_embedding, + "embeddings": mock_embedding, "index": 0, } ] @@ -128,11 +128,11 @@ def test_client_build_and_insert_embeddings_input_list_success( expected_embedding_response = { "data": [ { - "embedding": 0.77, + "embeddings": 0.77, "index": 0, }, { - "embedding": 0.88, + "embeddings": 0.88, "index": 1, }, ] @@ -224,7 +224,7 @@ def test_client_build_and_insert_embeddings_exception_during_write( expected_embedding_response = { "data": [ { - "embedding": 0.77, + "embeddings": 0.77, "index": 0, } ] diff --git a/python/tests/test_pandas.py b/python/tests/test_pandas.py index 31e034f..47ad819 100644 --- a/python/tests/test_pandas.py +++ b/python/tests/test_pandas.py @@ -62,7 +62,7 @@ def test_insert_by_dataframe_success(check_column_mock: MagicMock): test_dataframe = pd.DataFrame( [[1, 2]], - columns=["embedding", "metadata"], + columns=["embeddings", "metadata"], ) pandas_client.insert_by_dataframe(test_dataframe) @@ -89,7 +89,7 @@ def test_insert_by_dataframe_missing_embedding_column(): pandas_client.insert_by_dataframe(missing_embedding_column_dataframe) assert ( - pandas.MISSING_COLUMN.substitute(column_name="embedding") + pandas.MISSING_COLUMN.substitute(column_name="embeddings") in excinfo.value.__notes__ ) @@ -102,7 +102,7 @@ def test_update_by_dataframe_success(check_column_mock: MagicMock): test_dataframe = pd.DataFrame( [[1, 2]], - columns=["embedding", "metadata"], + columns=["embeddings", "metadata"], ) pandas_client.update_by_dataframe(test_dataframe) @@ -129,6 +129,6 @@ def test_update_by_dataframe_missing_embedding_column(): pandas_client.update_by_dataframe(missing_embedding_column_dataframe) assert ( - pandas.MISSING_COLUMN.substitute(column_name="embedding") + pandas.MISSING_COLUMN.substitute(column_name="embeddings") in excinfo.value.__notes__ ) diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index a5dc417..5763e00 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -165,11 +165,11 @@ def test_writer_column_insert(insert_mock: MagicMock, mock_writer: writer.Writer test_document_metadatas = [{"mock": "metadata"}, {"mock2": "metadata2"}] expected_insert_document = [ { - "embedding": test_embeddings[0], + "embeddings": test_embeddings[0], "metadata": test_document_metadatas[0], }, { - "embedding": test_embeddings[1], + "embeddings": test_embeddings[1], "metadata": test_document_metadatas[1], }, ] @@ -193,7 +193,7 @@ def test_writer_column_insert_collection_id_collection_name_passed_through( test_document_metadatas = [{"mock": "metadata"}] expected_insert_document = [ { - "embedding": test_embeddings[0], + "embeddings": test_embeddings[0], "metadata": test_document_metadatas[0], }, ] @@ -222,7 +222,7 @@ def test_writer_column_insert_shorter_metadatas_length( test_document_metadatas = [{"mock": "metadata"}] expected_insert_document = [ { - "embedding": test_embeddings[0], + "embeddings": test_embeddings[0], "metadata": test_document_metadatas[0], }, ] @@ -252,7 +252,7 @@ def test_writer_column_insert_shorter_embeddings_length( test_document_metadatas = [{"mock": "metadata"}, {"mock2": "metadata2"}] expected_insert_document = [ { - "embedding": test_embeddings[0], + "embeddings": test_embeddings[0], "metadata": test_document_metadatas[0], }, ] @@ -342,11 +342,11 @@ def test_writer_column_update(update_mock: MagicMock, mock_writer: writer.Writer test_document_metadatas = [{"mock": "metadata"}, {"mock2": "metadata2"}] expected_update_document = [ { - "embedding": test_embeddings[0], + "embeddings": test_embeddings[0], "metadata": test_document_metadatas[0], }, { - "embedding": test_embeddings[1], + "embeddings": test_embeddings[1], "metadata": test_document_metadatas[1], }, ] @@ -370,7 +370,7 @@ def test_writer_column_update_collection_id_collection_name_passed_through( test_document_metadatas = [{"mock": "metadata"}] expected_update_document = [ { - "embedding": test_embeddings[0], + "embeddings": test_embeddings[0], "metadata": test_document_metadatas[0], }, ] @@ -399,7 +399,7 @@ def test_writer_column_insert_shorter_metadatas_length( test_document_metadatas = [{"mock": "metadata"}] expected_update_document = [ { - "embedding": test_embeddings[0], + "embeddings": test_embeddings[0], "metadata": test_document_metadatas[0], }, ] @@ -429,7 +429,7 @@ def test_writer_column_update_shorter_embeddings_length( test_document_metadatas = [{"mock": "metadata"}, {"mock2": "metadata2"}] expected_update_document = [ { - "embedding": test_embeddings[0], + "embeddings": test_embeddings[0], "metadata": test_document_metadatas[0], }, ] From c9ac661b8b8b121f535cb7bc2d2cefcd50d4d30a Mon Sep 17 00:00:00 2001 From: scottwey Date: Thu, 26 Oct 2023 08:28:43 -0700 Subject: [PATCH 7/7] vectors -> values, dim -> dimensionality --- python/starpoint/embedding.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/starpoint/embedding.py b/python/starpoint/embedding.py index ec77fef..ff6be66 100644 --- a/python/starpoint/embedding.py +++ b/python/starpoint/embedding.py @@ -31,12 +31,12 @@ class Embedding(object): - vectors: List[float] - dim: int + values: List[float] + dimensionality: int - def __init__(self, vectors: List[float], dim: Optional[int] = None): - self.vectors = vectors - self.dim = len(vectors) if dim is None else dim + def __init__(self, values: List[float], dimensionality: Optional[int] = None): + self.values = values + self.dimensionality = len(values) if dimensionality is None else dimensionality class EmbeddingModel(Enum):