-
Notifications
You must be signed in to change notification settings - Fork 31
Support CrossValidator on connect plugin #907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: branch-25.04
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| #! /bin/bash | ||
|
|
||
| pushd src/main/ | ||
| buf generate --debug | ||
| popd |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| # | ||
| # Licensed to the Apache Software Foundation (ASF) under one or more | ||
| # contributor license agreements. See the NOTICE file distributed with | ||
| # this work for additional information regarding copyright ownership. | ||
| # The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| # (the "License"); you may not use this file except in compliance with | ||
| # the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
| version: v1 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How is this version used and does it have to align with the jvm side? |
||
| plugins: | ||
| # Building the Python build and building the mypy interfaces. | ||
| - plugin: buf.build/protocolbuffers/python:v28.3 | ||
| out: ../../../python/src/spark_rapids_ml/proto | ||
| - plugin: buf.build/grpc/python:v1.67.0 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do these versions (v28.3 and v1.67.0) need to match exactly the java side in the pom.xml file ? |
||
| out: ../../../python/src/spark_rapids_ml/proto | ||
| - name: mypy | ||
| out: ../../../python/src/spark_rapids_ml/proto | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| # | ||
| # Licensed to the Apache Software Foundation (ASF) under one or more | ||
| # contributor license agreements. See the NOTICE file distributed with | ||
| # this work for additional information regarding copyright ownership. | ||
| # The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| # (the "License"); you may not use this file except in compliance with | ||
| # the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
| version: v1 | ||
| directories: | ||
| - protobuf |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| syntax = 'proto3'; | ||
|
|
||
| // Must set the package into spark.connect if importing spark/connect/relations.proto | ||
| // package spark.connect; | ||
| package com.nvidia.rapids.ml.proto; | ||
|
|
||
| option java_multiple_files = true; | ||
| option java_package = "com.nvidia.rapids.ml.proto"; | ||
| option java_generate_equals_and_hash = true; | ||
|
|
||
| message TuningRelation { | ||
| oneof relation_type { | ||
| CrossValidatorRelation cv = 1; | ||
| } | ||
| } | ||
|
|
||
| message CrossValidatorRelation { | ||
| // (Required) Unique id of the ML operator | ||
| string uid = 1; | ||
| // (Required) the estimator info | ||
| MlOperator estimator = 2; | ||
| // (Required) the estimator parameter maps info | ||
| string estimator_param_maps = 3; | ||
| // (Required) the evaluator info | ||
| MlOperator evaluator = 4; | ||
| // parameters of CrossValidator | ||
| optional string params = 5; | ||
| // Can't use Relation directly due to shading issue in spark connect | ||
| optional bytes dataset = 6; | ||
| } | ||
|
|
||
| // MLOperator represents the ML operators like (Estimator, Transformer or Evaluator) | ||
| message MlOperator { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this be pulled in from Spark? |
||
| // (Required) The qualified name of the ML operator. | ||
| string name = 1; | ||
|
|
||
| // (Required) Unique id of the ML operator | ||
| string uid = 2; | ||
|
|
||
| // (Required) Represents what the ML operator is | ||
| OperatorType type = 3; | ||
|
|
||
| // (Optional) parameters of the operator which is a json string | ||
| optional string params = 4; | ||
|
|
||
| enum OperatorType { | ||
| OPERATOR_TYPE_UNSPECIFIED = 0; | ||
| // ML estimator | ||
| OPERATOR_TYPE_ESTIMATOR = 1; | ||
| // ML transformer (non-model) | ||
| OPERATOR_TYPE_TRANSFORMER = 2; | ||
| // ML evaluator | ||
| OPERATOR_TYPE_EVALUATOR = 3; | ||
| // ML model | ||
| OPERATOR_TYPE_MODEL = 4; | ||
| } | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| /** | ||
| * Copyright (c) 2025, NVIDIA CORPORATION. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package com.nvidia.rapids.ml | ||
|
|
||
| import org.apache.spark.ml.Estimator | ||
| import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator} | ||
| import org.apache.spark.ml.rapids.{Fit, PythonEstimatorRunner, RapidsUtils, TrainedModel} | ||
| import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel} | ||
| import org.apache.spark.ml.util.Identifiable | ||
| import org.apache.spark.sql.Dataset | ||
| import org.apache.spark.sql.connect.ml.rapids.RapidsConnectUtils | ||
|
|
||
| class RapidsCrossValidator(override val uid: String) extends CrossValidator with RapidsEstimator { | ||
|
|
||
| def this() = this(Identifiable.randomUID("cv")) | ||
|
|
||
| override def fit(dataset: Dataset[_]): CrossValidatorModel = { | ||
| val trainedModel = trainOnPython(dataset) | ||
|
|
||
| val bestModel = RapidsUtils.createModel(getName(getEstimator.getClass.getName), | ||
| getEstimator.uid, getEstimator, trainedModel) | ||
| copyValues(RapidsUtils.createCrossValidatorModel(this.uid, bestModel)) | ||
| } | ||
|
|
||
| private def getName(name: String): String = { | ||
| RapidsUtils.transform(name).getOrElse(name) | ||
| } | ||
|
|
||
| /** | ||
| * The estimator name | ||
| * | ||
| * @return | ||
| */ | ||
| override def name: String = "CrossValidator" | ||
|
|
||
| override def trainOnPython(dataset: Dataset[_]): TrainedModel = { | ||
| logger.info(s"Training $name ...") | ||
|
|
||
| val estimatorName = getName(getEstimator.getClass.getName) | ||
| // TODO estimator could be a PipeLine which contains multiple stages. | ||
| val cvParams = RapidsUtils.getJson(Map( | ||
| "estimator" -> RapidsUtils.getUserDefinedParams(getEstimator, | ||
| extra = Map( | ||
| "estimator_name" -> estimatorName, | ||
| "uid" -> getEstimator.uid)), | ||
| "evaluator" -> RapidsUtils.getUserDefinedParams(getEvaluator, | ||
| extra = Map( | ||
| "evaluator_name" -> getName(getEvaluator.getClass.getName), | ||
| "uid" -> getEvaluator.uid)), | ||
| "estimatorParaMaps" -> RapidsUtils.getEstimatorParamMapsJson(getEstimatorParamMaps), | ||
| "cv" -> RapidsUtils.getUserDefinedParams(this, | ||
| List("estimator", "evaluator", "estimatorParamMaps")) | ||
| )) | ||
| val runner = new PythonEstimatorRunner( | ||
| Fit(name, cvParams), | ||
| dataset.toDF) | ||
|
|
||
| val trainedModel = Arm.withResource(runner) { _ => | ||
| runner.runInPython(useDaemon = false) | ||
| } | ||
|
|
||
| logger.info(s"Finished $name training.") | ||
| trainedModel | ||
| } | ||
| } | ||
|
|
||
| object RapidsCrossValidator { | ||
|
|
||
| def fit(cvProto: proto.CrossValidatorRelation, dataset: Dataset[_]): CrossValidatorModel = { | ||
|
|
||
| val estProto = cvProto.getEstimator | ||
| var estimator: Option[Estimator[_]] = None | ||
| if (estProto.getName == "LogisticRegression") { | ||
| estimator = Some(new RapidsLogisticRegression(uid = estProto.getUid)) | ||
| val estParams = estProto.getParams | ||
| RapidsUtils.setParams(estimator.get, estParams) | ||
|
|
||
| } | ||
| val evalProto = cvProto.getEvaluator | ||
| var evaluator: Option[Evaluator] = None | ||
| if (evalProto.getName == "MulticlassClassificationEvaluator") { | ||
| evaluator = Some(new MulticlassClassificationEvaluator(uid = evalProto.getUid)) | ||
| val evalParams = evalProto.getParams | ||
| RapidsUtils.setParams(evaluator.get, evalParams) | ||
| } | ||
|
|
||
| val cv = new RapidsCrossValidator(uid = cvProto.getUid) | ||
| RapidsUtils.setParams(cv, cvProto.getParams) | ||
|
|
||
| cv.setEstimator(estimator.get).setEvaluator(evaluator.get) | ||
| val paramGrid = RapidsUtils.extractParamMap(cv, cvProto.getEstimatorParamMaps) | ||
| cv.setEstimatorParamMaps(paramGrid) | ||
| cv.fit(dataset) | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| /** | ||
| * Copyright (c) 2025, NVIDIA CORPORATION. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package com.nvidia.rapids.ml | ||
|
|
||
| import org.apache.commons.logging.LogFactory | ||
| import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan | ||
| import org.apache.spark.sql.Row | ||
| import org.apache.spark.sql.connect.planner.SparkConnectPlanner | ||
| import org.apache.spark.sql.connect.plugin.RelationPlugin | ||
| import org.apache.spark.connect.{proto => sparkProto} | ||
| import org.apache.spark.sql.connect.ml.rapids.RapidsConnectUtils | ||
| import org.apache.spark.sql.types.{StringType, StructField, StructType} | ||
|
|
||
| import java.util.Optional | ||
| import scala.jdk.CollectionConverters.SeqHasAsJava | ||
|
|
||
| class RapidsRelationPlugin extends RelationPlugin { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this need to be registered with spark connect via a spark conf? Would be helpful to add usage info? |
||
| protected val logger = LogFactory.getLog("Spark-Rapids-ML RapidsRelationPlugin") | ||
|
|
||
| override def transform(bytes: Array[Byte], sparkConnectPlanner: SparkConnectPlanner): Optional[LogicalPlan] = { | ||
| logger.info("In RapidsRelationPlugin") | ||
|
|
||
| val rel = com.google.protobuf.Any.parseFrom(bytes) | ||
| val sparkSession = sparkConnectPlanner.session | ||
|
|
||
| // CrossValidation | ||
| if (rel.is(classOf[proto.CrossValidatorRelation])) { | ||
| val cvProto = rel.unpack(classOf[proto.CrossValidatorRelation]) | ||
| val dataLogicalPlan = sparkProto.Plan.parseFrom(cvProto.getDataset.toByteArray) | ||
| val dataset = RapidsConnectUtils.ofRows(sparkSession, | ||
| sparkConnectPlanner.transformRelation(dataLogicalPlan.getRoot)) | ||
| val cvModel = RapidsCrossValidator.fit(cvProto, dataset) | ||
| val modelId = RapidsConnectUtils.cache(sparkConnectPlanner.sessionHolder, cvModel.bestModel) | ||
| val resultDf = sparkSession.createDataFrame( | ||
| List(Row(s"$modelId")).asJava, | ||
| StructType(Seq(StructField("best_model_id", StringType)))) | ||
| Optional.of(RapidsConnectUtils.getLogicalPlan(resultDf)) | ||
| } else { | ||
| Optional.empty() | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What runs this script?