Skip to content

Commit 8925aa6

Browse files
committed
[JTH] first version to try with real case scenario
1 parent 2383b91 commit 8925aa6

File tree

2 files changed

+48
-18
lines changed

2 files changed

+48
-18
lines changed

bluemath_tk/core/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ def set_num_processors_to_use(self, num_processors: int) -> None:
507507
num_processors = num_processors_available
508508
elif num_processors <= 0:
509509
raise ValueError("Number of processors must be greater than 0")
510-
elif (num_processors - num_processors_available) < 2:
510+
elif (num_processors_available - num_processors) < 2:
511511
raise ValueError(
512512
"Number of processors requested is less than 2 processors available"
513513
)

bluemath_tk/interpolation/rbf.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Callable, List, Tuple
33

44
import numpy as np
5+
import dask.array as da
56
import pandas as pd
67
from scipy.optimize import fmin, fminbound
78

@@ -278,6 +279,9 @@ def __init__(
278279
# Exclude attributes to .save_model() method
279280
self._exclude_attributes = []
280281

282+
# Row chunks for parallel computation
283+
self.row_chunks: int = None
284+
281285
@property
282286
def sigma_min(self) -> float:
283287
return self._sigma_min
@@ -714,24 +718,50 @@ def _rbf_variable_interpolation(
714718
The interpolated variable.
715719
"""
716720

717-
r = np.linalg.norm(
718-
normalized_dataset.values[:, np.newaxis, :]
719-
- self.normalized_subset_data.values[np.newaxis, :, :],
720-
axis=2,
721-
)
722-
kernel_values = self.kernel_func(r, opt_sigma)
723-
linear_part = np.dot(
724-
normalized_dataset.values,
725-
rbf_coeff[
726-
num_points_subset + 1 : num_points_subset + 1 + num_vars_subset
727-
].T,
728-
)
721+
# Calculate optimal chunk size based on memory
722+
norm_dataset = normalized_dataset.values
723+
norm_subset = self.normalized_subset_data.values
729724

730-
return (
731-
rbf_coeff[num_points_subset]
732-
+ np.dot(kernel_values, rbf_coeff[:num_points_subset])
733-
+ linear_part
734-
)
725+
if self.row_chunks is not None:
726+
chunks = (min(self.row_chunks, norm_dataset.shape[0]), -1)
727+
self.logger.info(f"Using row chunks of size {chunks[0]}")
728+
# elif self.num_workers > 1:
729+
# chunks = (norm_dataset.shape[0] // self.num_workers, -1)
730+
else:
731+
chunks = (norm_dataset.shape[0], -1)
732+
733+
# Convert to dask arrays for large operations
734+
d_dataset = da.from_array(norm_dataset, chunks=chunks)
735+
d_subset = da.from_array(norm_subset)
736+
737+
# Split computation into chunks
738+
result = []
739+
for i in range(0, len(d_dataset), chunks[0]):
740+
chunk = d_dataset[i : i + chunks[0]]
741+
742+
# Calculate r for this chunk
743+
r_chunk = da.linalg.norm(chunk[:, None, :] - d_subset[None, :, :], axis=2)
744+
745+
# Apply kernel and dot product
746+
kernel_values = self.kernel_func(r_chunk, opt_sigma)
747+
748+
# Compute this chunk's result
749+
chunk_result = (
750+
rbf_coeff[num_points_subset]
751+
+ da.dot(kernel_values, rbf_coeff[:num_points_subset])
752+
+ da.dot(
753+
chunk,
754+
rbf_coeff[
755+
num_points_subset + 1 : num_points_subset + 1 + num_vars_subset
756+
].T,
757+
)
758+
)
759+
760+
# Compute and append
761+
result.append(chunk_result.compute())
762+
763+
# Combine results
764+
return np.concatenate(result)
735765

736766
def _rbf_interpolate(
737767
self, dataset: pd.DataFrame, num_workers: int = None

0 commit comments

Comments
 (0)