Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILO-HE committed Sep 19, 2023
1 parent 017f1b2 commit 9bb46c7
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 6 deletions.
24 changes: 21 additions & 3 deletions velox/docs/functions/spark/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,29 @@ Mathematical Functions

.. spark:function:: rand() -> double
Returns a random value with independent and identically distributed uniformly distributed values in [0, 1). ::
Returns a random value with uniformly distributed values in [0, 1). ::

SELECT rand(); -- 0.9629742951434543
SELECT rand(0); -- 0.7604953758285915
SELECT rand(null); -- 0.7604953758285915

.. spark:function:: rand(seed, partitionIndex) -> double
Returns a random value with uniformly distributed values in [0, 1) using a seed formed
by combining user-specified ``seed`` and framework provided ``partitionIndex``. The
framework is responsible for deterministic partitioning of the data and assigning unique
``partitionIndex`` to each thread (in a deterministic way).
``seed`` must be constant. NULL ``seed`` is identical to zero ``seed``. ``partitionIndex``
cannot be NULL. ::

SELECT rand(0); -- 0.5488135024422883
SELECT rand(NULL); -- 0.5488135024422883

.. spark:function:: random() -> double
An alias for ``rand()``.

.. spark:function:: random(seed, partitionIndex) -> double
An alias for ``rand(seed, partitionIndex)``.

.. spark:function:: remainder(n, m) -> [same as n]
Expand Down
85 changes: 85 additions & 0 deletions velox/functions/sparksql/Rand.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "velox/functions/Macros.h"

namespace facebook::velox::functions::sparksql {

template <typename T>
struct RandFunction {
static constexpr bool is_deterministic = false;

std::optional<std::mt19937> generator;

FOLLY_ALWAYS_INLINE void call(double& result) {
result = folly::Random::randDouble01();
}

FOLLY_ALWAYS_INLINE void callNullable(
double& result,
const int32_t* seed,
const int32_t* partitionIndex) {
VELOX_USER_CHECK_NOT_NULL(partitionIndex, "partitionIndex cannot be null.");
if (!generator.has_value()) {
generator = std::mt19937{};
if (seed) {
generator->seed((uint64_t)*seed + *partitionIndex);
} else {
// For null input, 0 plus partitionIndex is the seed, consistent with
// Spark.
generator->seed(*partitionIndex);
}
}
result = folly::Random::randDouble01(*generator);
}

/// To differentiate generator for each thread, seed plus partitionIndex is
/// the actual seed used for generator.
FOLLY_ALWAYS_INLINE void callNullable(
double& result,
const int64_t* seed,
const int32_t* partitionIndex) {
VELOX_USER_CHECK_NOT_NULL(partitionIndex, "partitionIndex cannot be null.");
if (!generator.has_value()) {
generator = std::mt19937{};
if (seed) {
generator->seed((uint64_t)*seed + *partitionIndex);
} else {
// For null input, 0 plus partitionIndex is the seed, consistent with
// Spark.
generator->seed(*partitionIndex);
}
}
result = folly::Random::randDouble01(*generator);
}

// For NULL constant input of unknown type.
FOLLY_ALWAYS_INLINE void callNullable(
double& result,
const UnknownValue* seed,
const int32_t* partitionIndex) {
VELOX_USER_CHECK_NOT_NULL(partitionIndex, "partitionIndex cannot be null.");
if (!generator.has_value()) {
generator = std::mt19937{};
// For null input, 0 plus partitionIndex is the seed, consistent with
// Spark.
generator->seed(*partitionIndex);
}
result = folly::Random::randDouble01(*generator);
}
};
} // namespace facebook::velox::functions::sparksql
3 changes: 0 additions & 3 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "velox/functions/lib/Re2Functions.h"
#include "velox/functions/lib/RegistrationHelpers.h"
#include "velox/functions/prestosql/JsonFunctions.h"
#include "velox/functions/prestosql/Rand.h"
#include "velox/functions/prestosql/StringFunctions.h"
#include "velox/functions/sparksql/Arithmetic.h"
#include "velox/functions/sparksql/ArraySort.h"
Expand Down Expand Up @@ -73,8 +72,6 @@ static void workAroundRegistrationMacro(const std::string& prefix) {
namespace sparksql {

void registerFunctions(const std::string& prefix) {
registerFunction<RandFunction, double>({prefix + "rand"});

// Register size functions
registerSize(prefix + "size");

Expand Down
24 changes: 24 additions & 0 deletions velox/functions/sparksql/RegisterArithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,32 @@
#include "velox/functions/prestosql/Arithmetic.h"
#include "velox/functions/prestosql/CheckedArithmetic.h"
#include "velox/functions/sparksql/Arithmetic.h"
#include "velox/functions/sparksql/Rand.h"

namespace facebook::velox::functions::sparksql {

void registerRandFunctions(const std::string& prefix) {
registerFunction<RandFunction, double>({prefix + "rand", prefix + "random"});
// Has seed & partition index as input.
registerFunction<
RandFunction,
double,
int32_t /*seed*/,
int32_t /*partition index*/>({prefix + "rand", prefix + "random"});
// Has seed & partition index as input.
registerFunction<
RandFunction,
double,
int64_t /*seed*/,
int32_t /*partition index*/>({prefix + "rand", prefix + "random"});
// NULL constant as seed of unknown type.
registerFunction<
RandFunction,
double,
UnknownValue /*seed*/,
int32_t /*partition index*/>({prefix + "rand", prefix + "random"});
}

void registerArithmeticFunctions(const std::string& prefix) {
// Operators.
registerBinaryNumeric<PlusFunction>({prefix + "add"});
Expand Down Expand Up @@ -63,6 +86,7 @@ void registerArithmeticFunctions(const std::string& prefix) {
VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_sub, prefix + "decimal_subtract");
VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_mul, prefix + "decimal_multiply");
VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_div, prefix + "decimal_divide");
registerRandFunctions(prefix);
}

} // namespace facebook::velox::functions::sparksql
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ add_executable(
LeastGreatestTest.cpp
MapTest.cpp
MightContainTest.cpp
RandTest.cpp
RegexFunctionsTest.cpp
SizeTest.cpp
SortArrayTest.cpp
Expand Down
119 changes: 119 additions & 0 deletions velox/functions/sparksql/tests/RandTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h"

namespace facebook::velox::functions::sparksql::test {
namespace {

class RandTest : public SparkFunctionBaseTest {
public:
RandTest() {
// Allow for parsing literal integers as INTEGER, not BIGINT.
options_.parseIntegerAsBigint = false;
}

protected:
std::optional<double> rand(int32_t seed, int32_t partitionIndex = 0) {
return evaluateOnce<double>(
fmt::format("rand({}, {})", seed, partitionIndex),
makeRowVector(ROW({}), 1));
}

std::optional<double> randWithNullSeed(int32_t partitionIndex = 0) {
return evaluateOnce<double>(
fmt::format("rand(NULL, {})", partitionIndex),
makeRowVector(ROW({}), 1));
}

std::optional<double> randWithNoSeed() {
return evaluateOnce<double>("rand()", makeRowVector(ROW({}), 1));
}

VectorPtr randWithBatchInput(int32_t seed, int32_t partitionIndex = 0) {
auto exprSet = compileExpression(
fmt::format("rand({}, {})", seed, partitionIndex), ROW({}));
return evaluate(*exprSet, makeRowVector(ROW({}), 20));
}

void checkResult(const std::optional<double>& result) {
EXPECT_NE(result, std::nullopt);
EXPECT_GE(result.value(), 0.0);
EXPECT_LT(result.value(), 1.0);
}

// Check whether two vectors that have same size & type, but not all same
// values.
void assertNotEqualVectors(const VectorPtr& left, const VectorPtr& right) {
ASSERT_EQ(left->size(), right->size());
ASSERT_TRUE(left->type()->equivalent(*right->type()));
for (auto i = 0; i < left->size(); i++) {
if (!left->equalValueAt(right.get(), i, i)) {
return;
}
}
FAIL() << "Expect two different vectors are produced.";
}
};

TEST_F(RandTest, withSeed) {
checkResult(rand(0));
// With same default partitionIndex used, same seed always produces same
// result.
EXPECT_EQ(rand(0), rand(0));

checkResult(rand(1));
EXPECT_EQ(rand(1), rand(1));

checkResult(rand(20000));
EXPECT_EQ(rand(20000), rand(20000));

// Test with same seed, but different partitionIndex.
EXPECT_NE(rand(0, 0), rand(0, 1));
EXPECT_NE(rand(1000, 0), rand(1000, 1));

checkResult(randWithNullSeed());
// Null as seed is identical to 0 as seed.
EXPECT_EQ(randWithNullSeed(), rand(0));
// Same null as seed but different partition index.
EXPECT_NE(randWithNullSeed(0), randWithNullSeed(1));

// Test with batch input.
auto batchResult1 = randWithBatchInput(100);
auto batchResult2 = randWithBatchInput(100);
// Same seed & partition index produce same results.
velox::test::assertEqualVectors(batchResult1, batchResult2);
batchResult1 = randWithBatchInput(100, 0 /*partitionIndex*/);
batchResult2 = randWithBatchInput(100, 1 /*partitionIndex*/);
// Same seed but different partition index cannot produce absolutely same
// result.
assertNotEqualVectors(batchResult1, batchResult2);
}

TEST_F(RandTest, withoutSeed) {
auto result1 = randWithNoSeed();
auto result2 = randWithNoSeed();
auto result3 = randWithNoSeed();
checkResult(result1);
checkResult(result2);
checkResult(result3);
// It is impossible to get three same results by three separate callings.
EXPECT_FALSE(
(result1.value() == result2.value()) &&
(result1.value() == result3.value()));
}

} // namespace
} // namespace facebook::velox::functions::sparksql::test

0 comments on commit 9bb46c7

Please sign in to comment.