forked from openucx/sparkucx
-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathUcxFetchCallBack.scala
More file actions
50 lines (41 loc) · 1.5 KB
/
UcxFetchCallBack.scala
File metadata and controls
50 lines (41 loc) · 1.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
package org.apache.spark.shuffle.ucx
import java.nio.ByteBuffer
import org.apache.spark.network.util.TransportConf
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager}
import org.apache.spark.shuffle.utils.UnsafeUtils
class UcxFetchCallBack(
blockId: String, listener: BlockFetchingListener)
extends OperationCallback {
override def onComplete(result: OperationResult): Unit = {
val memBlock = result.getData
val buffer = UnsafeUtils.getByteBufferView(memBlock.address,
memBlock.size.toInt)
listener.onBlockFetchSuccess(blockId, new NioManagedBuffer(buffer) {
override def release: ManagedBuffer = {
memBlock.close()
this
}
})
}
}
class UcxDownloadCallBack(
blockId: String, listener: BlockFetchingListener,
downloadFileManager: DownloadFileManager,
transportConf: TransportConf)
extends OperationCallback {
private[this] val targetFile = downloadFileManager.createTempFile(
transportConf)
private[this] val channel = targetFile.openForWriting();
override def onData(buffer: ByteBuffer): Unit = {
while (buffer.hasRemaining()) {
channel.write(buffer);
}
}
override def onComplete(result: OperationResult): Unit = {
listener.onBlockFetchSuccess(blockId, channel.closeAndRead());
if (!downloadFileManager.registerTempFileToClean(targetFile)) {
targetFile.delete();
}
}
}