diff --git a/docs/src/main/sphinx/connector/mongodb.md b/docs/src/main/sphinx/connector/mongodb.md index ebdd4e219991a..83306597e2288 100644 --- a/docs/src/main/sphinx/connector/mongodb.md +++ b/docs/src/main/sphinx/connector/mongodb.md @@ -476,6 +476,8 @@ statements, the connector supports the following features: - {doc}`/sql/insert` - {doc}`/sql/delete` +- {doc}`/sql/update` +- {doc}`/sql/merge` - {doc}`/sql/create-table` - {doc}`/sql/create-table-as` - {doc}`/sql/drop-table` diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeSink.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeSink.java new file mode 100644 index 0000000000000..7148e502d6bc0 --- /dev/null +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeSink.java @@ -0,0 +1,283 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.mongodb; + +import com.google.common.collect.ImmutableList; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.BulkWriteOptions; +import com.mongodb.client.model.UpdateOneModel; +import com.mongodb.client.model.WriteModel; +import io.airlift.slice.Slice; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.connector.ConnectorMergeSink; +import io.trino.spi.connector.ConnectorPageSink; +import io.trino.spi.connector.ConnectorPageSinkId; +import org.bson.Document; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.stream.IntStream; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.mongodb.client.model.Filters.in; +import static io.trino.plugin.mongodb.MongoMetadata.MERGE_ROW_ID_BASE_NAME; +import static io.trino.spi.type.TinyintType.TINYINT; +import static java.util.Objects.requireNonNull; + +public class MongoMergeSink + implements ConnectorMergeSink +{ + private static final String SET = "$set"; + + private final int columnCount; + private final ConnectorPageSink insertSink; + private final ConnectorPageSink updateSink; + private final ConnectorPageSink deleteSink; + + public MongoMergeSink( + MongoSession mongoSession, + RemoteTableName remoteTableName, + List columns, + MongoColumnHandle mergeColumnHandle, + Optional deleteOutputTableHandle, + Optional updateOutputTableHandle, + String implicitPrefix, + Optional pageSinkIdColumnName, + ConnectorPageSinkId pageSinkId) + { + requireNonNull(mongoSession, "mongoSession is null"); + requireNonNull(remoteTableName, "remoteTableName is null"); + requireNonNull(columns, "columns is null"); + requireNonNull(pageSinkId, "pageSinkId is null"); + requireNonNull(mergeColumnHandle, "mergeColumnHandle is null"); + + this.columnCount = columns.size(); + + this.insertSink = new MongoPageSink(mongoSession, remoteTableName, columns, implicitPrefix, pageSinkIdColumnName, pageSinkId); + this.deleteSink = createDeleteSink(mongoSession, deleteOutputTableHandle, remoteTableName, mergeColumnHandle, implicitPrefix, pageSinkIdColumnName, pageSinkId); + + // Update should always include id column explicitly + List updateColumns = ImmutableList.builderWithExpectedSize(columns.size() + 1) + .addAll(columns) + .add(mergeColumnHandle) + .build(); + this.updateSink = createUpdateSink(mongoSession, updateOutputTableHandle, updateColumns, remoteTableName, implicitPrefix, pageSinkIdColumnName, pageSinkId); + } + + private static ConnectorPageSink createDeleteSink( + MongoSession mongoSession, + Optional outputTableHandle, + RemoteTableName remoteTableName, + MongoColumnHandle mergeColumnHandle, + String implicitPrefix, + Optional pageSinkIdColumnName, + ConnectorPageSinkId pageSinkId) + { + if (outputTableHandle.isEmpty()) { + return new MongoDeleteSink(mongoSession, remoteTableName, ImmutableList.of(mergeColumnHandle), implicitPrefix, pageSinkIdColumnName, pageSinkId); + } + + MongoOutputTableHandle mongoOutputTableHandle = outputTableHandle.get(); + checkState(mongoOutputTableHandle.getTemporaryRemoteTableName().isPresent(), "temporary table not exist"); + + return new MongoPageSink(mongoSession, + mongoOutputTableHandle.getTemporaryRemoteTableName().get(), + mongoOutputTableHandle.columns(), + implicitPrefix, + pageSinkIdColumnName, + pageSinkId); + } + + private static ConnectorPageSink createUpdateSink( + MongoSession mongoSession, + Optional outputTableHandle, + List columnHandles, + RemoteTableName remoteTableName, + String implicitPrefix, + Optional pageSinkIdColumnName, + ConnectorPageSinkId pageSinkId) + { + if (outputTableHandle.isEmpty()) { + return new MongoUpdateSink(mongoSession, remoteTableName, columnHandles, implicitPrefix, pageSinkIdColumnName, pageSinkId); + } + + MongoOutputTableHandle mongoOutputTableHandle = outputTableHandle.get(); + checkState(mongoOutputTableHandle.getTemporaryRemoteTableName().isPresent(), "temporary table not exist"); + + return new MongoPageSink( + mongoSession, + mongoOutputTableHandle.getTemporaryRemoteTableName().get(), + mongoOutputTableHandle.columns(), + implicitPrefix, + pageSinkIdColumnName, + pageSinkId); + } + + @Override + public void storeMergedRows(Page page) + { + checkArgument(page.getChannelCount() == 2 + columnCount, "The page size should be 2 + columnCount (%s), but is %s", columnCount, page.getChannelCount()); + int positionCount = page.getPositionCount(); + Block operationBlock = page.getBlock(columnCount); + Block rowId = page.getBlock(columnCount + 1); + + int[] dataChannel = IntStream.range(0, columnCount + 1).toArray(); + dataChannel[columnCount] = columnCount + 1; + Page dataPage = page.getColumns(dataChannel); + + int[] insertPositions = new int[positionCount]; + int insertPositionCount = 0; + int[] deletePositions = new int[positionCount]; + int deletePositionCount = 0; + int[] updatePositions = new int[positionCount]; + int updatePositionCount = 0; + + for (int position = 0; position < positionCount; position++) { + int operation = TINYINT.getByte(operationBlock, position); + switch (operation) { + case INSERT_OPERATION_NUMBER -> { + insertPositions[insertPositionCount] = position; + insertPositionCount++; + } + case DELETE_OPERATION_NUMBER -> { + deletePositions[deletePositionCount] = position; + deletePositionCount++; + } + case UPDATE_OPERATION_NUMBER -> { + updatePositions[updatePositionCount] = position; + updatePositionCount++; + } + default -> throw new IllegalStateException("Unexpected value: " + operation); + } + } + + if (deletePositionCount > 0) { + Block positions = rowId.getPositions(deletePositions, 0, deletePositionCount); + deleteSink.appendPage(new Page(deletePositionCount, positions)); + } + + if (updatePositionCount > 0) { + updateSink.appendPage(dataPage.getPositions(updatePositions, 0, updatePositionCount)); + } + + if (insertPositionCount > 0) { + // Insert page should not include _id column by default, unless the insert columns include it explicitly + Page insertPage = dataPage.getColumns(Arrays.copyOf(dataChannel, columnCount)); + insertSink.appendPage(insertPage.getPositions(insertPositions, 0, insertPositionCount)); + } + } + + @Override + public CompletableFuture> finish() + { + return insertSink.finish(); + } + + private static class MongoUpdateSink + implements ConnectorPageSink + { + private final MongoPageSink delegate; + private final MongoSession mongoSession; + private final RemoteTableName remoteTableName; + + public MongoUpdateSink( + MongoSession mongoSession, + RemoteTableName remoteTableName, + List columns, + String implicitPrefix, + Optional pageSinkIdColumnName, + ConnectorPageSinkId pageSinkId) + { + this.delegate = new MongoPageSink(mongoSession, remoteTableName, columns, implicitPrefix, pageSinkIdColumnName, pageSinkId); + this.mongoSession = requireNonNull(mongoSession, "mongoSession is null"); + this.remoteTableName = requireNonNull(remoteTableName, "remoteTableName is null"); + } + + @Override + public CompletableFuture appendPage(Page page) + { + MongoCollection collection = mongoSession.getCollection(remoteTableName); + List> bulkWrites = new ArrayList<>(); + for (Document document : delegate.buildBatchDocumentsFromPage(page)) { + Document filter = new Document(MERGE_ROW_ID_BASE_NAME, document.get(MERGE_ROW_ID_BASE_NAME)); + bulkWrites.add(new UpdateOneModel<>(filter, new Document(SET, document))); + } + collection.bulkWrite(bulkWrites, new BulkWriteOptions().ordered(false)); + return NOT_BLOCKED; + } + + @Override + public CompletableFuture> finish() + { + return delegate.finish(); + } + + @Override + public void abort() + { + delegate.abort(); + } + } + + private static class MongoDeleteSink + implements ConnectorPageSink + { + private final MongoPageSink delegate; + private final MongoSession mongoSession; + private final RemoteTableName remoteTableName; + + public MongoDeleteSink( + MongoSession mongoSession, + RemoteTableName remoteTableName, + List columns, + String implicitPrefix, + Optional pageSinkIdColumnName, + ConnectorPageSinkId pageSinkId) + { + this.delegate = new MongoPageSink(mongoSession, remoteTableName, columns, implicitPrefix, pageSinkIdColumnName, pageSinkId); + this.mongoSession = requireNonNull(mongoSession, "mongoSession is null"); + this.remoteTableName = requireNonNull(remoteTableName, "remoteTableName is null"); + } + + @Override + public CompletableFuture appendPage(Page page) + { + MongoCollection collection = mongoSession.getCollection(remoteTableName); + ImmutableList.Builder idsToDeleteBuilder = ImmutableList.builder(); + for (Document document : delegate.buildBatchDocumentsFromPage(page)) { + idsToDeleteBuilder.add(document.get(MERGE_ROW_ID_BASE_NAME)); + } + collection.deleteMany(in(MERGE_ROW_ID_BASE_NAME, idsToDeleteBuilder.build())); + return NOT_BLOCKED; + } + + @Override + public CompletableFuture> finish() + { + return delegate.finish(); + } + + @Override + public void abort() + { + delegate.abort(); + } + } +} diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeTableHandle.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeTableHandle.java new file mode 100644 index 0000000000000..3eb15b96bd956 --- /dev/null +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMergeTableHandle.java @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.mongodb; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorMergeTableHandle; +import io.trino.spi.connector.ConnectorTableHandle; +import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.predicate.TupleDomain; + +import java.util.List; +import java.util.Optional; +import java.util.OptionalInt; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public record MongoMergeTableHandle( + RemoteTableName remoteTableName, + List columns, + MongoColumnHandle mergeRowIdColumn, + Optional filter, + TupleDomain constraint, + Optional deleteOutputTableHandle, + Optional updateOutputTableHandle, + Optional temporaryTableName, + Optional pageSinkIdColumnName) + implements ConnectorMergeTableHandle +{ + public MongoMergeTableHandle + { + requireNonNull(remoteTableName, "remoteTableName is null"); + columns = ImmutableList.copyOf(requireNonNull(columns, "columns is null")); + requireNonNull(filter, "filter is null"); + requireNonNull(mergeRowIdColumn, "mergeRowIdColumn is null"); + requireNonNull(constraint, "constraint is null"); + requireNonNull(deleteOutputTableHandle, "deleteOutputTableHandle is null"); + requireNonNull(updateOutputTableHandle, "updateOutputTableHandle is null"); + requireNonNull(temporaryTableName, "temporaryTableName is null"); + requireNonNull(pageSinkIdColumnName, "pageSinkIdColumnName is null"); + checkArgument(temporaryTableName.isPresent() == pageSinkIdColumnName.isPresent(), + "temporaryTableName.isPresent is not equal to pageSinkIdColumnName.isPresent"); + } + + @JsonIgnore + public Optional getTemporaryRemoteTableName() + { + return temporaryTableName.map(tableName -> new RemoteTableName(remoteTableName.databaseName(), tableName)); + } + + @Override + public ConnectorTableHandle getTableHandle() + { + return new MongoTableHandle( + new SchemaTableName(remoteTableName.databaseName(), remoteTableName.collectionName()), + remoteTableName, + filter, + constraint, + ImmutableSet.of(), + OptionalInt.empty()); + } +} diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java index e023e61e72110..46a418f9e35fa 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java @@ -19,6 +19,7 @@ import com.google.common.collect.Streams; import com.google.common.io.Closer; import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.MergeOptions; import io.airlift.log.Logger; import io.airlift.slice.Slice; import io.trino.plugin.base.projection.ApplyProjectionUtil; @@ -29,6 +30,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorMetadata; import io.trino.spi.connector.ConnectorOutputMetadata; import io.trino.spi.connector.ConnectorOutputTableHandle; @@ -46,6 +48,7 @@ import io.trino.spi.connector.ProjectionApplicationResult; import io.trino.spi.connector.RelationColumnsMetadata; import io.trino.spi.connector.RetryMode; +import io.trino.spi.connector.RowChangeParadigm; import io.trino.spi.connector.SaveMode; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; @@ -71,6 +74,7 @@ import org.bson.Document; import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.Iterator; @@ -85,16 +89,21 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Strings.isNullOrEmpty; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getLast; import static com.google.common.collect.MoreCollectors.onlyElement; +import static com.mongodb.client.model.Aggregates.addFields; import static com.mongodb.client.model.Aggregates.lookup; import static com.mongodb.client.model.Aggregates.match; import static com.mongodb.client.model.Aggregates.merge; import static com.mongodb.client.model.Aggregates.project; +import static com.mongodb.client.model.Filters.in; import static com.mongodb.client.model.Filters.ne; import static com.mongodb.client.model.Projections.exclude; +import static com.mongodb.client.model.Projections.fields; import static io.trino.plugin.base.TemporaryTables.generateTemporaryTableName; import static io.trino.plugin.base.projection.ApplyProjectionUtil.ProjectedColumnRepresentation; import static io.trino.plugin.base.projection.ApplyProjectionUtil.extractSupportedProjectedColumns; @@ -107,6 +116,7 @@ import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.connector.RelationColumnsMetadata.forTable; +import static io.trino.spi.connector.RowChangeParadigm.CHANGE_ONLY_UPDATED_COLUMNS; import static io.trino.spi.connector.SaveMode.REPLACE; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DoubleType.DOUBLE; @@ -126,6 +136,8 @@ public class MongoMetadata implements ConnectorMetadata { + public static final String MERGE_ROW_ID_BASE_NAME = "_id"; + private static final Logger log = Logger.get(MongoMetadata.class); private static final Type TRINO_PAGE_SINK_ID_COLUMN_TYPE = BigintType.BIGINT; @@ -133,7 +145,9 @@ public class MongoMetadata private final MongoSession mongoSession; - private final AtomicReference rollbackAction = new AtomicReference<>(); + private final AtomicReference insertRollbackAction = new AtomicReference<>(); + private final AtomicReference deleteRollbackAction = new AtomicReference<>(); + private final AtomicReference updateRollBackAction = new AtomicReference<>(); public MongoMetadata(MongoSession mongoSession) { @@ -406,7 +420,7 @@ public ConnectorOutputTableHandle beginCreateTable( Closer closer = Closer.create(); closer.register(() -> mongoSession.dropTable(remoteTableName)); - setRollback(() -> { + setInsertRollback(() -> { try { closer.close(); } @@ -446,7 +460,7 @@ public Optional finishCreateTable(ConnectorSession sess if (handle.temporaryTableName().isPresent()) { finishInsert(session, handle.remoteTableName(), handle.getTemporaryRemoteTableName().get(), handle.pageSinkIdColumnName().get(), fragments); } - clearRollback(); + clearInsertRollback(); return Optional.empty(); } @@ -477,7 +491,7 @@ public ConnectorInsertTableHandle beginInsert(ConnectorSession session, Connecto RemoteTableName temporaryTable = new RemoteTableName(handle.schemaTableName().getSchemaName(), generateTemporaryTableName(session)); mongoSession.createTable(temporaryTable, allColumns, Optional.empty()); - setRollback(() -> mongoSession.dropTable(temporaryTable)); + setInsertRollback(() -> mongoSession.dropTable(temporaryTable)); return new MongoInsertTableHandle( handle.remoteTableName(), @@ -498,7 +512,7 @@ public Optional finishInsert( if (handle.temporaryTableName().isPresent()) { finishInsert(session, handle.remoteTableName(), handle.getTemporaryRemoteTableName().get(), handle.pageSinkIdColumnName().get(), fragments); } - clearRollback(); + clearInsertRollback(); return Optional.empty(); } @@ -508,30 +522,173 @@ private void finishInsert( RemoteTableName temporaryTable, String pageSinkIdColumnName, Collection fragments) + { + try (Closer closer = Closer.create()) { + closer.register(() -> mongoSession.dropTable(temporaryTable)); + RemoteTableName pageSinkIdsTable = getPageSinkIdsTable(session, temporaryTable, pageSinkIdColumnName, fragments, closer); + + MongoCollection temporaryCollection = mongoSession.getCollection(temporaryTable); + temporaryCollection.aggregate(ImmutableList.of( + lookup(pageSinkIdsTable.collectionName(), pageSinkIdColumnName, pageSinkIdColumnName, "page_sink_id"), + match(ne("page_sink_id", ImmutableList.of())), + project(exclude("page_sink_id")), + merge(targetTable.collectionName()))) + .toCollection(); + } + catch (IOException e) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, e); + } + } + + private RemoteTableName getPageSinkIdsTable(ConnectorSession session, RemoteTableName temporaryTable, String pageSinkIdColumnName, Collection fragments, Closer closer) + { + // Create the temporary page sink ID table + RemoteTableName pageSinkIdsTable = new RemoteTableName(temporaryTable.databaseName(), generateTemporaryTableName(session)); + MongoColumnHandle pageSinkIdColumn = new MongoColumnHandle(pageSinkIdColumnName, ImmutableList.of(), TRINO_PAGE_SINK_ID_COLUMN_TYPE, false, false, Optional.empty()); + mongoSession.createTable(pageSinkIdsTable, ImmutableList.of(pageSinkIdColumn), Optional.empty()); + closer.register(() -> mongoSession.dropTable(pageSinkIdsTable)); + + // Insert all the page sink IDs into the page sink ID table + MongoCollection pageSinkIdsCollection = mongoSession.getCollection(pageSinkIdsTable); + List pageSinkIds = fragments.stream() + .map(slice -> new Document(pageSinkIdColumnName, slice.getLong(0))) + .collect(toImmutableList()); + pageSinkIdsCollection.insertMany(pageSinkIds); + return pageSinkIdsTable; + } + + private Optional beginDeleteForMerge(ConnectorSession session, MongoTableHandle tableHandle, RetryMode retryMode) + { + if (retryMode != RetryMode.RETRIES_ENABLED) { + return Optional.empty(); + } + + MongoColumnHandle rowIdColumn = (MongoColumnHandle) getMergeRowIdColumnHandle(session, tableHandle); + + String rowIdTmpValueColumnName = getNonDuplicatedColumnName(MERGE_ROW_ID_BASE_NAME, ImmutableSet.of(MERGE_ROW_ID_BASE_NAME)); + MongoColumnHandle pageSinkIdColumn = buildPageSinkIdColumn(ImmutableSet.of(MERGE_ROW_ID_BASE_NAME, rowIdTmpValueColumnName)); + MongoColumnHandle rowIdTmpValueColumn = new MongoColumnHandle(rowIdTmpValueColumnName, ImmutableList.of(), rowIdColumn.type(), false, false, Optional.empty()); + List allColumns = ImmutableList.builder() + .add(rowIdTmpValueColumn) + .add(pageSinkIdColumn) + .build(); + + RemoteTableName temporaryTable = new RemoteTableName(tableHandle.schemaTableName().getSchemaName(), generateTemporaryTableName(session)); + mongoSession.createTable(temporaryTable, allColumns, Optional.empty()); + + setDeleteRollback(() -> mongoSession.dropTable(temporaryTable)); + + return Optional.of(new MongoOutputTableHandle( + tableHandle.remoteTableName(), + ImmutableList.of(rowIdTmpValueColumn), + Optional.of(temporaryTable.collectionName()), + Optional.of(pageSinkIdColumn.baseName()))); + } + + private void finishDeleteForMerge( + ConnectorSession session, + RemoteTableName targetTable, + RemoteTableName temporaryTable, + String pageSinkIdColumnName, + List columnHandles, + Collection fragments) + { + String rowIdName = getLast(columnHandles).baseName(); + checkArgument(!isNullOrEmpty(rowIdName) && rowIdName.startsWith(MERGE_ROW_ID_BASE_NAME), "Error rowId name: " + rowIdName); + + try (Closer closer = Closer.create()) { + closer.register(() -> mongoSession.dropTable(temporaryTable)); + RemoteTableName pageSinkIdsTable = getPageSinkIdsTable(session, temporaryTable, pageSinkIdColumnName, fragments, closer); + + MongoCollection temporaryCollection = mongoSession.getCollection(temporaryTable); + List idsToDelete = new ArrayList<>(); + temporaryCollection.aggregate(ImmutableList.of( + lookup(pageSinkIdsTable.collectionName(), pageSinkIdColumnName, pageSinkIdColumnName, "page_sink_id"), + match(ne("page_sink_id", ImmutableList.of())))) + .map(document -> document.get(rowIdName)).into(idsToDelete); + + MongoCollection targetCollection = mongoSession.getCollection(targetTable); + targetCollection.deleteMany(in(MERGE_ROW_ID_BASE_NAME, idsToDelete)); + } + catch (IOException e) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, e); + } + } + + private Optional beginUpdateForMerge(ConnectorSession session, MongoTableHandle tableHandle, RetryMode retryMode) + { + if (retryMode != RetryMode.RETRIES_ENABLED) { + return Optional.empty(); + } + + MongoColumnHandle rowIdColumn = (MongoColumnHandle) getMergeRowIdColumnHandle(session, tableHandle); + + MongoTable table = mongoSession.getTable(tableHandle.schemaTableName()); + List columns = table.columns(); + + Set allColumnNames = columns.stream() + .map(MongoColumnHandle::baseName) + .collect(toImmutableSet()); + + MongoColumnHandle pageSinkIdColumn = buildPageSinkIdColumn(allColumnNames); + MongoColumnHandle rowIdValueColumn = new MongoColumnHandle(getNonDuplicatedColumnName(MERGE_ROW_ID_BASE_NAME, allColumnNames), ImmutableList.of(), rowIdColumn.type(), false, false, Optional.empty()); + + List allColumns = ImmutableList.builderWithExpectedSize(columns.size() + 2) + .addAll(columns) + .add(rowIdValueColumn) + .add(pageSinkIdColumn) + .build(); + RemoteTableName temporaryTable = new RemoteTableName(tableHandle.schemaTableName().getSchemaName(), generateTemporaryTableName(session)); + mongoSession.createTable(temporaryTable, allColumns, Optional.empty()); + + setUpdateRollBack(() -> mongoSession.dropTable(temporaryTable)); + + List nonHiddenColumns = columns.stream() + .filter(column -> !column.hidden()) + .peek(column -> validateColumnNameForInsert(column.baseName())) + .collect(toImmutableList()); + + List handleColumns = ImmutableList.builderWithExpectedSize(nonHiddenColumns.size() + 1) + .addAll(nonHiddenColumns) + .add(rowIdValueColumn) + .build(); + return Optional.of(new MongoOutputTableHandle( + tableHandle.remoteTableName(), + handleColumns, + Optional.of(temporaryTable.collectionName()), + Optional.of(pageSinkIdColumn.baseName()))); + } + + private void finishUpdateForMerge( + ConnectorSession session, + RemoteTableName targetTable, + RemoteTableName temporaryTable, + String pageSinkIdColumnName, + List columnHandles, + Collection fragments) { Closer closer = Closer.create(); closer.register(() -> mongoSession.dropTable(temporaryTable)); - try { - // Create the temporary page sink ID table - RemoteTableName pageSinkIdsTable = new RemoteTableName(temporaryTable.databaseName(), generateTemporaryTableName(session)); - MongoColumnHandle pageSinkIdColumn = new MongoColumnHandle(pageSinkIdColumnName, ImmutableList.of(), TRINO_PAGE_SINK_ID_COLUMN_TYPE, false, false, Optional.empty()); - mongoSession.createTable(pageSinkIdsTable, ImmutableList.of(pageSinkIdColumn), Optional.empty()); - closer.register(() -> mongoSession.dropTable(pageSinkIdsTable)); - - // Insert all the page sink IDs into the page sink ID table - MongoCollection pageSinkIdsCollection = mongoSession.getCollection(pageSinkIdsTable); - List pageSinkIds = fragments.stream() - .map(slice -> new Document(pageSinkIdColumnName, slice.getLong(0))) - .collect(toImmutableList()); - pageSinkIdsCollection.insertMany(pageSinkIds); + String rowIdName = getLast(columnHandles).baseName(); + checkArgument(!isNullOrEmpty(rowIdName) && rowIdName.startsWith(MERGE_ROW_ID_BASE_NAME), "Error rowId name: " + rowIdName); + try { MongoCollection temporaryCollection = mongoSession.getCollection(temporaryTable); + + RemoteTableName pageSinkIdsTable = getPageSinkIdsTable(session, temporaryTable, pageSinkIdColumnName, fragments, closer); + MergeOptions mergeOptions = new MergeOptions() + .whenMatched(MergeOptions.WhenMatched.REPLACE) + .whenNotMatched(MergeOptions.WhenNotMatched.FAIL) + .uniqueIdentifier(MERGE_ROW_ID_BASE_NAME); + temporaryCollection.aggregate(ImmutableList.of( lookup(pageSinkIdsTable.collectionName(), pageSinkIdColumnName, pageSinkIdColumnName, "page_sink_id"), match(ne("page_sink_id", ImmutableList.of())), - project(exclude("page_sink_id")), - merge(targetTable.collectionName()))) + // Replace the unique key with the rowIdName reference value + addFields(new com.mongodb.client.model.Field<>(MERGE_ROW_ID_BASE_NAME, "$" + rowIdName)), + project(fields(exclude(rowIdName), exclude("page_sink_id"))), + merge(targetTable.collectionName(), mergeOptions))) .toCollection(); } finally { @@ -544,10 +701,73 @@ private void finishInsert( } } + @Override + public RowChangeParadigm getRowChangeParadigm(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return CHANGE_ONLY_UPDATED_COLUMNS; + } + + @Override + public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, RetryMode retryMode) + { + MongoTableHandle table = (MongoTableHandle) tableHandle; + MongoInsertTableHandle insertTableHandle = (MongoInsertTableHandle) beginInsert(session, tableHandle, ImmutableList.of(), retryMode); + return new MongoMergeTableHandle( + insertTableHandle.remoteTableName(), + insertTableHandle.columns(), + (MongoColumnHandle) getMergeRowIdColumnHandle(session, tableHandle), + table.filter(), + table.constraint(), + beginDeleteForMerge(session, table, retryMode), + beginUpdateForMerge(session, table, retryMode), + insertTableHandle.temporaryTableName(), + insertTableHandle.pageSinkIdColumnName()); + } + + @Override + public void finishMerge( + ConnectorSession session, + ConnectorMergeTableHandle mergeTableHandle, + List sourceTableHandles, + Collection fragments, + Collection computedStatistics) + { + MongoMergeTableHandle tableHandle = (MongoMergeTableHandle) mergeTableHandle; + MongoInsertTableHandle insertTableHandle = new MongoInsertTableHandle( + tableHandle.remoteTableName(), + ImmutableList.copyOf(tableHandle.columns()), + tableHandle.temporaryTableName(), + tableHandle.pageSinkIdColumnName()); + finishInsert(session, insertTableHandle, ImmutableList.of(), fragments, computedStatistics); + + tableHandle.deleteOutputTableHandle().ifPresent(deleteOutputHandle -> + finishDeleteForMerge( + session, + deleteOutputHandle.remoteTableName(), + deleteOutputHandle.getTemporaryRemoteTableName().get(), + deleteOutputHandle.pageSinkIdColumnName().get(), + deleteOutputHandle.columns(), + fragments)); + clearDeleteRollback(); + + tableHandle.updateOutputTableHandle().ifPresent(updateOutputHandle -> + finishUpdateForMerge( + session, + updateOutputHandle.remoteTableName(), + updateOutputHandle.getTemporaryRemoteTableName().get(), + updateOutputHandle.pageSinkIdColumnName().get(), + updateOutputHandle.columns(), + fragments)); + clearUpdateRollback(); + } + @Override public ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle) { - return new MongoColumnHandle("$merge_row_id", ImmutableList.of(), BIGINT, true, false, Optional.empty()); + Map columnHandles = getColumnHandles(session, tableHandle); + checkState(columnHandles.containsKey(MERGE_ROW_ID_BASE_NAME), "id column %s not exists", MERGE_ROW_ID_BASE_NAME); + Type idColumnType = ((MongoColumnHandle) columnHandles.get(MERGE_ROW_ID_BASE_NAME)).type(); + return new MongoColumnHandle(MERGE_ROW_ID_BASE_NAME, ImmutableList.of(), idColumnType, true, false, Optional.empty()); } @Override @@ -834,19 +1054,41 @@ public Optional> applyTable return Optional.of(new TableFunctionApplicationResult<>(tableHandle, columnHandles)); } - private void setRollback(Runnable action) + private void setInsertRollback(Runnable action) { - checkState(rollbackAction.compareAndSet(null, action), "rollback action is already set"); + checkState(insertRollbackAction.compareAndSet(null, action), "insertRollbackAction action is already set"); } - private void clearRollback() + private void clearInsertRollback() { - rollbackAction.set(null); + insertRollbackAction.set(null); + } + + private void setDeleteRollback(Runnable action) + { + checkState(deleteRollbackAction.compareAndSet(null, action), "deleteRollbackAction action is already set"); + } + + private void clearDeleteRollback() + { + deleteRollbackAction.set(null); + } + + private void setUpdateRollBack(Runnable action) + { + checkState(updateRollBackAction.compareAndSet(null, action), "updateRollBackAction action is already set"); + } + + private void clearUpdateRollback() + { + updateRollBackAction.set(null); } public void rollback() { - Optional.ofNullable(rollbackAction.getAndSet(null)).ifPresent(Runnable::run); + Optional.ofNullable(insertRollbackAction.getAndSet(null)).ifPresent(Runnable::run); + Optional.ofNullable(deleteRollbackAction.getAndSet(null)).ifPresent(Runnable::run); + Optional.ofNullable(updateRollBackAction.getAndSet(null)).ifPresent(Runnable::run); } private static SchemaTableName getTableName(ConnectorTableHandle tableHandle) @@ -880,16 +1122,20 @@ private static void validateColumnNameForInsert(String columnName) } private static MongoColumnHandle buildPageSinkIdColumn(Set otherColumnNames) + { + return new MongoColumnHandle(getNonDuplicatedColumnName("trino_page_sink_id", otherColumnNames), ImmutableList.of(), TRINO_PAGE_SINK_ID_COLUMN_TYPE, false, false, Optional.empty()); + } + + private static String getNonDuplicatedColumnName(String baseColumnName, Set otherColumnNames) { // While it's unlikely this column name will collide with client table columns, // guarantee it will not by appending a deterministic suffix to it. - String baseColumnName = "trino_page_sink_id"; String columnName = baseColumnName; int suffix = 1; while (otherColumnNames.contains(columnName)) { columnName = baseColumnName + "_" + suffix; suffix++; } - return new MongoColumnHandle(columnName, ImmutableList.of(), TRINO_PAGE_SINK_ID_COLUMN_TYPE, false, false, Optional.empty()); + return columnName; } } diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java index 06496aac7b09c..7d21e426c2b83 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSink.java @@ -110,6 +110,12 @@ public MongoPageSink( public CompletableFuture appendPage(Page page) { MongoCollection collection = mongoSession.getCollection(remoteTableName); + collection.insertMany(buildBatchDocumentsFromPage(page), new InsertManyOptions().ordered(true)); + return NOT_BLOCKED; + } + + protected List buildBatchDocumentsFromPage(Page page) + { List batch = new ArrayList<>(page.getPositionCount()); for (int position = 0; position < page.getPositionCount(); position++) { @@ -122,9 +128,7 @@ public CompletableFuture appendPage(Page page) } batch.add(doc); } - - collection.insertMany(batch, new InsertManyOptions().ordered(true)); - return NOT_BLOCKED; + return batch; } private Object getObjectValue(Type type, Block block, int position) diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSinkProvider.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSinkProvider.java index 7ebf8667b4de0..a7cab892af474 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSinkProvider.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSinkProvider.java @@ -15,6 +15,8 @@ import com.google.inject.Inject; import io.trino.spi.connector.ConnectorInsertTableHandle; +import io.trino.spi.connector.ConnectorMergeSink; +import io.trino.spi.connector.ConnectorMergeTableHandle; import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorPageSinkId; @@ -50,4 +52,20 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa MongoInsertTableHandle handle = (MongoInsertTableHandle) insertTableHandle; return new MongoPageSink(mongoSession, handle.getTemporaryRemoteTableName().orElseGet(handle::remoteTableName), handle.columns(), implicitPrefix, handle.pageSinkIdColumnName(), pageSinkId); } + + @Override + public ConnectorMergeSink createMergeSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorMergeTableHandle mergeHandle, ConnectorPageSinkId pageSinkId) + { + MongoMergeTableHandle handle = (MongoMergeTableHandle) mergeHandle; + return new MongoMergeSink( + mongoSession, + handle.getTemporaryRemoteTableName().orElseGet(handle::remoteTableName), + handle.columns(), + handle.mergeRowIdColumn(), + handle.deleteOutputTableHandle(), + handle.updateOutputTableHandle(), + implicitPrefix, + handle.pageSinkIdColumnName(), + pageSinkId); + } } diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoConnectorSmokeTest.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoConnectorSmokeTest.java index f39404a42ea08..27ad160ab8fcd 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoConnectorSmokeTest.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoConnectorSmokeTest.java @@ -30,11 +30,9 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) return switch (connectorBehavior) { case SUPPORTS_CREATE_MATERIALIZED_VIEW, SUPPORTS_CREATE_VIEW, - SUPPORTS_MERGE, SUPPORTS_NOT_NULL_CONSTRAINT, SUPPORTS_RENAME_SCHEMA, - SUPPORTS_TRUNCATE, - SUPPORTS_UPDATE -> false; + SUPPORTS_TRUNCATE -> false; default -> super.hasBehavior(connectorBehavior); }; } diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoFailureRecoveryTest.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoFailureRecoveryTest.java index 8b1d9ca99d1b4..a1b36d34f0e8c 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoFailureRecoveryTest.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/BaseMongoFailureRecoveryTest.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.Assumptions.abort; @@ -69,24 +70,16 @@ protected void testAnalyzeTable() @Override protected void testDelete() { - assertThatThrownBy(super::testDeleteWithSubquery).hasMessageContaining("This connector does not support modifying table rows"); - abort("skipped"); - } - - @Test - @Override - protected void testDeleteWithSubquery() - { - assertThatThrownBy(super::testDeleteWithSubquery).hasMessageContaining("This connector does not support modifying table rows"); - abort("skipped"); - } + // This simple delete on Mongo ends up as a very simple, single-fragment, coordinator-only plan, + // which has no ability to recover from errors. This test simply verifies that's still the case. + Optional setupQuery = Optional.of("CREATE TABLE AS SELECT * FROM orders"); + String testQuery = "DELETE FROM
WHERE orderkey = 1"; + Optional cleanupQuery = Optional.of("DROP TABLE
"); - @Test - @Override - protected void testMerge() - { - assertThatThrownBy(super::testMerge).hasMessageContaining("This connector does not support modifying table rows"); - abort("skipped"); + assertThatQuery(testQuery) + .withSetupQuery(setupQuery) + .withCleanupQuery(cleanupQuery) + .isCoordinatorOnly(); } @Test @@ -98,22 +91,6 @@ protected void testRefreshMaterializedView() abort("skipped"); } - @Test - @Override - protected void testUpdate() - { - assertThatThrownBy(super::testUpdate).hasMessageContaining("This connector does not support modifying table rows"); - abort("skipped"); - } - - @Test - @Override - protected void testUpdateWithSubquery() - { - assertThatThrownBy(super::testUpdateWithSubquery).hasMessageContaining("This connector does not support modifying table rows"); - abort("skipped"); - } - @Override protected boolean areWriteRetriesSupported() { diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java index fd85c4abb8a64..d1fd7b8cb979b 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java @@ -53,7 +53,6 @@ import static com.mongodb.client.model.CollationStrength.PRIMARY; import static io.trino.plugin.mongodb.MongoQueryRunner.createMongoClient; import static io.trino.plugin.mongodb.TypeUtils.isPushdownSupportedType; -import static io.trino.spi.connector.ConnectorMetadata.MODIFYING_ROWS_MESSAGE; import static io.trino.testing.TestingNames.randomNameSuffix; import static java.lang.String.format; import static java.nio.charset.StandardCharsets.UTF_8; @@ -103,13 +102,11 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) SUPPORTS_CREATE_MATERIALIZED_VIEW, SUPPORTS_CREATE_VIEW, SUPPORTS_DROP_FIELD, - SUPPORTS_MERGE, SUPPORTS_NOT_NULL_CONSTRAINT, SUPPORTS_RENAME_FIELD, SUPPORTS_RENAME_SCHEMA, SUPPORTS_SET_FIELD_TYPE, - SUPPORTS_TRUNCATE, - SUPPORTS_UPDATE -> false; + SUPPORTS_TRUNCATE -> false; default -> super.hasBehavior(connectorBehavior); }; } @@ -120,6 +117,29 @@ protected TestTable createTableWithDefaultColumns() return abort("MongoDB connector does not support column default values"); } + @Test + public void testMergeWithCustomizeTypeIdColumn() + { + String targetTable = "merge_with_customize_type_id_column_target_" + randomNameSuffix(); + String sourceTable = "merget_with_customize_type_id_column_source_" + randomNameSuffix(); + assertUpdate(createTableForWrites("CREATE TABLE %s (_id INT, customer VARCHAR, purchases INT, address VARCHAR)".formatted(targetTable))); + + assertUpdate("INSERT INTO %s (_id, customer, purchases, address) VALUES (1, 'Aaron', 5, 'Antioch'), (2, 'Bill', 7, 'Buena'), (3, 'Carol', 3, 'Cambridge'), (4, 'Dave', 11, 'Devon')".formatted(targetTable), 4); + + assertUpdate(createTableForWrites("CREATE TABLE %s (_id INT, customer VARCHAR, purchases INT, address VARCHAR)".formatted(sourceTable))); + + assertUpdate("INSERT INTO %s (_id, customer, purchases, address) VALUES (1, 'Aaron', 6, 'Arches'), (3, 'Ed', 7, 'Etherville'), (4, 'Carol', 9, 'Centreville'), (7, 'Dave', 11, 'Darbyshire')".formatted(sourceTable), 4); + + assertUpdate("MERGE INTO %s t USING %s s ON (t._id = s._id)".formatted(targetTable, sourceTable) + + " WHEN MATCHED AND s.address = 'Centreville' THEN DELETE" + + " WHEN MATCHED THEN UPDATE SET purchases = s.purchases + t.purchases, address = s.address" + + " WHEN NOT MATCHED THEN INSERT (_id, customer, purchases, address) VALUES(s._id, s.customer, s.purchases, s.address)", 4); + assertQuery("SELECT _id, customer, purchases, address FROM " + targetTable, "VALUES (1, 'Aaron', 11, 'Arches'), (2, 'Bill', 7, 'Buena'), (3, 'Carol', 10, 'Etherville'), (7, 'Dave', 11, 'Darbyshire')"); + + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + @Test @Override public void testColumnName() @@ -295,46 +315,6 @@ public void testInsertWithEveryType() assertThat(getQueryRunner().tableExists(getSession(), tableName)).isFalse(); } - @Test - @Override - public void testDeleteWithComplexPredicate() - { - assertThatThrownBy(super::testDeleteWithComplexPredicate) - .hasStackTraceContaining("TrinoException: " + MODIFYING_ROWS_MESSAGE); - } - - @Test - @Override - public void testDeleteWithLike() - { - assertThatThrownBy(super::testDeleteWithLike) - .hasStackTraceContaining("TrinoException: " + MODIFYING_ROWS_MESSAGE); - } - - @Test - @Override - public void testDeleteWithSemiJoin() - { - assertThatThrownBy(super::testDeleteWithSemiJoin) - .hasStackTraceContaining("TrinoException: " + MODIFYING_ROWS_MESSAGE); - } - - @Test - @Override - public void testDeleteWithSubquery() - { - assertThatThrownBy(super::testDeleteWithSubquery) - .hasStackTraceContaining("TrinoException: " + MODIFYING_ROWS_MESSAGE); - } - - @Test - @Override - public void testExplainAnalyzeWithDeleteWithSubquery() - { - assertThatThrownBy(super::testExplainAnalyzeWithDeleteWithSubquery) - .hasStackTraceContaining("TrinoException: " + MODIFYING_ROWS_MESSAGE); - } - @Test public void testPredicatePushdown() { @@ -1828,6 +1808,13 @@ public void testProjectionPushdownPhysicalInputSize() abort("MongoDB connector does not calculate physical data input size"); } + @Test + @Override + public void testUpdateRowConcurrently() + { + abort("Mongo doesn't support concurrent update of different columns in a row"); + } + @Override protected OptionalInt maxSchemaNameLength() {