From 4b126bfa7b91ed71c0632c7128edeacfbf9b03d6 Mon Sep 17 00:00:00 2001 From: Manfred Moser Date: Mon, 4 Nov 2024 10:49:06 -0800 Subject: [PATCH] Add cosine_distance for sparse vectors --- .../trino/operator/scalar/MathFunctions.java | 18 ++++++++++++++++++ .../operator/scalar/TestMathFunctions.java | 19 +++++++++++++++++++ docs/src/main/sphinx/functions/math.md | 10 ++++++++++ 3 files changed, 47 insertions(+) diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MathFunctions.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MathFunctions.java index 3b3cf4c5d8a03..b1c7f778be26e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MathFunctions.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MathFunctions.java @@ -1404,6 +1404,24 @@ public static Double cosineSimilarity( return dotProduct / (normLeftMap * normRightMap); } + @Description("Calculates the cosine distance between the give sparse vectors") + @ScalarFunction + @SqlType(StandardTypes.DOUBLE) + public static double cosineDistance( + @OperatorDependency( + operator = IDENTICAL, + argumentTypes = {"varchar", "varchar"}, + convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN)) BlockPositionIsIdentical varcharIdentical, + @OperatorDependency( + operator = HASH_CODE, + argumentTypes = "varchar", + convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL)) BlockPositionHashCode varcharHashCode, + @SqlType("map(varchar,double)") SqlMap leftMap, + @SqlType("map(varchar,double)") SqlMap rightMap) + { + return 1.0 - cosineSimilarity(varcharIdentical, varcharHashCode, leftMap, rightMap); + } + private static double mapDotProduct(BlockPositionIsIdentical varcharIdentical, BlockPositionHashCode varcharHashCode, SqlMap leftMap, SqlMap rightMap) { int leftRawOffset = leftMap.getRawOffset(); diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestMathFunctions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestMathFunctions.java index 20e50b065b1d0..33905b3339e7a 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestMathFunctions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestMathFunctions.java @@ -3457,6 +3457,25 @@ public void testCosineSimilarity() .isNull(DOUBLE); } + @Test + public void testCosineDistance() + { + assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b'], ARRAY[1.0E0, 2.0E0])", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])")) + .isEqualTo(1 - (2 * 3 / (Math.sqrt(5) * Math.sqrt(10)))); + + assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b', 'c'], ARRAY[1.0E0, 2.0E0, -1.0E0])", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])")) + .isEqualTo(1 - ((2 * 3 + -1 * 1) / (Math.sqrt(1 + 4 + 1) * Math.sqrt(1 + 9)))); + + assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b', 'c'], ARRAY[1.0E0, 2.0E0, -1.0E0])", "map(ARRAY['d', 'e'], ARRAY[1.0E0, 3.0E0])")) + .isEqualTo(1.0); + + assertThat(assertions.function("cosine_distance", "null", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])")) + .isNull(DOUBLE); + + assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b'], ARRAY[1.0E0, null])", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])")) + .isNull(DOUBLE); + } + @Test public void testInverseNormalCdf() { diff --git a/docs/src/main/sphinx/functions/math.md b/docs/src/main/sphinx/functions/math.md index 58823823e668f..762f49d183f04 100644 --- a/docs/src/main/sphinx/functions/math.md +++ b/docs/src/main/sphinx/functions/math.md @@ -205,6 +205,16 @@ SELECT cosine_distance(ARRAY[1.0, 2.0], ARRAY[3.0, 4.0]); ``` ::: +:::{function} cosine_distance(x, y) -> double +:no-index: +Calculates the cosine distance between two sparse vectors: + +```sql +SELECT cosine_distance(MAP(ARRAY['a'], ARRAY[1.0]), MAP(ARRAY['a'], ARRAY[2.0])); +-- 0.0 +``` +::: + :::{function} cosine_similarity(array(double), array(double)) -> double Calculates the cosine similarity of two dense vectors: