Skip to content

Commit 4685855

Browse files
authored
address a known mismatch with spark mllib linearregression when standardization is enabled. (#991)
- implement standardization in linear regression using cupy based data modification in a way that matches Spark (this greatly reduces the mismatch at the coefficient/intercept level to < 10% on a small existing unit test example - TBD to bridge the gap further) - patches an existing bug for gpu optimized fitMultiple when changed parameters included standardization for logisticregression and linearregression (as this is somewhat of a corner case, approach taken is to fall back to mllib fitMultiple but other more optimized approaches are possible in the future) --------- Signed-off-by: Erik Ordentlich <[email protected]>
1 parent 2d38a33 commit 4685855

File tree

9 files changed

+439
-165
lines changed

9 files changed

+439
-165
lines changed

python/benchmark/benchmark/bench_linear_regression.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ def run_once(
114114
)
115115

116116
# note: results for spark ML and spark rapids ml will currently match in all regularization
117-
# cases only if features and labels were standardized in the original dataset. Otherwise,
118-
# they will match only if regParam = 0 or elastNetParam = 1.0 (aka Lasso)
117+
# cases only if features and labels were standardized in the original dataset or if standardization is enabled.
118+
# Otherwise, they will match only if regParam = 0 or elasticNetParam = 1.0 (aka Lasso)
119119
print(
120120
f"RMSE: {rmse}, coefs l1: {coefs_l1}, coefs l2^2: {coefs_l2}, "
121121
f"full_objective: {full_objective}, intercept: {model.intercept}"

python/src/spark_rapids_ml/classification.py

Lines changed: 8 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,53 +1018,14 @@ def _logistic_regression_fit(
10181018
# Use cupy to standardize dataset as a workaround to gain better numeric stability
10191019
standarization_with_cupy = standardization and not is_sparse
10201020
if standarization_with_cupy is True:
1021-
import cupy as cp
1022-
1023-
if isinstance(concated, np.ndarray):
1024-
concated = cp.array(concated)
1025-
elif isinstance(concated, pd.DataFrame):
1026-
concated = cp.array(concated.values)
1027-
else:
1028-
assert isinstance(
1029-
concated, cp.ndarray
1030-
), "only numpy array, cupy array, and pandas dataframe are supported when standardization_with_cupy is on"
1031-
1032-
mean_partial = concated.sum(axis=0) / pdesc.m
1033-
1034-
import json
1035-
1036-
from pyspark import BarrierTaskContext
1037-
1038-
context = BarrierTaskContext.get()
1039-
1040-
def all_gather_then_sum(
1041-
cp_array: cp.ndarray, dtype: Union[np.float32, np.float64]
1042-
) -> cp.ndarray:
1043-
msgs = context.allGather(json.dumps(cp_array.tolist()))
1044-
arrays = [json.loads(p) for p in msgs]
1045-
array_sum = np.sum(arrays, axis=0).astype(dtype)
1046-
return cp.array(array_sum)
1047-
1048-
mean = all_gather_then_sum(mean_partial, concated.dtype)
1049-
concated -= mean
1050-
1051-
l2 = cp.linalg.norm(concated, ord=2, axis=0)
1052-
1053-
var_partial = l2 * l2 / (pdesc.m - 1)
1054-
var = all_gather_then_sum(var_partial, concated.dtype)
1055-
1056-
assert cp.all(
1057-
var >= 0
1058-
), "numeric instable detected when calculating variance. Got negative variance"
1059-
1060-
stddev = cp.sqrt(var)
1061-
1062-
stddev_inv = cp.where(stddev != 0, 1.0 / stddev, 1.0)
1063-
1064-
if fit_intercept is False:
1065-
concated += mean
1066-
1067-
concated *= stddev_inv
1021+
from .utils import _standardize_dataset
1022+
1023+
# TODO: fix for multiple param sweep that change standardization and/or fit intercept (unlikely scenario) since
1024+
# data modification effects all params. currently not invoked in these cases by fitMultiple (see fitMultiple)
1025+
_tmp_data = [(concated, None, None)]
1026+
# this will modify concated in place through _tmp_data
1027+
mean, stddev = _standardize_dataset(_tmp_data, pdesc, fit_intercept)
1028+
concated = _tmp_data[0][0]
10681029

10691030
def _single_fit(init_parameters: Dict[str, Any]) -> Dict[str, Any]:
10701031
if standarization_with_cupy is True:

python/src/spark_rapids_ml/core.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,15 @@
7676
from .metrics import EvalMetricInfo
7777
from .params import _CumlParams
7878
from .utils import (
79+
FitInputType,
7980
_ArrayOrder,
8081
_configure_memory_resource,
8182
_get_gpu_id,
8283
_get_spark_session,
8384
_is_local,
8485
_is_standalone_or_localcluster,
86+
_SingleNpArrayBatchType,
87+
_SinglePdDataFrameBatchType,
8588
dtype_to_pyspark_type,
8689
get_logger,
8790
)
@@ -95,14 +98,6 @@
9598

9699
_CumlParamMap = Dict[str, Any]
97100

98-
_SinglePdDataFrameBatchType = Tuple[
99-
pd.DataFrame, Optional[pd.DataFrame], Optional[pd.DataFrame]
100-
]
101-
_SingleNpArrayBatchType = Tuple[np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]
102-
103-
# FitInputType is type of [(feature, label), ...]
104-
FitInputType = Union[List[_SinglePdDataFrameBatchType], List[_SingleNpArrayBatchType]]
105-
106101
# TransformInput type
107102
TransformInputType = Union["cudf.DataFrame", np.ndarray]
108103

@@ -1170,13 +1165,27 @@ def fitMultiple(
11701165
using `paramMaps[index]`. `index` values may not be sequential.
11711166
"""
11721167

1168+
logger = get_logger(self.__class__)
1169+
11731170
if self._use_cpu_fallback():
11741171
return super().fitMultiple(dataset, paramMaps)
11751172

11761173
if self._enable_fit_multiple_in_single_pass():
11771174
for paramMap in paramMaps:
11781175
if self._use_cpu_fallback(paramMap):
11791176
return super().fitMultiple(dataset, paramMaps)
1177+
# standardization and fitIntercept currently may modify the dataset and is done once outside the param loop.
1178+
# If either appears in a param map, fall back to regular multiple passfitMultiple.
1179+
# TODO: sparse logistic regression does not modify data so ok in that case. Need logic to check dataset to detect that case.
1180+
# TODO: implement single pass with either of these by processing param maps with no
1181+
# standardization or fitIntercept before those with standardization or fitIntercept.
1182+
param_names = [p.name for p in paramMap.keys()]
1183+
for unsupported in ["standardization", "fitIntercept"]:
1184+
if unsupported in param_names:
1185+
logger.warning(
1186+
f"{unsupported} in param maps not supported for one pass GPU fitMultiple. Falling back to baseline fitMultiple."
1187+
)
1188+
return super().fitMultiple(dataset, paramMaps)
11801189

11811190
# reach here if no cpu fallback
11821191
estimator = self.copy()

python/src/spark_rapids_ml/regression.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def _param_mapping(cls) -> Dict[str, Optional[str]]:
191191
"maxIter": "max_iter",
192192
"regParam": "alpha",
193193
"solver": "solver",
194-
"standardization": "normalize",
194+
"standardization": "normalize", # TODO: standardization is carried out in cupy not cuml so need a new type of param mapped value to indicate that.
195195
"tol": "tol",
196196
"weightCol": None,
197197
}
@@ -309,9 +309,9 @@ class LinearRegression(
309309
310310
Notes
311311
-----
312-
Results for spark ML and spark rapids ml fit() will currently match in all regularization
313-
cases only if features and labels are standardized in the input dataframe. Otherwise,
314-
they will match only if regParam = 0 or elastNetParam = 1.0 (aka Lasso).
312+
Results for spark ML and spark rapids ml fit() will currently be close in all regularization
313+
cases only if features and labels are standardized in the input dataframe or when standardization is enabled. Otherwise,
314+
they will be close only if regParam = 0 or elasticNetParam = 1.0 (aka Lasso).
315315
316316
Parameters
317317
----------
@@ -513,6 +513,10 @@ def _get_cuml_fit_func(
513513
[FitInputType, Dict[str, Any]],
514514
Dict[str, Any],
515515
]:
516+
517+
standardization = self.getStandardization()
518+
fit_intercept = self.getFitIntercept()
519+
516520
def _linear_regression_fit(
517521
dfs: FitInputType,
518522
params: Dict[str, Any],
@@ -522,6 +526,20 @@ def _linear_regression_fit(
522526
params[param_alias.part_sizes], params[param_alias.num_cols]
523527
)
524528

529+
pdesc_labels = PartitionDescriptor.build(params[param_alias.part_sizes], 1)
530+
531+
if standardization:
532+
from .utils import _standardize_dataset
533+
534+
# this modifies dfs in place by copying to gpu and standardazing in place on gpu
535+
# TODO: fix for multiple param sweep that change standardization and/or fit intercept (unlikely scenario) since
536+
# data modification effects all params. currently not invoked in these cases by fitMultiple (see fitMultiple)
537+
mean, stddev = _standardize_dataset(dfs, pdesc, fit_intercept)
538+
stddev_label = stddev[-1]
539+
stddev_features = stddev[:-1]
540+
mean_label = mean[-1]
541+
mean_features = mean[:-1]
542+
525543
def _single_fit(init_parameters: Dict[str, Any]) -> Dict[str, Any]:
526544
if init_parameters["alpha"] == 0:
527545
# LR
@@ -532,7 +550,6 @@ def _single_fit(init_parameters: Dict[str, Any]) -> Dict[str, Any]:
532550
supported_params = [
533551
"algorithm",
534552
"fit_intercept",
535-
"normalize",
536553
"verbose",
537554
"copy_X",
538555
]
@@ -547,18 +564,19 @@ def _single_fit(init_parameters: Dict[str, Any]) -> Dict[str, Any]:
547564
"alpha",
548565
"solver",
549566
"fit_intercept",
550-
"normalize",
551567
"verbose",
552568
]
553569
# spark ML normalizes sample portion of objective by the number of examples
554570
# but cuml does not for RidgeRegression (l1_ratio=0). Induce similar behavior
555571
# to spark ml by scaling up the reg parameter by the number of examples.
556572
# With this, spark ML and spark rapids ML results match closely when features
557-
# and label columns are all standardized.
573+
# and label columns are all standardized, or when standardization is enabled.
558574
init_parameters = init_parameters.copy()
559575
if "alpha" in init_parameters.keys():
560576
init_parameters["alpha"] *= (float)(pdesc.m)
561-
577+
if standardization:
578+
# key to matching mllib when standardization is enabled
579+
init_parameters["alpha"] /= stddev_label
562580
else:
563581
# LR + L1, or LR + L1 + L2
564582
# Cuml uses Coordinate Descent algorithm to implement Lasso and ElasticNet
@@ -575,12 +593,15 @@ def _single_fit(init_parameters: Dict[str, Any]) -> Dict[str, Any]:
575593
"l1_ratio",
576594
"fit_intercept",
577595
"max_iter",
578-
"normalize",
579596
"tol",
580597
"shuffle",
581598
"verbose",
582599
]
583600

601+
if standardization:
602+
# key to matching mllib when standardization is enabled
603+
init_parameters["alpha"] /= stddev_label
604+
584605
# filter only supported params
585606
final_init_parameters = {
586607
k: v for k, v in init_parameters.items() if k in supported_params
@@ -604,9 +625,28 @@ def _single_fit(init_parameters: Dict[str, Any]) -> Dict[str, Any]:
604625
pdesc.rank,
605626
)
606627

628+
coef_ = linear_regression.coef_
629+
intercept_ = linear_regression.intercept_
630+
631+
if standardization is True:
632+
import cupy as cp
633+
634+
coef_ = cp.where(
635+
stddev_features > 0,
636+
(coef_ / stddev_features) * stddev_label,
637+
coef_,
638+
)
639+
if init_parameters["fit_intercept"] is True:
640+
641+
intercept_ = (
642+
intercept_ * stddev_label
643+
- cp.dot(coef_, mean_features)
644+
+ mean_label
645+
).tolist()
646+
607647
return {
608-
"coef_": linear_regression.coef_.get().tolist(),
609-
"intercept_": linear_regression.intercept_,
648+
"coef_": coef_.tolist(),
649+
"intercept_": intercept_,
610650
"dtype": linear_regression.dtype.name,
611651
"n_cols": linear_regression.n_cols,
612652
}

python/src/spark_rapids_ml/utils.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@
4343
from pyspark.sql.types import ArrayType, FloatType
4444

4545
_ArrayOrder = Literal["C", "F"]
46+
_SinglePdDataFrameBatchType = Tuple[
47+
pd.DataFrame, Optional[pd.DataFrame], Optional[pd.DataFrame]
48+
]
49+
_SingleNpArrayBatchType = Tuple[np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]
50+
51+
# FitInputType is type of [(feature, label), ...]
52+
FitInputType = Union[List[_SinglePdDataFrameBatchType], List[_SingleNpArrayBatchType]]
4653

4754

4855
def _method_names_from_param(spark_param_name: str) -> List[str]:
@@ -809,3 +816,112 @@ def getInputOrFeaturesCols(est: Union[Estimator, Transformer]) -> str:
809816
else getattr(est, "getInputCol")
810817
)
811818
return getter()
819+
820+
821+
def _standardize_dataset(
822+
data: FitInputType, pdesc: PartitionDescriptor, fit_intercept: bool
823+
) -> Tuple["cp.ndarray", "cp.ndarray"]:
824+
"""Inplace standardize the dataset feature and optionally label columns
825+
826+
Args:
827+
data: dataset to standardize (including features and label)
828+
pdesc: Partition descriptor
829+
fit_intercept: Whether to fit intercept in calling fit function.
830+
831+
Returns:
832+
Mean and standard deviation of features and label columns (latter is last element if present)
833+
Modifies data entries by replacing entries with standardized data on gpu.
834+
If data is already on gpu, modifies in place (i.e. no copy is made).
835+
"""
836+
import cupy as cp
837+
838+
mean_partials_labels = (
839+
cp.zeros(1, dtype=data[0][1].dtype) if data[0][1] is not None else None
840+
)
841+
mean_partials = [cp.zeros(pdesc.n, dtype=data[0][0].dtype), mean_partials_labels]
842+
for i in range(len(data)):
843+
_data = []
844+
for j in range(2):
845+
if data[i][j] is not None:
846+
847+
if isinstance(data[i][j], cp.ndarray):
848+
_data.append(data[i][j]) # type: ignore
849+
elif isinstance(data[i][j], np.ndarray):
850+
_data.append(cp.array(data[i][j])) # type: ignore
851+
elif isinstance(data[i][j], pd.DataFrame) or isinstance(
852+
data[i][j], pd.Series
853+
):
854+
_data.append(cp.array(data[i][j].values)) # type: ignore
855+
else:
856+
raise ValueError("Unsupported data type: ", type(data[i][j]))
857+
mean_partials[j] += _data[j].sum(axis=0) / pdesc.m # type: ignore
858+
else:
859+
_data.append(None)
860+
data[i] = (_data[0], _data[1], data[i][2]) # type: ignore
861+
862+
import json
863+
864+
from pyspark import BarrierTaskContext
865+
866+
context = BarrierTaskContext.get()
867+
868+
def all_gather_then_sum(
869+
cp_array: cp.ndarray, dtype: Union[np.float32, np.float64]
870+
) -> cp.ndarray:
871+
msgs = context.allGather(json.dumps(cp_array.tolist()))
872+
arrays = [json.loads(p) for p in msgs]
873+
array_sum = np.sum(arrays, axis=0).astype(dtype)
874+
return cp.array(array_sum)
875+
876+
if mean_partials[1] is not None:
877+
mean_partial = cp.concatenate(mean_partials) # type: ignore
878+
else:
879+
mean_partial = mean_partials[0]
880+
mean = all_gather_then_sum(mean_partial, mean_partial.dtype)
881+
882+
_mean = (mean[:-1], mean[-1]) if mean_partials[1] is not None else (mean, None)
883+
884+
var_partials_labels = (
885+
cp.zeros(1, dtype=data[0][1].dtype) if data[0][1] is not None else None
886+
)
887+
var_partials = [cp.zeros(pdesc.n, dtype=data[0][0].dtype), var_partials_labels]
888+
for i in range(len(data)):
889+
for j in range(2):
890+
if data[i][j] is not None and _mean[j] is not None:
891+
__data = data[i][j]
892+
__data -= _mean[j] # type: ignore
893+
l2 = cp.linalg.norm(__data, ord=2, axis=0)
894+
var_partials[j] += l2 * l2 / (pdesc.m - 1)
895+
896+
if var_partials[1] is not None:
897+
var_partial = cp.concatenate((var_partials[0], var_partials[1]))
898+
else:
899+
var_partial = var_partials[0]
900+
var = all_gather_then_sum(var_partial, var_partial.dtype)
901+
902+
assert cp.all(
903+
var >= 0
904+
), "numeric instable detected when calculating variance. Got negative variance"
905+
906+
stddev = cp.sqrt(var)
907+
stddev_inv = cp.where(stddev != 0, 1.0 / stddev, 1.0)
908+
_stddev_inv = (
909+
(stddev_inv[:-1], stddev_inv[-1])
910+
if var_partials[1] is not None
911+
else (stddev_inv, None)
912+
)
913+
914+
if fit_intercept is False:
915+
for i in range(len(data)):
916+
for j in range(2):
917+
if data[i][j] is not None and _mean[j] is not None:
918+
__data = data[i][j]
919+
__data += _mean[j] # type: ignore
920+
921+
for i in range(len(data)):
922+
for j in range(2):
923+
if data[i][j] is not None and _stddev_inv[j] is not None:
924+
__data = data[i][j]
925+
__data *= _stddev_inv[j] # type: ignore
926+
927+
return mean, stddev

0 commit comments

Comments
 (0)