Skip to content

Commit

Permalink
Support UPDATE and MERGE for Mongo connector
Browse files Browse the repository at this point in the history
  • Loading branch information
chenjian2664 committed Nov 4, 2024
1 parent b58e505 commit 1a8bb9e
Show file tree
Hide file tree
Showing 9 changed files with 392 additions and 56 deletions.
2 changes: 2 additions & 0 deletions docs/src/main/sphinx/connector/mongodb.md
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,8 @@ statements, the connector supports the following features:

- {doc}`/sql/insert`
- {doc}`/sql/delete`
- {doc}`/sql/update`
- {doc}`/sql/merge`
- {doc}`/sql/create-table`
- {doc}`/sql/create-table-as`
- {doc}`/sql/drop-table`
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.mongodb;

import com.google.common.collect.ImmutableList;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.model.BulkWriteOptions;
import com.mongodb.client.model.UpdateOneModel;
import com.mongodb.client.model.WriteModel;
import io.airlift.slice.Slice;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.connector.ConnectorMergeSink;
import io.trino.spi.connector.ConnectorPageSink;
import io.trino.spi.connector.ConnectorPageSinkId;
import org.bson.Document;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.mongodb.client.model.Filters.in;
import static io.trino.plugin.mongodb.MongoMetadata.MERGE_ROW_ID_BASE_NAME;
import static io.trino.spi.type.TinyintType.TINYINT;
import static java.util.Objects.requireNonNull;

public class MongoMergeSink
implements ConnectorMergeSink
{
private static final String SET = "$set";

private final int columnCount;
private final ConnectorPageSink insertSink;
private final ConnectorPageSink updateSink;
private final ConnectorPageSink deleteSink;

public MongoMergeSink(MongoSession mongoSession, RemoteTableName remoteTableName, List<MongoColumnHandle> columns, MongoColumnHandle mergeColumnHandle, String implicitPrefix, Optional<String> pageSinkIdColumnName, ConnectorPageSinkId pageSinkId)
{
requireNonNull(mongoSession, "mongoSession is null");
requireNonNull(remoteTableName, "remoteTableName is null");
requireNonNull(columns, "columns is null");
requireNonNull(pageSinkId, "pageSinkId is null");
requireNonNull(mergeColumnHandle, "mergeColumnHandle is null");

this.columnCount = columns.size();

this.insertSink = new MongoPageSink(mongoSession, remoteTableName, columns, implicitPrefix, pageSinkIdColumnName, pageSinkId);
this.deleteSink = new MongoDeleteSink(mongoSession, remoteTableName, ImmutableList.of(mergeColumnHandle), implicitPrefix, pageSinkIdColumnName, pageSinkId);

// Update should always include id column explicitly
List<MongoColumnHandle> updateColumns = ImmutableList.<MongoColumnHandle>builderWithExpectedSize(columns.size() + 1)
.addAll(columns)
.add(mergeColumnHandle)
.build();
this.updateSink = new MongoUpdateSink(mongoSession, remoteTableName, updateColumns, implicitPrefix, pageSinkIdColumnName, pageSinkId);
}

@Override
public void storeMergedRows(Page page)
{
checkArgument(page.getChannelCount() == 2 + columnCount, "The page size should be 2 + columnCount (%s), but is %s", columnCount, page.getChannelCount());
int positionCount = page.getPositionCount();
Block operationBlock = page.getBlock(columnCount);
Block rowId = page.getBlock(columnCount + 1);

int[] dataChannel = IntStream.range(0, columnCount + 1).toArray();
dataChannel[columnCount] = columnCount + 1;
Page dataPage = page.getColumns(dataChannel);

int[] insertPositions = new int[positionCount];
int insertPositionCount = 0;
int[] deletePositions = new int[positionCount];
int deletePositionCount = 0;
int[] updatePositions = new int[positionCount];
int updatePositionCount = 0;

for (int position = 0; position < positionCount; position++) {
int operation = TINYINT.getByte(operationBlock, position);
switch (operation) {
case INSERT_OPERATION_NUMBER -> {
insertPositions[insertPositionCount] = position;
insertPositionCount++;
}
case DELETE_OPERATION_NUMBER -> {
deletePositions[deletePositionCount] = position;
deletePositionCount++;
}
case UPDATE_OPERATION_NUMBER -> {
updatePositions[updatePositionCount] = position;
updatePositionCount++;
}
default -> throw new IllegalStateException("Unexpected value: " + operation);
}
}

if (deletePositionCount > 0) {
Block positions = rowId.getPositions(deletePositions, 0, deletePositionCount);
deleteSink.appendPage(new Page(deletePositionCount, positions));
}

if (updatePositionCount > 0) {
updateSink.appendPage(dataPage.getPositions(updatePositions, 0, updatePositionCount));
}

if (insertPositionCount > 0) {
// Insert page should not include _id column by default, unless the insert columns include it explicitly
Page insertPage = dataPage.getColumns(Arrays.copyOf(dataChannel, columnCount));
insertSink.appendPage(insertPage.getPositions(insertPositions, 0, insertPositionCount));
}
}

@Override
public CompletableFuture<Collection<Slice>> finish()
{
return insertSink.finish();
}

private static class MongoUpdateSink
implements ConnectorPageSink
{
private final MongoPageSink delegate;
private final MongoSession mongoSession;
private final RemoteTableName remoteTableName;

public MongoUpdateSink(
MongoSession mongoSession,
RemoteTableName remoteTableName,
List<MongoColumnHandle> columns,
String implicitPrefix,
Optional<String> pageSinkIdColumnName,
ConnectorPageSinkId pageSinkId)
{
this.delegate = new MongoPageSink(mongoSession, remoteTableName, columns, implicitPrefix, pageSinkIdColumnName, pageSinkId);
this.mongoSession = requireNonNull(mongoSession, "mongoSession is null");
this.remoteTableName = requireNonNull(remoteTableName, "remoteTableName is null");
}

@Override
public CompletableFuture<?> appendPage(Page page)
{
MongoCollection<Document> collection = mongoSession.getCollection(remoteTableName);
List<WriteModel<Document>> bulkWrites = new ArrayList<>();
for (Document document : delegate.buildBatchDocumentsFromPage(page)) {
Document filter = new Document(MERGE_ROW_ID_BASE_NAME, document.get(MERGE_ROW_ID_BASE_NAME));
bulkWrites.add(new UpdateOneModel<>(filter, new Document(SET, document)));
}
collection.bulkWrite(bulkWrites, new BulkWriteOptions().ordered(false));
return NOT_BLOCKED;
}

@Override
public CompletableFuture<Collection<Slice>> finish()
{
return delegate.finish();
}

@Override
public void abort()
{
delegate.abort();
}
}

private static class MongoDeleteSink
implements ConnectorPageSink
{
private final MongoPageSink delegate;
private final MongoSession mongoSession;
private final RemoteTableName remoteTableName;

public MongoDeleteSink(
MongoSession mongoSession,
RemoteTableName remoteTableName,
List<MongoColumnHandle> columns,
String implicitPrefix,
Optional<String> pageSinkIdColumnName,
ConnectorPageSinkId pageSinkId)
{
this.delegate = new MongoPageSink(mongoSession, remoteTableName, columns, implicitPrefix, pageSinkIdColumnName, pageSinkId);
this.mongoSession = requireNonNull(mongoSession, "mongoSession is null");
this.remoteTableName = requireNonNull(remoteTableName, "remoteTableName is null");
}

@Override
public CompletableFuture<?> appendPage(Page page)
{
MongoCollection<Document> collection = mongoSession.getCollection(remoteTableName);
ImmutableList.Builder<Object> idsToDeleteBuilder = ImmutableList.builder();
for (Document document : delegate.buildBatchDocumentsFromPage(page)) {
idsToDeleteBuilder.add(document.get(MERGE_ROW_ID_BASE_NAME));
}
collection.deleteMany(in(MERGE_ROW_ID_BASE_NAME, idsToDeleteBuilder.build()));
return NOT_BLOCKED;
}

@Override
public CompletableFuture<Collection<Slice>> finish()
{
return delegate.finish();
}

@Override
public void abort()
{
delegate.abort();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.mongodb;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorMergeTableHandle;
import io.trino.spi.connector.ConnectorTableHandle;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.predicate.TupleDomain;

import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;

import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;

public record MongoMergeTableHandle(
RemoteTableName remoteTableName,
List<MongoColumnHandle> columns,
MongoColumnHandle mergeRowIdColumn,
Optional<String> filter,
TupleDomain<ColumnHandle> constraint,
Optional<String> temporaryTableName,
Optional<String> pageSinkIdColumnName)
implements ConnectorMergeTableHandle
{
public MongoMergeTableHandle
{
requireNonNull(remoteTableName, "remoteTableName is null");
columns = ImmutableList.copyOf(requireNonNull(columns, "columns is null"));
requireNonNull(filter, "filter is null");
requireNonNull(constraint, "constraint is null");
requireNonNull(temporaryTableName, "temporaryTableName is null");
requireNonNull(pageSinkIdColumnName, "pageSinkIdColumnName is null");
checkArgument(temporaryTableName.isPresent() == pageSinkIdColumnName.isPresent(),
"temporaryTableName.isPresent is not equal to pageSinkIdColumnName.isPresent");
}

@JsonIgnore
public Optional<RemoteTableName> getTemporaryRemoteTableName()
{
return temporaryTableName.map(tableName -> new RemoteTableName(remoteTableName.databaseName(), tableName));
}

@Override
public ConnectorTableHandle getTableHandle()
{
return new MongoTableHandle(
new SchemaTableName(remoteTableName.databaseName(), remoteTableName.collectionName()),
remoteTableName,
filter,
constraint,
ImmutableSet.of(),
OptionalInt.empty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.connector.ConnectorInsertTableHandle;
import io.trino.spi.connector.ConnectorMergeTableHandle;
import io.trino.spi.connector.ConnectorMetadata;
import io.trino.spi.connector.ConnectorOutputMetadata;
import io.trino.spi.connector.ConnectorOutputTableHandle;
Expand All @@ -46,6 +47,7 @@
import io.trino.spi.connector.ProjectionApplicationResult;
import io.trino.spi.connector.RelationColumnsMetadata;
import io.trino.spi.connector.RetryMode;
import io.trino.spi.connector.RowChangeParadigm;
import io.trino.spi.connector.SaveMode;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.connector.SchemaTablePrefix;
Expand Down Expand Up @@ -107,6 +109,7 @@
import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.connector.RelationColumnsMetadata.forTable;
import static io.trino.spi.connector.RowChangeParadigm.CHANGE_ONLY_UPDATED_COLUMNS;
import static io.trino.spi.connector.SaveMode.REPLACE;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.DoubleType.DOUBLE;
Expand All @@ -126,6 +129,8 @@
public class MongoMetadata
implements ConnectorMetadata
{
public static final String MERGE_ROW_ID_BASE_NAME = "_id";

private static final Logger log = Logger.get(MongoMetadata.class);
private static final Type TRINO_PAGE_SINK_ID_COLUMN_TYPE = BigintType.BIGINT;

Expand Down Expand Up @@ -544,10 +549,46 @@ private void finishInsert(
}
}

@Override
public RowChangeParadigm getRowChangeParadigm(ConnectorSession session, ConnectorTableHandle tableHandle)
{
return CHANGE_ONLY_UPDATED_COLUMNS;
}

@Override
public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, RetryMode retryMode)
{
MongoTableHandle table = (MongoTableHandle) tableHandle;
MongoInsertTableHandle insertTableHandle = (MongoInsertTableHandle) beginInsert(session, tableHandle, ImmutableList.of(), retryMode);
return new MongoMergeTableHandle(
insertTableHandle.remoteTableName(),
insertTableHandle.columns(),
(MongoColumnHandle) getMergeRowIdColumnHandle(session, tableHandle),
table.filter(),
table.constraint(),
insertTableHandle.temporaryTableName(),
insertTableHandle.pageSinkIdColumnName());
}

@Override
public void finishMerge(ConnectorSession session,
ConnectorMergeTableHandle mergeTableHandle,
List<ConnectorTableHandle> sourceTableHandles,
Collection<Slice> fragments,
Collection<ComputedStatistics> computedStatistics)
{
MongoMergeTableHandle tableHandle = (MongoMergeTableHandle) mergeTableHandle;
MongoInsertTableHandle insertTableHandle = new MongoInsertTableHandle(tableHandle.remoteTableName(), ImmutableList.copyOf(tableHandle.columns()), tableHandle.temporaryTableName(), tableHandle.pageSinkIdColumnName());
finishInsert(session, insertTableHandle, ImmutableList.of(), fragments, computedStatistics);
}

@Override
public ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle)
{
return new MongoColumnHandle("$merge_row_id", ImmutableList.of(), BIGINT, true, false, Optional.empty());
Map<String, ColumnHandle> columnHandles = getColumnHandles(session, tableHandle);
checkState(columnHandles.containsKey(MERGE_ROW_ID_BASE_NAME), "id column %s not exists", MERGE_ROW_ID_BASE_NAME);
Type idColumnType = ((MongoColumnHandle) columnHandles.get(MERGE_ROW_ID_BASE_NAME)).type();
return new MongoColumnHandle(MERGE_ROW_ID_BASE_NAME, ImmutableList.of(), idColumnType, true, false, Optional.empty());
}

@Override
Expand Down
Loading

0 comments on commit 1a8bb9e

Please sign in to comment.