Skip to content

Commit 1788e9d

Browse files
authored
Fix NearestNeighbors _ensureIdCol to check whether id_col in df.columns instead of relying on isSet(idCol) (#642)
* fix ensureIdCol to avoid using isSet(idCol) * simply the logic of ensureIdCol * try set idCol to None --------- Signed-off-by: Jinfeng <[email protected]>
1 parent b244341 commit 1788e9d

File tree

3 files changed

+66
-34
lines changed

3 files changed

+66
-34
lines changed

python/src/spark_rapids_ml/knn.py

Lines changed: 51 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class _NearestNeighborsCumlParams(
8585

8686
def __init__(self) -> None:
8787
super().__init__()
88-
self._setDefault(idCol=alias.row_number)
88+
self._setDefault(idCol=None)
8989

9090
k = Param(
9191
Params._dummy(),
@@ -114,6 +114,16 @@ def getK(self: P) -> int:
114114
"""
115115
return self.getOrDefault("k")
116116

117+
def _getIdColOrDefault(self) -> str:
118+
"""
119+
Gets the value of `idCol`.
120+
"""
121+
122+
res = self.getIdCol()
123+
if res is None:
124+
res = alias.row_number
125+
return res
126+
117127
def setInputCol(self: P, value: Union[str, List[str]]) -> P:
118128
"""
119129
Sets the value of :py:attr:`inputCol` or :py:attr:`inputCols`.
@@ -142,19 +152,32 @@ def _ensureIdCol(self, df: DataFrame) -> DataFrame:
142152
Ensure an id column exists in the input dataframe. Add the column if not exists.
143153
Overwritten for knn assumption on error for not setting idCol and duplicate exists.
144154
"""
145-
if not self.isSet("idCol") and self.getIdCol() in df.columns:
146-
raise ValueError(
147-
f"Cannot create a default id column since a column with the default name '{self.getIdCol()}' already exists."
148-
+ "Please specify an id column"
149-
)
150155

151156
id_col_name = self.getIdCol()
152-
df_withid = (
153-
df
154-
if self.isSet("idCol")
155-
else df.select(monotonically_increasing_id().alias(id_col_name), "*")
156-
)
157-
return df_withid
157+
if id_col_name is None:
158+
if alias.row_number in df.columns:
159+
raise ValueError(
160+
f"Trying to create an id column with default name {alias.row_number}. But a column with the same name already exists."
161+
)
162+
else:
163+
get_logger(self.__class__).info(
164+
f"idCol not set. Spark Rapids ML will create one with default name {alias.row_number}."
165+
)
166+
df_withid = df.select(
167+
monotonically_increasing_id().alias(alias.row_number), "*"
168+
)
169+
return df_withid
170+
else:
171+
if id_col_name in df.columns:
172+
return df
173+
else:
174+
get_logger(self.__class__).info(
175+
f"column {id_col_name} does not exists in the input dataframe. Spark Rapids ML will create the {id_col_name} column."
176+
)
177+
df_withid = df.select(
178+
monotonically_increasing_id().alias(alias.row_number), "*"
179+
)
180+
return df_withid
158181

159182

160183
class NearestNeighbors(
@@ -179,7 +202,7 @@ class NearestNeighbors(
179202
* When the value is a string, the feature columns must be assembled into 1 column with vector or array type.
180203
* When the value is a list of strings, the feature columns must be numeric types.
181204
182-
idCol: str
205+
idCol: str (default = None)
183206
the name of the column in a dataframe that uniquely identifies each vector. idCol should be set
184207
if such a column exists in the dataframe. If idCol is not set, a column with the name `unique_id`
185208
will be automatically added to the dataframe and used as unique identifier for each vector.
@@ -400,7 +423,7 @@ def exactNearestNeighborsJoin(
400423
where item_vector v1 is one of the k nearest neighbors of query_vector v2 and their distance is dist(v1, v2).
401424
"""
402425

403-
id_col_name = self.getIdCol()
426+
id_col_name = self._getIdColOrDefault()
404427

405428
# call kneighbors then prepare return results
406429
(item_df_withid, query_df_withid, knn_df) = self.kneighbors(query_df)
@@ -471,7 +494,9 @@ def _out_schema(self) -> Union[StructType, str]: # type: ignore
471494
return StructType(
472495
[
473496
StructField(
474-
f"query_{self.getIdCol()}", ArrayType(LongType(), False), False
497+
f"query_{self._getIdColOrDefault()}",
498+
ArrayType(LongType(), False),
499+
False,
475500
),
476501
StructField(
477502
"indices", ArrayType(ArrayType(LongType(), False), False), False
@@ -509,11 +534,8 @@ def _pre_process_data( # type: ignore
509534

510535
select_cols.append(col(alias.label))
511536

512-
if self.hasParam("idCol") and self.isDefined("idCol"):
513-
id_col_name = self.getOrDefault("idCol")
514-
select_cols.append(col(id_col_name).alias(alias.row_number))
515-
else:
516-
select_cols.append(col(alias.row_number))
537+
id_col_name = self._getIdColOrDefault()
538+
select_cols.append(col(id_col_name).alias(alias.row_number))
517539

518540
return select_cols, multi_col_names, dimension, feature_type
519541

@@ -561,8 +583,8 @@ def kneighbors(self, query_df: DataFrame) -> Tuple[DataFrame, DataFrame, DataFra
561583
pipelinedrdd = self._call_cuml_fit_func(union_df, partially_collect=False)
562584
pipelinedrdd = pipelinedrdd.repartition(query_default_num_partitions) # type: ignore
563585

564-
query_id_col_name = f"query_{self.getIdCol()}"
565-
id_col_type = dict(union_df.dtypes)[self.getIdCol()]
586+
query_id_col_name = f"query_{self._getIdColOrDefault()}"
587+
id_col_type = dict(union_df.dtypes)[self._getIdColOrDefault()]
566588
knn_rdd = pipelinedrdd.flatMap(
567589
lambda row: list(
568590
zip(row[query_id_col_name], row["indices"], row["distances"])
@@ -584,7 +606,7 @@ def _get_cuml_fit_func(
584606
]:
585607
label_isdata = self._label_isdata
586608
label_isquery = self._label_isquery
587-
id_col_name = self.getIdCol()
609+
id_col_name = self._getIdColOrDefault()
588610

589611
def _cuml_fit(
590612
dfs: FitInputType,
@@ -849,7 +871,7 @@ class ApproximateNearestNeighbors(
849871
* When the value is a string, the feature columns must be assembled into 1 column with vector or array type.
850872
* When the value is a list of strings, the feature columns must be numeric types.
851873
852-
idCol: str
874+
idCol: str (default = None)
853875
the name of the column in a dataframe that uniquely identifies each vector. idCol should be set
854876
if such a column exists in the dataframe. If idCol is not set, a column with the name `unique_id`
855877
will be automatically added to the dataframe and used as unique identifier for each vector.
@@ -1037,9 +1059,7 @@ def __init__(
10371059
self.bcast_qfeatures: Optional[Broadcast] = None
10381060

10391061
def _out_schema(self) -> Union[StructType, str]: # type: ignore
1040-
return (
1041-
f"query_{self.getIdCol()} long, indices array<long>, distances array<float>"
1042-
)
1062+
return f"query_{self._getIdColOrDefault()} long, indices array<long>, distances array<float>"
10431063

10441064
def _pre_process_data(
10451065
self, dataset: DataFrame
@@ -1049,9 +1069,8 @@ def _pre_process_data(
10491069
dataset
10501070
)
10511071

1052-
if self.hasParam("idCol") and self.isDefined("idCol"):
1053-
id_col_name = self.getOrDefault("idCol")
1054-
dataset = dataset.withColumnRenamed(id_col_name, alias.row_number)
1072+
id_col_name = self._getIdColOrDefault()
1073+
dataset = dataset.withColumnRenamed(id_col_name, alias.row_number)
10551074

10561075
select_cols.append(alias.row_number)
10571076

@@ -1179,7 +1198,7 @@ def kneighbors(self, query_df: DataFrame) -> Tuple[DataFrame, DataFrame, DataFra
11791198
)
11801199
k = self.getK()
11811200

1182-
query_id_col_name = f"query_{self.getIdCol()}"
1201+
query_id_col_name = f"query_{self._getIdColOrDefault()}"
11831202

11841203
ascending = False if self.getMetric() == "inner_product" else True
11851204

@@ -1221,7 +1240,7 @@ def _construct_sgnn() -> CumlT:
12211240
row_number_col = alias.row_number
12221241
input_col, input_cols = self._get_input_columns()
12231242
assert input_col is not None or input_cols is not None
1224-
id_col_name = self.getIdCol()
1243+
id_col_name = self._getIdColOrDefault()
12251244

12261245
bcast_qids = self.bcast_qids
12271246
bcast_qfeatures = self.bcast_qfeatures

python/tests/test_approximate_nearest_neighbors.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,9 @@ def cal_avg_dist_gap(distances_ann: np.ndarray) -> float:
229229

230230
ascending = False if metric == "inner_product" else True
231231
reconstructed_knn_df = reconstruct_knn_df(
232-
knnjoin_df, row_identifier_col=knn_model.getIdCol(), ascending=ascending
232+
knnjoin_df,
233+
row_identifier_col=knn_model._getIdColOrDefault(),
234+
ascending=ascending,
233235
)
234236
reconstructed_collect = reconstructed_knn_df.collect()
235237

python/tests/test_nearest_neighbors.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,17 @@ def assert_knn_metadata_equal(knn_metadata: List[List[str]]) -> None:
234234
assert knnjoin_queries[i]["features"] == query[i][0]
235235
assert knnjoin_queries[i]["metadata"] == query[i][1]
236236

237+
# Test fit(dataset, ParamMap) that copies existing estimator
238+
# After copy, self.isSet("idCol") becomes true. But the added id column does not exist in the dataframe
239+
paramMap = gpu_knn.extractParamMap()
240+
gpu_model_v2 = gpu_knn.fit(data_df, paramMap)
241+
242+
assert gpu_knn.isSet("idCol") is False
243+
assert gpu_model_v2.isSet("idCol") is True
244+
245+
(_, _, knn_df_v2) = gpu_model_v2.kneighbors(query_df)
246+
assert knn_df_v2.collect() == knn_df.collect()
247+
237248
return gpu_knn, gpu_model
238249

239250

@@ -432,7 +443,7 @@ def test_nearest_neighbors(
432443
knn_model.setIdCol(item_df_withid.dtypes[0][0])
433444
knnjoin_df = knn_model.exactNearestNeighborsJoin(query_df_withid)
434445
reconstructed_knn_df = reconstruct_knn_df(
435-
knnjoin_df, row_identifier_col=knn_model.getIdCol()
446+
knnjoin_df, row_identifier_col=knn_model._getIdColOrDefault()
436447
)
437448
assert reconstructed_knn_df.collect() == knn_df.collect()
438449

0 commit comments

Comments
 (0)