Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/benchmark/benchmark/bench_linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def run_once(
)

# note: results for spark ML and spark rapids ml will currently match in all regularization
# cases only if features and labels were standardized in the original dataset. Otherwise,
# they will match only if regParam = 0 or elastNetParam = 1.0 (aka Lasso)
# cases only if features and labels were standardized in the original dataset or if standardization is enabled.
# Otherwise, they will match only if regParam = 0 or elasticNetParam = 1.0 (aka Lasso)
print(
f"RMSE: {rmse}, coefs l1: {coefs_l1}, coefs l2^2: {coefs_l2}, "
f"full_objective: {full_objective}, intercept: {model.intercept}"
Expand Down
55 changes: 8 additions & 47 deletions python/src/spark_rapids_ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,53 +1011,14 @@ def _logistic_regression_fit(
# Use cupy to standardize dataset as a workaround to gain better numeric stability
standarization_with_cupy = standardization and not is_sparse
if standarization_with_cupy is True:
import cupy as cp

if isinstance(concated, np.ndarray):
concated = cp.array(concated)
elif isinstance(concated, pd.DataFrame):
concated = cp.array(concated.values)
else:
assert isinstance(
concated, cp.ndarray
), "only numpy array, cupy array, and pandas dataframe are supported when standardization_with_cupy is on"

mean_partial = concated.sum(axis=0) / pdesc.m

import json

from pyspark import BarrierTaskContext

context = BarrierTaskContext.get()

def all_gather_then_sum(
cp_array: cp.ndarray, dtype: Union[np.float32, np.float64]
) -> cp.ndarray:
msgs = context.allGather(json.dumps(cp_array.tolist()))
arrays = [json.loads(p) for p in msgs]
array_sum = np.sum(arrays, axis=0).astype(dtype)
return cp.array(array_sum)

mean = all_gather_then_sum(mean_partial, concated.dtype)
concated -= mean

l2 = cp.linalg.norm(concated, ord=2, axis=0)

var_partial = l2 * l2 / (pdesc.m - 1)
var = all_gather_then_sum(var_partial, concated.dtype)

assert cp.all(
var >= 0
), "numeric instable detected when calculating variance. Got negative variance"

stddev = cp.sqrt(var)

stddev_inv = cp.where(stddev != 0, 1.0 / stddev, 1.0)

if fit_intercept is False:
concated += mean

concated *= stddev_inv
from .utils import _standardize_dataset

# TODO: fix for multiple param sweep that change standardization and/or fit intercept (unlikely scenario) since
# data modification effects all params. currently not invoked in these cases by fitMultiple (see fitMultiple)
_tmp_data = [(concated, None, None)]
# this will modify concated in place through _tmp_data
mean, stddev = _standardize_dataset(_tmp_data, pdesc, fit_intercept)
concated = _tmp_data[0][0]

def _single_fit(init_parameters: Dict[str, Any]) -> Dict[str, Any]:
if standarization_with_cupy is True:
Expand Down
25 changes: 17 additions & 8 deletions python/src/spark_rapids_ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,15 @@
from .metrics import EvalMetricInfo
from .params import _CumlParams
from .utils import (
FitInputType,
_ArrayOrder,
_configure_memory_resource,
_get_gpu_id,
_get_spark_session,
_is_local,
_is_standalone_or_localcluster,
_SingleNpArrayBatchType,
_SinglePdDataFrameBatchType,
dtype_to_pyspark_type,
get_logger,
)
Expand All @@ -95,14 +98,6 @@

_CumlParamMap = Dict[str, Any]

_SinglePdDataFrameBatchType = Tuple[
pd.DataFrame, Optional[pd.DataFrame], Optional[pd.DataFrame]
]
_SingleNpArrayBatchType = Tuple[np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]

# FitInputType is type of [(feature, label), ...]
FitInputType = Union[List[_SinglePdDataFrameBatchType], List[_SingleNpArrayBatchType]]

# TransformInput type
TransformInputType = Union["cudf.DataFrame", np.ndarray]

Expand Down Expand Up @@ -1170,13 +1165,27 @@ def fitMultiple(
using `paramMaps[index]`. `index` values may not be sequential.
"""

logger = get_logger(self.__class__)

if self._use_cpu_fallback():
return super().fitMultiple(dataset, paramMaps)

if self._enable_fit_multiple_in_single_pass():
for paramMap in paramMaps:
if self._use_cpu_fallback(paramMap):
return super().fitMultiple(dataset, paramMaps)
# standardization and fitIntercept currently may modify the dataset and is done once outside the param loop.
# If either appears in a param map, fall back to regular multiple passfitMultiple.
# TODO: sparse logistic regression does not modify data so ok in that case. Need logic to check dataset to detect that case.
# TODO: implement single pass with either of these by processing param maps with no
# standardization or fitIntercept before those with standardization or fitIntercept.
param_names = [p.name for p in paramMap.keys()]
for unsupported in ["standardization", "fitIntercept"]:
if unsupported in param_names:
logger.warning(
f"{unsupported} in param maps not supported for one pass GPU fitMultiple. Falling back to baseline fitMultiple."
)
return super().fitMultiple(dataset, paramMaps)

# reach here if no cpu fallback
estimator = self.copy()
Expand Down
62 changes: 51 additions & 11 deletions python/src/spark_rapids_ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _param_mapping(cls) -> Dict[str, Optional[str]]:
"maxIter": "max_iter",
"regParam": "alpha",
"solver": "solver",
"standardization": "normalize",
"standardization": "normalize", # TODO: standardization is carried out in cupy not cuml so need a new type of param mapped value to indicate that.
"tol": "tol",
"weightCol": None,
}
Expand Down Expand Up @@ -309,9 +309,9 @@ class LinearRegression(
Notes
-----
Results for spark ML and spark rapids ml fit() will currently match in all regularization
cases only if features and labels are standardized in the input dataframe. Otherwise,
they will match only if regParam = 0 or elastNetParam = 1.0 (aka Lasso).
Results for spark ML and spark rapids ml fit() will currently be close in all regularization
cases only if features and labels are standardized in the input dataframe or when standardization is enabled. Otherwise,
they will be close only if regParam = 0 or elasticNetParam = 1.0 (aka Lasso).
Parameters
----------
Expand Down Expand Up @@ -513,6 +513,10 @@ def _get_cuml_fit_func(
[FitInputType, Dict[str, Any]],
Dict[str, Any],
]:

standardization = self.getStandardization()
fit_intercept = self.getFitIntercept()

def _linear_regression_fit(
dfs: FitInputType,
params: Dict[str, Any],
Expand All @@ -522,6 +526,20 @@ def _linear_regression_fit(
params[param_alias.part_sizes], params[param_alias.num_cols]
)

pdesc_labels = PartitionDescriptor.build(params[param_alias.part_sizes], 1)

if standardization:
from .utils import _standardize_dataset

# this modifies dfs in place by copying to gpu and standardazing in place on gpu
# TODO: fix for multiple param sweep that change standardization and/or fit intercept (unlikely scenario) since
# data modification effects all params. currently not invoked in these cases by fitMultiple (see fitMultiple)
mean, stddev = _standardize_dataset(dfs, pdesc, fit_intercept)
stddev_label = stddev[-1]
stddev_features = stddev[:-1]
mean_label = mean[-1]
mean_features = mean[:-1]

def _single_fit(init_parameters: Dict[str, Any]) -> Dict[str, Any]:
if init_parameters["alpha"] == 0:
# LR
Expand All @@ -532,7 +550,6 @@ def _single_fit(init_parameters: Dict[str, Any]) -> Dict[str, Any]:
supported_params = [
"algorithm",
"fit_intercept",
"normalize",
"verbose",
"copy_X",
]
Expand All @@ -547,18 +564,19 @@ def _single_fit(init_parameters: Dict[str, Any]) -> Dict[str, Any]:
"alpha",
"solver",
"fit_intercept",
"normalize",
"verbose",
]
# spark ML normalizes sample portion of objective by the number of examples
# but cuml does not for RidgeRegression (l1_ratio=0). Induce similar behavior
# to spark ml by scaling up the reg parameter by the number of examples.
# With this, spark ML and spark rapids ML results match closely when features
# and label columns are all standardized.
# and label columns are all standardized, or when standardization is enabled.
init_parameters = init_parameters.copy()
if "alpha" in init_parameters.keys():
init_parameters["alpha"] *= (float)(pdesc.m)

if standardization:
# key to matching mllib when standardization is enabled
init_parameters["alpha"] /= stddev_label
else:
# LR + L1, or LR + L1 + L2
# Cuml uses Coordinate Descent algorithm to implement Lasso and ElasticNet
Expand All @@ -575,12 +593,15 @@ def _single_fit(init_parameters: Dict[str, Any]) -> Dict[str, Any]:
"l1_ratio",
"fit_intercept",
"max_iter",
"normalize",
"tol",
"shuffle",
"verbose",
]

if standardization:
# key to matching mllib when standardization is enabled
init_parameters["alpha"] /= stddev_label

# filter only supported params
final_init_parameters = {
k: v for k, v in init_parameters.items() if k in supported_params
Expand All @@ -604,9 +625,28 @@ def _single_fit(init_parameters: Dict[str, Any]) -> Dict[str, Any]:
pdesc.rank,
)

coef_ = linear_regression.coef_
intercept_ = linear_regression.intercept_
Comment on lines +628 to +629
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: changed from .get().tolist() to .tolist()

previously linear_regression.coef_.get().tolist() explicitly transferred from GPU to CPU via .get(). now coef_.tolist() is called which relies on CuPy's .tolist() to handle GPU-to-CPU transfer implicitly. verify this works correctly in all cases


if standardization is True:
import cupy as cp

coef_ = cp.where(
stddev_features > 0,
(coef_ / stddev_features) * stddev_label,
coef_,
)
if init_parameters["fit_intercept"] is True:

intercept_ = (
intercept_ * stddev_label
- cp.dot(coef_, mean_features)
+ mean_label
).tolist()
Comment on lines +641 to +645
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: when standardization=True but fit_intercept=False, the intercept adjustment will still use mean_label which may not match expected behavior

the code adds back the mean before scaling (lines 914-919 in utils.py) when fit_intercept=False, but here the intercept calculation still subtracts cp.dot(coef_, mean_features) and adds mean_label even though means were added back before scaling


return {
"coef_": linear_regression.coef_.get().tolist(),
"intercept_": linear_regression.intercept_,
"coef_": coef_.tolist(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: missing .get() call before .tolist() when standardization=False

when standardization is disabled, coef_ remains as linear_regression.coef_ (a CuPy array). the old code used .get().tolist() to transfer from GPU to CPU. now only .tolist() is called which will fail or give incorrect results

Suggested change
"coef_": coef_.tolist(),
"coef_": coef_.get().tolist() if isinstance(coef_, cp.ndarray) else coef_.tolist(),

"intercept_": intercept_,
"dtype": linear_regression.dtype.name,
"n_cols": linear_regression.n_cols,
}
Expand Down
116 changes: 116 additions & 0 deletions python/src/spark_rapids_ml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@
from pyspark.sql.types import ArrayType, FloatType

_ArrayOrder = Literal["C", "F"]
_SinglePdDataFrameBatchType = Tuple[
pd.DataFrame, Optional[pd.DataFrame], Optional[pd.DataFrame]
]
_SingleNpArrayBatchType = Tuple[np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]

# FitInputType is type of [(feature, label), ...]
FitInputType = Union[List[_SinglePdDataFrameBatchType], List[_SingleNpArrayBatchType]]


def _method_names_from_param(spark_param_name: str) -> List[str]:
Expand Down Expand Up @@ -809,3 +816,112 @@ def getInputOrFeaturesCols(est: Union[Estimator, Transformer]) -> str:
else getattr(est, "getInputCol")
)
return getter()


def _standardize_dataset(
data: FitInputType, pdesc: PartitionDescriptor, fit_intercept: bool
) -> Tuple["cp.ndarray", "cp.ndarray"]:
"""Inplace standardize the dataset feature and optionally label columns
Args:
data: dataset to standardize (including features and label)
pdesc: Partition descriptor
fit_intercept: Whether to fit intercept in calling fit function.
Returns:
Mean and standard deviation of features and label columns (latter is last element if present)
Modifies data entries by replacing entries with standardized data on gpu.
If data is already on gpu, modifies in place (i.e. no copy is made).
"""
import cupy as cp

mean_partials_labels = (
cp.zeros(1, dtype=data[0][1].dtype) if data[0][1] is not None else None
)
mean_partials = [cp.zeros(pdesc.n, dtype=data[0][0].dtype), mean_partials_labels]
for i in range(len(data)):
_data = []
for j in range(2):
if data[i][j] is not None:

if isinstance(data[i][j], cp.ndarray):
_data.append(data[i][j]) # type: ignore
elif isinstance(data[i][j], np.ndarray):
_data.append(cp.array(data[i][j])) # type: ignore
elif isinstance(data[i][j], pd.DataFrame) or isinstance(
data[i][j], pd.Series
):
_data.append(cp.array(data[i][j].values)) # type: ignore
else:
raise ValueError("Unsupported data type: ", type(data[i][j]))
mean_partials[j] += _data[j].sum(axis=0) / pdesc.m # type: ignore
else:
_data.append(None)
data[i] = (_data[0], _data[1], data[i][2]) # type: ignore

import json

from pyspark import BarrierTaskContext

context = BarrierTaskContext.get()

def all_gather_then_sum(
cp_array: cp.ndarray, dtype: Union[np.float32, np.float64]
) -> cp.ndarray:
msgs = context.allGather(json.dumps(cp_array.tolist()))
arrays = [json.loads(p) for p in msgs]
array_sum = np.sum(arrays, axis=0).astype(dtype)
return cp.array(array_sum)

if mean_partials[1] is not None:
mean_partial = cp.concatenate(mean_partials) # type: ignore
else:
mean_partial = mean_partials[0]
mean = all_gather_then_sum(mean_partial, mean_partial.dtype)

_mean = (mean[:-1], mean[-1]) if mean_partials[1] is not None else (mean, None)

var_partials_labels = (
cp.zeros(1, dtype=data[0][1].dtype) if data[0][1] is not None else None
)
var_partials = [cp.zeros(pdesc.n, dtype=data[0][0].dtype), var_partials_labels]
for i in range(len(data)):
for j in range(2):
if data[i][j] is not None and _mean[j] is not None:
__data = data[i][j]
__data -= _mean[j] # type: ignore
l2 = cp.linalg.norm(__data, ord=2, axis=0)
var_partials[j] += l2 * l2 / (pdesc.m - 1)

if var_partials[1] is not None:
var_partial = cp.concatenate((var_partials[0], var_partials[1]))
else:
var_partial = var_partials[0]
var = all_gather_then_sum(var_partial, var_partial.dtype)

assert cp.all(
var >= 0
), "numeric instable detected when calculating variance. Got negative variance"

stddev = cp.sqrt(var)
stddev_inv = cp.where(stddev != 0, 1.0 / stddev, 1.0)
_stddev_inv = (
(stddev_inv[:-1], stddev_inv[-1])
if var_partials[1] is not None
else (stddev_inv, None)
)

if fit_intercept is False:
for i in range(len(data)):
for j in range(2):
if data[i][j] is not None and _mean[j] is not None:
__data = data[i][j]
__data += _mean[j] # type: ignore

for i in range(len(data)):
for j in range(2):
if data[i][j] is not None and _stddev_inv[j] is not None:
__data = data[i][j]
__data *= _stddev_inv[j] # type: ignore

return mean, stddev
Loading