Skip to content

Commit

Permalink
add top_k
Browse files Browse the repository at this point in the history
  • Loading branch information
scottwey committed Mar 7, 2024
1 parent 2ac5b21 commit a00538e
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "starpoint"
version = "0.6.1"
version = "0.6.2"
authors = [{ name = "pointable", email = "package-maintainers@pointable.ai" }]
description = "SDK for Starpoint DB"
readme = "README.md"
Expand Down
6 changes: 6 additions & 0 deletions starpoint/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def query(
text_search_query: Optional[List[str]] = None,
text_search_weight: Optional[float] = None,
tokenizer_type: Optional[reader.TokenizerType] = None,
top_k: Optional[int] = None,
) -> Dict[Any, Any]:
"""Queries a collection. This could be by sql or query embeddings.
`query()` method from [`Reader`](#reader-objects).
Expand All @@ -144,6 +145,10 @@ def query(
query_embedding: An embedding to query against the collection using similarity search.
This is of the shape {"values": List[float], "dimensionality": int}
params: values for parameterized sql
text_search_query: a list of strings to search for in the text column
text_search_weight: the weight to apply to the text search
tokenizer_type: the tokenizer to use for the text search
top_k: the number of results to return
Returns:
dict: query response json
Expand All @@ -163,6 +168,7 @@ def query(
text_search_query=text_search_query,
text_search_weight=text_search_weight,
tokenizer_type=tokenizer_type,
top_k=top_k,
)

def infer_schema(
Expand Down
4 changes: 4 additions & 0 deletions starpoint/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
class TokenizerType(Enum):
LLAMA2 = "llama2"
ENSTEM = "en_stem"
NAIVE = "naive"


class Reader(object):
Expand All @@ -54,6 +55,7 @@ def query(
text_search_query: Optional[List[str]] = None,
text_search_weight: Optional[float] = None,
tokenizer_type: Optional[TokenizerType] = None,
top_k: Optional[int] = None,
) -> Dict[Any, Any]:
"""Queries a collection. This could be by sql or query embeddings.
Expand All @@ -67,6 +69,7 @@ def query(
params: values for parameterized sql
text_search_weight: weight for text search
tokenizer_type: the type of tokenizer used to perform full text search
top_k: the number of results to return
Returns:
dict: query response json
Expand Down Expand Up @@ -101,6 +104,7 @@ def query(
text_search_query=text_search_query,
text_search_weight=text_search_weight,
tokenizer_type=tokenizer_type,
top_k=top_k,
)
try:
response = requests.post(
Expand Down

0 comments on commit a00538e

Please sign in to comment.