@@ -61,17 +61,16 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
6161 val uuid : UUID = ObjectStoreWriter .dfToId.getOrElseUpdate(df, UUID .randomUUID())
6262
6363 def writeToRay (
64- data : Array [ Byte ] ,
64+ root : VectorSchemaRoot ,
6565 numRecords : Int ,
6666 queue : ObjectRefHolder .Queue ,
6767 ownerName : String ): RecordBatch = {
68-
69- var objectRef : ObjectRef [Array [Byte ]] = null
68+ var objectRef : ObjectRef [VectorSchemaRoot ] = null
7069 if (ownerName == " " ) {
71- objectRef = Ray .put(data )
70+ objectRef = Ray .put(root )
7271 } else {
7372 var dataOwner : PyActorHandle = Ray .getActor(ownerName).get()
74- objectRef = Ray .put(data , dataOwner)
73+ objectRef = Ray .put(root , dataOwner)
7574 }
7675
7776 // add the objectRef to the objectRefHolder to avoid reference GC
@@ -111,21 +110,15 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
111110 val root = VectorSchemaRoot .create(arrowSchema, allocator)
112111 val results = new ArrayBuffer [RecordBatch ]()
113112
114- val byteOut = new ByteArrayOutputStream ()
115113 val arrowWriter = ArrowWriter .create(root)
116114 var numRecords : Int = 0
117115
118116 Utils .tryWithSafeFinally {
119117 while (batchIter.hasNext) {
120118 // reset the state
121119 numRecords = 0
122- byteOut.reset()
123120 arrowWriter.reset()
124121
125- // write out the schema meta data
126- val writer = new ArrowStreamWriter (root, null , byteOut)
127- writer.start()
128-
129122 // get the next record batch
130123 val nextBatch = batchIter.next()
131124
@@ -136,19 +129,11 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
136129
137130 // set the write record count
138131 arrowWriter.finish()
139- // write out the record batch to the underlying out
140- writer.writeBatch()
141-
142- // get the wrote ByteArray and save to Ray ObjectStore
143- val byteArray = byteOut.toByteArray
144- results += writeToRay(byteArray, numRecords, queue, ownerName)
145- // end writes footer to the output stream and doesn't clean any resources.
146- // It could throw exception if the output stream is closed, so it should be
147- // in the try block.
148- writer.end()
132+
133+ // write and schema root directly and save to Ray ObjectStore
134+ results += writeToRay(root, numRecords, queue, ownerName)
149135 }
150136 arrowWriter.reset()
151- byteOut.close()
152137 } {
153138 // If we close root and allocator in TaskCompletionListener, there could be a race
154139 // condition where the writer thread keeps writing to the VectorSchemaRoot while
@@ -173,7 +158,7 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
173158 /**
174159 * For test.
175160 */
176- def getRandomRef (): List [Array [ Byte ] ] = {
161+ def getRandomRef (): List [VectorSchemaRoot ] = {
177162
178163 df.queryExecution.toRdd.mapPartitions { _ =>
179164 Iterator (ObjectRefHolder .getRandom(uuid))
@@ -233,7 +218,7 @@ object ObjectStoreWriter {
233218 var executorIds = df.sqlContext.sparkContext.getExecutorIds.toArray
234219 val numExecutors = executorIds.length
235220 val appMasterHandle = Ray .getActor(RayAppMaster .ACTOR_NAME )
236- .get.asInstanceOf [ActorHandle [RayAppMaster ]]
221+ .get.asInstanceOf [ActorHandle [RayAppMaster ]]
237222 val restartedExecutors = RayAppMasterUtils .getRestartedExecutors(appMasterHandle)
238223 // Check if there is any restarted executors
239224 if (! restartedExecutors.isEmpty) {
@@ -251,8 +236,8 @@ object ObjectStoreWriter {
251236 val refs = new Array [ObjectRef [Array [Byte ]]](numPartitions)
252237 val handles = executorIds.map {id =>
253238 Ray .getActor(" raydp-executor-" + id)
254- .get
255- .asInstanceOf [ActorHandle [RayDPExecutor ]]
239+ .get
240+ .asInstanceOf [ActorHandle [RayDPExecutor ]]
256241 }
257242 val handlesMap = (executorIds zip handles).toMap
258243 val locations = RayExecutorUtils .getBlockLocations(
@@ -261,18 +246,15 @@ object ObjectStoreWriter {
261246 // TODO use getPreferredLocs, but we don't have a host ip to actor table now
262247 refs(i) = RayExecutorUtils .getRDDPartition(
263248 handlesMap(locations(i)), rdd.id, i, schema, driverAgentUrl)
264- queue.add(refs(i))
265- }
266- for (i <- 0 until numPartitions) {
249+ queue.add(RayDPUtils .readBinary(refs(i).get(), classOf [VectorSchemaRoot ]))
267250 results(i) = RayDPUtils .convert(refs(i)).getId.getBytes
268251 }
269252 results
270253 }
271-
272254}
273255
274256object ObjectRefHolder {
275- type Queue = ConcurrentLinkedQueue [ObjectRef [Array [ Byte ] ]]
257+ type Queue = ConcurrentLinkedQueue [ObjectRef [VectorSchemaRoot ]]
276258 private val dfToQueue = new ConcurrentHashMap [UUID , Queue ]()
277259
278260 def getQueue (df : UUID ): Queue = {
@@ -297,7 +279,7 @@ object ObjectRefHolder {
297279 queue.size()
298280 }
299281
300- def getRandom (df : UUID ): Array [ Byte ] = {
282+ def getRandom (df : UUID ): VectorSchemaRoot = {
301283 val queue = checkQueueExists(df)
302284 val ref = RayDPUtils .convert(queue.peek())
303285 ref.get()
0 commit comments