From d629d1157cab637d991fa4ef1d628d92aaae7005 Mon Sep 17 00:00:00 2001 From: Tyler Duong Date: Fri, 27 Oct 2023 14:09:58 -0700 Subject: [PATCH 1/5] remove embedding object and use dicts instead --- python/starpoint/_utils.py | 12 +++++++++++- python/starpoint/db.py | 14 ++++---------- python/starpoint/embedding.py | 9 --------- python/starpoint/reader.py | 9 ++++++--- python/starpoint/writer.py | 5 ++--- 5 files changed, 23 insertions(+), 26 deletions(-) diff --git a/python/starpoint/_utils.py b/python/starpoint/_utils.py index 8747cf1..5592754 100644 --- a/python/starpoint/_utils.py +++ b/python/starpoint/_utils.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, Optional +from typing import Dict, Optional, List from uuid import UUID import requests @@ -102,3 +102,13 @@ def _check_collection_identifier_collision( raise ValueError(NO_COLLECTION_VALUE_ERROR) elif collection_id and collection_name: raise ValueError(MULTI_COLLECTION_VALUE_ERROR) + + +def _ensure_embedding_dict(embeddings: List[float] | Dict[str, List[float] | int] | None): + if isinstance(embeddings, list): + dict_embeddings = { + "values": embeddings, + "dimensionality": len(embeddings) + } + return dict_embeddings + return embeddings diff --git a/python/starpoint/db.py b/python/starpoint/db.py index 5cddffb..38ab888 100644 --- a/python/starpoint/db.py +++ b/python/starpoint/db.py @@ -8,7 +8,6 @@ import validators from starpoint import reader, writer, _utils -from starpoint.embedding import Embedding LOGGER = logging.getLogger(__name__) @@ -88,7 +87,7 @@ def insert( def column_insert( self, - embeddings: List[Embedding], + embeddings: List[Dict[str, List[float] | int]], document_metadatas: List[Dict[Any, Any]], collection_id: Optional[str] = None, collection_name: Optional[str] = None, @@ -127,7 +126,7 @@ def query( sql: Optional[str] = None, collection_id: Optional[str] = None, collection_name: Optional[str] = None, - query_embedding: Optional[List[float] | Embedding] = None, + query_embedding: Optional[List[float] | Dict[str, List[float] | int]] = None, params: Optional[List[Any]] = None, text_search_query: Optional[List[str]] = None, text_search_weight: Optional[float] = None, @@ -143,6 +142,7 @@ def query( collection_name: The collection's name where the query will happen. This or the `collection_id` needs to be provided. query_embedding: An embedding to query against the collection using similarity search. + This is of the shape {"values": List[float], "dimensionality": int} params: values for parameterized sql Returns: @@ -154,12 +154,6 @@ def query( requests.exceptions.SSLError: Failure likely due to network issues. """ - # check if query embedding is a float, if it is, convert to a embedding object - if isinstance(query_embedding, list): - query_embedding = Embedding( - vectors=query_embedding, - dim=len(query_embedding)) - return self.reader.query( sql=sql, collection_id=collection_id, @@ -231,7 +225,7 @@ def update( def column_update( self, - embeddings: List[Embedding], + embeddings: List[Dict[str, List[float] | int]], document_metadatas: List[Dict[Any, Any]], collection_id: Optional[str] = None, collection_name: Optional[str] = None, diff --git a/python/starpoint/embedding.py b/python/starpoint/embedding.py index ff6be66..75c68c9 100644 --- a/python/starpoint/embedding.py +++ b/python/starpoint/embedding.py @@ -30,15 +30,6 @@ ) -class Embedding(object): - values: List[float] - dimensionality: int - - def __init__(self, values: List[float], dimensionality: Optional[int] = None): - self.values = values - self.dimensionality = len(values) if dimensionality is None else dimensionality - - class EmbeddingModel(Enum): MINILM = "MINI_LM" diff --git a/python/starpoint/reader.py b/python/starpoint/reader.py index 459faa6..9a0fed3 100644 --- a/python/starpoint/reader.py +++ b/python/starpoint/reader.py @@ -9,9 +9,9 @@ _build_header, _check_collection_identifier_collision, _validate_host, + _ensure_embedding_dict ) -from starpoint.embedding import Embedding LOGGER = logging.getLogger(__name__) @@ -49,7 +49,7 @@ def query( sql: Optional[str] = None, collection_id: Optional[str] = None, collection_name: Optional[str] = None, - query_embeddings: Optional[Embedding] = None, + query_embeddings: Optional[Dict[str, int | List[float]] | List[float]] = None, params: Optional[List[Any]] = None, text_search_query: Optional[List[str]] = None, text_search_weight: Optional[float] = None, @@ -89,10 +89,13 @@ def query( ) """ + # check if type of query embeddings is list of float, if so convert to a dict + query_embeddings = _ensure_embedding_dict(query_embeddings) + request_data = dict( collection_id=collection_id, collection_name=collection_name, - query_embeddings=query_embeddings, + query_embedding=query_embeddings, sql=sql, params=params, text_search_query=text_search_query, diff --git a/python/starpoint/writer.py b/python/starpoint/writer.py index af6dd5b..2b45259 100644 --- a/python/starpoint/writer.py +++ b/python/starpoint/writer.py @@ -10,7 +10,6 @@ _validate_host, ) -from starpoint.embedding import Embedding LOGGER = logging.getLogger(__name__) @@ -171,7 +170,7 @@ def insert( def column_insert( self, - embeddings: List[Embedding], + embeddings: List[Dict[str, List[float] | int]], document_metadatas: List[Dict[Any, Any]], collection_id: Optional[str] = None, collection_name: Optional[str] = None, @@ -282,7 +281,7 @@ def update( def column_update( self, - embeddings: List[Embedding], + embeddings: List[Dict[str, List[float] | int]], document_metadatas: List[Dict[Any, Any]], collection_id: Optional[str] = None, collection_name: Optional[str] = None, From 0448374276debe04803a002c15b13a869e3f91ec Mon Sep 17 00:00:00 2001 From: scottwey Date: Fri, 27 Oct 2023 15:59:51 -0700 Subject: [PATCH 2/5] pin to 3.11 for now --- .github/workflows/test_python.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index 2cdc607..c31c9d2 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -21,8 +21,8 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: '3.x' - cache: 'pip' + python-version: "3.11" + cache: "pip" - name: Install dependencies run: | python -m pip install --upgrade pip From 7df05eddaef032d84f6575ae3b6fe8f94e3139f9 Mon Sep 17 00:00:00 2001 From: Tyler Duong Date: Fri, 27 Oct 2023 16:10:28 -0700 Subject: [PATCH 3/5] fix tests --- python/tests/test_db.py | 6 ++---- python/tests/test_writer.py | 36 +++++++++++++++++++++++++++--------- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/python/tests/test_db.py b/python/tests/test_db.py index 1dfc0c6..38d0588 100644 --- a/python/tests/test_db.py +++ b/python/tests/test_db.py @@ -1,7 +1,6 @@ from tempfile import NamedTemporaryFile from uuid import uuid4 from unittest.mock import MagicMock, patch -from starpoint.embedding import Embedding import pytest from _pytest.monkeypatch import MonkeyPatch @@ -47,7 +46,7 @@ def test_client_insert(mock_writer: MagicMock, mock_reader: MagicMock): def test_client_column_insert(mock_writer: MagicMock, mock_reader: MagicMock): client = db.Client(api_key=uuid4()) - client.column_insert(embeddings=[Embedding([1.1])], document_metadatas=[{"mock": "value"}]) + client.column_insert(embeddings=[{"values": [1.1], "dimensionality": 1}], document_metadatas=[{"mock": "value"}]) mock_reader.assert_called_once() # Only called during init mock_writer().column_insert.assert_called_once() @@ -91,7 +90,6 @@ def test_client_update(mock_writer: MagicMock, mock_reader: MagicMock): def test_client_column_update(mock_writer: MagicMock, mock_reader: MagicMock): client = db.Client(api_key=uuid4()) - client.column_update(embeddings=[Embedding([1.1])], document_metadatas=[{"mock": "value"}]) - + client.column_update(embeddings=[{"values": [1.1], "dimensionality": 1}], document_metadatas=[{"mock": "value"}]) mock_reader.assert_called_once() # Only called during init mock_writer().column_update.assert_called_once() diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 5763e00..729a599 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -6,7 +6,6 @@ from requests.exceptions import SSLError from starpoint import writer -from starpoint.embedding import Embedding @pytest.fixture(scope="session") @@ -161,7 +160,10 @@ def test_writer_insert_SSLError( @patch("starpoint.writer.Writer.insert") def test_writer_column_insert(insert_mock: MagicMock, mock_writer: writer.Writer): - test_embeddings = [Embedding([0.88]), Embedding([0.71])] + test_embeddings = [ + {"values": [0.88], "dimensionality": 1}, + {"values": [0.71], "dimensionality": 1} + ] test_document_metadatas = [{"mock": "metadata"}, {"mock2": "metadata2"}] expected_insert_document = [ { @@ -189,7 +191,10 @@ def test_writer_column_insert(insert_mock: MagicMock, mock_writer: writer.Writer def test_writer_column_insert_collection_id_collection_name_passed_through( insert_mock: MagicMock, mock_writer: writer.Writer ): - test_embeddings = [Embedding([0.88])] + test_embeddings = [ + {"values": [0.88], "dimensionality": 1}, + ] + test_document_metadatas = [{"mock": "metadata"}] expected_insert_document = [ { @@ -218,7 +223,10 @@ def test_writer_column_insert_collection_id_collection_name_passed_through( def test_writer_column_insert_shorter_metadatas_length( insert_mock: MagicMock, mock_writer: writer.Writer, monkeypatch: MonkeyPatch ): - test_embeddings = [Embedding([0.88]), Embedding([0.71])] + test_embeddings = [ + {"values": [0.88], "dimensionality": 1}, + {"values": [0.71], "dimensionality": 1} + ] test_document_metadatas = [{"mock": "metadata"}] expected_insert_document = [ { @@ -248,7 +256,9 @@ def test_writer_column_insert_shorter_metadatas_length( def test_writer_column_insert_shorter_embeddings_length( insert_mock: MagicMock, mock_writer: writer.Writer, monkeypatch: MonkeyPatch ): - test_embeddings = [Embedding([0.88])] + test_embeddings = [ + {"values": [0.88], "dimensionality": 1}, + ] test_document_metadatas = [{"mock": "metadata"}, {"mock2": "metadata2"}] expected_insert_document = [ { @@ -338,7 +348,10 @@ def test_writer_update_SSLError( @patch("starpoint.writer.Writer.update") def test_writer_column_update(update_mock: MagicMock, mock_writer: writer.Writer): - test_embeddings = [Embedding([0.88]), Embedding([0.71])] + test_embeddings = [ + {"values": [0.88], "dimensionality": 1}, + {"values": [0.71], "dimensionality": 1} + ] test_document_metadatas = [{"mock": "metadata"}, {"mock2": "metadata2"}] expected_update_document = [ { @@ -366,7 +379,7 @@ def test_writer_column_update(update_mock: MagicMock, mock_writer: writer.Writer def test_writer_column_update_collection_id_collection_name_passed_through( update_mock: MagicMock, mock_writer: writer.Writer ): - test_embeddings = [Embedding([0.88])] + test_embeddings = [{"values": [0.88], "dimensionality": 1}] test_document_metadatas = [{"mock": "metadata"}] expected_update_document = [ { @@ -395,7 +408,10 @@ def test_writer_column_update_collection_id_collection_name_passed_through( def test_writer_column_insert_shorter_metadatas_length( update_mock: MagicMock, mock_writer: writer.Writer, monkeypatch: MonkeyPatch ): - test_embeddings = [Embedding([0.88]), Embedding([0.71])] + test_embeddings = [ + {"values": [0.88], "dimensionality": 1}, + {"values": [0.71], "dimensionality": 1} + ] test_document_metadatas = [{"mock": "metadata"}] expected_update_document = [ { @@ -425,7 +441,9 @@ def test_writer_column_insert_shorter_metadatas_length( def test_writer_column_update_shorter_embeddings_length( update_mock: MagicMock, mock_writer: writer.Writer, monkeypatch: MonkeyPatch ): - test_embeddings = [Embedding([0.88])] + test_embeddings = [ + {"values": [0.88], "dimensionality": 1}, + ] test_document_metadatas = [{"mock": "metadata"}, {"mock2": "metadata2"}] expected_update_document = [ { From aaa6a56c9c2fa9f68038bebc5fe0e2aeca763f08 Mon Sep 17 00:00:00 2001 From: Tyler Duong Date: Fri, 27 Oct 2023 21:38:53 -0700 Subject: [PATCH 4/5] fix updates --- python/starpoint/db.py | 2 ++ python/starpoint/writer.py | 6 ++++-- python/tests/test_writer.py | 11 ++++++++--- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/python/starpoint/db.py b/python/starpoint/db.py index 38ab888..31fa764 100644 --- a/python/starpoint/db.py +++ b/python/starpoint/db.py @@ -225,6 +225,7 @@ def update( def column_update( self, + ids: List[str], embeddings: List[Dict[str, List[float] | int]], document_metadatas: List[Dict[Any, Any]], collection_id: Optional[str] = None, @@ -253,6 +254,7 @@ def column_update( requests.exceptions.SSLError: Failure likely due to network issues. """ return self.writer.column_update( + ids=ids, embeddings=embeddings, document_metadatas=document_metadatas, collection_id=collection_id, diff --git a/python/starpoint/writer.py b/python/starpoint/writer.py index 2b45259..5c72ad1 100644 --- a/python/starpoint/writer.py +++ b/python/starpoint/writer.py @@ -281,6 +281,7 @@ def update( def column_update( self, + ids: List[str], embeddings: List[Dict[str, List[float] | int]], document_metadatas: List[Dict[Any, Any]], collection_id: Optional[str] = None, @@ -307,15 +308,16 @@ def column_update( ValueError: If both collection id and collection name are provided. requests.exceptions.SSLError: Failure likely due to network issues. """ - if len(embeddings) != len(document_metadatas): + if len(embeddings) != len(document_metadatas) or len(embeddings) != len(ids): LOGGER.warning(EMBEDDING_METADATA_LENGTH_MISMATCH_WARNING) documents = [ { + "id": id, "embeddings": embedding, "metadata": document_metadata, } - for embedding, document_metadata in zip(embeddings, document_metadatas) + for embedding, document_metadata, id in zip(embeddings, document_metadatas, ids) ] return self.update( diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 729a599..2175943 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -348,6 +348,7 @@ def test_writer_update_SSLError( @patch("starpoint.writer.Writer.update") def test_writer_column_update(update_mock: MagicMock, mock_writer: writer.Writer): + ids = ["a", "b"] test_embeddings = [ {"values": [0.88], "dimensionality": 1}, {"values": [0.71], "dimensionality": 1} @@ -365,7 +366,7 @@ def test_writer_column_update(update_mock: MagicMock, mock_writer: writer.Writer ] mock_writer.column_update( - embeddings=test_embeddings, document_metadatas=test_document_metadatas + ids=ids, embeddings=test_embeddings, document_metadatas=test_document_metadatas ) update_mock.assert_called_once_with( @@ -379,6 +380,7 @@ def test_writer_column_update(update_mock: MagicMock, mock_writer: writer.Writer def test_writer_column_update_collection_id_collection_name_passed_through( update_mock: MagicMock, mock_writer: writer.Writer ): + ids = ["a", "b"] test_embeddings = [{"values": [0.88], "dimensionality": 1}] test_document_metadatas = [{"mock": "metadata"}] expected_update_document = [ @@ -391,6 +393,7 @@ def test_writer_column_update_collection_id_collection_name_passed_through( expected_collection_name = "mock_name" mock_writer.column_update( + ids=ids, embeddings=test_embeddings, document_metadatas=test_document_metadatas, collection_id=expected_collection_id, @@ -408,6 +411,7 @@ def test_writer_column_update_collection_id_collection_name_passed_through( def test_writer_column_insert_shorter_metadatas_length( update_mock: MagicMock, mock_writer: writer.Writer, monkeypatch: MonkeyPatch ): + ids = ["a", "b"] test_embeddings = [ {"values": [0.88], "dimensionality": 1}, {"values": [0.71], "dimensionality": 1} @@ -424,7 +428,7 @@ def test_writer_column_insert_shorter_metadatas_length( monkeypatch.setattr(writer, "LOGGER", logger_mock) mock_writer.column_update( - embeddings=test_embeddings, document_metadatas=test_document_metadatas + ids=ids, embeddings=test_embeddings, document_metadatas=test_document_metadatas ) logger_mock.warning.assert_called_once_with( @@ -441,6 +445,7 @@ def test_writer_column_insert_shorter_metadatas_length( def test_writer_column_update_shorter_embeddings_length( update_mock: MagicMock, mock_writer: writer.Writer, monkeypatch: MonkeyPatch ): + ids = ["a", "b"] test_embeddings = [ {"values": [0.88], "dimensionality": 1}, ] @@ -456,7 +461,7 @@ def test_writer_column_update_shorter_embeddings_length( monkeypatch.setattr(writer, "LOGGER", logger_mock) mock_writer.column_update( - embeddings=test_embeddings, document_metadatas=test_document_metadatas + ids=ids, embeddings=test_embeddings, document_metadatas=test_document_metadatas ) logger_mock.warning.assert_called_once_with( From 34b73fe5fcb94b317210151d93181c124639da93 Mon Sep 17 00:00:00 2001 From: Tyler Duong Date: Fri, 27 Oct 2023 21:50:50 -0700 Subject: [PATCH 5/5] fix tests --- python/tests/test_db.py | 2 +- python/tests/test_writer.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/python/tests/test_db.py b/python/tests/test_db.py index 38d0588..699933a 100644 --- a/python/tests/test_db.py +++ b/python/tests/test_db.py @@ -90,6 +90,6 @@ def test_client_update(mock_writer: MagicMock, mock_reader: MagicMock): def test_client_column_update(mock_writer: MagicMock, mock_reader: MagicMock): client = db.Client(api_key=uuid4()) - client.column_update(embeddings=[{"values": [1.1], "dimensionality": 1}], document_metadatas=[{"mock": "value"}]) + client.column_update(ids=["a"], embeddings=[{"values": [1.1], "dimensionality": 1}], document_metadatas=[{"mock": "value"}]) mock_reader.assert_called_once() # Only called during init mock_writer().column_update.assert_called_once() diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 2175943..7afbee2 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -356,10 +356,12 @@ def test_writer_column_update(update_mock: MagicMock, mock_writer: writer.Writer test_document_metadatas = [{"mock": "metadata"}, {"mock2": "metadata2"}] expected_update_document = [ { + "id": "a", "embeddings": test_embeddings[0], "metadata": test_document_metadatas[0], }, { + "id": "b", "embeddings": test_embeddings[1], "metadata": test_document_metadatas[1], }, @@ -380,11 +382,12 @@ def test_writer_column_update(update_mock: MagicMock, mock_writer: writer.Writer def test_writer_column_update_collection_id_collection_name_passed_through( update_mock: MagicMock, mock_writer: writer.Writer ): - ids = ["a", "b"] + ids = ["a"] test_embeddings = [{"values": [0.88], "dimensionality": 1}] test_document_metadatas = [{"mock": "metadata"}] expected_update_document = [ { + "id": "a", "embeddings": test_embeddings[0], "metadata": test_document_metadatas[0], }, @@ -419,6 +422,7 @@ def test_writer_column_insert_shorter_metadatas_length( test_document_metadatas = [{"mock": "metadata"}] expected_update_document = [ { + "id": "a", "embeddings": test_embeddings[0], "metadata": test_document_metadatas[0], }, @@ -452,6 +456,7 @@ def test_writer_column_update_shorter_embeddings_length( test_document_metadatas = [{"mock": "metadata"}, {"mock2": "metadata2"}] expected_update_document = [ { + "id": "a", "embeddings": test_embeddings[0], "metadata": test_document_metadatas[0], },