Skip to content

Commit

Permalink
Support Spark-2.1 version. (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
petro-rudenko authored Jul 14, 2021
1 parent f582715 commit c04e36b
Show file tree
Hide file tree
Showing 8 changed files with 507 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/sparkucx-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
build-sparkucx:
strategy:
matrix:
spark_version: ["2.4", "3.0"]
spark_version: ["2.1", "2.4", "3.0"]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/sparkucx-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
release:
strategy:
matrix:
spark_version: ["2.4", "3.0"]
spark_version: ["2.1", "2.4", "3.0"]
runs-on: ubuntu-latest
steps:
- name: Checkout code
Expand Down
103 changes: 88 additions & 15 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ See file LICENSE for terms.
-->

<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.openucx</groupId>
Expand Down Expand Up @@ -34,12 +34,68 @@ See file LICENSE for terms.
</properties>

<profiles>
<profile>
<id>spark-2.1</id>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<excludes>
<exclude>**/spark_3_0/**</exclude>
<exclude>**/spark_2_4/**</exclude>
</excludes>
</configuration>
</plugin>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<configuration>
<excludes>
<exclude>**/spark_3_0/**</exclude>
<exclude>**/spark_2_4/**</exclude>
</excludes>
</configuration>
</plugin>
</plugins>
</build>
<properties>
<spark.version>2.1.0</spark.version>
<sonar.exclusions>**/spark_3_0/**, **/spark_2_4/**</sonar.exclusions>
<scala.version>2.11.12</scala.version>
<scala.compat.version>2.11</scala.compat.version>
</properties>
</profile>
<profile>
<id>spark-2.4</id>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<excludes>
<exclude>**/spark_3_0/**</exclude>
<exclude>**/spark_2_1/**</exclude>
</excludes>
</configuration>
</plugin>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<configuration>
<excludes>
<exclude>**/spark_2_1/**</exclude>
<exclude>**/spark_3_0/**</exclude>
</excludes>
</configuration>
</plugin>
</plugins>
</build>
<properties>
<spark.version>2.4.0</spark.version>
<project.excludes>**/spark_3_0/**</project.excludes>
<sonar.exclusions>**/spark_3_0/**</sonar.exclusions>
<sonar.exclusions>**/spark_3_0/**, **/spark_2_1/**</sonar.exclusions>
<scala.version>2.11.12</scala.version>
<scala.compat.version>2.11</scala.compat.version>
</properties>
Expand All @@ -49,12 +105,35 @@ See file LICENSE for terms.
<activation>
<activeByDefault>true</activeByDefault>
</activation>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<excludes>
<exclude>**/spark_2_1/**</exclude>
<exclude>**/spark_2_4/**</exclude>
</excludes>
</configuration>
</plugin>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<configuration>
<excludes>
<exclude>**/spark_2_1/**</exclude>
<exclude>**/spark_2_4/**</exclude>
</excludes>
</configuration>
</plugin>
</plugins>
</build>
<properties>
<spark.version>3.0.1</spark.version>
<scala.version>2.12.10</scala.version>
<scala.compat.version>2.12</scala.compat.version>
<project.excludes>**/spark_2_4/**</project.excludes>
<sonar.exclusions>**/spark_2_4/**</sonar.exclusions>
<sonar.exclusions>**/spark_2_1/**, **/spark_2_4/**</sonar.exclusions>
</properties>
</profile>
</profiles>
Expand Down Expand Up @@ -84,9 +163,6 @@ See file LICENSE for terms.
<configuration>
<source>1.8</source>
<target>1.8</target>
<excludes>
<exclude>${project.excludes}</exclude>
</excludes>
</configuration>
</plugin>
<plugin>
Expand All @@ -95,9 +171,6 @@ See file LICENSE for terms.
<version>4.3.0</version>
<configuration>
<recompileMode>all</recompileMode>
<excludes>
<exclude>${project.excludes}</exclude>
</excludes>
<args>
<arg>-nobootcp</arg>
<arg>-Xexperimental</arg>
Expand All @@ -111,9 +184,9 @@ See file LICENSE for terms.
<executions>
<execution>
<id>compile</id>
<goals>
<goal>compile</goal>
</goals>
<goals>
<goal>compile</goal>
</goals>
<phase>compile</phase>
</execution>
<execution>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
package org.apache.spark.shuffle.ucx.reducer.compat.spark_2_1;

import org.apache.spark.network.shuffle.BlockFetchingListener;
import org.apache.spark.shuffle.ucx.UnsafeUtils;
import org.apache.spark.shuffle.ucx.memory.RegisteredMemory;
import org.apache.spark.shuffle.ucx.reducer.OnBlocksFetchCallback;
import org.apache.spark.shuffle.ucx.reducer.ReducerCallback;
import org.apache.spark.storage.ShuffleBlockId;
import org.openucx.jucx.UcxUtils;
import org.openucx.jucx.ucp.UcpEndpoint;
import org.openucx.jucx.ucp.UcpRemoteKey;
import org.openucx.jucx.ucp.UcpRequest;

import java.nio.ByteBuffer;
import java.util.Map;

/**
* Callback, called when got all offsets for blocks
*/
public class OnOffsetsFetchCallback extends ReducerCallback {
private final RegisteredMemory offsetMemory;
private final long[] dataAddresses;
private Map<Integer, UcpRemoteKey> dataRkeysCache;

public OnOffsetsFetchCallback(ShuffleBlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener,
RegisteredMemory offsetMemory, long[] dataAddresses,
Map<Integer, UcpRemoteKey> dataRkeysCache) {
super(blockIds, endpoint, listener);
this.offsetMemory = offsetMemory;
this.dataAddresses = dataAddresses;
this.dataRkeysCache = dataRkeysCache;
}

@Override
public void onSuccess(UcpRequest request) {
ByteBuffer resultOffset = offsetMemory.getBuffer();
long totalSize = 0;
int[] sizes = new int[blockIds.length];
int offsetSize = UnsafeUtils.LONG_SIZE;
for (int i = 0; i < blockIds.length; i++) {
// Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd |
long blockOffset = resultOffset.getLong(i * 2 * offsetSize);
long blockLength = resultOffset.getLong(i * 2 * offsetSize + offsetSize) - blockOffset;
assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE);
sizes[i] = (int) blockLength;
totalSize += blockLength;
dataAddresses[i] += blockOffset;
}

assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE);
mempool.put(offsetMemory);
RegisteredMemory blocksMemory = mempool.get((int) totalSize);

long offset = 0;
// Submits N fetch blocks requests
for (int i = 0; i < blockIds.length; i++) {
endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(((ShuffleBlockId)blockIds[i]).mapId()),
UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]);
offset += sizes[i];
}

// Process blocks when all fetched.
// Flush guarantees that callback would invoke when all fetch requests will completed.
endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
package org.apache.spark.shuffle.ucx.reducer.compat.spark_2_1;

import org.apache.spark.SparkEnv;
import org.apache.spark.executor.TempShuffleReadMetrics;
import org.apache.spark.network.shuffle.BlockFetchingListener;
import org.apache.spark.network.shuffle.ShuffleClient;
import org.apache.spark.shuffle.DriverMetadata;
import org.apache.spark.shuffle.UcxShuffleManager;
import org.apache.spark.shuffle.UcxWorkerWrapper;
import org.apache.spark.shuffle.ucx.UnsafeUtils;
import org.apache.spark.shuffle.ucx.memory.MemoryPool;
import org.apache.spark.shuffle.ucx.memory.RegisteredMemory;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.BlockManagerId;
import org.apache.spark.storage.ShuffleBlockId;
import org.openucx.jucx.UcxUtils;
import org.openucx.jucx.ucp.UcpEndpoint;
import org.openucx.jucx.ucp.UcpRemoteKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Option;

import java.util.Arrays;
import java.util.HashMap;

public class UcxShuffleClient extends ShuffleClient {
private final MemoryPool mempool;
private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class);
private final UcxShuffleManager ucxShuffleManager;
private final TempShuffleReadMetrics shuffleReadMetrics;
private final UcxWorkerWrapper workerWrapper;
final HashMap<Integer, UcpRemoteKey> offsetRkeysCache = new HashMap<>();
final HashMap<Integer, UcpRemoteKey> dataRkeysCache = new HashMap<>();

public UcxShuffleClient(TempShuffleReadMetrics shuffleReadMetrics,
UcxWorkerWrapper workerWrapper) {
this.ucxShuffleManager = (UcxShuffleManager) SparkEnv.get().shuffleManager();
this.mempool = ucxShuffleManager.ucxNode().getMemoryPool();
this.shuffleReadMetrics = shuffleReadMetrics;
this.workerWrapper = workerWrapper;
}

/**
* Submits n non blocking fetch offsets to get needed offsets for n blocks.
*/
private void submitFetchOffsets(UcpEndpoint endpoint, ShuffleBlockId[] blockIds,
long[] dataAddresses, RegisteredMemory offsetMemory) {
DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(blockIds[0].shuffleId());
for (int i = 0; i < blockIds.length; i++) {
ShuffleBlockId blockId = blockIds[i];

long offsetAddress = driverMetadata.offsetAddress(blockId.mapId());
dataAddresses[i] = driverMetadata.dataAddress(blockId.mapId());

offsetRkeysCache.computeIfAbsent(blockId.mapId(), mapId ->
endpoint.unpackRemoteKey(driverMetadata.offsetRkey(blockId.mapId())));

dataRkeysCache.computeIfAbsent(blockId.mapId(), mapId ->
endpoint.unpackRemoteKey(driverMetadata.dataRkey(blockId.mapId())));

endpoint.getNonBlockingImplicit(
offsetAddress + blockId.reduceId() * UnsafeUtils.LONG_SIZE,
offsetRkeysCache.get(blockId.mapId()),
UcxUtils.getAddress(offsetMemory.getBuffer()) + (i * 2L * UnsafeUtils.LONG_SIZE),
2L * UnsafeUtils.LONG_SIZE);
}
}

/**
* Reducer entry point. Fetches remote blocks, using 2 ucp_get calls.
* This method is inside ShuffleFetchIterator's for loop over hosts.
* First fetches block offset from index file, and then fetches block itself.
*/
@Override
public void fetchBlocks(String host, int port, String execId,
String[] blockIds, BlockFetchingListener listener) {
long startTime = System.currentTimeMillis();

BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty());
UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId);

long[] dataAddresses = new long[blockIds.length];

// Need to fetch 2 long offsets current block + next block to calculate exact block size.
RegisteredMemory offsetMemory = mempool.get(2 * UnsafeUtils.LONG_SIZE * blockIds.length);

ShuffleBlockId[] shuffleBlockIds = Arrays.stream(blockIds)
.map(blockId -> (ShuffleBlockId) BlockId.apply(blockId)).toArray(ShuffleBlockId[]::new);

// Submits N implicit get requests without callback
submitFetchOffsets(endpoint, shuffleBlockIds, dataAddresses, offsetMemory);

// flush guarantees that all that requests completes when callback is called.
// TODO: fix https://github.com/openucx/ucx/issues/4267 and use endpoint flush.
workerWrapper.worker().flushNonBlocking(
new OnOffsetsFetchCallback(shuffleBlockIds, endpoint, listener, offsetMemory,
dataAddresses, dataRkeysCache));
shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime);
}

@Override
public void close() {
offsetRkeysCache.values().forEach(UcpRemoteKey::close);
dataRkeysCache.values().forEach(UcpRemoteKey::close);
logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
package org.apache.spark.shuffle.compat.spark_2_1

import java.io.{File, RandomAccessFile}

import org.apache.spark.SparkEnv
import org.apache.spark.shuffle.{CommonUcxShuffleBlockResolver, CommonUcxShuffleManager, IndexShuffleBlockResolver}
import org.apache.spark.storage.ShuffleIndexBlockId

/**
* Mapper entry point for UcxShuffle plugin. Performs memory registration
* of data and index files and publish addresses to driver metadata buffer.
*/
class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager)
extends CommonUcxShuffleBlockResolver(ucxShuffleManager) {

private def getIndexFile(shuffleId: Int, mapId: Int): File = {
SparkEnv.get.blockManager
.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID))
}

/**
* Mapper commit protocol extension. Register index and data files and publish all needed
* metadata to driver.
*/
override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Int,
lengths: Array[Long], dataTmp: File): Unit = {
super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp)
val dataFile = getDataFile(shuffleId, mapId)
val dataBackFile = new RandomAccessFile(dataFile, "rw")

if (dataBackFile.length() == 0) {
dataBackFile.close()
return
}

val indexFile = getIndexFile(shuffleId, mapId)
val indexBackFile = new RandomAccessFile(indexFile, "rw")
writeIndexFileAndCommitCommon(shuffleId, mapId, lengths, dataTmp, indexBackFile, dataBackFile)
}
}
Loading

0 comments on commit c04e36b

Please sign in to comment.