Skip to content

Commit

Permalink
remove embedding object and use dicts instead
Browse files Browse the repository at this point in the history
  • Loading branch information
DuongTyler committed Oct 27, 2023
1 parent 215b9a9 commit d629d11
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 26 deletions.
12 changes: 11 additions & 1 deletion python/starpoint/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Dict, Optional
from typing import Dict, Optional, List
from uuid import UUID

import requests
Expand Down Expand Up @@ -102,3 +102,13 @@ def _check_collection_identifier_collision(
raise ValueError(NO_COLLECTION_VALUE_ERROR)
elif collection_id and collection_name:
raise ValueError(MULTI_COLLECTION_VALUE_ERROR)


def _ensure_embedding_dict(embeddings: List[float] | Dict[str, List[float] | int] | None):
if isinstance(embeddings, list):
dict_embeddings = {
"values": embeddings,
"dimensionality": len(embeddings)
}
return dict_embeddings
return embeddings
14 changes: 4 additions & 10 deletions python/starpoint/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import validators

from starpoint import reader, writer, _utils
from starpoint.embedding import Embedding

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -88,7 +87,7 @@ def insert(

def column_insert(
self,
embeddings: List[Embedding],
embeddings: List[Dict[str, List[float] | int]],
document_metadatas: List[Dict[Any, Any]],
collection_id: Optional[str] = None,
collection_name: Optional[str] = None,
Expand Down Expand Up @@ -127,7 +126,7 @@ def query(
sql: Optional[str] = None,
collection_id: Optional[str] = None,
collection_name: Optional[str] = None,
query_embedding: Optional[List[float] | Embedding] = None,
query_embedding: Optional[List[float] | Dict[str, List[float] | int]] = None,
params: Optional[List[Any]] = None,
text_search_query: Optional[List[str]] = None,
text_search_weight: Optional[float] = None,
Expand All @@ -143,6 +142,7 @@ def query(
collection_name: The collection's name where the query will happen.
This or the `collection_id` needs to be provided.
query_embedding: An embedding to query against the collection using similarity search.
This is of the shape {"values": List[float], "dimensionality": int}
params: values for parameterized sql
Returns:
Expand All @@ -154,12 +154,6 @@ def query(
requests.exceptions.SSLError: Failure likely due to network issues.
"""

# check if query embedding is a float, if it is, convert to a embedding object
if isinstance(query_embedding, list):
query_embedding = Embedding(
vectors=query_embedding,
dim=len(query_embedding))

return self.reader.query(
sql=sql,
collection_id=collection_id,
Expand Down Expand Up @@ -231,7 +225,7 @@ def update(

def column_update(
self,
embeddings: List[Embedding],
embeddings: List[Dict[str, List[float] | int]],
document_metadatas: List[Dict[Any, Any]],
collection_id: Optional[str] = None,
collection_name: Optional[str] = None,
Expand Down
9 changes: 0 additions & 9 deletions python/starpoint/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,6 @@
)


class Embedding(object):
values: List[float]
dimensionality: int

def __init__(self, values: List[float], dimensionality: Optional[int] = None):
self.values = values
self.dimensionality = len(values) if dimensionality is None else dimensionality


class EmbeddingModel(Enum):
MINILM = "MINI_LM"

Expand Down
9 changes: 6 additions & 3 deletions python/starpoint/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
_build_header,
_check_collection_identifier_collision,
_validate_host,
_ensure_embedding_dict
)

from starpoint.embedding import Embedding

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,7 +49,7 @@ def query(
sql: Optional[str] = None,
collection_id: Optional[str] = None,
collection_name: Optional[str] = None,
query_embeddings: Optional[Embedding] = None,
query_embeddings: Optional[Dict[str, int | List[float]] | List[float]] = None,
params: Optional[List[Any]] = None,
text_search_query: Optional[List[str]] = None,
text_search_weight: Optional[float] = None,
Expand Down Expand Up @@ -89,10 +89,13 @@ def query(
)
"""

# check if type of query embeddings is list of float, if so convert to a dict
query_embeddings = _ensure_embedding_dict(query_embeddings)

request_data = dict(
collection_id=collection_id,
collection_name=collection_name,
query_embeddings=query_embeddings,
query_embedding=query_embeddings,
sql=sql,
params=params,
text_search_query=text_search_query,
Expand Down
5 changes: 2 additions & 3 deletions python/starpoint/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
_validate_host,
)

from starpoint.embedding import Embedding

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -171,7 +170,7 @@ def insert(

def column_insert(
self,
embeddings: List[Embedding],
embeddings: List[Dict[str, List[float] | int]],
document_metadatas: List[Dict[Any, Any]],
collection_id: Optional[str] = None,
collection_name: Optional[str] = None,
Expand Down Expand Up @@ -282,7 +281,7 @@ def update(

def column_update(
self,
embeddings: List[Embedding],
embeddings: List[Dict[str, List[float] | int]],
document_metadatas: List[Dict[Any, Any]],
collection_id: Optional[str] = None,
collection_name: Optional[str] = None,
Expand Down

0 comments on commit d629d11

Please sign in to comment.