Skip to content

Commit

Permalink
Fix a flipped sign in the targeted score function.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 691839760
  • Loading branch information
sdenton4 authored and copybara-github committed Oct 31, 2024
1 parent 639b84f commit 982e225
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 19 deletions.
13 changes: 1 addition & 12 deletions chirp/projects/hoplite/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,7 @@ def from_db(
cls, db: interface.GraphSearchDBInterface, score_fn_name: str = 'dot'
) -> 'HopliteSearchIndex':
"""Create a VamanaSearchIndex from a GraphSearchDBInterface impl."""
# TODO(tomdenton): Use an enum for metric_name.
if score_fn_name in ('mip', 'dot'):
# mip == Max Inner Prouct
score_fn = score_functions.numpy_dot
elif score_fn_name in ('jax_mip', 'jax_dot'):
score_fn = score_functions.get_jax_dot()
elif score_fn_name == 'cosine':
score_fn = score_functions.numpy_cos
elif score_fn_name == 'euclidean':
score_fn = score_functions.numpy_euclidean
else:
raise ValueError(f'Unknown metric name: {score_fn_name}')
score_fn = score_functions.get_score_fn(score_fn_name)
return cls(db, score_fn=score_fn)

def initialize_index(self, out_degree: int, seed: int = 42) -> None:
Expand Down
15 changes: 8 additions & 7 deletions chirp/projects/hoplite/score_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def get_score_fn(
score_fn = numpy_dot
elif name == 'cos':
score_fn = numpy_cos
elif name == 'euclidean':
score_fn = numpy_euclidean
elif name == 'neg_euclidean':
score_fn = numpy_neg_euclidean
else:
raise ValueError('Unknown score function: ', name)

Expand All @@ -38,7 +38,8 @@ def get_score_fn(
bias_fn = score_fn

if target_score is not None:
targeted_fn = lambda x, y: np.abs(bias_fn(x, y) - target_score)
# We want 'up is good', so take the negative absolute value.
targeted_fn = lambda x, y: -np.abs(bias_fn(x, y) - target_score)
else:
targeted_fn = bias_fn

Expand Down Expand Up @@ -67,18 +68,18 @@ def numpy_cos(data: np.ndarray, query: np.ndarray) -> np.ndarray:
return np.dot(unit_data, unit_query)


def numpy_euclidean(data: np.ndarray, query: np.ndarray) -> np.ndarray:
"""Numpy L2 distance allowing multiple queries."""
def numpy_neg_euclidean(data: np.ndarray, query: np.ndarray) -> np.ndarray:
"""Negative L2 distance allowing multiple queries."""
data_norms = np.linalg.norm(data, axis=-1)
if len(query.shape) > 1:
query_norms = np.linalg.norm(query, axis=-1)
dot_products = np.tensordot(data, query, axes=(-1, -1))
pairs = data_norms[:, np.newaxis] + query_norms[np.newaxis, :]
return pairs - 2 * dot_products
return -pairs + 2 * dot_products

query_norm = np.linalg.norm(query)
dot_products = np.dot(data, query)
return data_norms - 2 * dot_products + query_norm
return -data_norms + 2 * dot_products + query_norm


def get_jax_dot():
Expand Down

0 comments on commit 982e225

Please sign in to comment.