@@ -85,7 +85,7 @@ class _NearestNeighborsCumlParams(
8585
8686 def __init__ (self ) -> None :
8787 super ().__init__ ()
88- self ._setDefault (idCol = alias . row_number )
88+ self ._setDefault (idCol = None )
8989
9090 k = Param (
9191 Params ._dummy (),
@@ -114,6 +114,16 @@ def getK(self: P) -> int:
114114 """
115115 return self .getOrDefault ("k" )
116116
117+ def _getIdColOrDefault (self ) -> str :
118+ """
119+ Gets the value of `idCol`.
120+ """
121+
122+ res = self .getIdCol ()
123+ if res is None :
124+ res = alias .row_number
125+ return res
126+
117127 def setInputCol (self : P , value : Union [str , List [str ]]) -> P :
118128 """
119129 Sets the value of :py:attr:`inputCol` or :py:attr:`inputCols`.
@@ -142,19 +152,32 @@ def _ensureIdCol(self, df: DataFrame) -> DataFrame:
142152 Ensure an id column exists in the input dataframe. Add the column if not exists.
143153 Overwritten for knn assumption on error for not setting idCol and duplicate exists.
144154 """
145- if not self .isSet ("idCol" ) and self .getIdCol () in df .columns :
146- raise ValueError (
147- f"Cannot create a default id column since a column with the default name '{ self .getIdCol ()} ' already exists."
148- + "Please specify an id column"
149- )
150155
151156 id_col_name = self .getIdCol ()
152- df_withid = (
153- df
154- if self .isSet ("idCol" )
155- else df .select (monotonically_increasing_id ().alias (id_col_name ), "*" )
156- )
157- return df_withid
157+ if id_col_name is None :
158+ if alias .row_number in df .columns :
159+ raise ValueError (
160+ f"Trying to create an id column with default name { alias .row_number } . But a column with the same name already exists."
161+ )
162+ else :
163+ get_logger (self .__class__ ).info (
164+ f"idCol not set. Spark Rapids ML will create one with default name { alias .row_number } ."
165+ )
166+ df_withid = df .select (
167+ monotonically_increasing_id ().alias (alias .row_number ), "*"
168+ )
169+ return df_withid
170+ else :
171+ if id_col_name in df .columns :
172+ return df
173+ else :
174+ get_logger (self .__class__ ).info (
175+ f"column { id_col_name } does not exists in the input dataframe. Spark Rapids ML will create the { id_col_name } column."
176+ )
177+ df_withid = df .select (
178+ monotonically_increasing_id ().alias (alias .row_number ), "*"
179+ )
180+ return df_withid
158181
159182
160183class NearestNeighbors (
@@ -179,7 +202,7 @@ class NearestNeighbors(
179202 * When the value is a string, the feature columns must be assembled into 1 column with vector or array type.
180203 * When the value is a list of strings, the feature columns must be numeric types.
181204
182- idCol: str
205+ idCol: str (default = None)
183206 the name of the column in a dataframe that uniquely identifies each vector. idCol should be set
184207 if such a column exists in the dataframe. If idCol is not set, a column with the name `unique_id`
185208 will be automatically added to the dataframe and used as unique identifier for each vector.
@@ -400,7 +423,7 @@ def exactNearestNeighborsJoin(
400423 where item_vector v1 is one of the k nearest neighbors of query_vector v2 and their distance is dist(v1, v2).
401424 """
402425
403- id_col_name = self .getIdCol ()
426+ id_col_name = self ._getIdColOrDefault ()
404427
405428 # call kneighbors then prepare return results
406429 (item_df_withid , query_df_withid , knn_df ) = self .kneighbors (query_df )
@@ -471,7 +494,9 @@ def _out_schema(self) -> Union[StructType, str]: # type: ignore
471494 return StructType (
472495 [
473496 StructField (
474- f"query_{ self .getIdCol ()} " , ArrayType (LongType (), False ), False
497+ f"query_{ self ._getIdColOrDefault ()} " ,
498+ ArrayType (LongType (), False ),
499+ False ,
475500 ),
476501 StructField (
477502 "indices" , ArrayType (ArrayType (LongType (), False ), False ), False
@@ -509,11 +534,8 @@ def _pre_process_data( # type: ignore
509534
510535 select_cols .append (col (alias .label ))
511536
512- if self .hasParam ("idCol" ) and self .isDefined ("idCol" ):
513- id_col_name = self .getOrDefault ("idCol" )
514- select_cols .append (col (id_col_name ).alias (alias .row_number ))
515- else :
516- select_cols .append (col (alias .row_number ))
537+ id_col_name = self ._getIdColOrDefault ()
538+ select_cols .append (col (id_col_name ).alias (alias .row_number ))
517539
518540 return select_cols , multi_col_names , dimension , feature_type
519541
@@ -561,8 +583,8 @@ def kneighbors(self, query_df: DataFrame) -> Tuple[DataFrame, DataFrame, DataFra
561583 pipelinedrdd = self ._call_cuml_fit_func (union_df , partially_collect = False )
562584 pipelinedrdd = pipelinedrdd .repartition (query_default_num_partitions ) # type: ignore
563585
564- query_id_col_name = f"query_{ self .getIdCol ()} "
565- id_col_type = dict (union_df .dtypes )[self .getIdCol ()]
586+ query_id_col_name = f"query_{ self ._getIdColOrDefault ()} "
587+ id_col_type = dict (union_df .dtypes )[self ._getIdColOrDefault ()]
566588 knn_rdd = pipelinedrdd .flatMap (
567589 lambda row : list (
568590 zip (row [query_id_col_name ], row ["indices" ], row ["distances" ])
@@ -584,7 +606,7 @@ def _get_cuml_fit_func(
584606 ]:
585607 label_isdata = self ._label_isdata
586608 label_isquery = self ._label_isquery
587- id_col_name = self .getIdCol ()
609+ id_col_name = self ._getIdColOrDefault ()
588610
589611 def _cuml_fit (
590612 dfs : FitInputType ,
@@ -849,7 +871,7 @@ class ApproximateNearestNeighbors(
849871 * When the value is a string, the feature columns must be assembled into 1 column with vector or array type.
850872 * When the value is a list of strings, the feature columns must be numeric types.
851873
852- idCol: str
874+ idCol: str (default = None)
853875 the name of the column in a dataframe that uniquely identifies each vector. idCol should be set
854876 if such a column exists in the dataframe. If idCol is not set, a column with the name `unique_id`
855877 will be automatically added to the dataframe and used as unique identifier for each vector.
@@ -1037,9 +1059,7 @@ def __init__(
10371059 self .bcast_qfeatures : Optional [Broadcast ] = None
10381060
10391061 def _out_schema (self ) -> Union [StructType , str ]: # type: ignore
1040- return (
1041- f"query_{ self .getIdCol ()} long, indices array<long>, distances array<float>"
1042- )
1062+ return f"query_{ self ._getIdColOrDefault ()} long, indices array<long>, distances array<float>"
10431063
10441064 def _pre_process_data (
10451065 self , dataset : DataFrame
@@ -1049,9 +1069,8 @@ def _pre_process_data(
10491069 dataset
10501070 )
10511071
1052- if self .hasParam ("idCol" ) and self .isDefined ("idCol" ):
1053- id_col_name = self .getOrDefault ("idCol" )
1054- dataset = dataset .withColumnRenamed (id_col_name , alias .row_number )
1072+ id_col_name = self ._getIdColOrDefault ()
1073+ dataset = dataset .withColumnRenamed (id_col_name , alias .row_number )
10551074
10561075 select_cols .append (alias .row_number )
10571076
@@ -1179,7 +1198,7 @@ def kneighbors(self, query_df: DataFrame) -> Tuple[DataFrame, DataFrame, DataFra
11791198 )
11801199 k = self .getK ()
11811200
1182- query_id_col_name = f"query_{ self .getIdCol ()} "
1201+ query_id_col_name = f"query_{ self ._getIdColOrDefault ()} "
11831202
11841203 ascending = False if self .getMetric () == "inner_product" else True
11851204
@@ -1221,7 +1240,7 @@ def _construct_sgnn() -> CumlT:
12211240 row_number_col = alias .row_number
12221241 input_col , input_cols = self ._get_input_columns ()
12231242 assert input_col is not None or input_cols is not None
1224- id_col_name = self .getIdCol ()
1243+ id_col_name = self ._getIdColOrDefault ()
12251244
12261245 bcast_qids = self .bcast_qids
12271246 bcast_qfeatures = self .bcast_qfeatures
0 commit comments