Skip to content

Commit

Permalink
initial version with PinotJsonContainsPredicate
Browse files Browse the repository at this point in the history
resolved style and deferred dependency issue
  • Loading branch information
robertzych committed Oct 31, 2024
1 parent 7006352 commit 5fcc0fd
Show file tree
Hide file tree
Showing 13 changed files with 331 additions and 40 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.pinot;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.airlift.slice.Slice;
import io.trino.spi.block.IntArrayBlock;
import io.trino.spi.block.VariableWidthBlock;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Constant;
import io.trino.spi.expression.FunctionName;
import io.trino.spi.expression.Variable;

import java.util.ArrayList;
import java.util.List;

import static com.google.common.base.MoreObjects.toStringHelper;
import static java.util.Objects.requireNonNull;

public class PinotJsonContainsPredicate
implements PinotJsonPredicate
{
private final String columnName;
private final String jsonPath;
private final List<String> values;
private final boolean valuesContainsStrings;
private final String type;

@JsonCreator
public PinotJsonContainsPredicate(
@JsonProperty("columnName") String columnName,
@JsonProperty("jsonPath") String jsonPath,
@JsonProperty("values") List<String> values,
@JsonProperty("valuesContainsStrings") boolean valuesContainsStrings,
@JsonProperty("type") String type)
{
this.columnName = requireNonNull(columnName, "columnName is null");
this.jsonPath = requireNonNull(jsonPath, "jsonPath is null");
this.values = requireNonNull(values, "values is null");
this.valuesContainsStrings = valuesContainsStrings;
this.type = "contains";
}

@JsonProperty
public String getColumnName()
{
return columnName;
}

@JsonProperty
public String getJsonPath()
{
return jsonPath;
}

@JsonProperty
public List<String> getValues()
{
return values;
}

@JsonProperty
public boolean getValuesContainsStrings()
{
return valuesContainsStrings;
}

@JsonProperty
public String getType()
{
return type;
}

public PinotJsonContainsPredicate(Call call)
{
List<ConnectorExpression> containsCallArgs = call.getArguments();
Constant arrayArg = (Constant) containsCallArgs.getFirst();
values = new ArrayList<>();
if (arrayArg.getValue() instanceof VariableWidthBlock stringArray) {
valuesContainsStrings = true;
for (int index = 0; index < stringArray.getPositionCount(); index++) {
values.add(stringArray.getSlice(index).toStringUtf8());
}
}
else if (arrayArg.getValue() instanceof IntArrayBlock intArray) {
valuesContainsStrings = false;
for (int index = 0; index < intArray.getPositionCount(); index++) {
values.add(String.valueOf(intArray.getInt(index)));
}
}
else {
throw new IllegalArgumentException("Unsupported array argument type: " + arrayArg.getValue());
}

Call innerCall = (Call) containsCallArgs.get(1);
Call jsonExtractScalarCall = innerCall;
if (new FunctionName("$cast").equals(innerCall.getFunctionName())) {
jsonExtractScalarCall = (Call) innerCall.getArguments().getFirst();
}

List<ConnectorExpression> args = jsonExtractScalarCall.getArguments();
columnName = ((Variable) args.get(0)).getName();
jsonPath = ((Slice) ((Constant) args.get(1)).getValue()).toStringUtf8();
type = "contains";
}

public static boolean supportsCall(Call call)
{
if (!new FunctionName("contains").equals(call.getFunctionName())) {
return false;
}

List<ConnectorExpression> arguments = call.getArguments();
ConnectorExpression arrayArg = arguments.get(0);
if (!(arrayArg instanceof Constant) || !(arguments.get(1) instanceof Call innerCall)) {
return false;
}

Constant constant = (Constant) arrayArg;
if (!(constant.getValue() instanceof VariableWidthBlock || constant.getValue() instanceof IntArrayBlock)) {
return false;
}

if (new FunctionName("$cast").equals(innerCall.getFunctionName())) {
List<ConnectorExpression> castArguments = innerCall.getArguments();
if (!(castArguments.getFirst() instanceof Call jsonExtracatScalarCall)) {
return false;
}
return isSupportedJsonExtractScalarCall(jsonExtracatScalarCall);
}
else {
return isSupportedJsonExtractScalarCall(innerCall);
}
}

private static boolean isSupportedJsonExtractScalarCall(Call call)
{
if (!new FunctionName("json_extract_scalar").equals(call.getFunctionName())) {
return false;
}

List<ConnectorExpression> arguments = call.getArguments();
if (!(arguments.get(0) instanceof Variable) || !(arguments.get(1) instanceof Constant)) {
return false;
}

// TODO: resolve dependency issues to allow usage of io.trino.type.JsonType
// return arguments.get(0).getType() instanceof JsonType;

return true;
}

@Override
public String toPQL()
{
String escape = valuesContainsStrings ? "''" : "";
String values = String.join(String.format("%s,%s", escape, escape), this.values);
return String.format("JSON_MATCH(%s, '\"%s\" in (%s%s%s)')",
columnName, jsonPath, escape, values, escape);
}

@Override
public String toString()
{
return toStringHelper(this)
.add("columnName", columnName)
.add("jsonPath", jsonPath)
.add("values", values)
.add("valuesContainsStrings", valuesContainsStrings)
.add("type", type)
.toString();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.pinot;

import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import io.trino.spi.expression.Call;

@JsonTypeInfo(
use = JsonTypeInfo.Id.NAME,
property = "type")
@JsonSubTypes({
@JsonSubTypes.Type(value = PinotJsonContainsPredicate.class, name = "contains")
})
public interface PinotJsonPredicate
{
static boolean supportsCall(Call call)
{
return false;
}

String toPQL();
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.trino.plugin.base.aggregation.AggregateFunctionRewriter;
import io.trino.plugin.base.aggregation.AggregateFunctionRule;
import io.trino.plugin.base.expression.ConnectorExpressionRewriter;
import io.trino.plugin.base.expression.ConnectorExpressions;
import io.trino.plugin.pinot.client.PinotClient;
import io.trino.plugin.pinot.query.AggregateExpression;
import io.trino.plugin.pinot.query.DynamicTable;
Expand Down Expand Up @@ -51,6 +52,7 @@
import io.trino.spi.connector.RelationColumnsMetadata;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.connector.TableNotFoundException;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Variable;
import io.trino.spi.predicate.Domain;
Expand All @@ -60,6 +62,7 @@
import io.trino.spi.type.Type;
import org.apache.pinot.spi.data.Schema;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
Expand Down Expand Up @@ -162,7 +165,7 @@ public PinotTableHandle getTableHandle(ConnectorSession session, SchemaTableName

if (tableName.getTableName().trim().contains("select ")) {
DynamicTable dynamicTable = DynamicTableBuilder.buildFromPql(this, tableName, pinotClient, typeConverter);
return new PinotTableHandle(tableName.getSchemaName(), dynamicTable.tableName(), false, TupleDomain.all(), OptionalLong.empty(), Optional.of(dynamicTable));
return new PinotTableHandle(tableName.getSchemaName(), dynamicTable.tableName(), false, TupleDomain.all(), OptionalLong.empty(), Optional.of(dynamicTable), List.of());
}
String pinotTableName = pinotClient.getPinotTableNameFromTrinoTableNameIfExists(tableName.getTableName());
if (pinotTableName == null) {
Expand All @@ -174,7 +177,8 @@ public PinotTableHandle getTableHandle(ConnectorSession session, SchemaTableName
getFromCache(pinotTableSchemaCache, pinotTableName).isEnableColumnBasedNullHandling(),
TupleDomain.all(),
OptionalLong.empty(),
Optional.empty());
Optional.empty(),
List.of());
}

@Override
Expand Down Expand Up @@ -288,7 +292,8 @@ public Optional<LimitApplicationResult<ConnectorTableHandle>> applyLimit(Connect
handle.enableNullHandling(),
handle.constraint(),
OptionalLong.of(limit),
dynamicTable);
dynamicTable,
List.of());
boolean singleSplit = dynamicTable.isPresent();
return Optional.of(new LimitApplicationResult<>(handle, singleSplit, false));
}
Expand Down Expand Up @@ -330,7 +335,40 @@ else if (isFilterPushdownUnsupported(entry.getValue())) {
remainingFilter = TupleDomain.withColumnDomains(unsupported);
}

if (oldDomain.equals(newDomain)) {
ConnectorExpression expression = constraint.getExpression();
List<PinotJsonPredicate> jsonPredicates = new ArrayList<>();
List<ConnectorExpression> notHandledExpressions = new ArrayList<>();
if (expression instanceof Call call) {
if (call.getFunctionName().getName().equals("$and")) {
List<ConnectorExpression> innerExpressions = ConnectorExpressions.extractConjuncts(constraint.getExpression());
for (ConnectorExpression innerExpression : innerExpressions) {
if (innerExpression instanceof Call innerCall) {
Optional<PinotJsonPredicate> jsonPredicate = getJsonPredicate(innerCall);
if (jsonPredicate.isPresent()) {
jsonPredicates.add(jsonPredicate.get());
}
else {
notHandledExpressions.add(innerExpression);
}
}
}
}
else {
Optional<PinotJsonPredicate> jsonPredicate = getJsonPredicate(call);
if (jsonPredicate.isPresent()) {
jsonPredicates.add(jsonPredicate.get());
}
else {
notHandledExpressions.add(expression);
}
}
}
else {
notHandledExpressions.add(expression);
}
ConnectorExpression newExpression = ConnectorExpressions.and(notHandledExpressions);

if (oldDomain.equals(newDomain) && expression.equals(newExpression)) {
return Optional.empty();
}

Expand All @@ -340,8 +378,23 @@ else if (isFilterPushdownUnsupported(entry.getValue())) {
handle.enableNullHandling(),
newDomain,
handle.limit(),
handle.query());
return Optional.of(new ConstraintApplicationResult<>(handle, remainingFilter, constraint.getExpression(), false));
handle.query(),
jsonPredicates);

return Optional.of(new ConstraintApplicationResult<>(handle, remainingFilter, newExpression, false));
}

private static PinotJsonContainsPredicate toJsonPredicate(Call call)
{
return new PinotJsonContainsPredicate(call);
}

private Optional<PinotJsonPredicate> getJsonPredicate(Call call)
{
if (PinotJsonContainsPredicate.supportsCall(call)) {
return Optional.of(new PinotJsonContainsPredicate(call));
}
return Optional.empty();
}

// IS NULL and IS NOT NULL are handled differently in Pinot, pushing down would lead to inconsistent results.
Expand Down Expand Up @@ -472,7 +525,8 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
tableHandle.enableNullHandling(),
tableHandle.constraint(),
tableHandle.limit(),
Optional.of(dynamicTable));
Optional.of(dynamicTable),
tableHandle.jsonPredicates());

return Optional.of(new AggregationApplicationResult<>(tableHandle, projections.build(), resultAssignments.build(), ImmutableMap.of(), false));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public ConnectorPageSource createPageSource(
if (pinotTableHandle.query().isPresent()) {
DynamicTable dynamicTable = pinotTableHandle.query().get();
pinotQueryInfo = new PinotQueryInfo(dynamicTable.tableName(),
extractPql(dynamicTable, pinotTableHandle.constraint()),
extractPql(dynamicTable, pinotTableHandle.constraint(), pinotTableHandle.jsonPredicates()),
dynamicTable.groupingColumns().size());
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.trino.spi.connector.ConnectorTableHandle;
import io.trino.spi.predicate.TupleDomain;

import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalLong;
Expand All @@ -30,7 +31,8 @@ public record PinotTableHandle(
boolean enableNullHandling,
TupleDomain<ColumnHandle> constraint,
OptionalLong limit,
Optional<DynamicTable> query)
Optional<DynamicTable> query,
List<PinotJsonPredicate> jsonPredicates)
implements ConnectorTableHandle
{
public PinotTableHandle
Expand All @@ -40,6 +42,7 @@ public record PinotTableHandle(
requireNonNull(constraint, "constraint is null");
requireNonNull(limit, "limit is null");
requireNonNull(query, "query is null");
requireNonNull(jsonPredicates, "jsonPredicates is null");
}

@Override
Expand Down
Loading

0 comments on commit 5fcc0fd

Please sign in to comment.