Skip to content

Commit 5ec5020

Browse files
authored
Fix issue of 24.12 RC [skip ci] (#826)
Additional changes to fix issue of 24.12 RC based on #822 NOTE: this PR must be merged as `create a merge commit`
2 parents 7c4686c + f9624be commit 5ec5020

File tree

2 files changed

+66
-18
lines changed

2 files changed

+66
-18
lines changed

python/src/spark_rapids_ml/umap.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -790,8 +790,9 @@ class UMAP(UMAPClass, _CumlEstimatorSupervised, _UMAPCumlParams):
790790
791791
sample_fraction : float (optional, default=1.0)
792792
The fraction of the dataset to be used for fitting the model. Since fitting is done on a single node, very large
793-
datasets must be subsampled to fit within the node's memory and execute in a reasonable time. Smaller fractions
794-
will result in faster training, but may result in sub-optimal embeddings.
793+
datasets must be subsampled to fit within the node's memory. Smaller fractions will result in faster training, but
794+
may decrease embedding quality. Note: this is not guaranteed to provide exactly the fraction specified of the total
795+
count of the given DataFrame.
795796
796797
featuresCol: str or List[str]
797798
The feature column names, spark-rapids-ml supports vector, array and columnar as the input.\n
@@ -1463,22 +1464,30 @@ def write_sparse_array(array: scipy.sparse.spmatrix, df_dir: str) -> None:
14631464
schema=indices_data_schema,
14641465
)
14651466

1466-
indptr_df.write.parquet(
1467-
os.path.join(df_dir, "indptr.parquet"), mode="overwrite"
1468-
)
1469-
indices_data_df.write.parquet(
1470-
os.path.join(df_dir, "indices_data.parquet"), mode="overwrite"
1471-
)
1467+
indptr_df.write.parquet(os.path.join(df_dir, "indptr.parquet"))
1468+
indices_data_df.write.parquet(os.path.join(df_dir, "indices_data.parquet"))
14721469

14731470
def write_dense_array(array: np.ndarray, df_path: str) -> None:
1471+
assert (
1472+
spark.conf.get("spark.sql.execution.arrow.pyspark.enabled") == "true"
1473+
), "spark.sql.execution.arrow.pyspark.enabled must be set to true to persist array attributes"
1474+
14741475
schema = StructType(
14751476
[
1476-
StructField(f"_{i}", FloatType(), False)
1477-
for i in range(1, array.shape[1] + 1)
1477+
StructField("row_id", LongType(), False),
1478+
StructField("data", ArrayType(FloatType(), False), False),
14781479
]
14791480
)
1480-
data_df = spark.createDataFrame(pd.DataFrame(array), schema=schema)
1481-
data_df.write.parquet(df_path, mode="overwrite")
1481+
data_df = spark.createDataFrame(
1482+
pd.DataFrame(
1483+
{
1484+
"row_id": range(array.shape[0]),
1485+
"data": list(array),
1486+
}
1487+
),
1488+
schema=schema,
1489+
)
1490+
data_df.write.parquet(df_path)
14821491

14831492
DefaultParamsWriter.saveMetadata(
14841493
self.instance,
@@ -1491,12 +1500,12 @@ def write_dense_array(array: np.ndarray, df_path: str) -> None:
14911500
},
14921501
)
14931502

1503+
# get a copy, since we're going to modify the array attributes
14941504
model_attributes = self.instance._get_model_attributes()
14951505
assert model_attributes is not None
1506+
model_attributes = model_attributes.copy()
14961507

14971508
data_path = os.path.join(path, "data")
1498-
if not os.path.exists(data_path):
1499-
os.makedirs(data_path)
15001509

15011510
for key in ["embedding_", "raw_data_"]:
15021511
array = model_attributes[key]
@@ -1547,8 +1556,10 @@ def read_sparse_array(
15471556
return scipy.sparse.csr_matrix((data, indices, indptr), shape=csr_shape)
15481557

15491558
def read_dense_array(df_path: str) -> np.ndarray:
1550-
data_df = spark.read.parquet(df_path)
1551-
return np.array(data_df.collect(), dtype=np.float32)
1559+
data_df = spark.read.parquet(df_path).orderBy("row_id")
1560+
pdf = data_df.toPandas()
1561+
assert type(pdf) == pd.DataFrame
1562+
return np.array(list(pdf.data), dtype=np.float32)
15521563

15531564
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
15541565
data_path = os.path.join(path, "data")

python/tests/test_umap.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# Copyright (c) 2024, NVIDIA CORPORATION.
2+
# Copyright (c) 2025, NVIDIA CORPORATION.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -18,7 +18,6 @@
1818
from typing import Any, Dict, List, Optional, Tuple, Union
1919

2020
import cupy as cp
21-
import cupyx
2221
import numpy as np
2322
import pytest
2423
import scipy
@@ -415,6 +414,9 @@ def test_umap_copy() -> None:
415414
def test_umap_model_persistence(
416415
sparse_fit: bool, gpu_number: int, tmp_path: str
417416
) -> None:
417+
import os
418+
import re
419+
418420
import pyspark
419421
from packaging import version
420422

@@ -459,7 +461,42 @@ def test_umap_model_persistence(
459461
path = tmp_path + "/umap_tests"
460462
model_path = f"{path}/umap_model"
461463
umap_model.write().overwrite().save(model_path)
464+
465+
try:
466+
umap_model.write().save(model_path)
467+
assert False, "Overwriting should not be permitted"
468+
except Exception as e:
469+
assert re.search(r"Output directory .* already exists", str(e))
470+
471+
# double check expected files/directories
472+
model_dir_contents = os.listdir(model_path)
473+
data_dir_contents = os.listdir(f"{model_path}/data")
474+
assert set(model_dir_contents) == {"data", "metadata"}
475+
if sparse_fit:
476+
assert set(data_dir_contents) == {
477+
"metadata.json",
478+
"embedding_.parquet",
479+
"raw_data_csr",
480+
}
481+
assert set(os.listdir(f"{model_path}/data/raw_data_csr")) == {
482+
"indptr.parquet",
483+
"indices_data.parquet",
484+
}
485+
else:
486+
assert set(data_dir_contents) == {
487+
"metadata.json",
488+
"embedding_.parquet",
489+
"raw_data_.parquet",
490+
}
491+
492+
# make sure we can overwrite
493+
umap_model._cuml_params["n_neighbors"] = 10
494+
umap_model._cuml_params["set_op_mix_ratio"] = 0.4
495+
umap_model.write().overwrite().save(model_path)
496+
462497
umap_model_loaded = UMAPModel.load(model_path)
498+
assert umap_model_loaded._cuml_params["n_neighbors"] == 10
499+
assert umap_model_loaded._cuml_params["set_op_mix_ratio"] == 0.4
463500
_assert_umap_model(umap_model_loaded, input_raw_data)
464501

465502

0 commit comments

Comments
 (0)