From 1143cff7dee0622014ca8fc3c55dd20d63a34d2b Mon Sep 17 00:00:00 2001 From: Tom Denton Date: Fri, 9 Aug 2024 08:55:38 -0700 Subject: [PATCH] Jax/GPU Indexing. Assumes max-inner-product score function. PiperOrigin-RevId: 661280495 --- chirp/projects/hoplite/index_jax.py | 304 ++++++++++++++++++ .../projects/hoplite/tests/index_jax_test.py | 146 +++++++++ 2 files changed, 450 insertions(+) create mode 100644 chirp/projects/hoplite/index_jax.py create mode 100644 chirp/projects/hoplite/tests/index_jax_test.py diff --git a/chirp/projects/hoplite/index_jax.py b/chirp/projects/hoplite/index_jax.py new file mode 100644 index 00000000..7432a6ae --- /dev/null +++ b/chirp/projects/hoplite/index_jax.py @@ -0,0 +1,304 @@ +# coding=utf-8 +# Copyright 2024 The Perch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tooling for building a Hoplite index using GPU.""" + +import dataclasses +import itertools + +import jax +from jax import numpy as jnp +import numpy as np + + +@dataclasses.dataclass +class IndexingData: + step_num: int + targets: jnp.ndarray + candidates: jnp.ndarray + embeddings: jnp.ndarray + edges: jnp.ndarray + delegate_lists: jnp.ndarray + delegate_scores: jnp.ndarray + batch_size: int + alpha: float + max_violations: int + num_steps: int + + +# Note that this creates an import side-effect. +jax.tree_util.register_dataclass( + IndexingData, + data_fields=[ + 'step_num', + 'targets', + 'candidates', + 'embeddings', + 'edges', + 'delegate_lists', + 'delegate_scores', + 'alpha', + 'max_violations', + ], + meta_fields=['batch_size', 'num_steps'], +) + + +def delegate_indexing( + embeddings: jnp.ndarray, + edges: jnp.ndarray, + delegate_lists: jnp.ndarray | None, + delegate_scores: jnp.ndarray | None, + sample_size: int, + max_delegates: int, + alpha: float, + max_violations: int, +) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Create an edge set using delegated pruning. + + This is the main indexing method in this library, which uses the JIT-compiled + prune_delegate_jax function as the inner-loop function to prune and surface + delegates. + + Args: + embeddings: The embedding matrix. + edges: The edge matrix. + delegate_lists: The delegate lists for each node. + delegate_scores: The scores for each delegate. + sample_size: The number of additional random candidates to score at each + step. + max_delegates: The maximum number of delegates to keep for each node. + alpha: The pruning parameter. + max_violations: Maximum number of pruning condition violations to allow. + + Returns: + Final IndexingData. + """ + corpus_size = embeddings.shape[0] + targets = np.arange(corpus_size) + np.random.shuffle(targets) + candidates = np.arange(corpus_size) + np.random.shuffle(candidates) + if delegate_lists is None or delegate_scores is None: + delegate_lists, delegate_scores = make_delegate_lists( + corpus_size, max_delegates + ) + + batch_size = sample_size + edges.shape[1] + max_delegates + initial_data = IndexingData( + step_num=0, + targets=targets, + candidates=candidates, + embeddings=embeddings, + edges=edges, + delegate_lists=delegate_lists, + delegate_scores=delegate_scores, + batch_size=batch_size, + alpha=alpha, + max_violations=max_violations, + num_steps=corpus_size, + ) + updated_data = unrolled_prune_delegate(initial_data) + return updated_data + + +@jax.jit +def unrolled_prune_delegate(idx_data: IndexingData): + """Wrapper for running prune_delegate_jax in a loop.""" + cond_fn = lambda idx_data: idx_data.step_num < idx_data.num_steps + return jax.lax.while_loop(cond_fn, prune_delegate_jax, idx_data) + + +def prune_delegate_jax(idx_data: IndexingData): + """Select a new set of edges for the target node.""" + target_idx = idx_data.targets[idx_data.step_num].astype(int) + target_emb = idx_data.embeddings[target_idx] + + rolled_candidates = jnp.roll(idx_data.candidates, idx_data.batch_size + 37) + random_candidates = rolled_candidates[: idx_data.batch_size] + candidates = assemble_batch( + target_idx, + idx_data.edges[target_idx], + idx_data.delegate_lists[target_idx], + random_candidates, + ) + candidate_embs = idx_data.embeddings[candidates] + + p_out, candidates, scores_c_c = prune_jax( + target_emb, + candidates, + candidate_embs, + idx_data.alpha, + idx_data.max_violations, + ) + + # Update target edges and delete the old delegate list. + degree_bound = idx_data.edges.shape[1] + edges = idx_data.edges.at[target_idx].set(p_out[:degree_bound]) + + # Update delegate lists with high-scoring elements. + new_delegates, new_scores = update_delegates( + idx_data.delegate_lists, idx_data.delegate_scores, candidates, scores_c_c + ) + delegate_lists = idx_data.delegate_lists.at[candidates].set(new_delegates) + delegate_scores = idx_data.delegate_scores.at[candidates].set(new_scores) + + return IndexingData( + step_num=idx_data.step_num + 1, + targets=idx_data.targets, + candidates=rolled_candidates, + embeddings=idx_data.embeddings, + edges=edges, + delegate_lists=delegate_lists, + delegate_scores=delegate_scores, + batch_size=idx_data.batch_size, + alpha=idx_data.alpha, + max_violations=idx_data.max_violations, + num_steps=idx_data.num_steps, + ) + + +def assemble_batch( + target_idx: jnp.ndarray, + target_edges: jnp.ndarray, + target_delegates: jnp.ndarray, + random_candidates: jnp.ndarray, +): + """Assemble a batch of candidates for the target node.""" + max_edges = target_edges.shape[0] + max_dels = target_delegates.shape[0] + joined = jnp.concatenate([target_edges, target_delegates], axis=0) + joined = unique1d(joined) + joined_mask = jnp.logical_and(joined >= 0, joined != target_idx) + joined_candidates = random_candidates[: max_edges + max_dels] + joined_candidates = jnp.where(joined_mask, joined, joined_candidates) + + remainder = random_candidates[max_edges + max_dels :] + return jnp.concatenate([joined_candidates, remainder], axis=0) + + +def update_delegates( + delegate_lists: jnp.ndarray, + delegate_scores: jnp.ndarray, + candidates: jnp.ndarray, + candidate_scores: jnp.ndarray, +) -> tuple[jnp.ndarray, jnp.ndarray]: + """Update the delegate lists with the highest-scoring candidates.""" + max_delegates = delegate_lists.shape[1] + + prev_scores = delegate_scores[candidates, :] + safe_scores = jnp.fill_diagonal(candidate_scores, -jnp.inf, inplace=False) + combined_scores = jnp.concatenate([prev_scores, safe_scores], axis=1) + stacked_candidates = jnp.tile( + candidates[jnp.newaxis, :], [candidates.shape[0], 1] + ) + combined_delegates = jnp.concatenate( + [delegate_lists[candidates], stacked_candidates], axis=1 + ) + + # Eliminate repeated delegates. + combined_delegates = unique1d(combined_delegates) + combined_scores = jnp.where( + combined_delegates == -1, -jnp.inf, combined_scores + ) + + # Sort the combined delegates by score. + combined_scores, combined_delegates = cosort( # pylint: disable=unbalanced-tuple-unpacking + combined_scores, combined_delegates, descending=True + ) + new_scores = combined_scores[:, :max_delegates] + new_delegates = combined_delegates[:, :max_delegates] + return new_delegates, new_scores + + +def prune_jax( + target_emb: jnp.ndarray, + candidates: jnp.ndarray, + candidate_embs: jnp.ndarray, + alpha: float, + max_violations: int, +) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Prune a set of candidates for the target embedding.""" + scores_c_t = jnp.dot(candidate_embs, target_emb) + + scores_c_t, candidates, candidate_embs = cosort( + scores_c_t, candidates, candidate_embs, axis=0, descending=True + ) + scores_c_c = jnp.tensordot(candidate_embs, candidate_embs, axes=(-1, -1)) + + # Sparse neighborhood condition. + mask = scores_c_c >= scores_c_t[np.newaxis, :] + alpha + mask = jnp.triu(mask, k=1) + violation_counts = mask.sum(axis=0) + violation_mask = violation_counts > max_violations + + _, p_out, violation_mask = cosort( # pylint: disable=unbalanced-tuple-unpacking + violation_counts, candidates, violation_mask, axis=-1 + ) + + empty_edges = -1 * jnp.ones_like(p_out) + p_out = jnp.where(violation_mask, empty_edges, p_out) + return p_out, candidates, scores_c_c + + +def make_delegate_lists( + num_embeddings: int, max_delegates: int +) -> tuple[jnp.ndarray, jnp.ndarray]: + delegate_lists = -1 * jnp.ones([num_embeddings, max_delegates], jnp.int32) + delegate_scores = -jnp.inf * jnp.ones( + [num_embeddings, max_delegates], jnp.float16 + ) + return delegate_lists, delegate_scores + + +def unique1d(v: jnp.ndarray) -> jnp.ndarray: + """Deduplicate v along the last axis, replacing dupes with -1. + + For example, if v = [1, 2, 3, 1, 4, 2, 5], + then unique1d(v) = [1, 2, 3, -1, 4, -1, 5]. + Obviously, it's a bad idea to use this if -1 might appear in the array. + + Args: + v: an array of values. + + Returns: + v with duplicates replaced by -1. + """ + v_sort_locs = jnp.argsort(v, axis=-1) + v_sorted = jnp.take_along_axis(v, v_sort_locs, axis=-1) + dv_sorted = jnp.concatenate( + [ + jnp.ones_like(v[..., :1]), # keep first even if diff === 0. + jnp.diff(v_sorted), + ], + axis=-1, + ) + masked = jnp.where(dv_sorted == 0, -1, v_sorted) + inverse_perm = jnp.argsort(v_sort_locs, axis=-1) + return jnp.take_along_axis(masked, inverse_perm, axis=-1) + + +def cosort( + s: jnp.ndarray, *others, axis: int = -1, descending: bool = False +) -> tuple[jnp.ndarray, ...]: + """Sort s and others by the values of s.""" + sort_locs = jnp.argsort(s, axis=axis, descending=descending) + if len(s.shape) == 1 and axis == 0: + return tuple((a[sort_locs] for a in itertools.chain((s,), others))) + + return tuple(( + jnp.take_along_axis(a, sort_locs, axis=axis) + for a in itertools.chain((s,), others) + )) diff --git a/chirp/projects/hoplite/tests/index_jax_test.py b/chirp/projects/hoplite/tests/index_jax_test.py new file mode 100644 index 00000000..fc1c378f --- /dev/null +++ b/chirp/projects/hoplite/tests/index_jax_test.py @@ -0,0 +1,146 @@ +# coding=utf-8 +# Copyright 2024 The Perch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for jax-based indexing functionality.""" + +import functools + +from chirp.projects.hoplite import index +from chirp.projects.hoplite import index_jax +from chirp.projects.hoplite.tests import test_utils +from jax import numpy as jnp +import numpy as np + +from absl.testing import absltest + + +class IndexJaxTest(absltest.TestCase): + + def test_unique1d(self): + with self.subTest('rank_one'): + v = jnp.array([1, 2, 3, 1, 4, 2, 5]) + unique = index_jax.unique1d(v) + expected = np.array([1, 2, 3, -1, 4, -1, 5]) + np.testing.assert_array_equal(unique, expected) + + with self.subTest('rank_two'): + v = jnp.array([[1, 2, 3, 1, 4, 2, 5], [33, 44, 55, 55, 66, 77, 66]]) + unique = index_jax.unique1d(v) + expected = np.array( + [[1, 2, 3, -1, 4, -1, 5], [33, 44, 55, -1, 66, 77, -1]] + ) + np.testing.assert_array_equal(unique, expected) + + def test_cosort(self): + scores = jnp.array([10, 9, 8, 7, 6, 5, 4, 3, 2, 1]) + values = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + expected_scores = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + expected_values = jnp.array([10, 9, 8, 7, 6, 5, 4, 3, 2, 1]) + + with self.subTest('two_arrays'): + sorted_scores, sorted_values = index_jax.cosort(scores, values) + np.testing.assert_array_equal(sorted_scores, expected_scores) + np.testing.assert_array_equal(sorted_values, expected_values) + + with self.subTest('three_arrays'): + other_values = jnp.array( + [1.1, 3.1, 5.1, 7.1, 9.1, 2.1, 4.1, 6.1, 8.1, 10.1] + ) + expected_other_values = jnp.array( + [10.1, 8.1, 6.1, 4.1, 2.1, 9.1, 7.1, 5.1, 3.1, 1.1] + ) + sorted_scores, sorted_values, sorted_other_values = index_jax.cosort( + scores, values, other_values + ) + np.testing.assert_array_equal(sorted_scores, expected_scores) + np.testing.assert_array_equal(sorted_values, expected_values) + np.testing.assert_array_equal(sorted_other_values, expected_other_values) + + def test_update_delegates(self): + n = 16 + embedding_dim = 32 + degree_bound = 2 + np.random.seed(42) + embs = jnp.float16(np.random.normal(size=[n, embedding_dim])) + d_lists, d_scores = index_jax.make_delegate_lists(n, degree_bound) + + m = 8 + candidates = jnp.arange(m) + embs_c = embs[candidates] + scores_c_c = jnp.tensordot(embs_c, embs_c, axes=(-1, -1)) + new_delegates, new_scores = index_jax.update_delegates( + d_lists, d_scores, candidates, scores_c_c + ) + # Check shapes. + np.testing.assert_array_equal(new_delegates.shape, [m, degree_bound]) + np.testing.assert_array_equal(new_scores.shape, [m, degree_bound]) + # There should be no -1 delegates, since we have enough candidates. + self.assertEqual(np.sum(new_delegates >= 0), m * degree_bound) + for c in candidates: + for i, d in enumerate(new_delegates[c]): + score = jnp.float16(jnp.dot(embs[c], embs[d])) + self.assertEqual(score, new_scores[c, i]) + d_lists = d_lists.at[candidates].set(new_delegates) + d_scores = d_scores.at[candidates].set(new_scores) + + # Now update with the full set of candidates. + candidates = jnp.arange(n) + scores_c_c = jnp.tensordot(embs, embs, axes=(-1, -1)) + new_delegates, _ = index_jax.update_delegates( + d_lists, d_scores, candidates, scores_c_c + ) + # The new_delegates should contain the top two scores for each embedding, + # excluding the diagonal. + safe_scores = jnp.fill_diagonal(scores_c_c, -jnp.inf, inplace=False) + for c in range(n): + top_idxes = jnp.argsort(-safe_scores[c])[:2] + self.assertEqual(new_delegates[c, 0], top_idxes[0]) + self.assertEqual(new_delegates[c, 1], top_idxes[1]) + + def test_run_e2e(self): + rng = np.random.default_rng(seed=22) + db = test_utils.make_db('test_db', 'in_mem', 1024, rng, embedding_dim=16) + embs = jnp.float16(db.embeddings[:1024]) + edges = jnp.zeros([1024, 8], dtype=jnp.int32) + max_delegates = 32 + output_data = index_jax.delegate_indexing( + embs, + edges, + None, + None, + sample_size=32, + max_delegates=max_delegates, + alpha=0.5, + max_violations=1, + ) + np.testing.assert_array_equal( + output_data.delegate_lists.shape, [1024, max_delegates] + ) + np.testing.assert_array_equal( + output_data.delegate_scores.shape, [1024, max_delegates] + ) + db.edges = np.asarray(output_data.edges).copy() + + v = index.HopliteSearchIndex.from_db(db, score_fn_name='dot') + search_partial_fn = functools.partial( + v.greedy_search, start_node=0, search_list_size=128 + ) + search_fn = lambda q: search_partial_fn(q)[0] + recall = v.multi_test_recall(search_fn) + self.assertGreater(recall, 0.9) + + +if __name__ == '__main__': + absltest.main()