|
43 | 43 | from pyspark.sql.types import ArrayType, FloatType |
44 | 44 |
|
45 | 45 | _ArrayOrder = Literal["C", "F"] |
| 46 | +_SinglePdDataFrameBatchType = Tuple[ |
| 47 | + pd.DataFrame, Optional[pd.DataFrame], Optional[pd.DataFrame] |
| 48 | +] |
| 49 | +_SingleNpArrayBatchType = Tuple[np.ndarray, Optional[np.ndarray], Optional[np.ndarray]] |
| 50 | + |
| 51 | +# FitInputType is type of [(feature, label), ...] |
| 52 | +FitInputType = Union[List[_SinglePdDataFrameBatchType], List[_SingleNpArrayBatchType]] |
46 | 53 |
|
47 | 54 |
|
48 | 55 | def _method_names_from_param(spark_param_name: str) -> List[str]: |
@@ -809,3 +816,112 @@ def getInputOrFeaturesCols(est: Union[Estimator, Transformer]) -> str: |
809 | 816 | else getattr(est, "getInputCol") |
810 | 817 | ) |
811 | 818 | return getter() |
| 819 | + |
| 820 | + |
| 821 | +def _standardize_dataset( |
| 822 | + data: FitInputType, pdesc: PartitionDescriptor, fit_intercept: bool |
| 823 | +) -> Tuple["cp.ndarray", "cp.ndarray"]: |
| 824 | + """Inplace standardize the dataset feature and optionally label columns |
| 825 | +
|
| 826 | + Args: |
| 827 | + data: dataset to standardize (including features and label) |
| 828 | + pdesc: Partition descriptor |
| 829 | + fit_intercept: Whether to fit intercept in calling fit function. |
| 830 | +
|
| 831 | + Returns: |
| 832 | + Mean and standard deviation of features and label columns (latter is last element if present) |
| 833 | + Modifies data entries by replacing entries with standardized data on gpu. |
| 834 | + If data is already on gpu, modifies in place (i.e. no copy is made). |
| 835 | + """ |
| 836 | + import cupy as cp |
| 837 | + |
| 838 | + mean_partials_labels = ( |
| 839 | + cp.zeros(1, dtype=data[0][1].dtype) if data[0][1] is not None else None |
| 840 | + ) |
| 841 | + mean_partials = [cp.zeros(pdesc.n, dtype=data[0][0].dtype), mean_partials_labels] |
| 842 | + for i in range(len(data)): |
| 843 | + _data = [] |
| 844 | + for j in range(2): |
| 845 | + if data[i][j] is not None: |
| 846 | + |
| 847 | + if isinstance(data[i][j], cp.ndarray): |
| 848 | + _data.append(data[i][j]) # type: ignore |
| 849 | + elif isinstance(data[i][j], np.ndarray): |
| 850 | + _data.append(cp.array(data[i][j])) # type: ignore |
| 851 | + elif isinstance(data[i][j], pd.DataFrame) or isinstance( |
| 852 | + data[i][j], pd.Series |
| 853 | + ): |
| 854 | + _data.append(cp.array(data[i][j].values)) # type: ignore |
| 855 | + else: |
| 856 | + raise ValueError("Unsupported data type: ", type(data[i][j])) |
| 857 | + mean_partials[j] += _data[j].sum(axis=0) / pdesc.m # type: ignore |
| 858 | + else: |
| 859 | + _data.append(None) |
| 860 | + data[i] = (_data[0], _data[1], data[i][2]) # type: ignore |
| 861 | + |
| 862 | + import json |
| 863 | + |
| 864 | + from pyspark import BarrierTaskContext |
| 865 | + |
| 866 | + context = BarrierTaskContext.get() |
| 867 | + |
| 868 | + def all_gather_then_sum( |
| 869 | + cp_array: cp.ndarray, dtype: Union[np.float32, np.float64] |
| 870 | + ) -> cp.ndarray: |
| 871 | + msgs = context.allGather(json.dumps(cp_array.tolist())) |
| 872 | + arrays = [json.loads(p) for p in msgs] |
| 873 | + array_sum = np.sum(arrays, axis=0).astype(dtype) |
| 874 | + return cp.array(array_sum) |
| 875 | + |
| 876 | + if mean_partials[1] is not None: |
| 877 | + mean_partial = cp.concatenate(mean_partials) # type: ignore |
| 878 | + else: |
| 879 | + mean_partial = mean_partials[0] |
| 880 | + mean = all_gather_then_sum(mean_partial, mean_partial.dtype) |
| 881 | + |
| 882 | + _mean = (mean[:-1], mean[-1]) if mean_partials[1] is not None else (mean, None) |
| 883 | + |
| 884 | + var_partials_labels = ( |
| 885 | + cp.zeros(1, dtype=data[0][1].dtype) if data[0][1] is not None else None |
| 886 | + ) |
| 887 | + var_partials = [cp.zeros(pdesc.n, dtype=data[0][0].dtype), var_partials_labels] |
| 888 | + for i in range(len(data)): |
| 889 | + for j in range(2): |
| 890 | + if data[i][j] is not None and _mean[j] is not None: |
| 891 | + __data = data[i][j] |
| 892 | + __data -= _mean[j] # type: ignore |
| 893 | + l2 = cp.linalg.norm(__data, ord=2, axis=0) |
| 894 | + var_partials[j] += l2 * l2 / (pdesc.m - 1) |
| 895 | + |
| 896 | + if var_partials[1] is not None: |
| 897 | + var_partial = cp.concatenate((var_partials[0], var_partials[1])) |
| 898 | + else: |
| 899 | + var_partial = var_partials[0] |
| 900 | + var = all_gather_then_sum(var_partial, var_partial.dtype) |
| 901 | + |
| 902 | + assert cp.all( |
| 903 | + var >= 0 |
| 904 | + ), "numeric instable detected when calculating variance. Got negative variance" |
| 905 | + |
| 906 | + stddev = cp.sqrt(var) |
| 907 | + stddev_inv = cp.where(stddev != 0, 1.0 / stddev, 1.0) |
| 908 | + _stddev_inv = ( |
| 909 | + (stddev_inv[:-1], stddev_inv[-1]) |
| 910 | + if var_partials[1] is not None |
| 911 | + else (stddev_inv, None) |
| 912 | + ) |
| 913 | + |
| 914 | + if fit_intercept is False: |
| 915 | + for i in range(len(data)): |
| 916 | + for j in range(2): |
| 917 | + if data[i][j] is not None and _mean[j] is not None: |
| 918 | + __data = data[i][j] |
| 919 | + __data += _mean[j] # type: ignore |
| 920 | + |
| 921 | + for i in range(len(data)): |
| 922 | + for j in range(2): |
| 923 | + if data[i][j] is not None and _stddev_inv[j] is not None: |
| 924 | + __data = data[i][j] |
| 925 | + __data *= _stddev_inv[j] # type: ignore |
| 926 | + |
| 927 | + return mean, stddev |
0 commit comments