diff --git a/src/fes/datasets/synthetic_dataset.py b/src/fes/datasets/synthetic_dataset.py index 695e246..c144db9 100644 --- a/src/fes/datasets/synthetic_dataset.py +++ b/src/fes/datasets/synthetic_dataset.py @@ -1,4 +1,5 @@ from typing import Dict, Any +from utils import calculate_snr import numpy as np from numpy import random @@ -145,13 +146,4 @@ def generate_grouped_data(n, m, noise_std, redundancy_rate, features_fill, num_g y_true = X @ w y = y_true + np.random.standard_normal((n)) * noise_std - return y, X, w, y_true, features_mask, groups_labels - - -""" -Support utils -""" - - -def calculate_snr(y_true, noise_std): - return (20 * np.log10(abs(np.where(noise_std == 0, 0, y_true / noise_std)))).mean() \ No newline at end of file + return y, X, w, y_true, features_mask, groups_labels \ No newline at end of file diff --git a/src/fes/datasets/utils.py b/src/fes/datasets/utils.py new file mode 100644 index 0000000..108eb2b --- /dev/null +++ b/src/fes/datasets/utils.py @@ -0,0 +1,11 @@ +import numpy as np + + + +""" +Support utils +""" + + +def calculate_snr(y_true, noise_std): + return (20 * np.log10(abs(np.where(noise_std == 0, 0, y_true / noise_std)))).mean() \ No newline at end of file diff --git a/src/fes/pipelines/data_science/nodes.py b/src/fes/pipelines/data_science/nodes.py index 4428ae5..889068f 100644 --- a/src/fes/pipelines/data_science/nodes.py +++ b/src/fes/pipelines/data_science/nodes.py @@ -5,6 +5,19 @@ from sklearn.metrics import mean_squared_error, r2_score +# Tool to check dimensionality +def assert_shapes(x, x_shape, y, y_shape): + assert_shape(x, x_shape) + assert_shape(y, y_shape) + + shapes = defaultdict(set) + for arr, shape in [(x, x_shape), (y, y_shape)]: + for i, char in enumerate(shape): + if isinstance(char, str): + shapes[char].add(arr.shape[i]) + for _, _set in shapes.items(): + assert len(_set) == 1, (x, x_shape, y, y_shape) + def fit_model(y, X): """ Parameters @@ -79,3 +92,4 @@ def evaluate_perm_importance(regressor, y, X, w, y_true, features_mask, paramete print(f"Approximation with {pi_parameters['explanation_rate']} explanation rate:") print(f"Number of proposed features: {len(features_hat_idx)}, {er_mse:.3f} MSE, {er_r2:.3f} R2", end='\n\n') + diff --git a/src/fes/pipelines/data_science/pipeline.py b/src/fes/pipelines/data_science/pipeline.py index 7e97e8c..6e14035 100644 --- a/src/fes/pipelines/data_science/pipeline.py +++ b/src/fes/pipelines/data_science/pipeline.py @@ -8,13 +8,19 @@ def perm_importance_pipeline(**kwargs): [ node( func=fit_model, - inputs=["y", "X"], + inputs=[ + "y", "X" + ], outputs="regressor", name="fit_model_node", ), node( func=evaluate_perm_importance, - inputs=["regressor", "y", "X", "w", "y_true", "features_mask", "parameters"], + inputs=[ + "regressor", + "y", "X", "w", "y_true", + "features_mask", "parameters" + ], outputs=None, name="evaluate_perm_importance_node", ),