Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions jvm/buf-gen.sh
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What runs this script?

Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#! /bin/bash

pushd src/main/
buf generate --debug
popd
38 changes: 38 additions & 0 deletions jvm/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
<scala.version>2.13.14</scala.version>
<scala.binary.version>2.13</scala.binary.version>
<spark.version>4.0.1-SNAPSHOT</spark.version>
<protobuf.version>4.29.3</protobuf.version>
<protoc-jar-maven-plugin.version>3.11.4</protoc-jar-maven-plugin.version>
<io.grpc.version>1.67.1</io.grpc.version>
<!-- Extra JVM arguments for tests (required for Java 17 module system) -->
<extraJavaTestArgs>
-XX:+IgnoreUnrecognizedVMOptions
Expand Down Expand Up @@ -81,6 +84,19 @@
<version>3.2.19</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>${protobuf.version}</version>
</dependency>

<!-- <dependency>-->
<!-- <groupId>io.grpc</groupId>-->
<!-- <artifactId>grpc-stub</artifactId>-->
<!-- <version>${io.grpc.version}</version>-->
<!-- </dependency>-->

</dependencies>

<build>
Expand All @@ -89,6 +105,28 @@

<plugins>

<plugin>
<groupId>com.github.os72</groupId>
<artifactId>protoc-jar-maven-plugin</artifactId>
<version>${protoc-jar-maven-plugin.version}</version>
<executions>
<execution>
<phase>generate-sources</phase>
<goals>
<goal>run</goal>
</goals>
<configuration>
<protocArtifact>com.google.protobuf:protoc:${protobuf.version}</protocArtifact>
<protocVersion>${protobuf.version}</protocVersion>
<inputDirectories>
<include>src/main/protobuf</include>
</inputDirectories>
<includeMavenTypes>direct</includeMavenTypes>
</configuration>
</execution>
</executions>
</plugin>

<!-- Scala Maven Plugin (adds Scala sources before compile phase) -->
<plugin>
<groupId>net.alchim31.maven</groupId>
Expand Down
26 changes: 26 additions & 0 deletions jvm/src/main/buf.gen.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
version: v1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this version used and does it have to align with the jvm side?

plugins:
# Building the Python build and building the mypy interfaces.
- plugin: buf.build/protocolbuffers/python:v28.3
out: ../../../python/src/spark_rapids_ml/proto
- plugin: buf.build/grpc/python:v1.67.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these versions (v28.3 and v1.67.0) need to match exactly the java side in the pom.xml file ?

out: ../../../python/src/spark_rapids_ml/proto
- name: mypy
out: ../../../python/src/spark_rapids_ml/proto

19 changes: 19 additions & 0 deletions jvm/src/main/buf.work.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
version: v1
directories:
- protobuf
58 changes: 58 additions & 0 deletions jvm/src/main/protobuf/relations.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
syntax = 'proto3';

// Must set the package into spark.connect if importing spark/connect/relations.proto
// package spark.connect;
package com.nvidia.rapids.ml.proto;

option java_multiple_files = true;
option java_package = "com.nvidia.rapids.ml.proto";
option java_generate_equals_and_hash = true;

message TuningRelation {
oneof relation_type {
CrossValidatorRelation cv = 1;
}
}

message CrossValidatorRelation {
// (Required) Unique id of the ML operator
string uid = 1;
// (Required) the estimator info
MlOperator estimator = 2;
// (Required) the estimator parameter maps info
string estimator_param_maps = 3;
// (Required) the evaluator info
MlOperator evaluator = 4;
// parameters of CrossValidator
optional string params = 5;
// Can't use Relation directly due to shading issue in spark connect
optional bytes dataset = 6;
}

// MLOperator represents the ML operators like (Estimator, Transformer or Evaluator)
message MlOperator {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this be pulled in from Spark?

// (Required) The qualified name of the ML operator.
string name = 1;

// (Required) Unique id of the ML operator
string uid = 2;

// (Required) Represents what the ML operator is
OperatorType type = 3;

// (Optional) parameters of the operator which is a json string
optional string params = 4;

enum OperatorType {
OPERATOR_TYPE_UNSPECIFIED = 0;
// ML estimator
OPERATOR_TYPE_ESTIMATOR = 1;
// ML transformer (non-model)
OPERATOR_TYPE_TRANSFORMER = 2;
// ML evaluator
OPERATOR_TYPE_EVALUATOR = 3;
// ML model
OPERATOR_TYPE_MODEL = 4;
}

}
10 changes: 4 additions & 6 deletions jvm/src/main/scala/com/nvidia/rapids/ml/Plugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package com.nvidia.rapids.ml

import org.apache.spark.ml.rapids.RapidsUtils
import org.apache.spark.sql.connect.plugin.MLBackendPlugin

import java.util.Optional
Expand All @@ -26,12 +27,9 @@ import java.util.Optional
class Plugin extends MLBackendPlugin {

override def transform(mlName: String): Optional[String] = {
mlName match {
case "org.apache.spark.ml.classification.LogisticRegression" =>
Optional.of("com.nvidia.rapids.ml.RapidsLogisticRegression")
case "org.apache.spark.ml.classification.LogisticRegressionModel" =>
Optional.of("org.apache.spark.ml.rapids.RapidsLogisticRegressionModel")
case _ => Optional.empty()
RapidsUtils.transform(mlName) match {
case Some(v) => Optional.of(v)
case None => Optional.empty()
}
}
}
109 changes: 109 additions & 0 deletions jvm/src/main/scala/com/nvidia/rapids/ml/RapidsCrossValidator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/**
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.rapids.ml

import org.apache.spark.ml.Estimator
import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator}
import org.apache.spark.ml.rapids.{Fit, PythonEstimatorRunner, RapidsUtils, TrainedModel}
import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.connect.ml.rapids.RapidsConnectUtils

class RapidsCrossValidator(override val uid: String) extends CrossValidator with RapidsEstimator {

def this() = this(Identifiable.randomUID("cv"))

override def fit(dataset: Dataset[_]): CrossValidatorModel = {
val trainedModel = trainOnPython(dataset)

val bestModel = RapidsUtils.createModel(getName(getEstimator.getClass.getName),
getEstimator.uid, getEstimator, trainedModel)
copyValues(RapidsUtils.createCrossValidatorModel(this.uid, bestModel))
}

private def getName(name: String): String = {
RapidsUtils.transform(name).getOrElse(name)
}

/**
* The estimator name
*
* @return
*/
override def name: String = "CrossValidator"

override def trainOnPython(dataset: Dataset[_]): TrainedModel = {
logger.info(s"Training $name ...")

val estimatorName = getName(getEstimator.getClass.getName)
// TODO estimator could be a PipeLine which contains multiple stages.
val cvParams = RapidsUtils.getJson(Map(
"estimator" -> RapidsUtils.getUserDefinedParams(getEstimator,
extra = Map(
"estimator_name" -> estimatorName,
"uid" -> getEstimator.uid)),
"evaluator" -> RapidsUtils.getUserDefinedParams(getEvaluator,
extra = Map(
"evaluator_name" -> getName(getEvaluator.getClass.getName),
"uid" -> getEvaluator.uid)),
"estimatorParaMaps" -> RapidsUtils.getEstimatorParamMapsJson(getEstimatorParamMaps),
"cv" -> RapidsUtils.getUserDefinedParams(this,
List("estimator", "evaluator", "estimatorParamMaps"))
))
val runner = new PythonEstimatorRunner(
Fit(name, cvParams),
dataset.toDF)

val trainedModel = Arm.withResource(runner) { _ =>
runner.runInPython(useDaemon = false)
}

logger.info(s"Finished $name training.")
trainedModel
}
}

object RapidsCrossValidator {

def fit(cvProto: proto.CrossValidatorRelation, dataset: Dataset[_]): CrossValidatorModel = {

val estProto = cvProto.getEstimator
var estimator: Option[Estimator[_]] = None
if (estProto.getName == "LogisticRegression") {
estimator = Some(new RapidsLogisticRegression(uid = estProto.getUid))
val estParams = estProto.getParams
RapidsUtils.setParams(estimator.get, estParams)

}
val evalProto = cvProto.getEvaluator
var evaluator: Option[Evaluator] = None
if (evalProto.getName == "MulticlassClassificationEvaluator") {
evaluator = Some(new MulticlassClassificationEvaluator(uid = evalProto.getUid))
val evalParams = evalProto.getParams
RapidsUtils.setParams(evaluator.get, evalParams)
}

val cv = new RapidsCrossValidator(uid = cvProto.getUid)
RapidsUtils.setParams(cv, cvProto.getParams)

cv.setEstimator(estimator.get).setEvaluator(evaluator.get)
val paramGrid = RapidsUtils.extractParamMap(cv, cvProto.getEstimatorParamMaps)
cv.setEstimatorParamMaps(paramGrid)
cv.fit(dataset)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
package com.nvidia.rapids.ml

import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.rapids.RapidsLogisticRegressionModel
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.rapids.{RapidsLogisticRegressionModel, RapidsUtils}
import org.apache.spark.sql.Dataset

/**
Expand All @@ -36,9 +36,7 @@ class RapidsLogisticRegression(override val uid: String) extends LogisticRegress

override def train(dataset: Dataset[_]): RapidsLogisticRegressionModel = {
val trainedModel = trainOnPython(dataset)
val cpuModel = copyValues(trainedModel.model.asInstanceOf[LogisticRegressionModel])
val isMultinomial = cpuModel.numClasses != 2
copyValues(new RapidsLogisticRegressionModel(uid, cpuModel, trainedModel.modelAttributes, isMultinomial))
RapidsUtils.createModel(name, uid, this, trainedModel).asInstanceOf[RapidsLogisticRegressionModel]
}

/**
Expand Down
56 changes: 56 additions & 0 deletions jvm/src/main/scala/com/nvidia/rapids/ml/RapidsRelationPlugin.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/**
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.rapids.ml

import org.apache.commons.logging.LogFactory
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.Row
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connect.plugin.RelationPlugin
import org.apache.spark.connect.{proto => sparkProto}
import org.apache.spark.sql.connect.ml.rapids.RapidsConnectUtils
import org.apache.spark.sql.types.{StringType, StructField, StructType}

import java.util.Optional
import scala.jdk.CollectionConverters.SeqHasAsJava

class RapidsRelationPlugin extends RelationPlugin {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be registered with spark connect via a spark conf? Would be helpful to add usage info?

protected val logger = LogFactory.getLog("Spark-Rapids-ML RapidsRelationPlugin")

override def transform(bytes: Array[Byte], sparkConnectPlanner: SparkConnectPlanner): Optional[LogicalPlan] = {
logger.info("In RapidsRelationPlugin")

val rel = com.google.protobuf.Any.parseFrom(bytes)
val sparkSession = sparkConnectPlanner.session

// CrossValidation
if (rel.is(classOf[proto.CrossValidatorRelation])) {
val cvProto = rel.unpack(classOf[proto.CrossValidatorRelation])
val dataLogicalPlan = sparkProto.Plan.parseFrom(cvProto.getDataset.toByteArray)
val dataset = RapidsConnectUtils.ofRows(sparkSession,
sparkConnectPlanner.transformRelation(dataLogicalPlan.getRoot))
val cvModel = RapidsCrossValidator.fit(cvProto, dataset)
val modelId = RapidsConnectUtils.cache(sparkConnectPlanner.sessionHolder, cvModel.bestModel)
val resultDf = sparkSession.createDataFrame(
List(Row(s"$modelId")).asJava,
StructType(Seq(StructField("best_model_id", StringType))))
Optional.of(RapidsConnectUtils.getLogicalPlan(resultDf))
} else {
Optional.empty()
}
}
}
Loading
Loading