Skip to content

Commit

Permalink
refactor: Declarative datawriter, options implementation (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Mar 24, 2024
1 parent a7ebd41 commit 53d1d2e
Show file tree
Hide file tree
Showing 13 changed files with 160 additions and 250 deletions.
22 changes: 11 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,31 @@

[Apache Spark](https://spark.apache.org/) is a distributed computing framework designed for big data processing and analytics. This connector enables [Qdrant](https://qdrant.tech/) to be a storage destination in Spark.

## Installation 🚀
## Installation

> [!IMPORTANT]
> Requires Java 8 or above.
### GitHub Releases 📦
### GitHub Releases

The packaged `jar` file can be found [here](https://github.com/qdrant/qdrant-spark/releases).

### Building from source 🛠️
### Building from source

To build the `jar` from source, you need [JDK@8](https://www.azul.com/downloads/#zulu) and [Maven](https://maven.apache.org/) installed.
Once the requirements have been satisfied, run the following command in the project root. 🛠️
Once the requirements have been satisfied, run the following command in the project root.

```bash
mvn package
```

This will build and store the fat JAR in the `target` directory by default.

### Maven Central 📚
### Maven Central

For use with Java and Scala projects, the package can be found [here](https://central.sonatype.com/artifact/io.qdrant/spark).

## Usage 📝
## Usage

### Creating a Spark session (Single-node) with Qdrant support

Expand All @@ -42,7 +42,7 @@ spark = SparkSession.builder.config(
.getOrCreate()
```

### Loading data 📊
### Loading data

> [!IMPORTANT]
> Before loading the data using this connector, a collection has to be [created](https://qdrant.tech/documentation/concepts/collections/#create-a-collection) in advance with the appropriate vector dimensions and configurations.
Expand Down Expand Up @@ -191,11 +191,11 @@ You can use the connector as a library in Databricks to ingest data into Qdrant.

<img width="1064" alt="Screenshot 2024-01-05 at 17 20 01 (1)" src="https://github.com/qdrant/qdrant-spark/assets/46051506/d95773e0-c5c6-4ff2-bf50-8055bb08fd1b">

## Datatype support 📋
## Datatype support

Qdrant supports all the Spark data types. The appropriate types are mapped based on the provided `schema`.
The appropriate Spark data types are mapped to the Qdrant payload based on the provided `schema`.

## Options and Spark types 🛠️
## Options and Spark types

| Option | Description | Column DataType | Required |
| :--------------------------- | :------------------------------------------------------------------ | :---------------------------- | :------- |
Expand All @@ -215,6 +215,6 @@ Qdrant supports all the Spark data types. The appropriate types are mapped based
| `sparse_vector_names` | Comma-separated names of the sparse vectors in the collection. | - ||
| `shard_key_selector` | Comma-separated names of custom shard keys to use during upsert. | - ||

## LICENSE 📜
## LICENSE

Apache 2.0 © [2024](https://github.com/qdrant/qdrant-spark/blob/master/LICENSE)
2 changes: 0 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,6 @@
<configuration>
<includeStale>false</includeStale>
<style>GOOGLE</style>
<formatMain>true</formatMain>
<formatTest>true</formatTest>
<filterModified>false</filterModified>
<skip>false</skip>
<fixImports>true</fixImports>
Expand Down
36 changes: 15 additions & 21 deletions src/main/java/io/qdrant/spark/Qdrant.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,46 +8,40 @@
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;

/** A class that implements the TableProvider and DataSourceRegister interfaces. */
/** Qdrant datasource for Apache Spark. */
public class Qdrant implements TableProvider, DataSourceRegister {

private final String[] requiredFields = new String[] {"schema", "collection_name", "qdrant_url"};
private static final String[] REQUIRED_FIELDS = {"schema", "collection_name", "qdrant_url"};

/**
* Returns the short name of the data source.
*
* @return The short name of the data source.
*/
/** Returns the short name of the data source. */
@Override
public String shortName() {
return "qdrant";
}

/**
* Infers the schema of the data source based on the provided options.
* Validates and infers the schema from the provided options.
*
* @param options The options used to infer the schema.
* @return The inferred schema.
* @throws IllegalArgumentException if required options are missing.
*/
@Override
public StructType inferSchema(CaseInsensitiveStringMap options) {
for (String fieldName : requiredFields) {
if (!options.containsKey(fieldName)) {
throw new IllegalArgumentException(fieldName.concat(" option is required"));
validateOptions(options);
return (StructType) StructType.fromJson(options.get("schema"));
}

private void validateOptions(CaseInsensitiveStringMap options) {
for (String field : REQUIRED_FIELDS) {
if (!options.containsKey(field)) {
throw new IllegalArgumentException(String.format("%s option is required", field));
}
}
StructType schema = (StructType) StructType.fromJson(options.get("schema"));

return schema;
}

/**
* Returns a table for the data source based on the provided schema, partitioning, and properties.
* Creates a Qdrant table instance with validated options.
*
* @param schema The schema of the table.
* @param partitioning The partitioning of the table.
* @param properties The properties of the table.
* @return The table for the data source.
* @throws IllegalArgumentException if options are invalid.
*/
@Override
public Table getTable(
Expand Down
11 changes: 3 additions & 8 deletions src/main/java/io/qdrant/spark/QdrantBatchWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import org.apache.spark.sql.connector.write.WriterCommitMessage;
import org.apache.spark.sql.types.StructType;

/** QdrantBatchWriter class implements the BatchWrite interface. */
/** Qdrant batch writer for Apache Spark. */
public class QdrantBatchWriter implements BatchWrite {

private final QdrantOptions options;
Expand All @@ -23,13 +23,8 @@ public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) {
}

@Override
public void commit(WriterCommitMessage[] messages) {
// TODO Auto-generated method stub

}
public void commit(WriterCommitMessage[] messages) {}

@Override
public void abort(WriterCommitMessage[] messages) {
// TODO Auto-generated method stub
}
public void abort(WriterCommitMessage[] messages) {}
}
16 changes: 6 additions & 10 deletions src/main/java/io/qdrant/spark/QdrantCluster.java
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
package io.qdrant.spark;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.EnumSet;
import java.util.Set;
import org.apache.spark.sql.connector.catalog.SupportsWrite;
import org.apache.spark.sql.connector.catalog.TableCapability;
import org.apache.spark.sql.connector.write.LogicalWriteInfo;
import org.apache.spark.sql.connector.write.WriteBuilder;
import org.apache.spark.sql.types.StructType;

/** QdrantCluster class implements the SupportsWrite interface. */
/** Qdrant cluster implementation supporting batch writes. */
public class QdrantCluster implements SupportsWrite {

private final StructType schema;
private final QdrantOptions options;

private static final Set<TableCapability> TABLE_CAPABILITY_SET =
Collections.unmodifiableSet(
new HashSet<>(
Arrays.asList(TableCapability.BATCH_WRITE, TableCapability.STREAMING_WRITE)));
private static final Set<TableCapability> CAPABILITIES = EnumSet.of(TableCapability.BATCH_WRITE);

public QdrantCluster(QdrantOptions options, StructType schema) {
this.options = options;
Expand All @@ -28,7 +24,7 @@ public QdrantCluster(QdrantOptions options, StructType schema) {

@Override
public WriteBuilder newWriteBuilder(LogicalWriteInfo info) {
return new QdrantWriteBuilder(this.options, this.schema);
return new QdrantWriteBuilder(options, schema);
}

@Override
Expand All @@ -38,11 +34,11 @@ public String name() {

@Override
public StructType schema() {
return this.schema;
return schema;
}

@Override
public Set<TableCapability> capabilities() {
return TABLE_CAPABILITY_SET;
return Collections.unmodifiableSet(CAPABILITIES);
}
}
92 changes: 43 additions & 49 deletions src/main/java/io/qdrant/spark/QdrantDataWriter.java
Original file line number Diff line number Diff line change
@@ -1,94 +1,88 @@
package io.qdrant.spark;

import io.qdrant.client.grpc.JsonWithInt.Value;
import io.qdrant.client.grpc.Points.PointId;
import io.qdrant.client.grpc.Points.PointStruct;
import io.qdrant.client.grpc.Points.Vectors;
import java.io.Serializable;
import java.net.URL;
import java.util.ArrayList;
import java.util.Map;
import java.util.List;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.write.DataWriter;
import org.apache.spark.sql.connector.write.WriterCommitMessage;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** A DataWriter implementation that writes data to Qdrant. */
/** DataWriter implementation for writing data to Qdrant. */
public class QdrantDataWriter implements DataWriter<InternalRow>, Serializable {

private static final Logger LOG = LoggerFactory.getLogger(QdrantDataWriter.class);

private final QdrantOptions options;
private final StructType schema;
private final String qdrantUrl;
private final String apiKey;
private final Logger LOG = LoggerFactory.getLogger(QdrantDataWriter.class);

private final ArrayList<PointStruct> points = new ArrayList<>();
private final List<PointStruct> pointsBuffer = new ArrayList<>();

public QdrantDataWriter(QdrantOptions options, StructType schema) {
this.options = options;
this.schema = schema;
this.qdrantUrl = options.qdrantUrl;
this.apiKey = options.apiKey;
}

@Override
public void write(InternalRow record) {
PointStruct.Builder pointBuilder = PointStruct.newBuilder();

PointId pointId = QdrantPointIdHandler.preparePointId(record, this.schema, this.options);
pointBuilder.setId(pointId);

Vectors vectors = QdrantVectorHandler.prepareVectors(record, this.schema, this.options);
pointBuilder.setVectors(vectors);

Map<String, Value> payload =
QdrantPayloadHandler.preparePayload(record, this.schema, this.options);
pointBuilder.putAllPayload(payload);

this.points.add(pointBuilder.build());
PointStruct point = createPointStruct(record);
pointsBuffer.add(point);

if (this.points.size() >= this.options.batchSize) {
this.write(this.options.retries);
if (pointsBuffer.size() >= options.batchSize) {
writeBatch(options.retries);
}
}

@Override
public WriterCommitMessage commit() {
this.write(this.options.retries);
return new WriterCommitMessage() {
@Override
public String toString() {
return "point committed to Qdrant";
}
};
private PointStruct createPointStruct(InternalRow record) {
PointStruct.Builder pointBuilder = PointStruct.newBuilder();
pointBuilder.setId(QdrantPointIdHandler.preparePointId(record, schema, options));
pointBuilder.setVectors(QdrantVectorHandler.prepareVectors(record, schema, options));
pointBuilder.putAllPayload(QdrantPayloadHandler.preparePayload(record, schema, options));
return pointBuilder.build();
}

public void write(int retries) {
LOG.info(
String.join(
"", "Uploading batch of ", Integer.toString(this.points.size()), " points to Qdrant"));

if (this.points.isEmpty()) {
private void writeBatch(int retries) {
if (pointsBuffer.isEmpty()) {
return;
}

try {
// Instantiate a new QdrantGrpc object to maintain serializability
QdrantGrpc qdrant = new QdrantGrpc(new URL(this.qdrantUrl), this.apiKey);
qdrant.upsert(this.options.collectionName, this.points, this.options.shardKeySelector);
qdrant.close();
this.points.clear();
doWriteBatch();
pointsBuffer.clear();
} catch (Exception e) {
LOG.error(String.join("", "Exception while uploading batch to Qdrant: ", e.getMessage()));
LOG.error("Exception while uploading batch to Qdrant: {}", e.getMessage());
if (retries > 0) {
LOG.info("Retrying upload batch to Qdrant");
write(retries - 1);
writeBatch(retries - 1);
} else {
throw new RuntimeException(e);
}
}
}

private void doWriteBatch() throws Exception {
LOG.info("Uploading batch of {} points to Qdrant", pointsBuffer.size());

// Instantiate QdrantGrpc client for each batch to maintain serializability
QdrantGrpc qdrant = new QdrantGrpc(new URL(options.qdrantUrl), options.apiKey);
qdrant.upsert(options.collectionName, pointsBuffer, options.shardKeySelector);
qdrant.close();
}

@Override
public WriterCommitMessage commit() {
writeBatch(options.retries);
return new WriterCommitMessage() {
@Override
public String toString() {
return "point committed to Qdrant";
}
};
}

@Override
public void abort() {}

Expand Down
17 changes: 4 additions & 13 deletions src/main/java/io/qdrant/spark/QdrantDataWriterFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,26 @@
import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory;
import org.apache.spark.sql.types.StructType;

/** Factory class for creating QdrantDataWriter instances for Spark Structured Streaming. */
/** Factory class for creating QdrantDataWriter instances for Spark data sources. */
public class QdrantDataWriterFactory implements StreamingDataWriterFactory, DataWriterFactory {

private final QdrantOptions options;
private final StructType schema;

/**
* Constructor for QdrantDataWriterFactory.
*
* @param options QdrantOptions instance containing configuration options for Qdrant.
* @param schema StructType instance containing schema information for the data being written.
*/
public QdrantDataWriterFactory(QdrantOptions options, StructType schema) {
this.options = options;
this.schema = schema;
}

@Override
public QdrantDataWriter createWriter(int partitionId, long taskId, long epochId) {
try {
return new QdrantDataWriter(this.options, this.schema);
} catch (Exception e) {
throw new RuntimeException(e);
}
return createWriter(partitionId, taskId);
}

@Override
public QdrantDataWriter createWriter(int partitionId, long taskId) {
try {
return new QdrantDataWriter(this.options, this.schema);
return new QdrantDataWriter(options, schema);
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand Down
Loading

0 comments on commit 53d1d2e

Please sign in to comment.