Skip to content

Commit

Permalink
Add threaded brute search to Hoplite, and moar tests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 658467055
  • Loading branch information
sdenton4 authored and copybara-github committed Aug 1, 2024
1 parent b2260be commit 12bee29
Show file tree
Hide file tree
Showing 11 changed files with 408 additions and 219 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,7 @@ jobs:
run: poetry run python -m unittest discover -s chirp/train_tests -p "*test.py"
- name: Test inference with unittest
run: poetry run python -m unittest discover -s chirp/inference/tests -p "*test.py"
- name: Test hoplite with unittest
run: poetry run python -m unittest discover -s chirp/projects/agile2/tests -p "*test.py"
- name: Test agile2 with unittest
run: poetry run python -m unittest discover -s chirp/projects/hoplite/tests -p "*test.py"
144 changes: 144 additions & 0 deletions chirp/projects/hoplite/brutalism.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# coding=utf-8
# Copyright 2024 The Perch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Brute force search and reranking utilities."""

import concurrent
import threading
from typing import Any, Callable, Sequence

from chirp.projects.hoplite import interface
from chirp.projects.hoplite import search_results
import numpy as np


def worker_initializer(state):
name = threading.current_thread().name
state[name + 'db'] = state['db'].thread_split()


def brute_search_worker_fn(emb_ids: Sequence[int], state: dict[str, Any]):
name = threading.current_thread().name
emb_ids, embeddings = state[name + 'db'].get_embeddings(emb_ids)
scores = state['score_fn'](embeddings, state['query_embedding'])
top_locs = np.argpartition(scores, state['search_list_size'], axis=-1)
return emb_ids[top_locs], scores[top_locs]


def threaded_brute_search(
db: interface.GraphSearchDBInterface,
query_embedding: np.ndarray,
search_list_size: int,
score_fn: Callable[[np.ndarray, np.ndarray], float],
batch_size: int = 1024,
max_workers: int = 8,
) -> tuple[search_results.TopKSearchResults, np.ndarray]:
"""Performs a brute-force search for neighbors of the query embedding.
Args:
db: Graph DB instance.
query_embedding: Query embedding vector.
search_list_size: Number of results to return.
score_fn: Scoring function to use for ranking results.
batch_size: Number of embeddings to score in each thread.
max_workers: Maximum number of threads to use for the search.
Returns:
A TopKSearchResults object containing the search results, and a list of
all scores computed during the search.
"""
state = {}
state['search_list_size'] = search_list_size
state['db'] = db
state['query_embedding'] = query_embedding
state['score_fn'] = score_fn

results = search_results.TopKSearchResults(search_list_size)
# Commit the DB, since we are about to create views in multiple threads.
db.commit()
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers,
initializer=worker_initializer,
initargs=(state,),
) as executor:
ids = db.get_embedding_ids()
futures = []
for q in range(0, ids.shape[0], batch_size):
futures.append(
executor.submit(
brute_search_worker_fn, ids[q : q + batch_size], state
)
)
all_scores = []
for f in futures:
idxes, scores = f.result()
all_scores.append(scores)
for idx, score in zip(idxes, scores):
if not results.will_filter(idx, score):
results.update(
search_results.SearchResult(idx, score), force_insert=True
)
all_scores = np.concatenate(all_scores)
return results, all_scores


def brute_search(
db: interface.GraphSearchDBInterface,
query_embedding: np.ndarray,
search_list_size: int,
score_fn: Callable[[np.ndarray, np.ndarray], float],
) -> tuple[search_results.TopKSearchResults, np.ndarray]:
"""Performs a brute-force search for neighbors of the query embedding.
Args:
db: Graph DB instance.
query_embedding: Query embedding vector.
search_list_size: Number of results to return.
score_fn: Scoring function to use for ranking results.
Returns:
A TopKSearchResults object containing the search results, and a list of
all scores computed during the search.
"""
results = search_results.TopKSearchResults(search_list_size)
all_scores = []
for idx in db.get_embedding_ids():
target_embedding = db.get_embedding(idx)
score = score_fn(query_embedding, target_embedding)
all_scores.append(score)
# Check filtering and then force insert to avoid creating a SearchResult
# object for discarded objects. This saves a small amount of time in the
# inner loop.
if not results.will_filter(idx, score):
results.update(search_results.SearchResult(idx, score), force_insert=True)
return results, np.array(all_scores)


def rerank(
query_embedding: np.ndarray,
results: search_results.TopKSearchResults,
db: interface.GraphSearchDBInterface,
score_fn: Callable[[np.ndarray, np.ndarray], float],
) -> search_results.TopKSearchResults:
"""Rescore the search results using a different score function."""
new_results = search_results.TopKSearchResults(results.top_k)
for r in results:
new_results.update(
search_results.SearchResult(
r.embedding_id,
score_fn(query_embedding, db.get_embedding(r.embedding_id)),
)
)
return new_results
39 changes: 3 additions & 36 deletions chirp/projects/hoplite/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@

"""Utility functions for graph operations."""

from collections.abc import Callable, Iterator
from typing import Iterator

from chirp.projects.hoplite import interface
from chirp.projects.hoplite import search_results
import numpy as np
import tqdm

Expand Down Expand Up @@ -47,10 +46,10 @@ def insert_random_embeddings(
):
"""Insert randomly generated embedding vectors into the DB."""
rng = np.random.default_rng(seed=seed)
np_alpha = list('ABCDEFGHIJKLMNOPQRSTUVWXYZ')
np_alpha = list('ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789')
dataset_names = ('a', 'b', 'c')
for _ in tqdm.tqdm(range(num_embeddings)):
embedding = np.float32(rng.normal(size=emb_dim, loc=0, scale=1.0))
embedding = np.float32(rng.normal(size=emb_dim, loc=0, scale=0.1))
dataset_name = rng.choice(dataset_names)
source_name = ''.join(
[str(a) for a in np.random.choice(np_alpha, size=3, replace=False)]
Expand Down Expand Up @@ -156,35 +155,3 @@ def random_walk(
break
idx = edges[rng.integers(0, len(edges))]
return idx


def brute_search(
db: interface.GraphSearchDBInterface,
query_embedding: np.ndarray,
search_list_size: int,
score_fn: Callable[[np.ndarray, np.ndarray], float],
) -> tuple[search_results.TopKSearchResults, np.ndarray]:
"""Performs a brute-force search for neighbors of the query embedding.
Args:
db: Graph DB instance.
query_embedding: Query embedding vector.
search_list_size: Number of results to return.
score_fn: Scoring function to use for ranking results.
Returns:
A TopKSearchResults object containing the search results, and a list of
all scores computed during the search.
"""
results = search_results.TopKSearchResults(search_list_size)
all_scores = []
for idx in db.get_embedding_ids():
target_embedding = db.get_embedding(idx)
score = score_fn(query_embedding, target_embedding)
all_scores.append(score)
# Check filtering and then force insert to avoid creating a SearchResult
# object for discarded objects. This saves a small amount of time in the
# inner loop.
if not results.will_filter(idx, score):
results.update(search_results.SearchResult(idx, score), force_insert=True)
return results, np.array(all_scores)
5 changes: 5 additions & 0 deletions chirp/projects/hoplite/in_mem_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def setup(self):
# Dropping all edges initializes the edge table.
self.drop_all_edges()

def thread_split(self) -> interface.GraphSearchDBInterface:
"""Return a readable instance of the database."""
# Since numpy arrays are in shared memory, we can reuse the same object.
return self

def count_embeddings(self) -> int:
"""Return a count of all embeddings in the database."""
return len(self.embedding_ids)
Expand Down
4 changes: 3 additions & 1 deletion chirp/projects/hoplite/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
"""Vamana implementation."""

import collections
import concurrent
import dataclasses
from typing import Callable

from chirp.projects.hoplite import brutalism
from chirp.projects.hoplite import graph_utils
from chirp.projects.hoplite import interface
from chirp.projects.hoplite import score_functions
Expand Down Expand Up @@ -454,7 +456,7 @@ def test_recall(
graph_results = search_fn(query)
graph_keys = set(r.embedding_id for r in graph_results)

brute_results, _ = graph_utils.brute_search(
brute_results, _ = brutalism.brute_search(
self.db, query, search_list_size=eval_top_k, score_fn=self.score_fn
)
brute_keys = set(r.embedding_id for r in brute_results)
Expand Down
7 changes: 7 additions & 0 deletions chirp/projects/hoplite/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ def setup(self):
def commit(self) -> None:
"""Commit any pending transactions to the database."""

@abc.abstractmethod
def thread_split(self) -> 'GraphSearchDBInterface':
"""Get a new instance of the database with the same contents.
For example, SQLite DB's need a distinct object in each thread.
"""

@abc.abstractmethod
def insert_metadata(self, key: str, value: config_dict.ConfigDict) -> None:
"""Insert a key-value pair into the metadata table."""
Expand Down
9 changes: 7 additions & 2 deletions chirp/projects/hoplite/sqlite_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class SQLiteGraphSearchDB(interface.GraphSearchDBInterface):
"""SQLite implementation of graph search database."""

db: sqlite3.Connection
db_path: str
embedding_dim: int
embedding_dtype: type[Any] = np.float16
_cursor: sqlite3.Cursor | None = None
Expand All @@ -50,7 +51,11 @@ def create(
cursor = db.cursor()
cursor.execute('PRAGMA journal_mode=WAL;') # Enable WAL mode
db.commit()
return SQLiteGraphSearchDB(db, embedding_dim)
return SQLiteGraphSearchDB(db, db_path, embedding_dim, embedding_dtype)

def thread_split(self):
"""Get a new instance of the SQLite DB."""
return self.create(self.db_path, self.embedding_dim, self.embedding_dtype)

def _get_cursor(self) -> sqlite3.Cursor:
if self._cursor is None:
Expand Down Expand Up @@ -189,7 +194,7 @@ def get_metadata(self, key: str | None) -> config_dict.ConfigDict:
)
result = cursor.fetchone()
if result is None:
raise ValueError(f'Metadata key not found: {key}')
raise KeyError(f'Metadata key not found: {key}')
return config_dict.ConfigDict(json.loads(result[0]))

def get_dataset_names(self) -> tuple[str, ...]:
Expand Down
77 changes: 77 additions & 0 deletions chirp/projects/hoplite/tests/brutalism_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# coding=utf-8
# Copyright 2024 The Perch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for brute search functionality."""

import shutil
import tempfile

from chirp.projects.hoplite import brutalism
from chirp.projects.hoplite.tests import test_utils
import numpy as np

from absl.testing import absltest
from absl.testing import parameterized

EMBEDDING_SIZE = 8


class BrutalismTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self.tempdir = tempfile.mkdtemp()

def tearDown(self):
super().tearDown()
shutil.rmtree(self.tempdir)

@parameterized.product(
db_type=(
'in_mem',
'sqlite',
),
)
def test_threaded_brute_search(self, db_type):
rng = np.random.default_rng(42)
db = test_utils.make_db(self.tempdir, db_type, 1000, rng, EMBEDDING_SIZE)
query_idx = db.get_one_embedding_id()
query_embedding = db.get_embedding(query_idx)
results, scores = brutalism.brute_search(
db,
query_embedding,
search_list_size=10,
score_fn=np.dot,
)
self.assertSequenceEqual(scores.shape, (1000,))
self.assertLen(results.search_results, 10)
got_ids = [r.embedding_id for r in results]
self.assertIn(query_idx, got_ids)

# Check agreement of threaded brute search with the non-threaded version.
t_results, t_scores = brutalism.threaded_brute_search(
db,
query_embedding,
search_list_size=10,
batch_size=128,
score_fn=np.dot,
)
np.testing.assert_equal(np.sort(t_scores), np.sort(scores))
t_got_ids = [r.embedding_id for r in t_results]
self.assertSequenceEqual(got_ids, t_got_ids)


if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit 12bee29

Please sign in to comment.