Skip to content

Commit 2933398

Browse files
committed
follow ups
Signed-off-by: Rishi Chandra <rishic@nvidia.com>
1 parent 4324a25 commit 2933398

6 files changed

Lines changed: 41 additions & 42 deletions

File tree

skills/udf-gen-test/templates/scala/src/main/scala/com/udf/bench/MicroBenchRunner.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,20 +85,19 @@ object MicroBenchRunner {
8585
def executeCpu(data: Array[AnyRef], numRows: Int): Unit = ???
8686

8787
/**
88-
* TODO: Execute the GPU UDF via evaluateColumnar.
88+
* TODO: Execute the GPU UDF via evaluateColumnar and close its result.
8989
*
9090
* Example:
9191
* {{{
9292
* val udf = new com.udf.PlaceholderRapidsUDFName()
93-
* udf.evaluateColumnar(numRows,
94-
* table.getColumn(0), table.getColumn(1))
93+
* withResource(udf.evaluateColumnar(numRows,
94+
* table.getColumn(0), table.getColumn(1))) { _ => }
9595
* }}}
9696
*
9797
* @param table the dataset loaded on GPU
9898
* @param numRows number of rows in the dataset
99-
* @return result ColumnVector (NOTE: caller must close)
10099
*/
101-
def executeGpu(table: Table, numRows: Int): ColumnVector = ???
100+
def executeGpu(table: Table, numRows: Int): Unit = ???
102101

103102
def main(args: Array[String]): Unit = {
104103
val parsed = parseArgs(args)
@@ -165,7 +164,7 @@ object MicroBenchRunner {
165164
if (runGpu) {
166165
try {
167166
val times = runBenchmark(warmup, measured, profile = profile) {
168-
withResource(executeGpu(table, numRows)) { _ => }
167+
executeGpu(table, numRows)
169168
}
170169
val medianMs = times(times.length / 2) / 1e6
171170
val minMs = times(0) / 1e6

skills/udf-gen-test/templates/scala/src/main/scala/com/udf/bench/SparkBenchRunner.scala

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -112,31 +112,26 @@ object SparkBenchRunner {
112112
val resultDir = new File(path).getParentFile
113113
if (resultDir != null) resultDir.mkdirs()
114114

115-
try {
116-
import java.util.{LinkedHashMap => JLinkedHashMap, Arrays => JArrays}
117-
val report = new JLinkedHashMap[String, AnyRef]()
118-
report.put("mode", mode)
119-
report.put("data_path", dataPath)
120-
report.put("status", status)
121-
report.put("e2e_runtime", java.lang.Double.valueOf(elapsed))
122-
report.put("cli_args", JArrays.asList(cliArgs: _*))
123-
errorMessage.foreach { msg =>
124-
val error = new JLinkedHashMap[String, String]()
125-
error.put("error_message", msg)
126-
errorLogFile.foreach(f => error.put("error_log_file", f))
127-
report.put("error", error)
128-
}
129-
130-
val mapper = new ObjectMapper()
131-
mapper.enable(SerializationFeature.INDENT_OUTPUT)
132-
val printer = new DefaultPrettyPrinter()
133-
printer.indentArraysWith(DefaultIndenter.SYSTEM_LINEFEED_INSTANCE)
134-
mapper.writer(printer).writeValue(new File(path), report)
135-
System.err.println(s"Report written to: $path")
136-
} catch {
137-
case e: Exception =>
138-
System.err.println(s"Failed to write report: ${e.getMessage}")
115+
import java.util.{LinkedHashMap => JLinkedHashMap, Arrays => JArrays}
116+
val report = new JLinkedHashMap[String, AnyRef]()
117+
report.put("mode", mode)
118+
report.put("data_path", dataPath)
119+
report.put("status", status)
120+
report.put("e2e_runtime", java.lang.Double.valueOf(elapsed))
121+
report.put("cli_args", JArrays.asList(cliArgs: _*))
122+
errorMessage.foreach { msg =>
123+
val error = new JLinkedHashMap[String, String]()
124+
error.put("error_message", msg)
125+
errorLogFile.foreach(f => error.put("error_log_file", f))
126+
report.put("error", error)
139127
}
128+
129+
val mapper = new ObjectMapper()
130+
mapper.enable(SerializationFeature.INDENT_OUTPUT)
131+
val printer = new DefaultPrettyPrinter()
132+
printer.indentArraysWith(DefaultIndenter.SYSTEM_LINEFEED_INSTANCE)
133+
mapper.writer(printer).writeValue(new File(path), report)
134+
System.err.println(s"Report written to: $path")
140135
}
141136

142137
/** Write an exception to an error log file. */

skills/udf-gen-test/templates/scala/src/test/scala/com/udf/CudfComparisonTest.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,18 @@ class CudfComparisonTest extends AnyFunSuite with BeforeAndAfterAll {
3232
def registerRapidsUDF(spark: SparkSession, udfName: String): Unit = ???
3333

3434
test("UDF vs RapidsUDF") {
35-
val testDF = UnitTest.createTestData(spark).repartition(1)
35+
// Repartition down to 2 tasks to ensure we exercise multi-row columns.
36+
val testDF = UnitTest.createTestData(spark).repartition(2)
3637

3738
// Run CPU UDF
3839
UnitTest.registerUDF(spark, "placeholder_udf_name")
3940
val cpuResultDF = UnitTest.executeUDF(spark, "placeholder_udf_name", testDF)
40-
UnitTest.verifyUDFResults(cpuResultDF, testDF)
41+
UnitTest.assertUDFResults(cpuResultDF, testDF)
4142

4243
// Run RapidsUDF
4344
registerRapidsUDF(spark, "placeholder_rapids_udf_name")
4445
val gpuResultDF = UnitTest.executeUDF(spark, "placeholder_rapids_udf_name", testDF)
45-
UnitTest.verifyUDFResults(gpuResultDF, testDF)
46+
UnitTest.assertUDFResults(gpuResultDF, testDF)
4647

4748
// Compare
4849
TestUtils.assertDataFrameEquals(actual = gpuResultDF, expected = cpuResultDF)

skills/udf-gen-test/templates/scala/src/test/scala/com/udf/SqlComparisonTest.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,20 @@ class SqlComparisonTest extends AnyFunSuite with BeforeAndAfterAll {
3030
}
3131

3232
test("UDF vs SQL expression") {
33-
val testDF = UnitTest.createTestData(spark).repartition(1)
33+
// Repartition down to 2 tasks to ensure we exercise multi-row columns.
34+
val testDF = UnitTest.createTestData(spark).repartition(2)
3435

3536
// Run CPU UDF
3637
UnitTest.registerUDF(spark, "placeholder_udf_name")
3738
val udfResultDF = UnitTest.executeUDF(spark, "placeholder_udf_name", testDF)
38-
UnitTest.verifyUDFResults(udfResultDF, testDF)
39+
UnitTest.assertUDFResults(udfResultDF, testDF)
3940

4041
// Read and execute SQL expression
4142
testDF.createOrReplaceTempView("test_table")
4243
val sqlSource = scala.io.Source.fromFile("src/main/resources/placeholder_udf_name.sql")
4344
val sqlContent = try sqlSource.mkString finally sqlSource.close()
4445
val sqlResultDF = spark.sql(sqlContent)
45-
UnitTest.verifyUDFResults(sqlResultDF, testDF)
46+
UnitTest.assertUDFResults(sqlResultDF, testDF)
4647

4748
// Compare results
4849
TestUtils.assertDataFrameEquals(actual = sqlResultDF, expected = udfResultDF)

skills/udf-gen-test/templates/scala/src/test/scala/com/udf/UnitTest.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ import org.scalatest.BeforeAndAfterAll
1313

1414
object UnitTest extends Assertions {
1515
/**
16-
* TODO: Create a test DataFrame with diverse test cases including edge cases.
16+
* TODO: Create a test DataFrame with diverse test cases including edge cases
17+
* (at least 10+ cases).
1718
*
1819
* Example:
1920
* {{{
@@ -25,6 +26,7 @@ object UnitTest extends Assertions {
2526
* Row(1, 800),
2627
* Row(2, 550),
2728
* Row(3, null)
29+
* // ...
2830
* )
2931
* spark.createDataFrame(spark.sparkContext.parallelize(testData), schema)
3032
* }}}
@@ -55,7 +57,7 @@ object UnitTest extends Assertions {
5557
def executeUDF(spark: SparkSession, udfName: String, testDF: DataFrame): DataFrame = ???
5658

5759
/**
58-
* TODO: Verify UDF results using assert statements.
60+
* TODO: Assert the UDF results match expectations.
5961
*
6062
* Example:
6163
* {{{
@@ -65,7 +67,7 @@ object UnitTest extends Assertions {
6567
* assert(results(2).getAs[String]("risk_level") === "UNKNOWN")
6668
* }}}
6769
*/
68-
def verifyUDFResults(resultDF: DataFrame, testDF: DataFrame): Unit = ???
70+
def assertUDFResults(resultDF: DataFrame, testDF: DataFrame): Unit = ???
6971
}
7072

7173
class UnitTest extends AnyFunSuite with BeforeAndAfterAll {
@@ -89,11 +91,12 @@ class UnitTest extends AnyFunSuite with BeforeAndAfterAll {
8991
}
9092

9193
test("UDF produces correct results") {
92-
val testDF = UnitTest.createTestData(spark).repartition(1)
94+
// Repartition down to 2 tasks to ensure we exercise multi-row columns.
95+
val testDF = UnitTest.createTestData(spark).repartition(2)
9396

9497
UnitTest.registerUDF(spark, "placeholder_udf_name")
9598
val resultDF = UnitTest.executeUDF(spark, "placeholder_udf_name", testDF)
9699

97-
UnitTest.verifyUDFResults(resultDF, testDF)
100+
UnitTest.assertUDFResults(resultDF, testDF)
98101
}
99102
}

skills/udf-judge-conversion/SKILL.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Check that:
3939
- Assertions verify schema, row count, deterministic ordering, output values, null propagation, and exception/default behavior where applicable.
4040
- The test exercises visible CPU UDF branches. Coverage reports should support this when available.
4141
- Assertions reflect the CPU UDF's actual behavior and do not merely assert weak properties such as non-null output.
42-
- Extra unit tests outside the shared `verifyUDFResults` path are mirrored in the comparison test and run against both CPU and GPU/SQL paths.
42+
- Extra unit tests outside the shared `assertUDFResults` path are mirrored in the comparison test and run against both CPU and GPU/SQL paths.
4343

4444
## Comparison Test Checks
4545

0 commit comments

Comments
 (0)