forked from openucx/sparkucx
-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathCommonUcxShuffleBlockResolver.scala
More file actions
executable file
·74 lines (64 loc) · 2.5 KB
/
CommonUcxShuffleBlockResolver.scala
File metadata and controls
executable file
·74 lines (64 loc) · 2.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
/*
* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
package org.apache.spark.shuffle.ucx
import java.io.RandomAccessFile
import java.nio.ByteBuffer
import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue}
import org.apache.spark.shuffle.IndexShuffleBlockResolver
import org.apache.spark.shuffle.utils.UnsafeUtils
class FileBackedMemoryBlock(baseAddress: Long, baseSize: Long, address: Long, size: Long)
extends MemoryBlock(address, size) {
override def close(): Unit = {
UnsafeUtils.munmap(baseAddress, baseSize)
}
}
/**
* Mapper entry point for UcxShuffle plugin. Performs memory registration
* of data and index files and publish addresses to driver metadata buffer.
*/
abstract class CommonUcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager)
extends IndexShuffleBlockResolver(ucxShuffleManager.conf) {
private val openFds = new ConcurrentHashMap[ShuffleId, ConcurrentLinkedQueue[RandomAccessFile]]()
private[ucx] lazy val transport = ucxShuffleManager.awaitUcxTransport
/**
* Mapper commit protocol extension. Register index and data files and publish all needed
* metadata to driver.
*/
def writeIndexFileAndCommitCommon(shuffleId: ShuffleId, mapId: Int,
lengths: Array[Long], dataBackFile: RandomAccessFile): Unit = {
openFds.computeIfAbsent(shuffleId, (_: ShuffleId) => new ConcurrentLinkedQueue[RandomAccessFile]())
openFds.get(shuffleId).add(dataBackFile)
var offset = 0L
val channel = dataBackFile.getChannel
for ((blockLength, reduceId) <- lengths.zipWithIndex) {
if (blockLength > 0) {
val blockId = UcxShuffleBockId(shuffleId, mapId ,reduceId)
val block = new Block {
private val fileOffset = offset
override def getBlock(byteBuffer: ByteBuffer, offset: Long): Unit = {
channel.read(byteBuffer, fileOffset + offset)
}
override def getSize: Long = blockLength
}
transport.register(blockId, block)
offset += blockLength
}
}
}
def removeShuffle(shuffleId: Int): Unit = {
val fds = openFds.remove(shuffleId)
if (fds != null) {
fds.forEach(f => f.close())
}
if (ucxShuffleManager.ucxTransport != null) {
ucxShuffleManager.ucxTransport.unregisterShuffle(shuffleId)
}
}
override def stop(): Unit = {
if (ucxShuffleManager.ucxTransport != null) {
ucxShuffleManager.ucxTransport.unregisterAllBlocks()
}
}
}