Skip to content

Commit ad5eccd

Browse files
committed
Fix shuffle test helper construction
1 parent ba030e2 commit ad5eccd

5 files changed

Lines changed: 14 additions & 14 deletions

File tree

tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClientSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2020-2026, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -33,7 +33,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper {
3333
def prepareBufferReceiveState(
3434
tableMeta: TableMeta,
3535
bounceBuffer: BounceBuffer): BufferReceiveState = {
36-
val ptr = PendingTransferRequest(client, tableMeta, mockHandler)
36+
val ptr = new PendingTransferRequest(client, tableMeta, mockHandler)
3737
spy(new BufferReceiveState(123L, bounceBuffer, Seq(ptr), () => {}))
3838
}
3939

@@ -42,7 +42,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper {
4242
bounceBuffer: BounceBuffer): BufferReceiveState = {
4343

4444
val ptrs = tableMetas.map { tm =>
45-
PendingTransferRequest(client, tm, mockHandler)
45+
new PendingTransferRequest(client, tm, mockHandler)
4646
}
4747

4848
spy(new BufferReceiveState(123L, bounceBuffer, ptrs, () => {}))

tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2020-2026, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -172,7 +172,7 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper {
172172
val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler])
173173
when(mockTransport.makeClient(any())).thenReturn(client)
174174
doNothing().when(client).doFetch(any(), ac.capture())
175-
val mockBuffer = RapidsShuffleHandle(mock[SpillableDeviceBufferHandle], null)
175+
val mockBuffer = new RapidsShuffleHandle(mock[SpillableDeviceBufferHandle], null)
176176
when(mockBuffer.spillable.sizeInBytes).thenReturn(123L)
177177

178178
val cb = new ColumnarBatch(Array.empty, 10)

tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServerSuite.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2020-2026, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -57,7 +57,7 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper {
5757
fillBuffer(hostBuff)
5858
deviceBuffer.copyFromHostBuffer(hostBuff)
5959
val mockMeta = RapidsShuffleTestHelper.mockTableMeta(100000)
60-
RapidsShuffleHandle(SpillableDeviceBufferHandle(deviceBuffer), mockMeta)
60+
new RapidsShuffleHandle(SpillableDeviceBufferHandle(deviceBuffer), mockMeta)
6161
}
6262
}
6363
new MockRapidsShuffleRequestHandler(mockBuffers)
@@ -208,7 +208,7 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper {
208208
withResource(new RefCountedDirectByteBuffer(bb)) { _ =>
209209
val tableMeta = MetaUtils.buildTableMeta(1, 456, bb, 100)
210210
val testHandle = SpillableDeviceBufferHandle(DeviceMemoryBuffer.allocate(456))
211-
val rapidsBuffer = RapidsShuffleHandle(testHandle, tableMeta)
211+
val rapidsBuffer = new RapidsShuffleHandle(testHandle, tableMeta)
212212
when(mockRequestHandler.getShuffleHandle(ArgumentMatchers.eq(1)))
213213
.thenReturn(rapidsBuffer)
214214

@@ -277,8 +277,8 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper {
277277
val ex = new IllegalStateException("something happened")
278278
when(mockHandleThatThrows.materialize()).thenThrow(ex)
279279

280-
val rapidsBuffer = RapidsShuffleHandle(mockHandle, tableMeta)
281-
val rapidsBufferThatThrows = RapidsShuffleHandle(mockHandleThatThrows, tableMeta)
280+
val rapidsBuffer = new RapidsShuffleHandle(mockHandle, tableMeta)
281+
val rapidsBufferThatThrows = new RapidsShuffleHandle(mockHandleThatThrows, tableMeta)
282282

283283
when(mockRequestHandler.getShuffleHandle(ArgumentMatchers.eq(1)))
284284
.thenReturn(rapidsBuffer)
@@ -359,7 +359,7 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper {
359359
val tableMeta = MetaUtils.buildTableMeta(tableId, 456, bb, 100)
360360
val rapidsBuffer = if (error) {
361361
val mockHandle = mock[SpillableDeviceBufferHandle]
362-
val rapidsBuffer = RapidsShuffleHandle(mockHandle, tableMeta)
362+
val rapidsBuffer = new RapidsShuffleHandle(mockHandle, tableMeta)
363363
when(mockHandle.sizeInBytes).thenReturn(tableMeta.bufferMeta().size())
364364
// mock an error with the copy
365365
when(rapidsBuffer.spillable.materialize())
@@ -369,7 +369,7 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper {
369369
rapidsBuffer
370370
} else {
371371
val testHandle = spy(SpillableDeviceBufferHandle(spy(DeviceMemoryBuffer.allocate(456))))
372-
RapidsShuffleHandle(testHandle, tableMeta)
372+
new RapidsShuffleHandle(testHandle, tableMeta)
373373
}
374374
when(mockRequestHandler.getShuffleHandle(ArgumentMatchers.eq(tableId)))
375375
.thenAnswer(_ => rapidsBuffer)

tests/src/test/spark330/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ abstract class RapidsShuffleTestHelper
117117

118118
def getSendBounceBuffer(size: Long): SendBounceBuffers = {
119119
val db = DeviceMemoryBuffer.allocate(size)
120-
SendBounceBuffers(new BounceBuffer(db) {
120+
new SendBounceBuffers(new BounceBuffer(db) {
121121
override def free(bb: BounceBuffer): Unit = {
122122
db.close()
123123
}

tests/src/test/spark340/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ abstract class RapidsShuffleTestHelper
131131

132132
def getSendBounceBuffer(size: Long): SendBounceBuffers = {
133133
val db = DeviceMemoryBuffer.allocate(size)
134-
SendBounceBuffers(new BounceBuffer(db) {
134+
new SendBounceBuffers(new BounceBuffer(db) {
135135
override def free(bb: BounceBuffer): Unit = {
136136
db.close()
137137
}

0 commit comments

Comments
 (0)