@@ -669,14 +669,63 @@ def _calc_opt_sigma(
669669
670670 return rbf_coeff , opt_sigma
671671
672- def _rbf_interpolate (self , dataset : pd .DataFrame ) -> pd .DataFrame :
672+ def _rbf_variable_interpolation (
673+ self ,
674+ normalized_dataset : pd .DataFrame ,
675+ opt_sigma : float ,
676+ rbf_coeff : np .ndarray ,
677+ num_points_subset : int ,
678+ num_vars_subset : int ,
679+ ) -> np .ndarray :
680+ """
681+ Interpolates the surface for a variable.
682+
683+ normalized_dataset : pd.DataFrame
684+ The normalized dataset.
685+ opt_sigma : float
686+ The optimal sigma calculated for variable.
687+ rbf_coeff : np.ndarray
688+ The fitted coefficients for variable.
689+ num_points_subset : int
690+ The number of points used in the fitting.
691+ num_vars_subset : int
692+ The number of variables used in the fitting.
693+
694+ np.ndarray
695+ The interpolated variable.
696+ """
697+
698+ r = np .linalg .norm (
699+ normalized_dataset .values [:, np .newaxis , :]
700+ - self .normalized_subset_data .values [np .newaxis , :, :],
701+ axis = 2 ,
702+ )
703+ kernel_values = self .kernel_func (r , opt_sigma )
704+ linear_part = np .dot (
705+ normalized_dataset .values ,
706+ rbf_coeff [
707+ num_points_subset + 1 : num_points_subset + 1 + num_vars_subset
708+ ].T ,
709+ )
710+
711+ return (
712+ rbf_coeff [num_points_subset ]
713+ + np .dot (kernel_values , rbf_coeff [:num_points_subset ])
714+ + linear_part
715+ )
716+
717+ def _rbf_interpolate (
718+ self , dataset : pd .DataFrame , num_threads : int = None
719+ ) -> pd .DataFrame :
673720 """
674721 This function interpolates the dataset.
675722
676723 Parameters
677724 ----------
678725 dataset : pd.DataFrame
679726 The dataset to interpolate (must have same variables as subset).
727+ num_threads : int, optional
728+ The number of threads to use for the interpolation. Default is None.
680729
681730 Returns
682731 -------
@@ -698,28 +747,42 @@ def _rbf_interpolate(self, dataset: pd.DataFrame) -> pd.DataFrame:
698747 )
699748
700749 # Loop through the target variables
701- for i_var , target_var in enumerate (self .target_processed_variables ):
702- self .logger .info (f"Interpolating target variable { target_var } " )
703- rbf_coeff = self ._rbf_coeffs [target_var ].values
704- opt_sigma = self ._opt_sigmas [target_var ]
705- r = np .linalg .norm (
706- normalized_dataset .values [:, np .newaxis , :]
707- - self .normalized_subset_data .values [np .newaxis , :, :],
708- axis = 2 ,
709- )
710- kernel_values = self .kernel_func (r , opt_sigma )
711- linear_part = np .dot (
712- normalized_dataset .values ,
713- rbf_coeff [
714- num_points_subset + 1 : num_points_subset + 1 + num_vars_subset
715- ].T ,
716- )
717- s = (
718- rbf_coeff [num_points_subset ]
719- + np .dot (kernel_values , rbf_coeff [:num_points_subset ])
720- + linear_part
721- )
722- interpolated_array [:, i_var ] = s
750+ if num_threads is not None :
751+ # self.set_num_processors_to_use(num_processors=num_threads)
752+ num_threads = min (num_threads , self .get_num_processors_available ())
753+ self .logger .info (f"Using { num_threads } threads for interpolation." )
754+ with ThreadPoolExecutor (max_workers = num_threads ) as executor :
755+ rbf_variable_calculation = {
756+ executor .submit (
757+ self ._rbf_variable_interpolation ,
758+ normalized_dataset ,
759+ self ._opt_sigmas [target_var ],
760+ self ._rbf_coeffs [target_var ].values ,
761+ num_points_subset ,
762+ num_vars_subset ,
763+ ): (i_var , target_var )
764+ for i_var , target_var in enumerate (self .target_processed_variables )
765+ }
766+ for future in as_completed (rbf_variable_calculation ):
767+ i_rbf_var , rbf_variable = rbf_variable_calculation [future ]
768+ try :
769+ interpolated_var = future .result ()
770+ interpolated_array [:, i_rbf_var ] = interpolated_array
771+ except Exception as exc :
772+ self .logger .error (
773+ f"Job for { rbf_variable } generated an exception: { exc } ."
774+ )
775+ else :
776+ for i_var , target_var in enumerate (self .target_processed_variables ):
777+ self .logger .info (f"Interpolating target variable { target_var } " )
778+ interpolated_var = self ._rbf_variable_interpolation (
779+ normalized_dataset = normalized_dataset ,
780+ opt_sigma = self ._opt_sigmas [target_var ],
781+ rbf_coeff = self ._rbf_coeffs [target_var ].values ,
782+ num_points_subset = num_points_subset ,
783+ num_vars_subset = num_vars_subset ,
784+ )
785+ interpolated_array [:, i_var ] = interpolated_var
723786
724787 return pd .DataFrame (interpolated_array , columns = self .target_processed_variables )
725788
@@ -826,14 +889,16 @@ def fit(
826889 # Set the is_fitted attribute to True
827890 self .is_fitted = True
828891
829- def predict (self , dataset : pd .DataFrame ) -> pd .DataFrame :
892+ def predict (self , dataset : pd .DataFrame , num_threads : int = None ) -> pd .DataFrame :
830893 """
831894 Predicts the data for the provided dataset.
832895
833896 Parameters
834897 ----------
835898 dataset : pd.DataFrame
836899 The dataset to predict (must have same variables than subset).
900+ num_threads : int, optional
901+ The number of threads to use for the interpolation. Default is None.
837902
838903 Returns
839904 -------
@@ -857,7 +922,9 @@ def predict(self, dataset: pd.DataFrame) -> pd.DataFrame:
857922 raise RBFError ("RBF model must be fitted before predicting." )
858923
859924 self .logger .info ("Reconstructing data using fitted coefficients." )
860- interpolated_target = self ._rbf_interpolate (dataset = dataset )
925+ interpolated_target = self ._rbf_interpolate (
926+ dataset = dataset , num_threads = num_threads
927+ )
861928 if self .is_target_normalized :
862929 self .logger .info ("Denormalizing target data" )
863930 interpolated_target = self .denormalize (
@@ -870,6 +937,7 @@ def predict(self, dataset: pd.DataFrame) -> pd.DataFrame:
870937 xu = interpolated_target [f"{ directional_variable } _u" ].values ,
871938 xv = interpolated_target [f"{ directional_variable } _v" ].values ,
872939 )
940+
873941 return interpolated_target
874942
875943 def fit_predict (
@@ -933,4 +1001,4 @@ def fit_predict(
9331001 iteratively_update_sigma = iteratively_update_sigma ,
9341002 )
9351003
936- return self .predict (dataset = dataset )
1004+ return self .predict (dataset = dataset , num_threads = num_threads )
0 commit comments