Skip to content

Conversation

@eordentlich
Copy link
Collaborator

  • implement standardization in linear regression using cupy based data modification in a way that matches Spark
    (this greatly reduces the mismatch at the coefficient/intercept level to < 10% on a small existing unit test example - TBD to bridge the gap further)
  • patches an existing bug for gpu optimized fitMultiple when changed parameters included standardization for logisticregression and linearregression (as this is somewhat of a corner case, approach taken is to fall back to mllib fitMultiple but other more optimized approaches are possible in the future)

…k; avoid on-gpu fitmultiple when param maps standardize (for now)

Signed-off-by: Erik Ordentlich <[email protected]>
Signed-off-by: Erik Ordentlich <[email protected]>
@eordentlich
Copy link
Collaborator Author

build

@eordentlich eordentlich changed the title address a known mismatch with spark mllib lineargression when standardization is enabled. address a known mismatch with spark mllib linearegression when standardization is enabled. Nov 4, 2025
@eordentlich eordentlich changed the title address a known mismatch with spark mllib linearegression when standardization is enabled. address a known mismatch with spark mllib linearregression when standardization is enabled. Nov 4, 2025
@eordentlich
Copy link
Collaborator Author

build

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR implements standardization for linear regression using CuPy-based data modification to match Apache Spark MLlib behavior, significantly reducing coefficient/intercept mismatches. The implementation centers and scales features/labels on GPU, then applies inverse transformations to model coefficients and intercepts post-fitting.

Key Changes:

  • New _standardize_dataset utility function in utils.py handles distributed standardization across partitions using BarrierTaskContext
  • Linear regression now applies standardization adjustments: coefficients scaled by (coef_ / stddev_features) * stddev_label, intercept adjusted by mean centering
  • Ridge and ElasticNet alpha parameters scaled by stddev_label when standardization enabled
  • fitMultiple safeguard added to fall back to baseline when standardization or fitIntercept in param maps (prevents data modification conflicts)
  • Logistic regression refactored to use shared _standardize_dataset function
  • Comprehensive test coverage added including fallback behavior verification

Observations:

  • Tests show ~2-10% tolerance needed to match MLlib (improvement from previous behavior but gap remains)
  • The normalize param mapping in linear regression commented as needing a new type since standardization now handled in CuPy not CuML

Confidence Score: 3/5

  • Safe to merge with minor concerns around numeric precision and incomplete optimization for param sweeps
  • The implementation is mathematically sound and well-tested, but relies on relatively large tolerances (2-10%) to match MLlib results. The fallback mechanism for fitMultiple with standardization params is conservative and safe but suboptimal. TODOs indicate incomplete optimization paths.
  • python/src/spark_rapids_ml/regression.py - verify numeric precision issues; python/src/spark_rapids_ml/utils.py - ensure distributed standardization is numerically stable

Important Files Changed

File Analysis

Filename Score Overview
python/src/spark_rapids_ml/utils.py 4/5 Added _standardize_dataset function to handle CuPy-based standardization of features and labels. Function handles in-place GPU standardization with proper mean centering and scaling.
python/src/spark_rapids_ml/regression.py 3/5 Implemented standardization using CuPy to match Spark MLlib behavior. Applies mean/stddev adjustments to coefficients and intercept after fitting. Alpha parameter scaled by stddev_label for Ridge/ElasticNet.
python/src/spark_rapids_ml/core.py 4/5 Added safeguard in fitMultiple to fall back to baseline when standardization or fitIntercept are in param maps, preventing data modification issues with single-pass GPU optimization.
python/src/spark_rapids_ml/classification.py 4/5 Refactored to use new _standardize_dataset utility function, replacing inline CuPy standardization code with centralized implementation.

Sequence Diagram

sequenceDiagram
    participant User
    participant LinearRegression
    participant _linear_regression_fit
    participant _standardize_dataset
    participant BarrierTaskContext
    participant CuML
    
    User->>LinearRegression: fit(dataset)
    LinearRegression->>_linear_regression_fit: call with data partitions
    
    alt standardization enabled
        _linear_regression_fit->>_standardize_dataset: standardize(dfs, pdesc, fit_intercept)
        _standardize_dataset->>BarrierTaskContext: allGather(mean_partial)
        BarrierTaskContext-->>_standardize_dataset: global mean
        _standardize_dataset->>_standardize_dataset: center data (data -= mean)
        _standardize_dataset->>BarrierTaskContext: allGather(var_partial)
        BarrierTaskContext-->>_standardize_dataset: global variance
        _standardize_dataset->>_standardize_dataset: scale data (data *= 1/stddev)
        _standardize_dataset-->>_linear_regression_fit: mean, stddev
    end
    
    _linear_regression_fit->>CuML: fit(standardized_data)
    CuML-->>_linear_regression_fit: coef_, intercept_
    
    alt standardization enabled
        _linear_regression_fit->>_linear_regression_fit: coef_ = (coef_ / stddev_features) * stddev_label
        _linear_regression_fit->>_linear_regression_fit: intercept_ = intercept_ * stddev_label - dot(coef_, mean) + mean_label
    end
    
    _linear_regression_fit-->>LinearRegression: model parameters
    LinearRegression-->>User: trained model
Loading

9 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +641 to +645
intercept_ = (
intercept_ * stddev_label
- cp.dot(coef_, mean_features)
+ mean_label
).tolist()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: when standardization=True but fit_intercept=False, the intercept adjustment will still use mean_label which may not match expected behavior

the code adds back the mean before scaling (lines 914-919 in utils.py) when fit_intercept=False, but here the intercept calculation still subtracts cp.dot(coef_, mean_features) and adds mean_label even though means were added back before scaling

return {
"coef_": linear_regression.coef_.get().tolist(),
"intercept_": linear_regression.intercept_,
"coef_": coef_.tolist(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: missing .get() call before .tolist() when standardization=False

when standardization is disabled, coef_ remains as linear_regression.coef_ (a CuPy array). the old code used .get().tolist() to transfer from GPU to CPU. now only .tolist() is called which will fail or give incorrect results

Suggested change
"coef_": coef_.tolist(),
"coef_": coef_.get().tolist() if isinstance(coef_, cp.ndarray) else coef_.tolist(),

Comment on lines +628 to +629
coef_ = linear_regression.coef_
intercept_ = linear_regression.intercept_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: changed from .get().tolist() to .tolist()

previously linear_regression.coef_.get().tolist() explicitly transferred from GPU to CPU via .get(). now coef_.tolist() is called which relies on CuPy's .tolist() to handle GPU-to-CPU transfer implicitly. verify this works correctly in all cases

@eordentlich eordentlich merged commit 4685855 into NVIDIA:main Nov 5, 2025
4 checks passed
@eordentlich eordentlich deleted the eo_linear_standardization branch November 5, 2025 21:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants