Skip to content

Commit

Permalink
Add cosine_distance for sparse vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
mosabua committed Nov 4, 2024
1 parent bd9fee7 commit 4b126bf
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1404,6 +1404,24 @@ public static Double cosineSimilarity(
return dotProduct / (normLeftMap * normRightMap);
}

@Description("Calculates the cosine distance between the give sparse vectors")
@ScalarFunction
@SqlType(StandardTypes.DOUBLE)
public static double cosineDistance(
@OperatorDependency(
operator = IDENTICAL,
argumentTypes = {"varchar", "varchar"},
convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN)) BlockPositionIsIdentical varcharIdentical,
@OperatorDependency(
operator = HASH_CODE,
argumentTypes = "varchar",
convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL)) BlockPositionHashCode varcharHashCode,
@SqlType("map(varchar,double)") SqlMap leftMap,
@SqlType("map(varchar,double)") SqlMap rightMap)
{
return 1.0 - cosineSimilarity(varcharIdentical, varcharHashCode, leftMap, rightMap);
}

private static double mapDotProduct(BlockPositionIsIdentical varcharIdentical, BlockPositionHashCode varcharHashCode, SqlMap leftMap, SqlMap rightMap)
{
int leftRawOffset = leftMap.getRawOffset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3457,6 +3457,25 @@ public void testCosineSimilarity()
.isNull(DOUBLE);
}

@Test
public void testCosineDistance()
{
assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b'], ARRAY[1.0E0, 2.0E0])", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))
.isEqualTo(1 - (2 * 3 / (Math.sqrt(5) * Math.sqrt(10))));

assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b', 'c'], ARRAY[1.0E0, 2.0E0, -1.0E0])", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))
.isEqualTo(1 - ((2 * 3 + -1 * 1) / (Math.sqrt(1 + 4 + 1) * Math.sqrt(1 + 9))));

assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b', 'c'], ARRAY[1.0E0, 2.0E0, -1.0E0])", "map(ARRAY['d', 'e'], ARRAY[1.0E0, 3.0E0])"))
.isEqualTo(1.0);

assertThat(assertions.function("cosine_distance", "null", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))
.isNull(DOUBLE);

assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b'], ARRAY[1.0E0, null])", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))

Check failure on line 3475 in core/trino-main/src/test/java/io/trino/operator/scalar/TestMathFunctions.java

View workflow job for this annotation

GitHub Actions / test (core/trino-main)

TestMathFunctions.testCosineDistance

Cannot invoke "java.lang.Double.doubleValue()" because the return value of "io.trino.operator.scalar.MathFunctions.cosineSimilarity(io.trino.type.BlockTypeOperators$BlockPositionIsIdentical, io.trino.type.BlockTypeOperators$BlockPositionHashCode, io.trino.spi.block.SqlMap, io.trino.spi.block.SqlMap)" is null
.isNull(DOUBLE);
}

@Test
public void testInverseNormalCdf()
{
Expand Down
10 changes: 10 additions & 0 deletions docs/src/main/sphinx/functions/math.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,16 @@ SELECT cosine_distance(ARRAY[1.0, 2.0], ARRAY[3.0, 4.0]);
```
:::

:::{function} cosine_distance(x, y) -> double
:no-index:
Calculates the cosine distance between two sparse vectors:

```sql
SELECT cosine_distance(MAP(ARRAY['a'], ARRAY[1.0]), MAP(ARRAY['a'], ARRAY[2.0]));
-- 0.0
```
:::

:::{function} cosine_similarity(array(double), array(double)) -> double
Calculates the cosine similarity of two dense vectors:

Expand Down

0 comments on commit 4b126bf

Please sign in to comment.