diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml deleted file mode 100644 index 959ba588..00000000 --- a/.github/workflows/ci.yml +++ /dev/null @@ -1,33 +0,0 @@ -name: Ruff CI - -on: - push: - branches: - - main - - development - pull_request: - branches: - - main - - development - -jobs: - lint: - runs-on: ubuntu-latest - - steps: - - name: Checkout code - uses: actions/checkout@v3 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: "3.10.17" - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install ruff - - - name: Run Ruff - run: | - ruff check --line-length 127 --statistics diff --git a/configs/prediction_models/LGBMClassifier.gin b/configs/prediction_models/LGBMClassifier.gin index f29f40cc..0ea5d245 100644 --- a/configs/prediction_models/LGBMClassifier.gin +++ b/configs/prediction_models/LGBMClassifier.gin @@ -6,11 +6,54 @@ include "configs/prediction_models/common/MLCommon.gin" # Train params train_common.model = @LGBMClassifier +# Hyperparameter tuning configuration model/hyperparameter.class_to_tune = @LGBMClassifier -model/hyperparameter.colsample_bytree = (0.33, 1.0) -model/hyperparameter.max_depth = (3, 7) -model/hyperparameter.min_child_samples = 1000 -model/hyperparameter.n_estimators = 100000 -model/hyperparameter.num_leaves = (8, 128, "log", 2) -model/hyperparameter.subsample = (0.33, 1.0) -model/hyperparameter.subsample_freq = 1 + +# Core tree parameters +model/hyperparameter.n_estimators = (500, 2000, 5000, 10000, 100000) +model/hyperparameter.max_depth = (3, 5, 7, 10, 15) +model/hyperparameter.num_leaves = (8, 16, 31, 64, 128, 256, "log", 2) +model/hyperparameter.min_child_samples = (10, 20, 50, 100, 500, 1000) +model/hyperparameter.min_child_weight = (1e-3, 10.0, "log") + +# Learning rate and regularization +model/hyperparameter.learning_rate = (0.01, 0.3, "log") +model/hyperparameter.reg_alpha = (1e-6, 1.0, "log") +model/hyperparameter.reg_lambda = (1e-6, 1.0, "log") + +# Sampling parameters +model/hyperparameter.subsample = (0.4, 1.0) +model/hyperparameter.subsample_freq = (0, 1, 5, 10) +model/hyperparameter.colsample_bytree = (0.4, 1.0) +model/hyperparameter.colsample_bynode = (0.4, 1.0) + +# Boosting parameters +model/hyperparameter.boosting_type = ["gbdt", "dart"] + +# Advanced DART parameters (active when boosting_type="dart") +model/hyperparameter.drop_rate = (0.1, 0.5) +model/hyperparameter.max_drop = (10, 50) +model/hyperparameter.skip_drop = (0.1, 0.9) + +# GOSS parameters (active when boosting_type="goss") +model/hyperparameter.top_rate = (0.1, 0.5) +model/hyperparameter.other_rate = (0.05, 0.2) + +# Performance and stability +model/hyperparameter.feature_fraction = (0.4, 1.0) +model/hyperparameter.bagging_fraction = (0.4, 1.0) +model/hyperparameter.bagging_freq = (0, 1, 5, 10) +model/hyperparameter.min_split_gain = (1e-6, 1.0, "log") + +# Categorical handling +model/hyperparameter.cat_smooth = (1.0, 100.0, "log") +model/hyperparameter.cat_l2 = (1.0, 100.0, "log") + +# Early stopping and validation +model/hyperparameter.early_stopping_rounds = 100 +model/hyperparameter.eval_metric = ["binary_logloss", "auc", "binary_error"] + +# Class imbalance handling +model/hyperparameter.is_unbalance = [True, False] +model/hyperparameter.scale_pos_weight = (0.1, 10.0, "log") + diff --git a/configs/prediction_models/RUSBClassifier.gin b/configs/prediction_models/RUSBClassifier.gin index e8f17722..2516fb44 100644 --- a/configs/prediction_models/RUSBClassifier.gin +++ b/configs/prediction_models/RUSBClassifier.gin @@ -1,4 +1,4 @@ -# Settings for ImbLearn Balanced Random Forest Classifier. +# Settings for ImbLearn RUSBoost Classifier (Random Under-sampling with Boosting) # Common settings for ML models include "configs/prediction_models/common/MLCommon.gin" @@ -6,9 +6,23 @@ include "configs/prediction_models/common/MLCommon.gin" # Train params train_common.model = @RUSBClassifier +# Hyperparameter tuning configuration model/hyperparameter.class_to_tune = @RUSBClassifier -model/hyperparameter.n_estimators = (10, 50, 100, 200, 500) -model/hyperparameter.learning_rate = (0.005, 1, "log") -model/hyperparameter.sampling_strategy = "auto" +# Number of estimators (boosting rounds) +model/hyperparameter.n_estimators = (50, 100, 200, 300, 500) + +# Learning rate for boosting +model/hyperparameter.learning_rate = (0.01, 2.0, "log") + +# Sampling strategy for random under-sampling +model/hyperparameter.sampling_strategy = ["auto", "majority", "not minority"] + +# Base estimator parameters (typically DecisionTreeClassifier) +model/hyperparameter.base_estimator__max_depth = [1, 2, 3, 4, 5, 6] +model/hyperparameter.base_estimator__min_samples_split = [2, 5, 10, 20] +model/hyperparameter.base_estimator__min_samples_leaf = [1, 2, 5, 10] + +# Replacement strategy for under-sampling +model/hyperparameter.replacement = [True, False] diff --git a/configs/prediction_models/XGBClassifier.gin b/configs/prediction_models/XGBClassifier.gin index f1070672..74099f09 100644 --- a/configs/prediction_models/XGBClassifier.gin +++ b/configs/prediction_models/XGBClassifier.gin @@ -8,10 +8,14 @@ train_common.model = @XGBClassifier model/hyperparameter.class_to_tune = @XGBClassifier model/hyperparameter.learning_rate = (0.01, 0.1, "log") -model/hyperparameter.n_estimators = [50, 100, 250, 500, 750, 1000,1500,2000] +model/hyperparameter.n_estimators = [50, 100, 250, 500, 750, 1000,1500,2000, 2500, 3000, 3500, 4000, 4500, 5000] model/hyperparameter.max_depth = [3, 5, 10, 15] model/hyperparameter.scale_pos_weight = [1, 5, 10, 15, 20, 25, 30, 35, 40, 50, 75, 99, 100, 1000] -model/hyperparameter.min_child_weight = [1, 0.5] +model/hyperparameter.min_child_weight = [0.1, 0.5, 1, 2, 5, 10] model/hyperparameter.max_delta_step = [0, 1, 2, 3, 4, 5, 10] model/hyperparameter.colsample_bytree = [0.1, 0.25, 0.5, 0.75, 1.0] -model/hyperparameter.eval_metric = "aucpr" \ No newline at end of file +# model/hyperparameter.eval_metric = "aucpr" +model/hyperparameter.gamma = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 2.0] +# model/hyperparameter.early_stopping_rounds = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100] +model/hyperparameter.reg_lambda = [0, 0.01, 0.1, 1, 10, 100] +model/hyperparameter.reg_alpha = [0, 0.01, 0.1, 1, 10, 100] \ No newline at end of file diff --git a/configs/prediction_models/XGBClassifierGPU.gin b/configs/prediction_models/XGBClassifierGPU.gin new file mode 100644 index 00000000..482f9972 --- /dev/null +++ b/configs/prediction_models/XGBClassifierGPU.gin @@ -0,0 +1,21 @@ +# Settings for XGBoost classifier. + +# Common settings for ML models +include "configs/prediction_models/common/MLCommon.gin" + +# Train params +train_common.model = @XGBClassifierGPU + +model/hyperparameter.class_to_tune = @XGBClassifierGPU +model/hyperparameter.learning_rate = (0.01, 0.1, "log") +model/hyperparameter.n_estimators = [50, 100, 250, 500, 750, 1000,1500,2000, 2500, 3000, 3500, 4000, 4500, 5000] +model/hyperparameter.max_depth = [3, 5, 10, 15] +model/hyperparameter.scale_pos_weight = [1, 5, 10, 15, 20, 25, 30, 35, 40, 50, 75, 99, 100, 1000] +model/hyperparameter.min_child_weight = [0.1, 0.5, 1, 2, 5, 10] +model/hyperparameter.max_delta_step = [0, 1, 2, 3, 4, 5, 10] +model/hyperparameter.colsample_bytree = [0.1, 0.25, 0.5, 0.75, 1.0] +# model/hyperparameter.eval_metric = "aucpr" +model/hyperparameter.gamma = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 2.0] +# model/hyperparameter.early_stopping_rounds = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100] +model/hyperparameter.reg_lambda = [0, 0.01, 0.1, 1, 10, 100] +model/hyperparameter.reg_alpha = [0, 0.01, 0.1, 1, 10, 100] \ No newline at end of file diff --git a/configs/prediction_models/common/MLTuning.gin b/configs/prediction_models/common/MLTuning.gin index 9df38c47..51361f39 100644 --- a/configs/prediction_models/common/MLTuning.gin +++ b/configs/prediction_models/common/MLTuning.gin @@ -1,5 +1,6 @@ # Hyperparameter tuner settings for classical Machine Learning. tune_hyperparameters.scopes = ["model"] tune_hyperparameters.n_initial_points = 5 -tune_hyperparameters.n_calls = 30 -tune_hyperparameters.folds_to_tune_on = 5 \ No newline at end of file +tune_hyperparameters.n_calls = 100 +tune_hyperparameters.folds_to_tune_on = 1 +tune_hyperparameters.repetitions_to_tune_on = 5 \ No newline at end of file diff --git a/configs/tasks/BinaryClassification.gin b/configs/tasks/BinaryClassification.gin index f86436a4..3bd32779 100644 --- a/configs/tasks/BinaryClassification.gin +++ b/configs/tasks/BinaryClassification.gin @@ -22,6 +22,12 @@ preprocess.preprocessor = @base_classification_preprocessor preprocess.modality_mapping = %modality_mapping preprocess.vars = %vars preprocess.use_static = True +preprocess.required_segments = ["OUTCOME", "STATIC"] +preprocess.file_names = { + "DYNAMIC": "dyn.parquet", + "OUTCOME": "outc.parquet", + "STATIC": "sta.parquet", +} # SELECTING DATASET include "configs/tasks/common/Dataloader.gin" diff --git a/configs/tasks/common/Dataloader.gin b/configs/tasks/common/Dataloader.gin index 6bed1b7e..8e0a14a2 100644 --- a/configs/tasks/common/Dataloader.gin +++ b/configs/tasks/common/Dataloader.gin @@ -3,6 +3,8 @@ PredictionPandasDataset.vars = %vars PredictionPandasDataset.ram_cache = True PredictionPolarsDataset.vars = %vars PredictionPolarsDataset.ram_cache = True +PredictionPolarsDataset.mps = True # Imputation ImputationPandasDataset.vars = %vars -ImputationPandasDataset.ram_cache = True \ No newline at end of file +ImputationPandasDataset.ram_cache = True +PredictionPolarsDataset.mps = True \ No newline at end of file diff --git a/demo_data/mortality24/mimic_demo_static/attrition.csv b/demo_data/mortality24/mimic_demo_static/attrition.csv new file mode 100644 index 00000000..2ef75d4e --- /dev/null +++ b/demo_data/mortality24/mimic_demo_static/attrition.csv @@ -0,0 +1,3 @@ +incl_n,excl_n_total,excl_n +125,10,7 +99,34,26 diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 00000000..d0c3cbf1 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 00000000..dc1312ab --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 00000000..a4401b59 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,4 @@ +sphinx +sphinx-rtd-theme +sphinx-autoapi +sphinx-autobuild \ No newline at end of file diff --git a/docs/source/adding_models.rst b/docs/source/adding_models.rst new file mode 100644 index 00000000..e445a574 --- /dev/null +++ b/docs/source/adding_models.rst @@ -0,0 +1,226 @@ +Adding new models to YAIB +========================== + +Example +------- + +We refer to the page `adding a new model `_ for detailed instructions on adding new models. +We allow prediction models to be easily added and integrated into a Pytorch Lightning module. This +incorporates advanced logging and debugging capabilities, as well as +built-in parallelism. Our interface derives from the `BaseModule `_. + +Adding a model consists of three steps: + +1. Add a model through the existing ``MLPredictionWrapper`` or ``DLPredictionWrapper``. +2. Add a GIN config file to bind hyperparameters. +3. Execute YAIB using a simple command. + +This folder contains everything you need to add a model to YAIB. +Putting the ``RNN.gin`` file in ``configs/prediction_models`` and the ``rnn.py`` file into ``icu_benchmarks/models`` allows you to run the model fully. + +.. code-block:: bash + + icu-benchmarks train \ + -d demo_data/mortality24/mimic_demo \ # Insert cohort dataset here + -n mimic_demo \ + -t BinaryClassification \ # Insert task name here + -tn Mortality24 \ + --log-dir ../yaib_logs/ \ + -m RNN \ # Insert model here + -s 2222 \ + -l ../yaib_logs/ \ + --tune + +Adding more models +================== + +Regular ML +---------- + +For standard Scikit-Learn type models (e.g., LGBM), one can +simply wrap ``MLPredictionWrapper`` the function with minimal code +overhead. Many ML (and some DL) models can be incorporated this way, requiring minimal code additions. See below. + +.. code-block:: python + :caption: Example ML model definition + + @gin.configurable + class RFClassifier(MLWrapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model = self.model_args() + + @gin.configurable(module="RFClassifier") + def model_args(self, *args, **kwargs): + return RandomForestClassifier(*args, **kwargs) + +Adding DL models +---------------- + +It is relatively straightforward to add new Pytorch models to YAIB. We first provide a standard RNN-model which needs no extra components. Then, we show the implementation of the Temporal Fusion Transformer model. + +Standard RNN-model +~~~~~~~~~~~~~~~~~~ + +The definition of dl models can be done by creating a subclass from the +``DLPredictionWrapper``, inherits the standard methods needed for +training dl learning models. Pytorch Lightning significantly reduces the code +overhead. + +.. code-block:: python + :caption: Example DL model definition + + @gin.configurable + class RNNet(DLPredictionWrapper): + """Torch standard RNN model""" + + def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): + super().__init__( + input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs + ) + self.hidden_dim = hidden_dim + self.layer_dim = layer_dim + self.rnn = nn.RNN(input_size[2], hidden_dim, layer_dim, batch_first=True) + self.logit = nn.Linear(hidden_dim, num_classes) + + def init_hidden(self, x): + h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) + return h0 + + def forward(self, x): + h0 = self.init_hidden(x) + out, hn = self.rnn(x, h0) + pred = self.logit(out) + return pred + +Adding a SOTA model: Temporal Fusion Transformer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +There are two main questions when you want to add a more complex model: + +* *Do you want to manually define the model or use an existing library?* This might require adapting the ``DLPredictionWrapper``. +* *Does the model expect the data to be in a certain format?* This might require adapting the ``PredictionDataset``. + +By adapting, we mean creating a new subclass that inherits most functionality to avoid code duplication, is future-proof, and follows good coding practices. + +First, you can add modules to ``models/layers.py`` to use them for your model. + +.. code-block:: python + :caption: Example building block + + class StaticCovariateEncoder(nn.Module): + """ + Network to produce 4 context vectors to enrich static variables + Variable selection Network --> GRNs + """ + + def __init__(self, num_static_vars, hidden, dropout): + super().__init__() + self.vsn = VariableSelectionNetwork(hidden, dropout, num_static_vars) + self.context_grns = nn.ModuleList([GRN(hidden, hidden, dropout=dropout) for _ in range(4)]) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + variable_ctx, sparse_weights = self.vsn(x) + + # Context vectors: + # variable selection context + # enrichment context + # state_c context + # state_h context + cs, ce, ch, cc = [m(variable_ctx) for m in self.context_grns] + + return cs, ce, ch, cc + +Note that we can create modules out of modules as well. + +Adapting the ``DLPredictionWrapper`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The next step is to use the building blocks defined in ``layers.py`` or modules from an existing library to add to the model in ``models/dl_models.py``. In this case, we use the Pytorch-forecasting library (https://github.com/jdb78/pytorch-forecasting): + +.. code-block:: python + :caption: Example DL model definition + + class TFTpytorch(DLPredictionWrapper): + + supported_run_modes = [RunMode.classification, RunMode.regression] + + def __init__(self, dataset, hidden, dropout, n_heads, dropout_att, lr, optimizer, num_classes, *args, **kwargs): + super().__init__(lr=lr, optimizer=optimizer, *args, **kwargs) + self.model = TemporalFusionTransformer.from_dataset( + dataset=dataset) + self.logit = nn.Linear(7, num_classes) + + + def forward(self, x): + out = self.model(x) + pred = self.logit(out["prediction"]) + return pred + +Adapting the ``PredictionDataset`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Some models require an adjusted dataloader to facilitate, for example, explainability methods. In this case, changes need to be made to the ``data/loader.py`` file to ensure the data loader returns the data in the correct format. +This can be done by creating a class that inherits from ``PredictionDataset`` and editing the ``get_item`` method. + +.. code-block:: python + :caption: Example custom dataset definition + + @gin.configurable("PredictionDatasetTFT") + class PredictionDatasetTFT(PredictionDataset): + def __init__(self, *args, ram_cache: bool = True, **kwargs): + super().__init__(*args, ram_cache=True, **kwargs) + + def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]: + """Function to sample from the data split of choice. Used for TFT. + The data needs to be given to the model in the following order + [static categorical, static continuous,known categorical,known continuous, observed categorical, observed continuous,target,id] + +Then, you must check ``models/wrapper.py``, particularly the ``step_fn`` method, to ensure the data is correctly transferred to the device. + +Adding the model config GIN file +================================= + +To define hyperparameters for each model in a standardized manner, we use GIN-config. We need to specify a GIN file to bind the parameters to train and optimize this model from a choice of hyperparameters. Note that we can use modifiers for the optimizer (e.g, Adam optimizer) and ranges that we can specify in rounded brackets "()" . Square brackets, "[]", result in a random choice where the variable is uniformly sampled. + +.. code-block:: gin + + # Hyperparameters for TFT model. + + # Common settings for DL models + include "configs/prediction_models/common/DLCommon.gin" + + # Optimizer params + train_common.model = @TFT + + optimizer/hyperparameter.class_to_tune = @Adam + optimizer/hyperparameter.weight_decay = 1e-6 + optimizer/hyperparameter.lr = (1e-5, 3e-4) + + # Encoder params + model/hyperparameter.class_to_tune = @TFT + model/hyperparameter.encoder_length = 24 + model/hyperparameter.hidden = 256 + model/hyperparameter.num_classes = %NUM_CLASSES + model/hyperparameter.dropout = (0.0, 0.4) + model/hyperparameter.dropout_att = (0.0, 0.4) + model/hyperparameter.n_heads =4 + model/hyperparameter.example_length=25 + +Training the model +================== + +After these steps, your model should be trainable with the following command: + +.. code-block:: bash + + icu-benchmarks train \ + -d demo_data/mortality24/mimic_demo \ # Insert cohort dataset here + -n mimic_demo \ + -t BinaryClassification \ # Insert task name here + -tn Mortality24 \ + --log-dir ../yaib_logs/ \ + -m TFT \ # Insert model here + -s 2222 \ + -l ../yaib_logs/ \ + --tune \ No newline at end of file diff --git a/docs/source/api.rst b/docs/source/api.rst new file mode 100644 index 00000000..45c57ae3 --- /dev/null +++ b/docs/source/api.rst @@ -0,0 +1,122 @@ +API Reference +============= + +This page contains the API reference for YAIB (Yet Another ICU Benchmark). + +Core Modules +------------ + +.. automodule:: icu_benchmarks + :members: + :undoc-members: + :show-inheritance: + +Run Utilities +~~~~~~~~~~~~~ + +.. automodule:: icu_benchmarks.run + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: icu_benchmarks.run_utils + :members: + :undoc-members: + :show-inheritance: + +Utilities +~~~~~~~~~ + +.. automodule:: icu_benchmarks.utils + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: icu_benchmarks.wandb_utils + :members: + :undoc-members: + :show-inheritance: + +Cross Validation +~~~~~~~~~~~~~~~~ + +.. automodule:: icu_benchmarks.cross_validation + :members: + :undoc-members: + :show-inheritance: + +Constants +~~~~~~~~~ + +.. automodule:: icu_benchmarks.constants + :members: + :undoc-members: + :show-inheritance: + +Data Processing +--------------- + +.. automodule:: icu_benchmarks.data + :members: + :undoc-members: + :show-inheritance: + +Models +------ + +.. automodule:: icu_benchmarks.models + :members: + :undoc-members: + :show-inheritance: + +Imputation +---------- + +.. automodule:: icu_benchmarks.imputation + :members: + :undoc-members: + :show-inheritance: + +Hyperparameter Tuning +---------------------- + +.. automodule:: icu_benchmarks.tuning + :members: + :undoc-members: + :show-inheritance: + +Scripts +------- + +Evaluation Scripts +~~~~~~~~~~~~~~~~~~ + +.. automodule:: scripts.evaluate_results + :members: + :undoc-members: + :show-inheritance: + +Plotting Scripts +~~~~~~~~~~~~~~~~ + +.. automodule:: scripts.plotting + :members: + :undoc-members: + :show-inheritance: + +Sample Usage +~~~~~~~~~~~~ + +.. automodule:: scripts.sample_usage + :members: + :undoc-members: + :show-inheritance: + +Sweep Configurations +~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: scripts.sweep_configs + :members: + :undoc-members: + :show-inheritance: + diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 00000000..88cd230d --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,76 @@ +import os +import sys +sys.path.insert(0, os.path.abspath('../../')) # Add your project root to Python path +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = 'Yet Another ICU Benchmark' +copyright = '2025, Robin P. van de Water, Hendrik Schmidt, Patrick Rockenschaub, MIT License' +author = 'Robin P. van de Water, Hendrik Schmidt, Patrick Rockenschaub' +release = '1.0' +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.autosummary', + 'sphinx.ext.viewcode', + 'sphinx.ext.napoleon', # For Google/NumPy style docstrings + 'sphinx_immaterial' +] +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + + +templates_path = ['_templates'] +exclude_patterns = [] + + + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +# Theme configuration +html_theme = 'sphinx_immaterial' +# Logo configuration +html_logo = '../figures/yaib_logo.png' # Path to your logo file +html_favicon = '../figures/yaib_logo.png' # Optional: favicon +# Theme options +# html_theme_options = { +# 'canonical_url': '', +# 'analytics_id': '', +# 'logo_only': False, +# 'display_version': True, +# 'prev_next_buttons_location': 'bottom', +# 'style_external_links': False, +# 'vcs_pageview_mode': '', +# 'style_nav_header_background': 'white', +# # Toc options +# 'collapse_navigation': True, +# 'sticky_navigation': True, +# 'navigation_depth': 4, +# 'includehidden': True, +# 'titles_only': False, +# 'body_max_width': 'none', +# 'page_width': 'auto', +# } +# sphinx_immaterial theme options + + +html_static_path = ['_static'] + + + +# Autodoc settings +autodoc_default_options = { + 'members': True, + 'member-order': 'bysource', + 'special-members': '__init__', + 'undoc-members': True, + 'exclude-members': '__weakref__' +} + +# Generate autosummary even if no references +autosummary_generate = True \ No newline at end of file diff --git a/docs/source/development.rst b/docs/source/development.rst new file mode 100644 index 00000000..37f4d52a --- /dev/null +++ b/docs/source/development.rst @@ -0,0 +1,55 @@ +Development +=========== + +YAIB is in active development. The following sections could be relevant for adding new code to our repository. + +Libraries +--------- + +The following libraries are important to the operation of YAIB: + +Core Dependencies +~~~~~~~~~~~~~~~~~ + +- `Pandas `_: Popular data structure framework. +- `ReciPys `_: A modular preprocessing package for Pandas dataframes. +- `Pytorch `_: An open source machine learning framework for deep learning applications. +- `Pytorch Lightning `_: A lightweight Pytorch wrapper for AI research. +- `Pytorch Ignite `_: Library for training and evaluating neural networks in Pytorch. +- `Cuda Toolkit `_: GPU acceleration used for deep learning models. +- `Scikit-learn `_: Machine learning library. +- `Scikit-optimize `_: Used for Bayesian optimization. +- `LightGBM `_: Gradient boosting framework. +- `GIN `_: Provides a lightweight configuration framework for Python. +- `Wandb `_: A tool for visualizing and tracking machine learning experiments. +- `Pytest `_: A testing framework for Python. + +Imputation Libraries +~~~~~~~~~~~~~~~~~~~~ + +- `HyperImpute `_: Imputation library for MissForest and GAIN. +- `PyPOTS `_: Imputation library. + +Running Tests +------------- + +To run the test suite: + +.. code-block:: bash + + python -m pytest ./tests/recipes + coverage run -m pytest ./tests/recipes + + # then use either of the following + coverage report + coverage html + +Code Formatting and Linting +---------------------------- + +For development purposes, we use the ``Black`` package to autoformat our code and a ``Flake8`` linting/CI check: + +.. code-block:: bash + + black . -l 127 + flake8 . --count --max-complexity=14 --max-line-length=127 --statistics \ No newline at end of file diff --git a/docs/source/imputation_methods.rst b/docs/source/imputation_methods.rst new file mode 100644 index 00000000..4b04c553 --- /dev/null +++ b/docs/source/imputation_methods.rst @@ -0,0 +1,100 @@ +=========================== +Adding New Imputation Models +============================ + +To add another imputation model, you have to create a class that inherits from ``ImputationWrapper`` in ``icu_benchmarks.models.wrappers``. Your model class should look like this: + +.. code-block:: python + + from icu_benchmarks.models.wrappers import ImputationWrapper + import gin + + + @gin.configurable("newmethod") + class New_Method(ImputationWrapper): + # adjust this accordingly + # if true, the method is trained iteratively (like a deep learning model). + # If false it receives the complete training data to perform a fit on + requires_backprop = False + + def __init__(self, *args, model_arg1, model_arg2, **kwargs): + super().__init__(*args, **kwargs) + # define your new model here + self.model = ... + + # the following method has to be implemented for all methods + def forward(self, amputated_values, amputation_mask): + imputated_values = amputated_values + ... + return imputated_values + + # implement this, if needs_fit is true, otherwise you can leave it out. + # this method receives the complete input training data to perform a fit on. + def fit(self, train_data): + ... + +Configuration File +------------------ + +You also need to create a gin configuration file in the `configs/imputation` directory, +named `newmethod.gin` after the name that was entered into the ``gin.configurable`` decorator call. + +Your ``.gin`` file should look like this: + +.. code-block:: python + + import gin.torch.external_configurables + import icu_benchmarks.models.wrappers + import icu_benchmarks.models.dl_models + import icu_benchmarks.models.utils + import icu_benchmarks.data.split_process_data + # import here the file you created your New_Method class in + import icu_benchmarks.imputation.new_model + + # Train params + train_common.model = + + + @newmethod # change this into the name of the gin configuration file + + # here you can set some training parameters + + + train_common.epochs = 1000 + train_common.batch_size = 64 + train_common.patience = 10 + train_common.min_delta = 1e-4 + train_common.use_wandb = True + + ImputationWrapper.optimizer = + + + @Adam + + + ImputationWrapper.lr_scheduler = "cosine" + + # Optimizer params + Adam.lr = 3e-4 + Adam.weight_decay = 1e-6 + + # here you can set the model parameters you want to configure + newmethod.model_arg1 = 20 + newmethod.model_arg2 = 15 + +Running Training +---------------- + +You can find further configurations in the ``Dataset_Imputation.gin`` file in the `configs/tasks/` directory. +To start a training of an imputation method with the newly created imputation method, use the following command: + +.. code-block:: bash + + python run.py train -d path/to/preprocessed/data/files -n dataset_name -t Dataset_Imputation -m newmethod + +For the dataset path please enter the path to the directory where the preprocessed `dyn.parquet`, `outc.parquet` and `sta.parquet` are stored. The ``dataset_name`` is only for logging purposes and breaks nothing if not set correctly. Keep in mind to use the name of the ``.gin`` config file created for the imputation method as model name for the ``-m`` parameter. + +Examples and References +----------------------- + +For reference for a deep learning based imputation method you can take a look at how the ``MLPImputation`` method is implemented in `icu_benchmarks/imputation/mlp.py` with its `MLP.gin` configuration file. For reference regarding methods with ``needs_fit=True``, take a look at the `icu_benchmarks/imputation/baselines.py` file with several baseline implementations and their corresponding config files in `configs/imputation/`. \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 00000000..c786a319 --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,372 @@ +.. image:: https://github.com/rvandewater/YAIB/blob/development/docs/figures/yaib_logo.png?raw=true + :alt: YAIB logo + +🧪 Yet Another ICU Benchmark +============================ + +.. image:: https://github.com/rvandewater/YAIB/actions/workflows/ci.yml/badge.svg?branch=development + :target: https://github.com/rvandewater/YAIB/actions/workflows/ci.yml + :alt: CI + +.. image:: https://img.shields.io/badge/code%20style-black-000000.svg + :target: https://github.com/psf/black + :alt: Black + +.. image:: https://img.shields.io/badge/platform-linux--64%20|%20win--64%20|%20osx--64-lightgrey + :alt: Platform + +.. image:: https://img.shields.io/badge/arXiv-2306.05109-b31b1b.svg + :target: http://arxiv.org/abs/2306.05109 + :alt: arXiv + +.. image:: https://img.shields.io/pypi/v/yaib.svg + :target: https://pypi.python.org/pypi/yaib/ + :alt: PyPI version shields.io + +.. image:: https://img.shields.io/badge/-Python_3.10-blue?logo=python&logoColor=white + :target: https://www.python.org/downloads/release/python-3100/ + :alt: python + +.. image:: https://img.shields.io/badge/PyTorch_2.0+-ee4c2c?logo=pytorch&logoColor=white + :target: https://pytorch.org/get-started/locally/ + :alt: pytorch + +.. image:: https://img.shields.io/badge/-Lightning_2.0+-792ee5?logo=pytorchlightning&logoColor=white + :target: https://pytorchlightning.ai/ + :alt: lightning + +.. image:: https://img.shields.io/badge/license-MIT-green.svg + :target: LICENSE + :alt: License + +Yet another ICU benchmark (YAIB) provides a framework for doing clinical machine learning experiments on Intensive Care Unit +(ICU) EHR data. + + ++----------------+---------------------------------------------------------+--------------------------------------------------------+--------------------------------------------------------+-----------------------------------------------------+ +| **Dataset** | `MIMIC-III `__ | `eICU-CRD `__ | `HiRID `__ | `AUMCdb `__ | +| | / `IV `__ | | | | ++================+=========================================================+========================================================+========================================================+=====================================================+ +| **Admissions** | 40k / 73k | 200k | 33k | 23k | ++----------------+---------------------------------------------------------+--------------------------------------------------------+--------------------------------------------------------+-----------------------------------------------------+ +| **Version** | v1.4 / v2.2 | v2.0 | v1.1.1 | v1.0.2 | ++----------------+---------------------------------------------------------+--------------------------------------------------------+--------------------------------------------------------+-----------------------------------------------------+ +| **Frequency** | 1 hour | 5 minutes | 2 / 5 minutes | up to 1 minute | +| (time-series) | | | | | ++----------------+---------------------------------------------------------+--------------------------------------------------------+--------------------------------------------------------+-----------------------------------------------------+ +| **Originally | 2015 / 2020 | 2017 | 2020 | 2019 | +| published** | | | | | ++----------------+---------------------------------------------------------+--------------------------------------------------------+--------------------------------------------------------+-----------------------------------------------------+ +| **Origin** | USA | USA | Switzerland | Netherlands | ++----------------+---------------------------------------------------------+--------------------------------------------------------+--------------------------------------------------------+-----------------------------------------------------+ + +New datasets can also be added. We are currently working on a package to +make this process as smooth as possible. The benchmark is designed for +operating on preprocessed parquet files. + +We provide five common tasks for clinical prediction by default: + ++----+----------------------+----------------------+-------------------+ +| No | Task | Frequency | Type | ++====+======================+======================+===================+ +| 1 | ICU Mortality | Once per Stay (after | Binary | +| | | 24H) | Classification | ++----+----------------------+----------------------+-------------------+ +| 2 | Acute Kidney Injury | Hourly (within 6H) | Binary | +| | (AKI) | | Classification | ++----+----------------------+----------------------+-------------------+ +| 3 | Sepsis | Hourly (within 6H) | Binary | +| | | | Classification | ++----+----------------------+----------------------+-------------------+ +| 4 | Kidney Function(KF) | Once per stay | Regression | ++----+----------------------+----------------------+-------------------+ +| 5 | Length of Stay (LoS) | Hourly (within 7D) | Regression | ++----+----------------------+----------------------+-------------------+ + +New tasks can be easily added. +To get started right away, we include the eICU and MIMIC-III demo datasets in our repository. + +The following repositories may be relevant as well: + +- `YAIB-cohorts `_: Cohort generation for YAIB. +- `YAIB-models `_: Pretrained models for YAIB. +- `ReciPys `_: Preprocessing package for YAIB pipelines. + +For all YAIB-related repositories, please see: https://github.com/stars/rvandewater/lists/yaib. + +📄Paper +======= + +To reproduce the benchmarks in our paper, we refer to the `ML reproducibility document `_. +If you use this code in your research, please cite the following publication: + +.. code-block:: bibtex + + @inproceedings{vandewaterYetAnotherICUBenchmark2024, + title = {Yet Another ICU Benchmark: A Flexible Multi-Center Framework for Clinical ML}, + shorttitle = {Yet Another ICU Benchmark}, + booktitle = {The Twelfth International Conference on Learning Representations}, + author = {van de Water, Robin and Schmidt, Hendrik Nils Aurel and Elbers, Paul and Thoral, Patrick and Arnrich, Bert and Rockenschaub, Patrick}, + year = {2024}, + month = oct, + urldate = {2024-02-19}, + langid = {english}, + } + +This paper can also be found on arxiv `2306.05109 `_ + +💿Installation +============== + +YAIB is currently ideally installed from source, however we also offer it an early PyPi release. + +Installation from source +------------------------- + +First, we clone this repository using git: + +.. code-block:: bash + + git clone https://github.com/rvandewater/YAIB.git + +Please note the branch. The newest features and fixes are available at the development branch: + +.. code-block:: bash + + git checkout development + +YAIB can be installed using a conda environment (preferred) or pip. Below are the three CLI commands to install YAIB +using **conda**. + +The first command will install an environment based on Python 3.10. + +.. code-block:: bash + + conda env update -f environment.yml + +.. note:: + Use ``environment.yml`` on x86 hardware. Please note that this installs Pytorch as well. + +.. note:: + For mps, one needs to comment out *pytorch-cuda*, see the `PyTorch install guide `_. + +We then activate the environment and install a package called ``icu-benchmarks``, after which YAIB should be operational. + +.. code-block:: bash + + conda activate yaib + pip install -e . + +After installation, please check if your Pytorch version works with CUDA (in case available) to ensure the best performance. +YAIB will automatically list available processors at initialization in its log files. + +👩‍💻Usage +========= + +Please refer to `our wiki `_ for detailed information on how to use YAIB. + +Quickstart 🚀 (demo data) +-------------------------- + +The authors of MIMIC-III and eICU have made a small demo dataset available to demonstrate their use. They can be found on Physionet: `MIMIC-III Clinical Database Demo `_ and `eICU Collaborative Research Database Demo `_. These datasets are published under the `Open Data Commons Open Database License v1.0 `_ and can be used without credentialing procedure. We have created demo cohorts processed **solely from these datasets** for each of our currently supported task endpoints. To the best of our knowledge, this complies with the license and the respective dataset author's instructions. Usage of the task cohorts and the dataset is only permitted with the above license. +We **strongly recommend** completing a human subject research training to ensure you properly handle human subject research data. + +In the folder ``demo_data`` we provide processed publicly available demo datasets from eICU and MIMIC with the necessary labels +for ``Mortality at 24h``, ``Sepsis``, ``Akute Kidney Injury``, ``Kidney Function``, and ``Length of Stay``. + +If you do not yet have access to the ICU datasets, you can run the following command to train models for the included demo +cohorts: + +.. code-block:: bash + + wandb sweep --verbose experiments/demo_benchmark_classification.yml + wandb sweep --verbose experiments/demo_benchmark_regression.yml + +.. code-block:: bash + + wandb agent + +.. tip:: + You can choose to run each of the configurations on a SLURM cluster instance by ``wandb agent --count 1 `` + +.. note:: + You will need to have a wandb account and be logged in to run the above commands. + +Getting the datasets +--------------------- + +HiRID, eICU, and MIMIC IV can be accessed through `PhysioNet `_. A guide to this process can be +found `here `_. +AUMCdb can be accessed through a separate access `procedure `_. We do not have +involvement in the access procedure and can not answer to any requests for data access. + +Cohort creation +--------------- + +Since the datasets were created independently of each other, they do not share the same data structure or data identifiers. In +order to make them interoperable, use the preprocessing utilities +provided by the `ricu package `_. +Ricu pre-defines a large number of clinical concepts and how to load them from a given dataset, providing a common interface to +the data, that is used in this +benchmark. Please refer to our `cohort definition `_ code for generating the cohorts +using our python interface for ricu. +After this, you can run the benchmark once you have gained access to the datasets. + +👟 Running YAIB +=============== + +Preprocessing and Training +--------------------------- + +The following command will run training and evaluation on the MIMIC demo dataset for (Binary) mortality prediction at 24h with +the +LGBMClassifier. Child samples are reduced due to the small amount of training data. We load available cache and, if available, +load +existing cache files. + +.. code-block:: bash + + icu-benchmarks \ + -d demo_data/mortality24/mimic_demo \ + -n mimic_demo \ + -t BinaryClassification \ + -tn Mortality24 \ + -m LGBMClassifier \ + -hp LGBMClassifier.min_child_samples=10 \ + --generate_cache \ + --load_cache \ + --seed 2222 \ + -l ../yaib_logs/ \ + --tune + +.. tip:: + For a list of available flags, run ``icu-benchmarks train -h``. + +.. tip:: + Run with ``PYTORCH_ENABLE_MPS_FALLBACK=1`` on Macs with Metal Performance Shaders. + +.. note:: + For Windows based systems, the next line character (\\) needs to be replaced by (^) (Command Prompt) or (\`) (Powershell) + respectively. + +Alternatively, the easiest method to train all the models in the paper is to run these commands from the directory root: + +.. code-block:: bash + + wandb sweep --verbose experiments/benchmark_classification.yml + wandb sweep --verbose experiments/benchmark_regression.yml + +This will create two hyperparameter sweeps for WandB for the classification and regression tasks. +This configuration will train all the models in the paper. You can then run the following command to train the models: + +.. code-block:: bash + + wandb agent + +.. tip:: + You can choose to run each of the configurations on a SLURM cluster instance by ``wandb agent --count 1 `` + +.. note:: + You will need to have a wandb account and be logged in to run the above commands. + +Evaluate or Finetune +-------------------- + +It is possible to evaluate a model trained on another dataset and no additional training is done. +In this case, the source dataset is the demo data from MIMIC and the target is the eICU demo: + +.. code-block:: bash + + icu-benchmarks \ + --eval \ + -d demo_data/mortality24/eicu_demo \ + -n eicu_demo \ + -t BinaryClassification \ + -tn Mortality24 \ + -m LGBMClassifier \ + --generate_cache \ + --load_cache \ + -s 2222 \ + -l ../yaib_logs \ + -sn mimic \ + --source-dir ../yaib_logs/mimic_demo/Mortality24/LGBMClassifier/2022-12-12T15-24-46/repetition_0/fold_0 + +.. note:: + A similar syntax is used for finetuning, where a model is loaded and then retrained. To run finetuning, replace ``--eval`` with ``-ft``. + +Models +------ + +We provide several existing machine learning models that are commonly used for multivariate time-series data. +``pytorch`` is used for the deep learning models, ``lightgbm`` for the boosted tree approaches, and ``sklearn`` for other classical +machine learning models. +The benchmark provides (among others) the following built-in models: + +- `Logistic Regression `_: + Standard regression approach. +- `Elastic Net `_: Linear regression with + combined L1 and L2 priors as regularizer. +- `LightGBM `_: Efficient gradient + boosting trees. +- `Long Short-term Memory (LSTM) `_: The most commonly used type of Recurrent Neural + Networks for long sequences. +- `Gated Recurrent Unit (GRU) `_ : A extension to LSTM which showed + improvements (`paper `_). +- `Temporal Convolutional Networks (TCN) `_ : 1D convolution approach to sequence data. By + using dilated convolution to extend the receptive field of the network it has shown great performance on long-term + dependencies. +- `Transformers `_: The most common Attention + based approach. + +🛠️ Development +=============== + +To adapt YAIB to your own use case, you can use +the `development information `_ page as a reference. +We appreciate contributions to the project. Please read the `contribution guidelines `_ before submitting a pull +request. + +Acknowledgements +================ + +This project has been developed partially under the funding of "Gemeinsamer Bundesausschuss (G-BA) Innovationsausschuss" in the framework of "CASSANDRA - Clinical ASSist AND aleRt Algorithms". +(project number 01VSF20015). We would like to acknowledge the work of Alisher Turubayev, Anna Shopova, Fabian Lange, Mahmut Kamalak, Paul Mattes, and Victoria Ayvasky for adding Pytorch Lightning, Weights and Biases compatibility, and several optional imputation methods to a later version of the benchmark repository. + +We do not own any of the datasets used in this benchmark. This project uses heavily adapted components of +the `HiRID benchmark `_. We thank the authors for providing this codebase and +encourage further development to benefit the scientific community. The demo datasets have been released under +an `Open Data Commons Open Database License (ODbL) `_. + +License +======= + +This source code is released under the MIT license, included `here `_. We do not own any of the datasets used or +included in this repository. + + .. toctree:: + :maxdepth: 2 + + api + Example-usage + adding_models + development + imputation_methods + Adding-datasets + Defining-a-new-clinical-concept + Generating-Cohorts + Adding-tasks + Preprocessing + Adding-a-new-model + Adding-evaluation-metrics + Creating-Pooled-datasets + Getting-access-to-ICU-EHR-datasets + Reproducing-Paper-Results + Weights-and-Biases + + + Indices and tables + ================== + + * :ref:`genindex` + * :ref:`modindex` + * :ref:`search` \ No newline at end of file diff --git a/experiments/benchmark_cass.yml b/experiments/benchmark_cass.yml new file mode 100644 index 00000000..0dd22e51 --- /dev/null +++ b/experiments/benchmark_cass.yml @@ -0,0 +1,152 @@ +command: + - ${env} + - ${program} + - train + - -d + - ../data/ + - -t +# - BinaryClassification + - CassClassification + - --log-dir + - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs + - --tune + - --wandb-sweep + - -tn + - SSI + - --hp-checkpoint + - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-20T18-29-31.571508/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-23T09-42-04.806320/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-20T18-29-31.571508/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-07-16T14:02:27/SSI/XGBClassifier/2025-07-20T18-32-19.553702/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/real_life_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-16T14:58:54/SSI/XGBClassifier/2025-07-16T19-51-45.484723/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-07-16T14:02:27/SSI/XGBClassifier/2025-07-16T19-41-44.887611/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-16T19-39-58.063736/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-14T17-02-09.569252/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-14T13-49-47.629185/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/real_life_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-06-27T14:38:43/SSI/XGBClassifier/2025-06-27T17-13-28.076070/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-06-27T13:36:02/SSI/XGBClassifier/2025-06-27T17-12-03.184584/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-06-27T14:45:50/SSI/XGBClassifier/2025-06-27T17-08-55.389121/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_segment_1.0_horizon_6:00:00_transfer_full_2025-06-19T19:11:23/SSI/XGBClassifier/2025-06-20T12-14-27.191590/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-06-20T16:37:14/SSI/XGBClassifier/2025-06-20T17-36-53.795313/hyperparameter_tuning_logs.db" +# - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/real_life_normal_ward_SSI-3_segment_1.0_horizon_6:00:00_transfer_full_2025-06-03T16:05:00/SSI/XGBClassifier/2025-06-04T17-09-19.640244/hyperparameter_tuning_logs.db +## - Mortality +# - --verbose + - --modalities + - "all" + - --file-names + - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' + - --explain_features +# - --verbose + - --load_data_vars +method: grid +name: yaib_classification_benchmark +parameters: + data_dir: + values: +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-21T17:08:54" + - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/real_life_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-09T13:30:00" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-05T12:24:09" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/real_life_set_ICU_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-06T10:34:13" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-01T12:13:31" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-05T12:05:19" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/real_life_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-05T17:02:14" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-05T12:24:09" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-07-31T11:05:22" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-07-30T10:54:08" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-29T12:10:11" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-27T13:50:45" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-25T17:15:01" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/real_life_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-16T14:58:54" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-07-16T14:02:27" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-12T16:49:25" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-10T21:42:32" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-07-10T14:16:29" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-07-09T16:09:03" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-08T15:24:29" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-08T12:12:15" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-07T19:04:46" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-07-06T15:14:29" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-06T15:06:43" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-04T13:11:42" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-07-04T13:30:20" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/real_life_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-02T14:52:41" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-06-27T13:36:02" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-06-27T14:45:50" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/real_life_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-06-27T14:38:43" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/real_life_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-06-26T21:23:03" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-06-25T19:27:55" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-06-25T16:49:25" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-06-23T17:45:52" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-06-23T12:50:14" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck/BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-06-20T16:37:14" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_SSI-3_POPF_Galleleck/BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-06-20T14:51:44" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck/BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-06-20T15:17:09" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_segment_1.0_horizon_6:00:00_transfer_full_2025-06-19T15:30:35" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_SSI-3_segment_1.0_horizon_6:00:00_transfer_full_2025-06-19T12:07:34" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_segment_1.0_horizon_6:00:00_transfer_full_2025-06-19T13:53:54" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_SSI-3_segment_1.0_horizon_6:00:00_transfer_full_2025-06-10T20:57:18" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_segment_1.0_horizon_6:00:00_transfer_full_2025-06-10T16:12:39" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/real_life_normal_ward_SSI-3_segment_1.0_horizon_6:00:00_transfer_full_2025-06-03T16:05:00" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/normal_ward_SSI-3_segment_1.0_horizon_6:00:00_transfer_full_2025-05-27T23:34:00" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/normal_ward_SSI-3_segment_1.0_horizon_6:00:00_transfer_full_2025-05-09T19:19:33_wearable_cut_off_0.75" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/ICU_and_normal_ward_SSI-3_segment_1.0_horizon_6:00:00_transfer_full_2025-04-22T21:19:32" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-04-15T12:42:04" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/normal_ward_SSI-3_POPF_heamatoma_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-04-07T15:35:16" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/cohort_separation/1/" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/cohort_separation/2/" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/cohort_separation/3/" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/cohort_separation/4/" + model: + values: +# - LogisticRegression +# - TimesNet +# - LGBMClassifier + - XGBClassifier +# - RUSBClassifier +# - CBClassifier +# - BRFClassifier +# - RFClassifier +# - GRU +# - LSTM +# - TCN +# - Transformer + modalities: + values: + - "all" +# - [ top_features ] +# - [copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding, static, static_retro_icu, static_retro_pre_intra ] # missing:hour + # - [copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_clinical_notes, cat_clinical_notes, static] +# - [copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_clinical_notes, wearable_activity, wearable_core, static] +# - [copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, static] #, wearable_activity, wearable_core] # without wearable data +# - [copra_observations, copra_scores, copra_fluid_balance, cat_clinical_notes, wearable_activity, wearable_core, static] # without ishmed lab data and clinical notes +# - [copra_observations, copra_scores, copra_fluid_balance, cat_clinical_notes, static ] # without wearable data and ishmed lab data, clinical notes +# - [copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_clinical_notes, wearable_activity, wearable_core, static] # without clinical notes +# - [copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_clinical_notes, wearable_activity, wearable_core, static] # without med embeddings +# - [wearable_activity, wearable_core, static] +# - [ishmed_lab_numeric, static] +# - [static] + # 1 missing modality +# - [ wearable_ppgembedding, core, wearable_ppgfeature, static ] +# - [ wearable_activity, static ] +# - [ wearable_ppgfeature, static ] +# - [ wearable_core, static] +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, static, static_retro_icu, static_retro_pre_intra, wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding, cat_medications, cat_clinical_notes] #hour ] # missing:notes, medications +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, static, static_retro_icu, static_retro_pre_intra, wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding, hour ] # missing:notes, medications +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, static, static_retro_icu, static_retro_pre_intra, hour ] # missing: wearable (activity, ppgfeature, ppgembedding) + - [static, hour, cat_clinical_notes] + file_names: + values: +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_3.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' + - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_12.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_18.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_24.parquet","STATIC":"sta.parquet"' + seed: + values: + - 1111 + use_pretrained_imputation: + values: + - None +program: icu-benchmarks diff --git a/experiments/benchmark_cass_baselines.yml b/experiments/benchmark_cass_baselines.yml new file mode 100644 index 00000000..00069b90 --- /dev/null +++ b/experiments/benchmark_cass_baselines.yml @@ -0,0 +1,64 @@ +command: + - ${env} + - ${program} + - train + - -d + - ../data/ + - -t +# - BinaryClassification + - CassClassification + - --log-dir + - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs + - --tune + - --wandb-sweep + - -tn + - SSI + - --hp-checkpoint + - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/dataset_baseline_cohort_post-operative_static_2025-08-01T11:00:26/SSI/XGBClassifier/2025-08-01T11-08-08.547141/hyperparameter_tuning_logs.db" + - --modalities + - "all" + - --file-names + - '"DYNAMIC":"dyn.parquet","STATIC":"sta.parquet"' + - --explain_features +# - --verbose + - --load_data_vars +method: grid +name: yaib_classification_benchmark +parameters: + data_dir: + values: +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/dataset_baseline_cohort_pre-operative_2025-08-01T10:57:46" + - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/dataset_baseline_cohort_intra-operative_static_2025-08-01T10:59:15" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/dataset_baseline_cohort_post-operative_static_2025-08-01T11:00:26" + model: + values: +# - LogisticRegression +# - TimesNet +# - LGBMClassifier + - XGBClassifier +# - CBClassifier +# - BRFClassifier +# - RUSBClassifier +# - RFClassifier +# - GRU +# - LSTM +# - TCN +# - Transformer + modalities: + values: + - "all" + file_names: + values: + - '"OUTCOME":"outc.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_12.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_18.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_24.parquet","STATIC":"sta.parquet"' + seed: + values: + - 1111 + use_pretrained_imputation: + values: + - None +program: icu-benchmarks \ No newline at end of file diff --git a/experiments/benchmark_cass_explain_modalities.yml b/experiments/benchmark_cass_explain_modalities.yml new file mode 100644 index 00000000..6b4f1586 --- /dev/null +++ b/experiments/benchmark_cass_explain_modalities.yml @@ -0,0 +1,69 @@ +command: + - ${env} + - ${program} + - train + - -d + - ../data/ + - -t +# - BinaryClassification + - CassClassification + - --log-dir + - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs +# - --tune + - --wandb-sweep + - -tn + - SSI + - --hp-checkpoint + - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-20T18-29-31.571508/hyperparameter_tuning_logs.db" +# - --verbose + - --modalities + - "all" + - --file-names + - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' + - --explain_features +# - --verbose + - --load_data_vars +method: grid +name: yaib_classification_benchmark +parameters: + data_dir: + values: + - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-05T12:24:09" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-01T12:13:31" + model: + values: +# - LogisticRegression +# - TimesNet +# - LGBMClassifier + - XGBClassifier +# - CBClassifier +# - BRFClassifier +# - RUSBClassifier +# - RFClassifier +# - GRU +# - LSTM +# - TCN +# - Transformer + modalities: + values: +# - "all" +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding, static, static_retro_icu, static_retro_pre_intra, hour ] +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, static, static_retro_icu, static_retro_pre_intra, hour ] # missing:wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, static, static_retro_icu, static_retro_pre_intra, hour, wearable_core ] # missing:wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, static, static_retro_icu, static_retro_pre_intra, hour, wearable_ppgfeature ] # missing:wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, static, static_retro_icu, static_retro_pre_intra, hour, wearable_activity ] # missing:wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, static, static_retro_icu, static_retro_pre_intra, hour, wearable_ppgembedding ] # missing:wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding + - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding, static, static_retro_icu, static_retro_pre_intra ] # missing:hour +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding, static, static_retro_icu, static_retro_pre_intra, hour ] # missing:cat_medications, cat_clinical_notes +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding, static, hour ] # missing:static_retro_icu, static_retro_pre_intra + file_names: + values: +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' + - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"' + seed: + values: + - 1111 + use_pretrained_imputation: + values: + - None +program: icu-benchmarks \ No newline at end of file diff --git a/experiments/benchmark_cass_modalities.yml b/experiments/benchmark_cass_modalities.yml new file mode 100644 index 00000000..efc8b893 --- /dev/null +++ b/experiments/benchmark_cass_modalities.yml @@ -0,0 +1,68 @@ +command: + - ${env} + - ${program} + - train + - -d + - ../data/ + - -t +# - BinaryClassification + - CassClassification + - --log-dir + - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs +# - --tune + - --wandb-sweep + - -tn + - SSI + - --hp-checkpoint + - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-20T18-29-31.571508/hyperparameter_tuning_logs.db" +# - --verbose + - --modalities + - "all" + - --file-names + - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' +# - --explain_features +# - --verbose + - --load_data_vars +method: grid +name: yaib_classification_benchmark +parameters: + data_dir: + values: + - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-05T12:24:09" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-01T12:13:31" + model: + values: +# - LogisticRegression +# - TimesNet +# - LGBMClassifier + - XGBClassifier +# - CBClassifier +# - BRFClassifier +# - RUSBClassifier +# - RFClassifier +# - GRU +# - LSTM +# - TCN +# - Transformer + modalities: + values: +# - "all" + - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding, static, static_retro_icu, static_retro_pre_intra, hour ] + - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, static, static_retro_icu, static_retro_pre_intra, hour ] # missing:wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding +# - [ ishmed_lab_numeric, cat_medications, cat_clinical_notes, wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding, static, static_retro_icu, static_retro_pre_intra, hour ] # missing:copra_observations, copra_scores, copra_fluid_balance +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding, hour ] # missing:static, static_retro_icu, static_retro_pre_intra + - [ copra_observations, copra_scores, copra_fluid_balance, cat_medications, cat_clinical_notes, wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding, static, static_retro_icu, static_retro_pre_intra, hour ] # missing:ishmed_lab_numeric +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding, static, static_retro_icu, static_retro_pre_intra ] # missing:hour +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding, static, static_retro_icu, static_retro_pre_intra, hour ] # missing:cat_medications, cat_clinical_notes +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding, static, hour ] # missing:static_retro_icu, static_retro_pre_intra + file_names: + values: +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' + - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"' + seed: + values: + - 1111 + use_pretrained_imputation: + values: + - None +program: icu-benchmarks \ No newline at end of file diff --git a/experiments/benchmark_cass_model_architecture.yml b/experiments/benchmark_cass_model_architecture.yml new file mode 100644 index 00000000..47db893f --- /dev/null +++ b/experiments/benchmark_cass_model_architecture.yml @@ -0,0 +1,74 @@ +command: + - ${env} + - ${program} + - train + - -d + - ../data/ + - -t +# - BinaryClassification + - CassClassification + - --log-dir + - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs + - --tune + - --wandb-sweep + - -tn + - SSI + - --hp-checkpoint + - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-21T17:08:54/SSI/GRU/2025-08-28T10-30-02.002119/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-01T12:13:31/SSI/CBClassifier/2025-08-13T11-36-55.269201/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-01T12:13:31/SSI/GRU/2025-08-10T10-57-05.448548/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-01T12:13:31/SSI/Transformer/2025-08-10T10-59-28.332900/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-01T12:13:31/SSI/RFClassifier/2025-08-11T10-10-03.033462/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-01T12:13:31/SSI/CBClassifier/2025-08-10T10-53-33.604092/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-01T12:13:31/SSI/LGBMClassifier/2025-08-11T10-09-52.484304/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-20T18-29-31.571508/hyperparameter_tuning_logs.db" + - --modalities + - "all" + - --file-names + - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' +# - --explain_features +# - --verbose + - --load_data_vars +method: grid +name: yaib_classification_benchmark +parameters: + data_dir: + values: + - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-21T17:08:54" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-01T12:13:31" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-05T12:05:19" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-05T12:24:09" + model: + values: +# - LogisticRegression +## - TimesNet +# - LGBMClassifier +## - XGBClassifier +# - CBClassifier +# - BRFClassifier +# - XGBEnsembleClassifier +# - RUSBClassifier +# - RFClassifier + - GRU +# - LSTM +# - TCN +# - Transformer + modalities: + values: +# - "all" + - [ top_500_features ] + file_names: + values: +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_3.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' + - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_12.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_18.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_24.parquet","STATIC":"sta.parquet"' + seed: + values: + - 1111 + use_pretrained_imputation: + values: + - None +program: icu-benchmarks \ No newline at end of file diff --git a/experiments/benchmark_cass_outcome_time.yml b/experiments/benchmark_cass_outcome_time.yml new file mode 100644 index 00000000..8a5d57a4 --- /dev/null +++ b/experiments/benchmark_cass_outcome_time.yml @@ -0,0 +1,82 @@ +command: + - ${env} + - ${program} + - train + - -d + - ../data/ + - -t +# - BinaryClassification + - CassClassification + - --log-dir + - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs +# - --tune + - --wandb-sweep + - -tn + - SSI + - --hp-checkpoint + - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-20T18-29-31.571508/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-23T09-42-04.806320/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-20T18-29-31.571508/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-07-16T14:02:27/SSI/XGBClassifier/2025-07-20T18-32-19.553702/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/real_life_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-16T14:58:54/SSI/XGBClassifier/2025-07-16T19-51-45.484723/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-07-16T14:02:27/SSI/XGBClassifier/2025-07-16T19-41-44.887611/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-16T19-39-58.063736/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-14T17-02-09.569252/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-14T13-49-47.629185/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/real_life_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-06-27T14:38:43/SSI/XGBClassifier/2025-06-27T17-13-28.076070/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-06-27T13:36:02/SSI/XGBClassifier/2025-06-27T17-12-03.184584/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-06-27T14:45:50/SSI/XGBClassifier/2025-06-27T17-08-55.389121/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_segment_1.0_horizon_6:00:00_transfer_full_2025-06-19T19:11:23/SSI/XGBClassifier/2025-06-20T12-14-27.191590/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-06-20T16:37:14/SSI/XGBClassifier/2025-06-20T17-36-53.795313/hyperparameter_tuning_logs.db" +# - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/real_life_normal_ward_SSI-3_segment_1.0_horizon_6:00:00_transfer_full_2025-06-03T16:05:00/SSI/XGBClassifier/2025-06-04T17-09-19.640244/hyperparameter_tuning_logs.db +## - Mortality +# - --verbose + - --modalities + - "all" + - --file-names + - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' +# - --explain_features +# - --verbose + - --load_data_vars +method: grid +name: yaib_classification_benchmark +parameters: + data_dir: + values: + - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-01T12:13:31" + model: + values: +# - LogisticRegression +# - TimesNet +# - LGBMClassifier + - XGBClassifier +# - CBClassifier +# - BRFClassifier +# - RUSBClassifier +# - RFClassifier +# - GRU +# - LSTM +# - TCN +# - Transformer + modalities: + values: + - "all" +# - [copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, static] #, wearable_activity, wearable_core] # without wearable data + + file_names: + values: + - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_1.parquet","STATIC":"sta.parquet"' + - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_2.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_3.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_12.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_18.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_24.parquet","STATIC":"sta.parquet"' + seed: + values: + - 1111 + use_pretrained_imputation: + values: + - None +program: icu-benchmarks \ No newline at end of file diff --git a/experiments/benchmark_cass_segment_length.yml b/experiments/benchmark_cass_segment_length.yml new file mode 100644 index 00000000..753690a1 --- /dev/null +++ b/experiments/benchmark_cass_segment_length.yml @@ -0,0 +1,79 @@ +command: + - ${env} + - ${program} + - train + - -d + - ../data/ + - -t +# - BinaryClassification + - CassClassification + - --log-dir + - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs +# - --tune + - --wandb-sweep + - -tn + - SSI + - --hp-checkpoint + - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-20T18-29-31.571508/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.25_horizon_6:00:00_transfer_full_2025-08-07T20:09:44/SSI/XGBClassifier/2025-08-08T10-30-09.905452/hyperparameter_tuning_logs.db" +# - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-20T18-29-31.571508/hyperparameter_tuning_logs.db" +## - Mortality +# - --verbose + - --modalities + - "all" + - --file-names + - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' +# - --explain_features +# - --verbose + - --load_data_vars +method: grid +name: cass_segmentation_length +parameters: + data_dir: + values: + - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.125_full_2025-08-22T12:28:04" + - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.25_horizon_6:00:00_transfer_full_2025-08-21T17:01:27" + - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-21T17:08:54" + - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_full_2025-08-21T16:46:25" + - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_3.0_full_2025-08-21T16:44:31" + - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_6.0_full_2025-08-22T11:38:16" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_9.0_full_2025-08-21T16:46:46" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_12.0_full_2025-08-21T16:46:09" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.25_horizon_6:00:00_transfer_full_2025-08-07T20:09:44" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-01T12:13:31" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-08-07T20:36:31" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_2.0_horizon_6:00:00_transfer_full_2025-08-07T20:02:59" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_4.0_horizon_6:00:00_transfer_full_2025-08-07T20:11:16" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_6.0_horizon_6:00:00_transfer_full_2025-08-07T20:03:41" + model: + values: +# - LogisticRegression +# - TimesNet +# - LGBMClassifier + - XGBClassifier +# - CBClassifier +# - BRFClassifier +# - RUSBClassifier +# - RFClassifier +# - GRU +# - LSTM +# - TCN +# - Transformer + modalities: + values: + - "all" + file_names: + values: +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_3.parquet","STATIC":"sta.parquet"' + - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_12.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_18.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_24.parquet","STATIC":"sta.parquet"' + seed: + values: + - 1111 + use_pretrained_imputation: + values: + - None +program: icu-benchmarks \ No newline at end of file diff --git a/experiments/benchmark_cass_wearable_modalities.yml b/experiments/benchmark_cass_wearable_modalities.yml new file mode 100644 index 00000000..87c0ba64 --- /dev/null +++ b/experiments/benchmark_cass_wearable_modalities.yml @@ -0,0 +1,73 @@ +command: + - ${env} + - ${program} + - train + - -d + - ../data/ + - -t +# - BinaryClassification + - CassClassification + - --log-dir + - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs +# - --tune + - --wandb-sweep + - -tn + - SSI + - --hp-checkpoint + - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-20T18-29-31.571508/hyperparameter_tuning_logs.db" +# - --verbose + - --modalities + - "all" + - --file-names + - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' +# - --explain_features +# - --verbose + - --load_data_vars +method: grid +name: yaib_classification_benchmark +parameters: + data_dir: + values: + - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-05T12:24:09" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-01T12:13:31" + model: + values: +# - LogisticRegression +# - TimesNet +# - LGBMClassifier + - XGBClassifier +# - CBClassifier +# - BRFClassifier +# - RUSBClassifier +# - RFClassifier +# - GRU +# - LSTM +# - TCN +# - Transformer + modalities: + values: +# - "all" + - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding, static, static_retro_icu, static_retro_pre_intra, hour ] + - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, wearable_activity, wearable_core, wearable_ppgfeature, static, static_retro_icu, static_retro_pre_intra, hour ] # missing:wearable_ppgembedding + - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, wearable_activity, wearable_ppgfeature, wearable_ppgembedding, static, static_retro_icu, static_retro_pre_intra, hour ] # missing:wearable_core + - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, wearable_activity, wearable_core, wearable_ppgembedding, static, static_retro_icu, static_retro_pre_intra, hour ] # missing:wearable_ppgfeature + - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, wearable_core, wearable_ppgfeature, wearable_ppgembedding, static, static_retro_icu, static_retro_pre_intra, hour ] # missing:wearable_activity +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, static, static_retro_icu, static_retro_pre_intra, hour ] # missing:wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, static, static_retro_icu, static_retro_pre_intra, hour, wearable_core ] # missing:wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, static, static_retro_icu, static_retro_pre_intra, hour, wearable_ppgfeature ] # missing:wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, static, static_retro_icu, static_retro_pre_intra, hour, wearable_activity ] # missing:wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, static, static_retro_icu, static_retro_pre_intra, hour, wearable_ppgembedding ] # missing:wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding, static, static_retro_icu, static_retro_pre_intra ] # missing:hour +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding, static, static_retro_icu, static_retro_pre_intra, hour ] # missing:cat_medications, cat_clinical_notes +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding, static, hour ] # missing:static_retro_icu, static_retro_pre_intra + file_names: + values: +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' + - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"' + seed: + values: + - 1111 + use_pretrained_imputation: + values: + - None +program: icu-benchmarks \ No newline at end of file diff --git a/experiments/benchmark_organsystems.yml b/experiments/benchmark_organsystems.yml new file mode 100644 index 00000000..4dc17890 --- /dev/null +++ b/experiments/benchmark_organsystems.yml @@ -0,0 +1,62 @@ +command: + - ${env} + - ${program} + - train + - -d + - ../data/ + - -t +# - BinaryClassification + - CassClassification + - --log-dir + - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs +# - --tune + - --wandb-sweep + - -tn + - SSI + - --hp-checkpoint + - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-20T18-29-31.571508/hyperparameter_tuning_logs.db" +# - --verbose + - --modalities + - "all" + - --file-names + - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' +# - --explain_features +# - --verbose + - --load_data_vars +method: grid +name: organ_system_benchmark +parameters: + data_dir: + values: + - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/organ_cohorts_lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-01T12:13:31/Pancreas" + - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/organ_cohorts_lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-01T12:13:31/Liver" +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/organ_cohorts_lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-01T12:13:31/Upper_GI" + - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/organ_cohorts_lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-01T12:13:31/Colorectal" + model: + values: +# - LogisticRegression +# - TimesNet +# - LGBMClassifier + - XGBClassifier +# - CBClassifier +# - BRFClassifier +# - RUSBClassifier +# - RFClassifier +# - GRU +# - LSTM +# - TCN +# - Transformer + modalities: + values: + - "all" + file_names: + values: +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' + - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"' + seed: + values: + - 1111 + use_pretrained_imputation: + values: + - None +program: icu-benchmarks \ No newline at end of file diff --git a/experiments/benchmark_sex.yml b/experiments/benchmark_sex.yml new file mode 100644 index 00000000..5daf9f92 --- /dev/null +++ b/experiments/benchmark_sex.yml @@ -0,0 +1,60 @@ +command: + - ${env} + - ${program} + - train + - -d + - ../data/ + - -t +# - BinaryClassification + - CassClassification + - --log-dir + - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs +# - --tune + - --wandb-sweep + - -tn + - SSI + - --hp-checkpoint + - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-20T18-29-31.571508/hyperparameter_tuning_logs.db" +# - --verbose + - --modalities + - "all" + - --file-names + - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' +# - --explain_features +# - --verbose + - --load_data_vars +method: grid +name: organ_system_benchmark +parameters: + data_dir: + values: +# - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/sex_cohorts_lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-01T12:13:31/female" + - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/sex_cohorts_lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-01T12:13:31/male" + model: + values: +# - LogisticRegression +# - TimesNet +# - LGBMClassifier + - XGBClassifier +# - CBClassifier +# - BRFClassifier +# - RUSBClassifier +# - RFClassifier +# - GRU +# - LSTM +# - TCN +# - Transformer + modalities: + values: + - "all" + file_names: + values: +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' + - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"' + seed: + values: + - 1111 + use_pretrained_imputation: + values: + - None +program: icu-benchmarks \ No newline at end of file diff --git a/experiments/charhpc_wandb_sweep_cpu.sh b/experiments/charhpc_wandb_sweep_cpu.sh new file mode 100644 index 00000000..da98bda8 --- /dev/null +++ b/experiments/charhpc_wandb_sweep_cpu.sh @@ -0,0 +1,16 @@ +#!/bin/bash +#SBATCH --job-name=yaib_experiment +#SBATCH --partition=compute # -p +#SBATCH --cpus-per-task=16 # -c +#SBATCH --mem=300gb +#SBATCH --output=logs/classification_%a_%j.log # %j is job id +#SBATCH --time=48:00:00 + + +source /etc/profile.d/conda.sh + +eval "$(conda shell.bash hook)" +conda activate yaib_req_pl +wandb agent --count 1 cassandra_hpi/cassandra/"$1" + +# Debug instance: srun -p gpu --pty -t 5:00:00 --gres=gpu:1 --cpus-per-task=16 --mem=100GB bash \ No newline at end of file diff --git a/experiments/charhpc_wandb_sweep_gpu.sh b/experiments/charhpc_wandb_sweep_gpu.sh new file mode 100644 index 00000000..1181ee2e --- /dev/null +++ b/experiments/charhpc_wandb_sweep_gpu.sh @@ -0,0 +1,14 @@ +#!/bin/bash +#SBATCH --job-name=yaib_experiment +#SBATCH --partition=gpu # -p +#SBATCH --cpus-per-task=16 # -c +#SBATCH --mem=200gb +#SBATCH --output=logs/classification_%a_%j.log # %j is job id +#SBATCH --gpus=1 +#SBATCH --time=48:00:00 + +source /etc/profile.d/conda.sh + +eval "$(conda shell.bash hook)" +conda activate yaib_req_pl +wandb agent --count 1 cassandra_hpi/cassandra/"$1" \ No newline at end of file diff --git a/experiments/slurm_base_char_sc.sh b/experiments/slurm_base_char_sc.sh new file mode 100644 index 00000000..6693b339 --- /dev/null +++ b/experiments/slurm_base_char_sc.sh @@ -0,0 +1,44 @@ +#!/bin/bash +#SBATCH --job-name=default +#SBATCH --mail-type=ALL +#SBATCH --mail-user=[INSERT:EMAIL] +#SBATCH --partition=gpu # -p +#SBATCH --cpus-per-task=4 # -c +#SBATCH --mem=48gb +#SBATCH --gpus=1 +#SBATCH --output=%x_%a_%j.log # %x is job-name, %j is job id, %a is array id +#SBATCH --array=0-3 + +# Submit with e.g. --export=TASK_NAME=mortality24,MODEL_NAME=LGBMClassifier +# Basic experiment variables, please exchange [INSERT] for your experiment parameters + +TASK=BinaryClassification # BinaryClassification +YAIB_PATH=/home/vandewrp/projects/YAIB #/dhc/home/robin.vandewater/projects/yaib +EXPERIMENT_PATH=../yaib_logs/${TASK_NAME}_experiment +DATASET_ROOT_PATH= /sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format #data/YAIB_Datasets/data + +echo "This is a SLURM job named" $SLURM_JOB_NAME "with array id" $SLURM_ARRAY_TASK_ID "and job id" $SLURM_JOB_ID +echo "Resources allocated: " $SLURM_CPUS_PER_TASK "CPUs, " $SLURM_MEM_PER_NODE "GB RAM, " $SLURM_GPUS_PER_NODE "GPUs" +echi "Task type:" ${TASK} +echo "Task: " ${TASK_NAME} +echo "Model: "${MODEL_NAME} +echo "Dataset: "${DATASETS[$SLURM_ARRAY_TASK_ID]} +echo "Experiment path: "${EXPERIMENT_PATH} + +cd ${YAIB_PATH} + +eval "$(conda shell.bash hook)" +conda activate yaib + + + +icu-benchmarks train \ + -d ${DATASET_ROOT_PATH} \ + -n ${DATASETS} \ + -t ${TASK} \ + -tn ${TASK_NAME} \ + -m ${MODEL_NAME} \ + -c \ + -s 1111 \ + -l ${EXPERIMENT_PATH} \ + --tune \ No newline at end of file diff --git a/experiments/top_features_benchmark_cass.yml b/experiments/top_features_benchmark_cass.yml new file mode 100644 index 00000000..aa03153e --- /dev/null +++ b/experiments/top_features_benchmark_cass.yml @@ -0,0 +1,88 @@ +command: + - ${env} + - ${program} + - train + - -d + - ../data/ + - -t +# - BinaryClassification + - CassClassification + - --log-dir + - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs + - --tune + - --wandb-sweep + - -tn + - SSI + - --hp-checkpoint + - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-20T18-29-31.571508/hyperparameter_tuning_logs.db" + - --modalities + - "all" + - --file-names + - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' +# - --explain_features +# - --verbose + - --load_data_vars +method: grid +name: yaib_classification_benchmark +parameters: + data_dir: + values: + - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/lab_set_ICU_and_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_0.5_horizon_6:00:00_transfer_full_2025-08-21T17:08:54" + model: + values: +# - LogisticRegression +# - TimesNet +# - LGBMClassifier + - XGBClassifier +# - RUSBClassifier +# - CBClassifier +# - BRFClassifier +# - RFClassifier +# - GRU +# - LSTM +# - TCN +# - Transformer + modalities: + values: +# - "all" + - [ top_50_features ] +# - [ top_100_features ] + - [ top_250_features ] + - [ top_500_features ] + - [ top_1000_features ] + - [ top_2000_features ] +# - [copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding, static, static_retro_icu, static_retro_pre_intra ] # missing:hour + # - [copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_clinical_notes, cat_clinical_notes, static] +# - [copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_clinical_notes, wearable_activity, wearable_core, static] +# - [copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, static] #, wearable_activity, wearable_core] # without wearable data +# - [copra_observations, copra_scores, copra_fluid_balance, cat_clinical_notes, wearable_activity, wearable_core, static] # without ishmed lab data and clinical notes +# - [copra_observations, copra_scores, copra_fluid_balance, cat_clinical_notes, static ] # without wearable data and ishmed lab data, clinical notes +# - [copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_clinical_notes, wearable_activity, wearable_core, static] # without clinical notes +# - [copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_clinical_notes, wearable_activity, wearable_core, static] # without med embeddings +# - [wearable_activity, wearable_core, static] +# - [ishmed_lab_numeric, static] +# - [static] + # 1 missing modality +# - [ wearable_ppgembedding, core, wearable_ppgfeature, static ] +# - [ wearable_activity, static ] +# - [ wearable_ppgfeature, static ] +# - [ wearable_core, static] +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, static, static_retro_icu, static_retro_pre_intra, wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding, cat_medications, cat_clinical_notes] #hour ] # missing:notes, medications +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, static, static_retro_icu, static_retro_pre_intra, wearable_activity, wearable_core, wearable_ppgfeature, wearable_ppgembedding, hour ] # missing:notes, medications +# - [ copra_observations, copra_scores, copra_fluid_balance, ishmed_lab_numeric, cat_medications, cat_clinical_notes, static, static_retro_icu, static_retro_pre_intra, hour ] # missing: wearable (activity, ppgfeature, ppgembedding) +# - [static, cat_clinical_notes] + file_names: + values: +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_3.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' + - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_12.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_18.parquet","STATIC":"sta.parquet"' +# - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_24.parquet","STATIC":"sta.parquet"' + seed: + values: + - 1111 + use_pretrained_imputation: + values: + - None +program: icu-benchmarks diff --git a/icu_benchmarks/cross_validation.py b/icu_benchmarks/cross_validation.py index c3951d19..f62d7f0e 100644 --- a/icu_benchmarks/cross_validation.py +++ b/icu_benchmarks/cross_validation.py @@ -39,6 +39,7 @@ def execute_repeated_cv( verbose: bool = False, wandb: bool = False, complete_train: bool = False, + explain_features: bool = False, ) -> float: """Preprocesses data and trains a model for each fold. @@ -73,6 +74,11 @@ def execute_repeated_cv( cv_repetitions_to_train = cv_repetitions if not cv_folds_to_train: cv_folds_to_train = cv_folds + elif cv_folds_to_train > cv_folds: + raise ValueError( + f"cv_folds_to_train is {cv_folds_to_train}, cv_folds is {cv_folds}. " + f" This is likely due to a hyperparameter tuning settings mismatch." + ) agg_loss = 0 seed_everything(seed, reproducible) if complete_train: @@ -120,6 +126,7 @@ def execute_repeated_cv( verbose=verbose, use_wandb=wandb, train_only=complete_train, + explain_features=explain_features, ) train_time = datetime.now() - start_time @@ -135,7 +142,7 @@ def execute_repeated_cv( wandb_log({"Iteration": repetition * cv_folds_to_train + fold_index}) if repetition * cv_folds_to_train + fold_index > 1: try: - aggregate_results(log_dir) + aggregate_results(log_dir, explain_features=explain_features) except Exception as e: logging.error(f"Failed to aggregate results: {e}") log_full_line(f"FINISHED CV REPETITION {repetition}", level=logging.INFO, char="=", num_newlines=3) diff --git a/icu_benchmarks/data/loader.py b/icu_benchmarks/data/loader.py index bfa944ad..b3808d7e 100644 --- a/icu_benchmarks/data/loader.py +++ b/icu_benchmarks/data/loader.py @@ -32,9 +32,10 @@ def __init__( self.vars = vars self.grouping_df = data[split][grouping_segment] # Get the row indicators for the data to be able to match predicted labels - if not isinstance(vars["SEQUENCE"], str): - raise ValueError(f'Expected key "SEQUENCE" to be of type str, got {type(vars["SEQUENCE"])} instead') + if "SEQUENCE" in self.vars and self.vars["SEQUENCE"] in data[split][DataSegment.features].columns: + if not isinstance(vars["SEQUENCE"], str): + raise ValueError(f'Expected key "SEQUENCE" to be of type str, got {type(vars["SEQUENCE"])} instead') # We have a time series dataset self.row_indicators = data[split][DataSegment.features][self.vars["GROUP"], self.vars["SEQUENCE"]] @@ -42,11 +43,15 @@ def __init__( self.features_df = data[split][DataSegment.features] self.features_df = self.features_df.sort([self.vars["GROUP"], self.vars["SEQUENCE"]]) self.features_df = self.features_df.drop(self.vars["SEQUENCE"]) + self.row_indicators = self.row_indicators.sort([self.vars["GROUP"], self.vars["SEQUENCE"]]) + else: # We have a static dataset logging.info("Using static dataset") self.row_indicators = data[split][DataSegment.features][self.vars["GROUP"]] self.features_df = data[split][DataSegment.features] + # Series with unique values + self.row_indicators = self.row_indicators.sort() # calculate basic info for the data self.num_stays = self.grouping_df[self.vars["GROUP"]].unique().shape[0] self.maxlen = self.features_df.group_by([self.vars["GROUP"]]).len().max().item(0, 1) @@ -67,7 +72,7 @@ def __len__(self) -> int: return self.num_stays def get_feature_names(self) -> List[str]: - return self.features_df.columns + return [col for col in self.features_df.columns] # if col != self.vars["GROUP"] and col != self.vars["SEQUENCE"]] def to_tensor(self) -> tuple[Union[Tensor, np.ndarray], ...]: values: list[list] = [] @@ -89,6 +94,14 @@ class PredictionPolarsDataset(CommonPolarsDataset): def __init__(self, *args, ram_cache: bool = True, **kwargs): super().__init__(*args, **kwargs) + if "SEQUENCE" in self.vars: + if self.row_indicators[self.vars["SEQUENCE"]].dtype.is_temporal(): + self.row_indicators = self.row_indicators.with_columns(pl.col(self.vars["SEQUENCE"]).dt.total_hours()) + else: + self.row_indicators = self.row_indicators.with_columns(pl.col(self.vars["SEQUENCE"])) + else: + logging.info("Using static dataset") + self.row_indicators = self.grouping_df[self.vars["GROUP"]].to_frame(self.vars["GROUP"]) self.outcome_df = self.grouping_df self.ram_cache(ram_cache) @@ -116,6 +129,7 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]: if len(labels) == 1: # only one label per stay, align with window labels = np.concatenate([np.empty(window.shape[0] - 1) * np.nan, labels], axis=0) + row_inds = self.row_indicators.filter(pl.col(self.vars["GROUP"]) == stay_id).to_numpy() length_diff = self.maxlen - window.shape[0] pad_mask = np.ones(window.shape[0]) @@ -126,7 +140,9 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]: window = np.concatenate([window, np.ones((length_diff, window.shape[1])) * pad_value], axis=0) labels = np.concatenate([labels, np.ones(length_diff) * pad_value], axis=0) pad_mask = np.concatenate([pad_mask, np.zeros(length_diff)], axis=0) - + # row_inds = np.concatenate([row_inds, np.ones(length_diff, row_inds.shape[1]) * pad_value], axis=0) + row_inds = np.concatenate([row_inds, np.ones((length_diff, row_inds.shape[1])) * pad_value], axis=0) + row_inds = row_inds.astype(np.float32) not_labeled = np.argwhere(np.isnan(labels)) if len(not_labeled) > 0: labels[not_labeled] = -1 @@ -135,8 +151,8 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor, Tensor]: pad_mask = pad_mask.astype(bool) labels = labels.astype(np.float32) data = window.astype(np.float32) - - return from_numpy(data), from_numpy(labels), from_numpy(pad_mask) + # if self.vars"SEQUENCE" in self.vars: + return from_numpy(data), from_numpy(labels), from_numpy(pad_mask), from_numpy(row_inds) def get_balance(self) -> list: """Return the weight balance for the split of interest. @@ -163,12 +179,24 @@ def get_data_and_labels(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: if len(labels) == self.num_stays: # order of groups could be random, we make sure not to change it rep = rep.group_by(self.vars["GROUP"]).last() - else: - # Adding segment count for each stay id and timestep. - rep = rep.with_columns(pl.col(self.vars["GROUP"]).cum_count().over(self.vars["GROUP"]).alias("counter")) - rep = rep.to_numpy().astype(float) + # else: + # # Adding segment count for each stay id and timestep. + # rep = rep.with_columns(pl.col(self.vars["GROUP"]).cum_count().over(self.vars["GROUP"]).alias("counter")) + # rep = rep.sort([self.vars["GROUP"], "counter"]) + # rep = rep.to_numpy().astype(float) + # Remove the first column from the rep array (group column) + # needs to still be in there? logging.debug(f"rep shape: {rep.shape}") logging.debug(f"labels shape: {labels.shape}") + # if self.vars["SEQUENCE"] in rep.columns: + # rep = rep.select(pl.exclude(self.vars["SEQUENCE"])) + # logging.info(f"Removed sequence column {self.vars['SEQUENCE']} from features_df.") + rep = rep.to_numpy().astype(float) + rep = rep[:, 1:] + if ("SEQUENCE" in self.vars and self.vars["SEQUENCE"] in self.row_indicators + and self.row_indicators[self.vars["SEQUENCE"]].dtype == pl.Duration): + self.row_indicators = self.row_indicators.with_columns(pl.col(self.vars["SEQUENCE"]).dt.total_hours()) + # Todo: check if row indicators introduce information loss return rep, labels, self.row_indicators.to_numpy() def to_tensor(self) -> tuple[Union[Tensor, np.ndarray], Union[Tensor, np.ndarray], Union[Tensor, np.ndarray]]: diff --git a/icu_benchmarks/data/preprocessor.py b/icu_benchmarks/data/preprocessor.py index c4b0e6d4..ef126839 100644 --- a/icu_benchmarks/data/preprocessor.py +++ b/icu_benchmarks/data/preprocessor.py @@ -1,13 +1,14 @@ import copy -import logging import pickle from abc import ABC, abstractmethod from pathlib import Path from typing import Optional, Union +import logging import gin import pandas as pd import polars as pl +import recipys import torch from numpy import nan as np_nan from recipys.recipe import Recipe @@ -66,6 +67,24 @@ def set_imputation_model(self, imputation_model): update_wandb_config({"imputation_model": self.imputation_model.__class__.__name__}) + def vars_selection(self, input_variables, segment: str) -> Union[str, list[str]]: + if input_variables.get(segment, None) is None or len(input_variables[segment]) == 0: + logging.warning("No dynamic variables provided. Skipping dynamic preprocessing.") + return [] # Return empty list if no variables are provided + vars_to_apply: Union[str, list[str]] + if self.vars_to_exclude is not None: + # Exclude vars_to_exclude from missing indicator/ feature generation + vars_to_apply = list(set(input_variables[segment]) - set(self.vars_to_exclude)) + logging.info(f"Excluding features: {len(self.vars_to_exclude)} : " + f"{self.vars_to_exclude if len(self.vars_to_exclude) < 10 else f'{self.vars_to_exclude[:10]}...'}...") + if len(vars_to_apply) == 0: + logging.warning( + f"No variables left after excluding vars_to_exclude: {self.vars_to_exclude} for segment {segment}. " + "Skipping preprocessing for this segment." + ) + else: + vars_to_apply = input_variables[segment] + return vars_to_apply @gin.configurable("base_classification_preprocessor") class PolarsClassificationPreprocessor(Preprocessor): @@ -157,25 +176,29 @@ def apply( logging.debug(data[DataSplit.train][DataSegment.features].head()) logging.debug(data[DataSplit.train][DataSegment.outcome]) - if not isinstance(vars["SEQUENCE"], str): - raise TypeError(f'Expected key "SEQUENCE" to be of type str, got {type(vars["SEQUENCE"])} instead') - - for split in [DataSplit.train, DataSplit.val, DataSplit.test]: - if vars["SEQUENCE"] in data[split][DataSegment.outcome] and len(data[split][DataSegment.features]) != len( - data[split][DataSegment.outcome] - ): - raise Exception( - f"Data and outcome length mismatch in {split} split: " - f"features: {len(data[split][DataSegment.features])}, outcome: {len(data[split][DataSegment.outcome])}" - ) + # if not isinstance(vars["SEQUENCE"], str): + # raise TypeError(f'Expected key "SEQUENCE" to be of type str, got {type(vars["SEQUENCE"])} instead') + if "SEQUENCE" in vars: + for split in [DataSplit.train, DataSplit.val, DataSplit.test]: + # Check if we have sequence in the outcome and the data and outcome length match + if vars["SEQUENCE"] in data[split][DataSegment.outcome] and len(data[split][DataSegment.features]) != len( + data[split][DataSegment.outcome] + ): + raise Exception( + f"Data and outcome length mismatch in {split} split: " + f"features: {len(data[split][DataSegment.features])}, outcome: {len(data[split][DataSegment.outcome])}" + ) data[DataSplit.train][DataSegment.features] = data[DataSplit.train][DataSegment.features].unique() data[DataSplit.val][DataSegment.features] = data[DataSplit.val][DataSegment.features].unique() data[DataSplit.test][DataSegment.features] = data[DataSplit.test][DataSegment.features].unique() - logging.info(f"Generate features: {self.generate_features}") + logging.debug(f"Generate features in preprocessing: {self.generate_features}") return data def _process_static(self, data: dict[str, dict[str, pl.DataFrame]], vars: dict[str, Union[str, list[str]]]): + vars_to_apply = self.vars_selection(input_variables=vars, segment=DataSegment.dynamic) + if len(vars_to_apply) == 0: + return data sta_rec = Recipe(data[DataSplit.train][DataSegment.static], [], vars[DataSegment.static]) sta_rec.add_step(StepSklearn(MissingIndicator(features="all"), sel=all_of(vars[DataSegment.static]), in_place=False)) if self.scaling: @@ -209,28 +232,27 @@ def _model_impute(self, data: pd.DataFrame, group: Optional[str] = None): return data def _process_dynamic(self, data: dict[str, dict[str, pl.DataFrame]], vars: dict[str, Union[str, list[str]]]): + vars_to_apply = self.vars_selection(input_variables=vars, segment=DataSegment.dynamic) + if len(vars_to_apply) == 0: + return data dyn_rec = Recipe( data[DataSplit.train][DataSegment.dynamic], [], vars[DataSegment.dynamic], vars["GROUP"], vars["SEQUENCE"] ) if self.scaling: - dyn_rec.add_step(StepScale()) + dyn_rec.add_step(StepScale(sel=all_numeric_predictors(backend=recipys.constants.Backend.POLARS))) if self.imputation_model is not None: dyn_rec.add_step(StepImputeModel(model=self.model_impute, sel=all_of(vars[DataSegment.dynamic]))) - - vars_to_apply: Union[str, list[str]] - if self.vars_to_exclude is not None: - # Exclude vars_to_exclude from missing indicator/ feature generation - vars_to_apply = list(set(vars[DataSegment.dynamic]) - set(self.vars_to_exclude)) - else: - vars_to_apply = vars[DataSegment.dynamic] dyn_rec.add_step(StepSklearn(MissingIndicator(features="all"), sel=all_of(vars_to_apply), in_place=False)) dyn_rec.add_step(StepImputeFill(strategy="forward")) dyn_rec.add_step(StepImputeFill(strategy="zero")) if self.generate_features: dyn_rec = self._dynamic_feature_generation(dyn_rec, all_of(vars_to_apply)) data = apply_recipe_to_splits(dyn_rec, data, DataSegment.dynamic, self.save_cache, self.load_cache) + # logging.info(f"Data columns: {len(data[Split.train][Segment.dynamic].columns)} -> old columns: {len(old_columns)}, added columns: {set(data[Split.train][Segment.dynamic].columns) - set(old_columns)}") return data + + def _dynamic_feature_generation(self, data, dynamic_vars): logging.debug("Adding dynamic feature generation.") data.add_step(StepHistorical(sel=dynamic_vars, fun=Accumulator.MIN, suffix="min_hist")) @@ -398,12 +420,15 @@ def apply( logging.debug("Data head") logging.debug(data[DataSplit.train][DataSegment.features].head()) logging.debug(data[DataSplit.train][DataSegment.outcome].head()) - logging.info(f"Generate features: {self.generate_features}") + logging.debug(f"Generate features: {self.generate_features}") return data def _process_static( self, data: dict[str, dict[str, pd.DataFrame]], vars: dict[str, Union[str, list[str]]] ) -> dict[str, dict[str, pd.DataFrame]]: + vars_to_apply = self.vars_selection(input_variables=vars, segment=DataSegment.static) + if len(vars_to_apply) == 0: + return data sta_rec = Recipe(data[DataSplit.train][DataSegment.static], [], vars[DataSegment.static]) if self.scaling: sta_rec.add_step(StepScale()) @@ -439,9 +464,15 @@ def _model_impute(self, data: pd.DataFrame, group: Optional[str] = None) -> pd.D def _process_dynamic( self, data: dict[str, dict[str, pd.DataFrame]], vars: dict[str, Union[str, list[str]]] ) -> dict[str, dict[str, pd.DataFrame]]: + vars_to_apply = self.vars_selection(input_variables=vars, segment=DataSegment.dynamic) + if len(vars_to_apply) == 0: + return data dyn_rec = Recipe( data[DataSplit.train][DataSegment.dynamic], [], vars[DataSegment.dynamic], vars["GROUP"], vars["SEQUENCE"] ) + vars_to_apply = self.vars_selection(input_variables=vars, segment=DataSegment.dynamic) + if len(vars_to_apply) == 0: + return data if self.scaling: dyn_rec.add_step(StepScale()) if self.imputation_model is not None: diff --git a/icu_benchmarks/data/split_process_data.py b/icu_benchmarks/data/split_process_data.py index 9e4df484..39a1a9f2 100644 --- a/icu_benchmarks/data/split_process_data.py +++ b/icu_benchmarks/data/split_process_data.py @@ -3,36 +3,36 @@ import json import logging import os -import pickle -from pathlib import Path from timeit import default_timer as timer from typing import Any, Iterable, Optional, Union - import gin import pandas as pd import polars as pl from sklearn.model_selection import KFold, ShuffleSplit, StratifiedKFold, StratifiedShuffleSplit - +import polars.selectors as cs +from pathlib import Path +import pickle from icu_benchmarks.constants import RunMode from icu_benchmarks.data.preprocessor import ( PandasClassificationPreprocessor, PolarsClassificationPreprocessor, - PolarsRegressionPreprocessor, Preprocessor, ) from .constants import DataSegment, DataSplit, VarType +from .utils import check_sanitize_data, modality_selection @gin.configurable("preprocess") def preprocess_data( data_dir: Path, file_names: dict[str, str] | Any = gin.REQUIRED, - preprocessor: type[PolarsClassificationPreprocessor | PolarsRegressionPreprocessor] = PolarsClassificationPreprocessor, + preprocessor: type[Preprocessor] = PolarsClassificationPreprocessor, use_static: bool = True, vars: dict[str, str | list[str]] | Any = gin.REQUIRED, modality_mapping: Optional[dict[str, list[str]]] = None, selected_modalities: Optional[list[str]] = None, + exclude_preproc: list[str] = None, seed: int = 42, debug: bool = False, cv_repetitions: int = 5, @@ -48,6 +48,9 @@ def preprocess_data( label: Optional[str] = None, required_var_types: Optional[list[str]] = None, required_segments: Optional[list[str]] = None, + reduce_sequence_steps: int = 0, + remove_short_stays: bool = True, + min_remaining_steps: int = 1, ) -> dict[str, dict[str, pl.DataFrame]]: """ Perform loading, splitting, imputing and normalising of task data. @@ -60,6 +63,9 @@ def preprocess_data( data_dir: Path to the directory holding the data. file_names: Contains the parquet file names in data_dir. vars: Contains the names of columns in the data. + modality_mapping: Mapping of modalities to column names. + selected_modalities: List of selected modalities to use. + exclude_preproc: List of modalities to exclude from preprocessing. seed: Random seed. debug: Load less data if true. cv_repetitions: Number of times to repeat cross validation. @@ -70,7 +76,9 @@ def preprocess_data( generate_cache: Generate cached preprocessed data if true. fold_index: Index of the fold to return. pretrained_imputation_model: pretrained imputation model to use. if None, standard imputation is used. - + reduce_sequence_steps: Number of steps to reduce sequence length. + remove_short_stays: Whether to remove stays that are shorter than reduce_sequence_steps. + min_remaining_steps: Minimum number of remaining steps after reduction to keep a stay. Returns: Preprocessed data as DataFrame in a hierarchical dict with features type (STATIC) / DYNAMIC/ OUTCOME nested within split (train/val/test). @@ -80,7 +88,7 @@ def preprocess_data( if selected_modalities is None: selected_modalities = ["all"] if required_var_types is None: - required_var_types = ["GROUP", "SEQUENCE", "LABEL"] + required_var_types = ["GROUP", "LABEL"] if required_segments is None: required_segments = [DataSegment.static, DataSegment.dynamic, DataSegment.outcome] @@ -106,16 +114,24 @@ def preprocess_data( dumped_vars = json.dumps(vars, sort_keys=True) logging.info(f"Using preprocessor: {preprocessor.__name__}") - - cat_clinical_notes = modality_mapping.get("cat_clinical_notes") - cat_med_embeddings_map = modality_mapping.get("cat_med_embeddings_map") - if cat_clinical_notes is not None and cat_med_embeddings_map is not None: - vars_to_exclude = cat_clinical_notes + cat_med_embeddings_map - else: - vars_to_exclude = None - cache_dir = data_dir / "cache" cache_filename = f"s_{seed}_r_{repetition_index}_f_{fold_index}_t_{train_size}_d_{debug}" + + vars_to_exclude = [] + if exclude_preproc is not None: + # Exclude variables from preprocessing based on modality: + # useful if modality has already undergone extensive preprocessing. + if modality_mapping is not None and len(modality_mapping) > 0: + for modality in exclude_preproc: + if modality in modality_mapping: + vars_to_exclude.extend(modality_mapping.get(modality)) + else: + logging.warning(f"Modality '{modality}' not found in modality mapping.") + logging.info( + f"Excluding modalities in {exclude_preproc}. Total vars excluded from preprocessing: {len(vars_to_exclude)}" + ) + else: + logging.warning("No modality mapping provided. Excluding variables from preprocessing will have no effect.") preprocessor_instance: Preprocessor = preprocessor( use_static_features=use_static, save_cache=data_dir / "preproc" / (cache_filename + "_recipe") if generate_cache else None, @@ -138,12 +154,14 @@ def preprocess_data( # Read parquet files into dataframes and remove the parquet file from memory logging.info(f"Loading data from directory {data_dir.absolute()}") + if not data_dir.exists(): + raise FileNotFoundError(f"Data directory {data_dir} does not exist. Please check the path.") data: dict[str, pl.DataFrame] = { f: pl.read_parquet(data_dir / file_names[f]) for f in file_names.keys() if os.path.exists(data_dir / file_names[f]) } - logging.info(f"Loaded data: {list(data.keys())}") - sanitized_data = check_sanitize_data(data, vars) + logging.info(f"Loaded datasets: {list(data.keys())}") + sanitized_data, vars = check_sanitize_data(data, vars) if DataSegment.dynamic not in sanitized_data.keys(): logging.warning("No dynamic data found, using only static data.") @@ -156,6 +174,18 @@ def preprocess_data( else: logging.info("Selecting all modalities.") + # Reduce stays by sequence steps if requested + if reduce_sequence_steps > 0: + logging.info(f"Reducing stays by {reduce_sequence_steps} sequence steps") + if remove_short_stays: + logging.info(f"Removing stays with less than {min_remaining_steps} remaining steps after reduction.") + sanitized_data = reduce_stays_by_steps( + sanitized_data, + vars, + reduce_sequence_steps, + remove_short_stays, + min_remaining_steps + ) # Generate the splits logging.info("Generating splits.") if not complete_train: @@ -180,9 +210,9 @@ def preprocess_data( sanitized_data = preprocessor_instance.apply(sanitized_data, vars) end = timer() logging.info(f"Preprocessing took {end - start:.2f} seconds.") - logging.info(f"Checking for NaNs and nulls in {data.keys()}.") + logging.debug(f"Checking for NaNs and nulls in {data.keys()}.") for _dict in sanitized_data.values(): - for key, val in _dict.items(): + for key, dataframe in _dict.items(): logging.debug(f"Data type: {key}") logging.debug("Is NaN:") sel = _dict[key].select(pl.selectors.numeric().is_nan().max()) @@ -190,12 +220,49 @@ def preprocess_data( logging.debug("Has nulls:") sel = _dict[key].select(pl.all().has_nulls()) logging.debug(sel.select(col.name for col in sel if col.item(0))) - _dict[key] = val.fill_null(strategy="zero") - _dict[key] = val.fill_nan(0) + _dict[key] = dataframe.fill_null(strategy="zero") + _dict[key] = dataframe.fill_nan(0) logging.debug("Dropping columns with nulls") sel = _dict[key].select(pl.all().has_nulls()) logging.debug(sel.select(col.name for col in sel if col.item(0))) + logging.debug("Checking for infinite values.") + for col in dataframe.select(cs.numeric()).columns: + if dataframe[col].is_infinite().any(): + logging.warning(f"Column '{col}' contains infinite values. Datatype: {dataframe[col].dtype}") + + max_float64 = 0 + # Replace infinite values with the maximum value for float64 + dataframe = dataframe.with_columns( + [ + pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col) + for col in dataframe.columns + if dataframe[col].dtype == pl.Float64 + ] + ) + _dict[key] = dataframe + logging.debug(f"Amount of columns: {len(dataframe.columns)}") + logging.info(f"{len(sanitized_data[DataSplit.train][DataSegment.features].columns)} columns in dynamic data.") + train_samples = len(sanitized_data[DataSplit.train][DataSegment.outcome]) + val_samples = len(sanitized_data[DataSplit.val][DataSegment.outcome]) + test_samples = len(sanitized_data[DataSplit.test][DataSegment.outcome]) + train_incidence = sanitized_data[DataSplit.test][DataSegment.outcome][vars[VarType.label]] + val_incidence = sanitized_data[DataSplit.val][DataSegment.outcome][vars[VarType.label]] + test_incidence = sanitized_data[DataSplit.test][DataSegment.outcome][vars[VarType.label]] + total_samples = train_samples + val_samples + test_samples + logging.info( + f"Train segments: {train_samples} ({train_samples / total_samples:.1%}), " + f"Val segments: {val_samples} ({val_samples / total_samples:.1%}), " + f"Test segments: {test_samples} ({test_samples / total_samples:.1%})" + ) + # Define the number of decimal places for rounding + decimal_places = 4 # + logging.info( + f"Train incidence: {train_incidence.mean():.{decimal_places}f}, STD {train_incidence.std():.{decimal_places}f}| " + f"Val incidence: {val_incidence.mean():.{decimal_places}f}, STD {val_incidence.std():.{decimal_places}f}| " + f"Test incidence: {test_incidence.mean():.{decimal_places}f}, STD {test_incidence.std():.{decimal_places}f}" + ) + # logging.info(f"{len(ou)}") # Generate cache if generate_cache: caching(cache_dir, cache_file, data, load_cache) @@ -217,7 +284,7 @@ def flatten_column_names(*args: object) -> list[str]: return result -def check_sanitize_data(data: dict[str, pl.DataFrame], vars: dict[str, str | list[str]]) -> dict[str, pl.DataFrame]: +def check_sanitize_data(data: dict[str, pl.DataFrame], vars: dict[str, str | list[str]]) -> dict[str, pl.DataFrame]: # noqa: F811 """Check for duplicates in the loaded data and remove them.""" group: Optional[Union[str, list[str]]] = vars.get(VarType.group) sequence: Optional[Union[str, list[str]]] = vars.get(VarType.sequence) @@ -243,13 +310,13 @@ def check_sanitize_data(data: dict[str, pl.DataFrame], vars: dict[str, str | lis subset=flatten_column_names(group, sequence), keep=keep, maintain_order=True ) else: - data[DataSegment.outcome] = data[DataSegment.outcome].unique(subset=group, keep=keep, maintain_order=True) + data[DataSegment.outcome] = data[DataSegment.outcome].unique(subset=[group], keep=keep, maintain_order=True) if old_len != len(data[DataSegment.outcome]): logging.warning(f"Removed {old_len - len(data[DataSegment.outcome])} duplicates from outcome data.") - return data + return data, vars -def modality_selection( +def modality_selection( # noqa: F811 data: dict[str, pl.DataFrame], modality_mapping: dict[str, list[str]], selected_modalities: list[str], @@ -444,7 +511,7 @@ def make_single_split_pandas( repetition_index: int, cv_folds: int, fold_index: int, - train_size: Optional[float] = None, + train_size: Optional[float] = 0.80, seed: int = 42, debug: bool = False, runmode: RunMode = RunMode.classification, @@ -557,8 +624,10 @@ def make_single_split_polars( outer_cv = StratifiedShuffleSplit(cv_repetitions, train_size=train_size) else: outer_cv = StratifiedKFold(cv_repetitions, shuffle=True, random_state=seed) - - inner_cv = StratifiedKFold(cv_folds, shuffle=True, random_state=seed) + if cv_folds > 2: + inner_cv = StratifiedKFold(cv_folds, shuffle=True, random_state=seed) + else: + inner_cv = StratifiedShuffleSplit(cv_folds, train_size=0.75, random_state=seed) dev, test = list(outer_cv.split(stays, labels))[repetition_index] dev_stays = stays[dev] train, val = list(inner_cv.split(dev_stays, labels[dev]))[fold_index] @@ -569,6 +638,7 @@ def make_single_split_polars( outer_cv = ShuffleSplit(cv_repetitions, train_size=train_size) else: outer_cv = KFold(cv_repetitions, shuffle=True, random_state=seed) + inner_cv = KFold(cv_folds, shuffle=True, random_state=seed) dev, test = list(outer_cv.split(stays))[repetition_index] dev_stays = stays[dev] @@ -666,3 +736,137 @@ def check_required_keys(vars, required_keys): missing_keys = [key for key in required_keys if key not in vars] if missing_keys: raise KeyError(f"Missing required keys in vars: {', '.join(missing_keys)}") + + +def reduce_stays_by_steps( + data: dict[str, pl.DataFrame], + vars: dict[str, Union[str, list[str]]], + steps_to_remove: int, + remove_short_stays: bool = True, + min_remaining_steps: int = 1 +) -> dict[str, pl.DataFrame]: + """ + Reduce stays by removing the last x sequence steps from each stay. + + Args: + data: Dictionary containing DataFrames for different segments (DYNAMIC, STATIC, OUTCOME) + vars: Dictionary containing variable names including GROUP and SEQUENCE + steps_to_remove: Number of sequence steps to remove from the end of each stay + remove_short_stays: Whether to remove stays that have fewer steps than steps_to_remove + min_remaining_steps: Minimum number of steps that must remain after reduction + + Returns: + Modified data dictionary with reduced stays + """ + if steps_to_remove <= 0: + logging.warning("steps_to_remove must be positive. No changes made.") + return data + + group_var = vars[VarType.group] + sequence_var = vars[VarType.sequence] + + if not isinstance(group_var, str) or not isinstance(sequence_var, str): + raise TypeError(f"GROUP and SEQUENCE variables must be strings, got {type(group_var)} and {type(sequence_var)}") + + # Check if we have dynamic data with sequences + if DataSegment.dynamic not in data: + logging.warning("No dynamic data found. Cannot reduce stays by sequence steps.") + return data + + dynamic_data = data[DataSegment.dynamic] + + # Get stay lengths from dynamic data + stay_lengths = dynamic_data.group_by(group_var).len().sort(group_var) + total_stays = stay_lengths.height + + stays_to_remove = [] + stays_modified = 0 + + # Identify stays that are too short + short_stays = stay_lengths.filter(pl.col("len") <= steps_to_remove) + if short_stays.height > 0: + short_stay_ids = short_stays[group_var].to_list() + if remove_short_stays: + stays_to_remove.extend(short_stay_ids) + logging.info(f"Removing {len(short_stay_ids)} stays with <= {steps_to_remove} steps: {short_stay_ids}") + else: + logging.warning( + f"Found {len(short_stay_ids)} stays with <= {steps_to_remove} steps. Keeping them unchanged.") + + # Identify stays that would have too few remaining steps + insufficient_stays = stay_lengths.filter( + (pl.col("len") > steps_to_remove) & + (pl.col("len") - steps_to_remove < min_remaining_steps) + ) + if insufficient_stays.height > 0: + insufficient_stay_ids = insufficient_stays[group_var].to_list() + if remove_short_stays: + stays_to_remove.extend(insufficient_stay_ids) + logging.info( + f"Removing {len(insufficient_stay_ids)} stays that would have < {min_remaining_steps} steps after reduction: {insufficient_stay_ids}") + else: + logging.warning( + f"Found {len(insufficient_stay_ids)} stays that would have < {min_remaining_steps} steps after reduction. Keeping them unchanged.") + + # Remove identified stays from all data segments + if stays_to_remove: + for segment in data: + original_length = data[segment].height + data[segment] = data[segment].filter(~pl.col(group_var).is_in(stays_to_remove)) + removed_count = original_length - data[segment].height + logging.info(f"Removed {removed_count} rows from {segment} segment") + + # Reduce remaining stays by removing last steps from dynamic data + valid_stays = stay_lengths.filter( + (pl.col("len") > steps_to_remove) & + (pl.col("len") - steps_to_remove >= min_remaining_steps) & + (~pl.col(group_var).is_in(stays_to_remove)) + )[group_var].to_list() + + if valid_stays: + # For dynamic data: keep only the first (total_length - steps_to_remove) rows per stay + reduced_dynamic_parts = [] + unchanged_dynamic = data[DataSegment.dynamic].filter(~pl.col(group_var).is_in(valid_stays + stays_to_remove)) + + for stay_id in valid_stays: + stay_data = data[DataSegment.dynamic].filter(pl.col(group_var) == stay_id) + current_length = stay_data.height + keep_length = current_length - steps_to_remove + + # Sort by sequence and keep first keep_length rows + reduced_stay = stay_data.sort(sequence_var).head(keep_length) + reduced_dynamic_parts.append(reduced_stay) + stays_modified += 1 + + # Reconstruct dynamic data + if reduced_dynamic_parts: + data[DataSegment.dynamic] = pl.concat([unchanged_dynamic] + reduced_dynamic_parts).sort( + [group_var, sequence_var]) + + # Handle outcome data if it has sequence information (sequence-to-sequence case) + if DataSegment.outcome in data and sequence_var in data[DataSegment.outcome].columns: + logging.info("Outcome data has sequence information. Reducing outcome sequences to match dynamic data.") + + reduced_outcome_parts = [] + unchanged_outcome = data[DataSegment.outcome].filter( + ~pl.col(group_var).is_in(valid_stays + stays_to_remove)) + + for stay_id in valid_stays: + stay_outcome = data[DataSegment.outcome].filter(pl.col(group_var) == stay_id) + current_length = stay_outcome.height + keep_length = current_length - steps_to_remove + + # Sort by sequence and keep first keep_length rows + reduced_outcome = stay_outcome.sort(sequence_var).head(keep_length) + reduced_outcome_parts.append(reduced_outcome) + + # Reconstruct outcome data + if reduced_outcome_parts: + data[DataSegment.outcome] = pl.concat([unchanged_outcome] + reduced_outcome_parts).sort( + [group_var, sequence_var]) + + logging.info(f"Stay reduction complete: {stays_modified} stays modified, " + f"{len(stays_to_remove)} stays removed, " + f"{total_stays - len(stays_to_remove)} stays remaining") + + return data \ No newline at end of file diff --git a/icu_benchmarks/data/utils.py b/icu_benchmarks/data/utils.py new file mode 100644 index 00000000..30afcae7 --- /dev/null +++ b/icu_benchmarks/data/utils.py @@ -0,0 +1,81 @@ +import logging +import polars as pl +import polars.selectors as cs +import numpy as np + +from icu_benchmarks.data.constants import VarType as Var, DataSegment as Segment + + +def infinite_removal(val): + for col in val.select(cs.numeric()).columns: + if val[col].is_infinite().any(): + logging.info(f"Column '{col}' contains infinite values. Datatype: {val[col].dtype}") + + max_float64 = np.finfo(np.float64).max / 100 + # Replace infinite values with the maximum value for float64 + val = val.with_columns( + [ + pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col) + for col in val.columns + if val[col].dtype == pl.Float64 + ] + ) + return val + + +def check_sanitize_data(data, vars): + """Check for duplicates in the loaded data and remove them.""" + group = vars[Var.group] if Var.group in vars.keys() else None + sequence = vars[Var.sequence] if Var.sequence in vars.keys() else None + keep = "last" + logging.info(data.keys()) + if Segment.static in data.keys(): + old_len = len(data[Segment.static]) + data[Segment.static] = data[Segment.static].unique(subset=group, keep=keep, maintain_order=True) + logging.warning(f"Removed {old_len - len(data[Segment.static])} duplicates from static data.") + if Segment.dynamic in data.keys(): + old_len = len(data[Segment.dynamic]) + data[Segment.dynamic] = data[Segment.dynamic].unique(subset=[group, sequence], keep=keep, maintain_order=True) + data[Segment.dynamic] = infinite_removal(data[Segment.dynamic]) + data[Segment.dynamic] = data[Segment.dynamic].select(~cs.starts_with("wearable_ppgfeature_HRV_SampEn")) + vars[Segment.dynamic] = [col for col in vars[Segment.dynamic] if col in data[Segment.dynamic].columns] + logging.warning(f"Removed {old_len - len(data[Segment.dynamic])} duplicates from dynamic data.") + if Segment.outcome in data.keys(): + old_len = len(data[Segment.outcome]) + if sequence in data[Segment.outcome].columns: + # We have a dynamic outcome with group and sequence + data[Segment.outcome] = data[Segment.outcome].unique(subset=[group, sequence], keep=keep, maintain_order=True) + else: + data[Segment.outcome] = data[Segment.outcome].unique(subset=[group], keep=keep, maintain_order=True) + logging.warning(f"Removed {old_len - len(data[Segment.outcome])} duplicates from outcome data.") + return data, vars + + +def modality_selection( + data: dict[pl.DataFrame], modality_mapping: dict[str], selected_modalities: list[str], vars +) -> dict[pl.DataFrame]: + logging.info(f"Selected modalities: {selected_modalities}") + selected_columns = [modality_mapping[cols] for cols in selected_modalities if cols in modality_mapping.keys()] + if not any(col in modality_mapping.keys() for col in selected_modalities): + raise ValueError("None of the selected modalities found in modality mapping.") + if selected_columns == []: + logging.info("No columns selected. Using all columns.") + return data, vars + selected_columns = sum(selected_columns, []) + selected_columns.extend([vars[Var.group], vars[Var.label], vars[Var.sequence]]) + old_columns = [] + # Update vars dict + for key, value in vars.items(): + if key not in [Var.group, Var.label, Var.sequence]: + old_columns.extend(value) + vars[key] = [col for col in value if col in selected_columns] + # -3 because of standard columns + logging.info(f"Selected columns: {len(selected_columns) - 3}, original columns: {len(old_columns)}, " + f"not using: {len(set(old_columns) - set(selected_columns))} columns") + logging.debug(f"Not using columns: {set(old_columns) - set(selected_columns)}") + # Update data dict + for key in data.keys(): + sel_col = [col for col in data[key].columns if col in selected_columns] + data[key] = data[key].select(sel_col) + logging.debug(f"Selected columns in {key}: {len(data[key].columns)}") + return data, vars diff --git a/icu_benchmarks/models/__init__.py b/icu_benchmarks/models/__init__.py index ffe34641..9821d2fe 100644 --- a/icu_benchmarks/models/__init__.py +++ b/icu_benchmarks/models/__init__.py @@ -4,7 +4,7 @@ from icu_benchmarks.models.dl_models.tcn import TemporalConvNet from icu_benchmarks.models.dl_models.transformer import BaseTransformer, LocalTransformer, Transformer from icu_benchmarks.models.ml_models.catboost import CBClassifier -from icu_benchmarks.models.ml_models.imblearn import BRFClassifier, RUSBClassifier +from icu_benchmarks.models.ml_models.imblearn import BRFClassifier, XGBEnsembleClassifier from icu_benchmarks.models.ml_models.lgbm import LGBMClassifier, LGBMRegressor from icu_benchmarks.models.ml_models.sklearn import ( ElasticNet, @@ -31,7 +31,7 @@ MLModelClassifier = Union[ XGBClassifier, LGBMClassifier, - RUSBClassifier, + XGBEnsembleClassifier, BRFClassifier, CBClassifier, LogisticRegression, @@ -57,7 +57,7 @@ "Transformer", "LocalTransformer", "CBClassifier", - "RUSBClassifier", + "XGBEnsembleClassifier", "BRFClassifier", "LGBMClassifier", "LGBMRegressor", diff --git a/icu_benchmarks/models/alarm_metrics.py b/icu_benchmarks/models/alarm_metrics.py new file mode 100644 index 00000000..68c8c21b --- /dev/null +++ b/icu_benchmarks/models/alarm_metrics.py @@ -0,0 +1,70 @@ +import numpy as np + +# def convert_to_alarm(y_true: ndarray, y_pred: ndarray, normalize=False) -> torch.tensor: +# y_pred = np.rint(y_pred).astype(int) +# confusion = sk_confusion_matrix(y_true, y_pred) +# +# return y_pred + + +def convert_to_alarm(ground_truth, predictions, grace_horizon=12, silencing_length=6): + """Convert the predictions to alarm predictions based on the grace horizon and silencing length.""" + # Silence positive predictions + if len(ground_truth) != len(predictions): + raise ValueError("Ground truth and predictions must have the same length.") + if len(ground_truth) < grace_horizon: + raise ValueError("Length of the ground truth must be greater than the grace horizon.") + if not isinstance(ground_truth, np.ndarray): + ground_truth = np.array(ground_truth) + if not isinstance(predictions, np.ndarray): + predictions = np.array(predictions) + threshold = 0.5 + predictions = np.where(predictions >= threshold, 1, 0) + silenced_predictions = silence_positives(ground_truth, predictions, grace_horizon, silencing_length) + # Fill gaps in the predictions + filled_predictions = fill_gaps(silenced_predictions, grace_horizon) + return filled_predictions, ground_truth + + +def silence_positives(ground_truth, predictions, grace_horizon=12, silencing_length=6): + """Silence positive predictions in the predictions array based on the grace horizon and silencing length.""" + # Find all positive indices in the rounded predictions + positive_indices = np.where(predictions == 1)[0] + positive_indices = positive_indices[positive_indices < len(ground_truth) - grace_horizon] + # print(positive_indices) + # Create the silence array + silence_array = np.ones_like(ground_truth) + + while len(positive_indices) > 0: + positive_index = positive_indices[0] + if silence_array[positive_indices[0]] == 0: + # print(f"Already silenced: {positive_index}") + positive_indices = positive_indices[1:] + continue + silence_array[positive_index + 1 : positive_index + silencing_length] = 0 + positive_indices = positive_indices[1:] + # print(predictions) + # print(silence_array) + silenced_predictions = predictions * silence_array + # print(silenced_predictions) + return silenced_predictions + + +def fill_gaps(predictions, ground_truth, grace_horizon=12): + """Fill gaps in the predictions by taking the maximum value between ground truth and predictions.""" + if grace_horizon > len(predictions): + grace_horizon = len(predictions) + # Take the last grace_horizon values + grace_values = predictions[-grace_horizon:] + # print(grace_values) + # Find the first occurrence of 1 in the grace_values + first_one_index = np.where(grace_values == 1) + if first_one_index[0].size > 0: + # Make all subsequent values after the first 1 positive + first_one_index = first_one_index[0][0] + grace_values[first_one_index:] = 1 + # Update the original prediction array + # print(grace_values) + predictions[-grace_horizon:] = grace_values + ground_truth[-grace_horizon:] = np.maximum(predictions[-grace_horizon:], ground_truth[-grace_horizon:]) + return predictions, ground_truth diff --git a/icu_benchmarks/models/constants.py b/icu_benchmarks/models/constants.py index 43843db8..088584ab 100644 --- a/icu_benchmarks/models/constants.py +++ b/icu_benchmarks/models/constants.py @@ -27,6 +27,10 @@ JSD, BinaryFairnessWrapper, confusion_matrix, + sensitivity, + specificity, + positive_predictive_value, + binary_incidence, ) @@ -38,6 +42,13 @@ class MLMetrics: "PR_Curve": precision_recall_curve, "RO_Curve": roc_curve, "Confusion_Matrix": confusion_matrix, + "Sensitivity": sensitivity, + # "Precision": precision_score, + # "Recall": recall_score, + "Specificity": specificity, + "PPV": positive_predictive_value, + # "MCC": matthews_corrcoef, + "Incidence": binary_incidence, } MULTICLASS_CLASSIFICATION = { @@ -63,6 +74,7 @@ class DLMetrics: "PR": AveragePrecision, "PR_Curve": PrecisionRecallCurve, "RO_Curve": RocCurve, + # "MCC": MatthewsCorrCoef } BINARY_CLASSIFICATION_TORCHMETRICS = { diff --git a/icu_benchmarks/models/custom_metrics.py b/icu_benchmarks/models/custom_metrics.py index eb0a5d23..5c621da3 100644 --- a/icu_benchmarks/models/custom_metrics.py +++ b/icu_benchmarks/models/custom_metrics.py @@ -3,7 +3,12 @@ import numpy as np from ignite.metrics import EpochMetric from numpy import ndarray -from sklearn.metrics import balanced_accuracy_score, mean_absolute_error, confusion_matrix as sk_confusion_matrix +from sklearn.metrics import ( + balanced_accuracy_score, + mean_absolute_error, + confusion_matrix as sk_confusion_matrix, + matthews_corrcoef, +) from sklearn.calibration import calibration_curve from scipy.spatial.distance import jensenshannon from torchmetrics.classification import BinaryFairness @@ -62,15 +67,15 @@ def __init__( invert_transform: Callable = lambda x: x, ) -> None: super(MAE, self).__init__( - lambda x, y: mae_with_invert_compute_fn(x, y, invert_transform), + lambda x, y: self.mae_with_invert_compute_fn(x, y, invert_transform), output_transform=output_transform, check_compute_fn=check_compute_fn, ) - def mae_with_invert_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor, invert_fn=Callable) -> float: - y_true = invert_fn(y_targets.numpy().reshape(-1, 1))[:, 0] - y_pred = invert_fn(y_preds.numpy().reshape(-1, 1))[:, 0] - return mean_absolute_error(y_true, y_pred) + def mae_with_invert_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor, invert_fn=Callable) -> float: + y_true = invert_fn(y_targets.numpy().reshape(-1, 1))[:, 0] + y_pred = invert_fn(y_preds.numpy().reshape(-1, 1))[:, 0] + return mean_absolute_error(y_true, y_pred) class JSD(EpochMetric): @@ -143,3 +148,90 @@ def confusion_matrix(y_true: ndarray, y_pred: ndarray, normalize=False) -> torch for j in range(confusion.shape[1]): confusion_dict[f"class_{i}_pred_{j}"] = confusion[i][j] return confusion_dict + + +def matthews_correlation_coefficient(y_true: ndarray, y_pred: ndarray, normalize=False) -> float: + if y_pred.ndim == 2: + y_pred = np.argmax(y_pred, axis=-1) + return matthews_corrcoef() + + +class Sensitivity(EpochMetric): + def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False) -> None: + super(Sensitivity, self).__init__( + self.sensitivity_compute, output_transform=output_transform, check_compute_fn=check_compute_fn + ) + + +def sensitivity(y_preds: torch.Tensor | ndarray, y_targets: torch.Tensor | ndarray) -> float: + if isinstance(y_preds, torch.Tensor): + y_preds = y_preds.numpy() + if isinstance(y_targets, torch.Tensor): + y_targets = y_targets.numpy() + y_true = np.rint(y_targets).astype(int) + y_pred = np.rint(y_preds).astype(int) + tn, fp, fn, tp = sk_confusion_matrix(y_true, y_pred).ravel() + return tp / (tp + fn) + + +# class Specificity(EpochMetric): +# def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False) -> None: +# super(Specificity, self).__init__( +# self.specificity_compute, output_transform=output_transform, check_compute_fn=check_compute_fn +# ) + + +def specificity(y_preds: torch.Tensor, y_targets: torch.Tensor) -> float: + if isinstance(y_preds, torch.Tensor): + y_preds = y_preds.numpy() + if isinstance(y_targets, torch.Tensor): + y_targets = y_targets.numpy() + y_true = np.rint(y_targets).astype(int) + y_pred = np.rint(y_preds).astype(int) + tn, fp, fn, tp = sk_confusion_matrix(y_true, y_pred).ravel() + return tn / (tn + fp) + + +def positive_predictive_value(y_preds: torch.Tensor | np.ndarray, y_targets: torch.Tensor | np.ndarray) -> float: + if isinstance(y_preds, torch.Tensor): + y_preds = y_preds.numpy() + if isinstance(y_targets, torch.Tensor): + y_targets = y_targets.numpy() + y_true = np.rint(y_targets).astype(int) + y_pred = np.rint(y_preds).astype(int) + tn, fp, fn, tp = sk_confusion_matrix(y_true, y_pred).ravel() + return tp / (tp + fp) + + +def binary_incidence(y_preds, y_targets): + """ + Computes the binary incidence (proportion of positive labels). + + Args: + y_true (numpy.ndarray): Ground truth binary labels (0 or 1). + + Returns: + float: Proportion of positive labels. + """ + y_true = np.rint(y_targets).astype(int) # Ensure binary labels + return np.sum(y_true) / len(y_true) + + +# from torchmetrics.classification import Specificity as TorchMetricsSpecificity +# +# class Specificity(EpochMetric): +# def __init__(self, task="binary", output_transform: Callable = lambda x: x, check_compute_fn: bool = False) -> None: +# super(Specificity, self).__init__( +# self.specificity_compute, output_transform=output_transform, check_compute_fn=check_compute_fn +# ) +# if isinstance(task, np.ndarray): +# task = task.item() +# self.metric = TorchMetricsSpecificity(task=task) +# +# def specificity_compute(self, y_preds: torch.Tensor, y_targets: torch.Tensor) -> float: +# if isinstance(y_preds, np.ndarray): +# y_preds = torch.tensor(y_preds) +# if isinstance(y_targets, np.ndarray): +# y_targets = torch.tensor(y_targets) +# self.metric.update(y_preds, y_targets) +# return self.metric.compute().item() diff --git a/icu_benchmarks/models/ml_models/imblearn.py b/icu_benchmarks/models/ml_models/imblearn.py index d1db0703..42916da8 100644 --- a/icu_benchmarks/models/ml_models/imblearn.py +++ b/icu_benchmarks/models/ml_models/imblearn.py @@ -1,8 +1,12 @@ +import logging + from imblearn.ensemble import BalancedRandomForestClassifier, RUSBoostClassifier from icu_benchmarks.constants import RunMode from icu_benchmarks.models.wrappers import MLWrapper import gin - +from sklearn.tree import DecisionTreeClassifier +from imblearn.ensemble import RUSBoostClassifier, BalancedRandomForestClassifier, EasyEnsembleClassifier +import xgboost as xgb @gin.configurable class BRFClassifier(MLWrapper): @@ -14,9 +18,40 @@ def __init__(self, *args, **kwargs): @gin.configurable -class RUSBClassifier(MLWrapper): +class XGBEnsembleClassifier(MLWrapper): _supported_run_modes = [RunMode.classification] - + individual_model = xgb.XGBClassifier( + learning_rate=0.1, + n_estimators=5000, + max_depth=10, + scale_pos_weight=30, + min_child_weight=1, + max_delta_step=3, + colsample_bytree=0.25, + gamma=0.9, + reg_lambda=0.1, + reg_alpha=100, + random_state=42, + eval_metric='logloss' + ) def __init__(self, *args, **kwargs): - self.model = self.set_model_args(RUSBoostClassifier, *args, **kwargs) + self.model = self.set_model_args(EasyEnsembleClassifier, *args, **kwargs, estimator=self.individual_model) super().__init__(*args, **kwargs) + + def fit_model(self, train_data, train_labels, val_data, val_labels): + import xgboost as xgb + from sklearn.metrics import log_loss + + # Try XGBoost first with provided config + try: + self.model.fit(train_data, train_labels,) + + val_pred_proba = self.model.predict_proba(val_data) + val_loss = log_loss(val_labels, val_pred_proba) + logging.info(f"XGBoost model trained successfully. Validation loss: {val_loss:.4f}") + return val_loss + + except Exception as e: + logging.warning(f"XGBoost failed: {e}") + + return val_loss diff --git a/icu_benchmarks/models/ml_models/lgbm.py b/icu_benchmarks/models/ml_models/lgbm.py index c2207555..6d262b13 100644 --- a/icu_benchmarks/models/ml_models/lgbm.py +++ b/icu_benchmarks/models/ml_models/lgbm.py @@ -1,6 +1,10 @@ import gin import lightgbm as lgbm import numpy as np +import shap +import sys +from io import StringIO +import os import wandb from wandb.integration.lightgbm import wandb_callback as wandb_lgbm @@ -12,16 +16,36 @@ class LGBMWrapper(MLWrapper): def fit_model(self, train_data, train_labels, val_data, val_labels): """Fitting function for LGBM models.""" self.model.set_params(random_state=np.random.get_state()[1][0]) - callbacks = [lgbm.early_stopping(self.hparams.patience, verbose=True), lgbm.log_evaluation(period=-1)] - - if wandb.run is not None: - callbacks.append(wandb_lgbm()) - - self.model = self.model.fit( - train_data, - train_labels, - eval_set=(val_data, val_labels), - callbacks=callbacks, + # callbacks = [lgbm.early_stopping(self.hparams.patience, verbose=False)] + callbacks = [ + lgbm.log_evaluation(period=0), # Disable default logging + ] + # callbacks = [] + # Redirect stderr and stdout to suppress warnings + stderr_backup = sys.stderr + stdout_backup = sys.stdout + sys.stderr = StringIO() + sys.stdout = StringIO() + try: + self.model = self.model.fit( + train_data, + train_labels, + eval_set=(val_data, val_labels), + callbacks=callbacks, + ) + finally: + sys.stderr = stderr_backup + # to surpress: [Warning] No further splits with positive gain, best gain: -inf with the current config + sys.stdout = stdout_backup + reduce_samples = False + if reduce_samples: + n_samples = min(1000, len(train_data)) + indices = np.random.choice(len(train_data), size=n_samples, replace=False) + else: + indices = np.arange(len(train_data)) + background_sample = train_data[indices] + self.explainer = shap.TreeExplainer( + self.model, background_sample, feature_perturbation="interventional", model_output="probability" ) val_loss = list(self.model.best_score_["valid_0"].values())[0] return val_loss @@ -32,7 +56,8 @@ class LGBMClassifier(LGBMWrapper): _supported_run_modes = [RunMode.classification] def __init__(self, *args, **kwargs): - self.model = self.set_model_args(lgbm.LGBMClassifier, *args, **kwargs) + kwargs.setdefault("verbosity", -1) + self.model = self.set_model_args(lgbm.LGBMClassifier, *args, **kwargs, verbose=-1) super().__init__(*args, **kwargs) def predict(self, features): @@ -47,6 +72,14 @@ def predict(self, features): """ return self.model.predict_proba(features) + def _explain_model(self, reps, labels): + if not hasattr(self.model, "feature_importances_"): + raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.") + # feature_importances = self.model.feature_importances_ + shap_values = self.explainer.shap_values(reps, labels) + # feature_importances = np.abs(shap_values).mean(axis=1) + return shap_values + @gin.configurable class LGBMRegressor(LGBMWrapper): diff --git a/icu_benchmarks/models/ml_models/xgboost.py b/icu_benchmarks/models/ml_models/xgboost.py index f1802ece..064d4b8e 100644 --- a/icu_benchmarks/models/ml_models/xgboost.py +++ b/icu_benchmarks/models/ml_models/xgboost.py @@ -1,4 +1,3 @@ -import inspect import logging from statistics import mean @@ -6,8 +5,10 @@ import shap import wandb import xgboost as xgb +import numpy as np from xgboost.callback import EarlyStopping from wandb.integration.xgboost import wandb_callback as wandb_xgb +from sklearn.metrics import log_loss from icu_benchmarks.constants import RunMode from icu_benchmarks.models.wrappers import MLWrapper @@ -20,23 +21,22 @@ @gin.configurable class XGBClassifier(MLWrapper): _supported_run_modes = [RunMode.classification] - _explain_values = False def __init__(self, *args, **kwargs): - self.model = self.set_model_args(xgb.XGBClassifier, *args, **kwargs, device="cpu") + self.model = self.set_model_args(xgb.XGBClassifier, *args, **kwargs, eval_metric=log_loss, device="cpu", verbosity=0) super().__init__(*args, **kwargs) - def predict(self, features): - """ - Predicts class probabilities for the given features. - - Args: - features: Input features for prediction. - - Returns: - numpy.ndarray: Predicted probabilities for each class. - """ - return self.model.predict_proba(features) + # def predict(self, features): + # """ + # Predicts class probabilities for the given features. + # + # Args: + # features: Input features for prediction. + # + # Returns: + # numpy.ndarray: Predicted probabilities for each class. + # """ + # return self.model.predict_proba(features) def fit_model(self, train_data, train_labels, val_data, val_labels): """Fit the model to the training data (default SKlearn syntax)""" @@ -44,12 +44,19 @@ def fit_model(self, train_data, train_labels, val_data, val_labels): if wandb.run is not None: callbacks.append(wandb_xgb()) - logging.debug(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}") - logging.debug(train_labels) - self.model.fit(train_data, train_labels, eval_set=[(val_data, val_labels)], verbose=False) - if self._explain_values: - self.explainer = shap.TreeExplainer(self.model) - self.train_shap_values = self.explainer(train_data) + logging.info(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}") + logging.info(train_labels) + self.model.fit(train_data, train_labels, eval_set=[(val_data, val_labels)], verbose=0) + + n_samples = min(1000, len(train_data)) + indices = np.random.choice(len(train_data), size=n_samples, replace=False) + background_sample = train_data[indices] + self.explainer = shap.TreeExplainer( + self.model, background_sample, feature_perturbation="interventional", model_output="probability" + ) + # if self.explain_features: + # logging.info("Explaining features") + # self.train_shap_values = self.explainer.shap_values(train_data) # shap.summary_plot(shap_values, X_test, feature_names=features) # logging.info(self.model.get_booster().get_score(importance_type='weight')) # self.log_dict(self.model.get_booster().get_score(importance_type='weight')) @@ -59,16 +66,86 @@ def fit_model(self, train_data, train_labels, val_data, val_labels): def set_model_args(self, model, *args, **kwargs): """XGBoost signature does not include the hyperparams so we need to pass them manually.""" - signature = inspect.signature(model.__init__).parameters - valid_params = signature.keys() - + # signature = inspect.signature(model.__init__).parameters + # valid_params = signature.keys() + valid_params = model().get_params().keys() # Filter out invalid arguments valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} - + if len(valid_kwargs) == 0: + logging.warning("No valid arguments passed to XGBoost") logging.debug(f"Creating model with: {valid_kwargs}.") return model(**valid_kwargs) - def get_feature_importance(self): + def _explain_model(self, reps, labels): if not hasattr(self.model, "feature_importances_"): - raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.") - return self.model.feature_importances_ + raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.") + # feature_importances = self.model.feature_importances_ + shap_values = self.explainer.shap_values(reps, labels) + # feature_importances = np.abs(shap_values).mean(axis=1) + return shap_values + +@gin.configurable +class XGBClassifierGPU(MLWrapper): + _supported_run_modes = [RunMode.classification] + _explain_values = False + + def __init__(self, *args, **kwargs): + # self.model = self.set_model_args( + # xgb, *args, **kwargs, eval_metric="logloss", tree_method="hist", device="cuda", verbosity=0 + # ) + self.model = xgb + self.params = { + "eval_metric": "logloss", + "tree_method": "hist", + "device": "cuda", + "verbosity": 0, + **kwargs, + } + super().__init__(*args, **kwargs) + + def predict(self, features): + """ + Predicts class probabilities for the given features. + + Args: + features: Input features for prediction. + + Returns: + numpy.ndarray: Predicted probabilities for each class. + """ + return self.model.predict(xgb.DMatrix(features)) + + def fit_model(self, train_data, train_labels, val_data, val_labels): + """ + Train the model using the XGBoost `train` method. + + Args: + train_data: Training features. + train_labels: Training labels. + val_data: Validation features. + val_labels: Validation labels. + + Returns: + float: Evaluation score on the validation set. + """ + dtrain = xgb.DMatrix(train_data, label=train_labels) + dval = xgb.DMatrix(val_data, label=val_labels) + evals = [(dtrain, "train"), (dval, "validation")] + + callbacks = [EarlyStopping(self.hparams.patience)] + + if wandb.run is not None: + callbacks.append(wandb_xgb()) + self.model.train(self.params, train_data=dtrain, evals=evals, callbacks=callbacks) + # self.model.fit(train_data, train_labels, eval_set=[(val_data, val_labels)], verbose=0) + + # shap_interaction_values = self.model.predict(dtrain) + # self.explainer = shap.TreeExplainer( + # self.model, dtrain, feature_perturbation="interventional", model_output="probability" + # ) + # if self.explain_features: + # logging.info("Explaining features") + # self.train_shap_values = self.explainer.shap_values(dtrain) + + eval_score = mean(next(iter(self.model.evals_result()["validation"].values()))) + return eval_score diff --git a/icu_benchmarks/models/train.py b/icu_benchmarks/models/train.py index 1e5360b1..21593117 100644 --- a/icu_benchmarks/models/train.py +++ b/icu_benchmarks/models/train.py @@ -4,7 +4,6 @@ from typing import Literal, Optional import gin -import numpy as np import polars as pl import torch from joblib import load @@ -58,6 +57,7 @@ def train_common( num_workers: int = min(cpu_core_count, torch.cuda.device_count() * 8 * int(torch.cuda.is_available()), 32), polars: bool = True, persistent_workers: bool = False, + explain_features: bool = False, ): """Common wrapper to train all benchmarked models. @@ -182,6 +182,8 @@ def train_common( return 0 test_dataset = dataset_class(data, split=test_on, name=dataset_names["test"], ram_cache=ram_cache) test_dataset = assure_minimum_length(test_dataset) + # Set explainer for testing + model.set_explain_features(explain_features) logging.info(f"Testing on {test_dataset.name} with {len(test_dataset)} samples.") test_loader = ( DataLoader( @@ -199,33 +201,65 @@ def train_common( model.set_weight("balanced", train_dataset) test_loss = trainer.test(model, dataloaders=test_loader, verbose=verbose)[0]["test/loss"] - persist_shap_data(trainer, log_dir) + if explain_features: + persist_shap_data(trainer, log_dir) save_config_file(log_dir) return test_loss -def persist_shap_data(trainer: Trainer, log_dir: Path): +def persist_shap_data(trainer: Trainer, log_dir: Path, save_full_valuesets=True) -> None: """ Persist shap values to disk. Args: trainer: Pytorch lightning trainer object log_dir: Log directory """ + logging.info("Persisting explainer values to disk.") try: - if trainer.lightning_module.test_shap_values is not None: - shap_values = trainer.lightning_module.test_shap_values - shaps_test = pl.DataFrame(schema=trainer.lightning_module.trained_columns, data=np.transpose(shap_values.values)) - with (log_dir / "shap_values_test.parquet").open("wb") as f: + trained_columns = trainer.lightning_module.trained_columns + + if ( + any(x in ["stay_id", "id"] for x in trained_columns) + and len(trained_columns) != trainer.lightning_module.explainer_values_test.shape[1] + ): + trained_columns.remove("stay_id" if "stay_id" in trained_columns else "id") + logging.info("Saving SHAPS") + if hasattr(trainer.lightning_module, "explainer_values_test"): + # todo: abs values + explainer_values = trainer.lightning_module.explainer_values_test + shaps_test = pl.DataFrame(schema=trained_columns, data=explainer_values) + # reps_test = pl.DataFrame(schema=trained_columns, data=trainer.lightning_module.representations_test) + if save_full_valuesets: + with (log_dir / "full_explainer_values_test.parquet").open("wb") as f: + shaps_test.write_parquet(f) + # with (log_dir / "explainer_reps_test.parquet").open("wb") as f: + # reps_test.write_parquet(f) + if hasattr(trainer.lightning_module, "rep_test"): + reps_test = pl.DataFrame(schema=trained_columns, data=trainer.lightning_module.rep_test) + with (log_dir / "explainer_rep_test.parquet").open("wb") as f: + reps_test.write_parquet(f) + labels = pl.DataFrame(schema=["label"], data=trainer.lightning_module.label_test) + with (log_dir / "explainer_label_test.parquet").open("wb") as f: + labels.write_parquet(f) + # logging.info(f"{explainer_values.values.shape}") + # shaps_test = pl.DataFrame(schema=trainer.lightning_module.trained_columns, data=np.transpose(explainer_values.values)) + shaps_test = shaps_test.select(pl.all().mean()) + with (log_dir / "explainer_values_test.parquet").open("wb") as f: shaps_test.write_parquet(f) - logging.info(f"Saved shap values to {log_dir / 'test_shap_values.parquet'}") - if trainer.lightning_module.train_shap_values is not None: - shap_values = trainer.lightning_module.train_shap_values - shaps_train = pl.DataFrame(schema=trainer.lightning_module.trained_columns, data=np.transpose(shap_values.values)) - with (log_dir / "shap_values_train.parquet").open("wb") as f: + logging.debug(f"Saved test explainer values to {log_dir / 'explainer_values_test.parquet'}") + else: + logging.warning("No explainer values for test set found. Skipping saving.") + if hasattr(trainer.lightning_module, "explainer_values_train"): + explainer_values = trainer.lightning_module.explainer_values_train + # shaps_train = pl.DataFrame(schema=trainer.lightning_module.trained_columns, data=np.transpose(explainer_values.values)) + shaps_train = pl.DataFrame(schema=trained_columns, data=explainer_values) + with (log_dir / "explainer_values_train.parquet").open("wb") as f: shaps_train.write_parquet(f) + else: + logging.warning("No explainer values for train set found. Skipping saving.") except Exception as e: - logging.error(f"Failed to save shap values: {e}") + logging.error(f"Failed to save explainer values: {e}") def load_model(model, source_dir, pl_model=True) -> DLModel | MLModelClassifier | MLModelRegression: diff --git a/icu_benchmarks/models/utils.py b/icu_benchmarks/models/utils.py index fc5b5506..4cf35980 100644 --- a/icu_benchmarks/models/utils.py +++ b/icu_benchmarks/models/utils.py @@ -8,6 +8,7 @@ import logging import numpy as np import torch +import csv from pytorch_lightning.loggers.logger import Logger from pytorch_lightning.utilities import rank_zero_only @@ -150,7 +151,7 @@ def __init__(self, output_dir=None, **kwargs): super().__init__(**kwargs) if output_dir is None: output_dir = Path.cwd() / "metrics" - logging.info(f"logging metrics to file: {str(output_dir.resolve())}") + logging.info(f"Logging metrics to file: {str(output_dir.resolve())}") self.output_dir = output_dir self.output_dir.mkdir(parents=True, exist_ok=True) @@ -282,3 +283,33 @@ def get_smoothed_labels( return np.array( list(map(lambda x: smoothing_fn(x, h_true=h_true, h_min=h_min, h_max=h_max, delta_h=delta_h, gamma=gamma), dte)) ) + + + + +def log_single_metric_to_file(metric_name: str, data_points: tuple[np.ndarray], output_file: Path) -> None: + """ + Logs a single metric to a file, where the input is a tuple of numpy arrays. + + Args: + metric_name (str): Name of the metric. + data_points (tuple[np.ndarray]): Tuple of numpy arrays to log. + output_file (Path): Path to the file where the metric will be saved. + + Raises: + ValueError: If data_points is not a tuple of numpy arrays or output_file is not a valid Path. + """ + if not isinstance(data_points, tuple) or not all(isinstance(arr, np.ndarray) for arr in data_points): + raise ValueError("data_points must be a tuple of numpy arrays.") + if not isinstance(output_file, Path): + raise ValueError("output_file must be a Path object.") + + output_file.parent.mkdir(parents=True, exist_ok=True) # Ensure the directory exists + + with output_file.open(mode="w", newline="") as file: + writer = csv.writer(file) + # Write header + writer.writerow(["Row"] + [f"Column {i}" for i in range(data_points[0].size)]) + # Write each numpy array as a row + for index, array in enumerate(data_points): + writer.writerow([index] + array.tolist()) diff --git a/icu_benchmarks/models/wrappers.py b/icu_benchmarks/models/wrappers.py index 9c0a6f68..b7ae2512 100644 --- a/icu_benchmarks/models/wrappers.py +++ b/icu_benchmarks/models/wrappers.py @@ -4,6 +4,7 @@ from pathlib import Path import torchmetrics from sklearn.metrics import log_loss, mean_squared_error, average_precision_score, roc_auc_score +from sklearn.calibration import CalibratedClassifierCV import torch from torch.nn import MSELoss, CrossEntropyLoss @@ -17,7 +18,7 @@ from ignite.exceptions import NotComputableError from icu_benchmarks.models.constants import ImputationInit from icu_benchmarks.models.custom_metrics import confusion_matrix -from icu_benchmarks.models.utils import create_optimizer, create_scheduler +from icu_benchmarks.models.utils import create_optimizer, create_scheduler, log_single_metric_to_file from joblib import dump from pytorch_lightning import LightningModule @@ -47,6 +48,7 @@ class BaseModule(LightningModule): # Type of run mode run_mode = None debug = False + # We do not want to explain features by default as it is expensive (and not needed during hp tuning) explain_features = False def forward(self, *args, **kwargs): @@ -62,6 +64,7 @@ def set_metrics(self, *args, **kwargs): self.metrics = {} def set_trained_columns(self, columns: List[str]): + logging.info(f"Setting trained columns: {len(columns)}") self.trained_columns = columns def set_weight(self, weight, dataset): @@ -104,6 +107,64 @@ def check_supported_runmode(self, runmode: RunMode): raise ValueError(f"Runmode {runmode} not supported for {self.__class__.__name__}") return True + def set_explain_features(self, set_explain_features: bool): + self.explain_features = set_explain_features + self.persist_reps = set_explain_features + + def _explain_model(self, reps, labels): + raise NotImplementedError(f"Model {self.__class__.__name__} does not currently support feature explanation.") + + def _save_model_outputs(self, pred_indicators, test_pred, test_label): + "Save model outputs to CSV file for further post-hoc analysis" + if len(pred_indicators.shape) > 1 and len(test_pred.shape) > 1: + # Temporal dataset + if pred_indicators.shape[1] == test_pred.shape[1] and pred_indicators.shape[0] == test_pred.shape[0]: + # One outcome per dataset + pred_indicators = np.hstack((pred_indicators, test_label.reshape(-1, 1))) + pred_indicators = np.hstack((pred_indicators, test_pred)) + # Save as: id, time (hours), ground truth, prediction 0, prediction 1 + # if pred_indicators.shape(1) == 5: + np.savetxt( + Path(self.logger.save_dir) / "pred_indicators.csv", + pred_indicators, + delimiter=",", + header="id,time,ground_truth,prediction_0,prediction_1", + fmt="%d,%d,%.3f,%.3f,%.3f", + ) + logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / 'row_indicators.csv'}") + # else: + # # Flat/static dataset + # np.savetxt(Path(self.logger.save_dir) / "pred_indicators.csv", pred_indicators, delimiter=",", + # header="id,ground_truth,prediction_0,prediction_1", fmt='%d,%d,%.3f,%.3f,%.3f') + # np.savetxt(Path(self.logger.save_dir) / "pred_indicators.csv", pred_indicators, delimiter=",", + # header="id,time,ground_truth,prediction_0,prediction_1", fmt='%d,%d,%.3f,%.3f,%.3f') + # logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / f'row_indicators.csv'}") + else: + logging.info(np.unique(pred_indicators[:, 0])) + pred_indicators = np.unique(pred_indicators[:, 0]) + pred_indicators = np.hstack((pred_indicators.reshape(-1, 1), test_label.reshape(-1, 1))) + pred_indicators = np.hstack((pred_indicators, test_pred)) + + np.savetxt( + Path(self.logger.save_dir) / "pred_indicators.csv", + pred_indicators, + delimiter=",", + header="id,ground_truth,prediction_0,prediction_1", + fmt="%d,%d,%0.3f,%0.3f", + ) + logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / 'row_indicators.csv'}") + else: + pred_indicators = np.hstack((pred_indicators.reshape(-1, 1), test_label.reshape(-1, 1))) + logging.info(pred_indicators.shape) + pred_indicators = np.hstack((pred_indicators, test_pred)) + np.savetxt( + Path(self.logger.save_dir) / "pred_indicators.csv", + pred_indicators, + delimiter=",", + header="id,ground_truth,prediction_0,prediction_1", + fmt="%d,%d,%0.3f,%0.3f", + ) + logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / f'row_indicators.csv'}") @gin.configurable("DLWrapper") class DLWrapper(BaseModule, ABC): @@ -210,6 +271,10 @@ def on_test_epoch_start(self) -> None: step_name: {metric_name: metric() for metric_name, metric in self.set_metrics().items()} for step_name in ["train", "val", "test"] } + if hasattr(self.trainer, 'val_dataloaders') and self.trainer.val_dataloaders: + val_loader = self.trainer.val_dataloaders[0] if isinstance(self.trainer.val_dataloaders, + list) else self.trainer.val_dataloaders + self.setup_calibration(val_loader) return super().on_test_epoch_start() def save_model(self, save_path, file_name, file_extension=".ckpt"): @@ -260,6 +325,8 @@ def __init__( ) self.output_transform = None self.loss_weights = None + self.calibrated_model = None + def set_metrics(self, *args): """Set the evaluation metrics for the prediction model.""" @@ -306,8 +373,8 @@ def step_fn(self, element, step_prefix=""): step_prefix (str): Step type, by default: test, train, val. """ - if len(element) == 2: - data, labels = element[0], element[1].to(self.device) + if len(element) == 3: + data, labels, indicators = element[0], element[1].to(self.device), element[2] if isinstance(data, list): for i in range(len(data)): data[i] = data[i].float().to(self.device) @@ -315,8 +382,8 @@ def step_fn(self, element, step_prefix=""): data = data.float().to(self.device) mask = torch.ones_like(labels).bool() - elif len(element) == 3: - data, labels, mask = element[0], element[1].to(self.device), element[2].to(self.device) + elif len(element) == 4: + data, labels, mask, indicators = element[0], element[1].to(self.device), element[2].to(self.device), element[3] if isinstance(data, list): for i in range(len(data)): data[i] = data[i].float().to(self.device) @@ -335,10 +402,44 @@ def step_fn(self, element, step_prefix=""): prediction = torch.masked_select(out, mask.unsqueeze(-1)).reshape(-1, out.shape[-1]).to(self.device) target = torch.masked_select(labels, mask).to(self.device) + # valid_idx is (batch, seq_len) boolean mask + valid_idx = mask.detach().cpu().numpy().astype(bool) + + # indicators is (batch, seq_len, ...) or (batch, seq_len, 2) + indicators_np = indicators.detach().cpu().numpy()[valid_idx] + predictors = prediction.detach().cpu() + target_np = target.detach().cpu().numpy() + # transformed_predictors = torch.softmax(predictors, dim=1) + # transformed_predictors = transformed_predictors.numpy() + # Apply calibration for test step if available + if (step_prefix == "test" and hasattr(self, 'calibrated_model') and + self.calibrated_model is not None and self.run_mode == RunMode.classification): + # Get uncalibrated probabilities + uncalibrated_probs = torch.softmax(prediction, dim=1).cpu().numpy() + # Apply calibration + calibrated_probs = self.calibrated_model.predict_proba(uncalibrated_probs) + # Convert back to tensor for loss calculation (keeping original logits for loss) + calibrated_tensor = torch.tensor(calibrated_probs, device=self.device) + transformed_predictors = calibrated_probs + logging.info("Using calibrated predictions for test step") + else: + # Use uncalibrated predictions + predictors = prediction.detach().cpu() + transformed_predictors = torch.softmax(predictors, dim=1) + transformed_predictors = transformed_predictors.numpy() + + # if prediction.shape[-1] > 1 and self.run_mode == RunMode.classification: + # # Classification task + # loss = self.loss(prediction, target.long(), weight=self.loss_weights.to(self.device)) + aux_loss + # # Returns torch.long because negative log likelihood loss if prediction.shape[-1] > 1 and self.run_mode == RunMode.classification: # Classification task loss = self.loss(prediction, target.long(), weight=self.loss_weights.to(self.device)) + aux_loss - # Returns torch.long because negative log likelihood loss + # Use calibrated probabilities for output transform in test step + if step_prefix == "test" and hasattr(self, 'calibrated_model') and self.calibrated_model is not None: + transformed_output = self.output_transform((calibrated_tensor, target)) + else: + transformed_output = self.output_transform((prediction, target)) elif self.run_mode == RunMode.regression: # Regression task loss = self.loss(prediction[:, 0], target.float()) + aux_loss @@ -346,6 +447,9 @@ def step_fn(self, element, step_prefix=""): raise ValueError(f"Run mode {self.run_mode} not yet supported. Please implement it.") transformed_output = self.output_transform((prediction, target)) + # Save predictions to file + self._save_model_outputs(indicators_np, transformed_predictors, target_np) + for key, value in self.metrics[step_prefix].items(): if isinstance(value, torchmetrics.Metric): if key == "Binary_Fairness": @@ -358,6 +462,141 @@ def step_fn(self, element, step_prefix=""): self.log(f"{step_prefix}/loss", loss, on_step=False, on_epoch=True, sync_dist=True) return loss + def setup_calibration(self, val_loader, method='isotonic'): + """ + Setup model calibration using validation data for better probability estimates. + + Args: + val_loader: Validation dataloader for calibration + method: 'isotonic' or 'sigmoid' calibration method + + Returns: + float: Validation loss after calibration + """ + if self.run_mode != RunMode.classification: + logging.warning("Calibration only supported for classification tasks") + return None + + logging.info(f"Applying {method} calibration using validation data") + + # Collect validation predictions and labels + self.eval() + val_probs = [] + val_labels = [] + + with torch.no_grad(): + for batch in val_loader: + if len(batch) == 3: + data, labels, _ = batch + mask = torch.ones_like(labels).bool() + elif len(batch) == 4: + data, labels, mask, _ = batch + else: + raise Exception("Loader should return either (data, label) or (data, label, mask)") + + # Move to device + if isinstance(data, list): + data = [d.float().to(self.device) for d in data] + else: + data = data.float().to(self.device) + labels = labels.to(self.device) + mask = mask.to(self.device) + + # Get model predictions + out = self(data) + if len(out) == 2 and isinstance(out, tuple): + out, _ = out + + # Apply mask and get probabilities + prediction = torch.masked_select(out, mask.unsqueeze(-1)).reshape(-1, out.shape[-1]) + target = torch.masked_select(labels, mask) + + # Convert to probabilities + probs = torch.softmax(prediction, dim=1) + + val_probs.append(probs.cpu().numpy()) + val_labels.append(target.cpu().numpy()) + + # Concatenate all validation data + val_probs = np.vstack(val_probs) + val_labels = np.concatenate(val_labels) + + # Create sklearn-compatible wrapper for the PyTorch model + class PytorchModelWrapper: + def __init__(self, model, device): + self.model = model + self.device = device + self.classes_ = np.arange(val_probs.shape[1]) + + def predict_proba(self, X): + # X should be the raw features, but we'll use pre-computed probabilities + # This is a simplified approach - in practice you'd need to handle the full pipeline + return X # X is already probabilities in this case + + # Use pre-computed probabilities for calibration + wrapper = PytorchModelWrapper(self, self.device) + + # Create calibrated version + self.calibrated_model = CalibratedClassifierCV( + wrapper, + method=method, + cv='prefit' + ) + + # Fit calibration on validation probabilities + self.calibrated_model.fit(val_probs, val_labels) + + # Calculate calibrated validation score + cal_val_pred = self.calibrated_model.predict_proba(val_probs) + cal_val_loss = self.loss_fn_numpy(val_labels, cal_val_pred) + orig_val_loss = self.loss_fn_numpy(val_labels, val_probs) + + logging.info(f"Calibration complete. " + f"Original val loss: {orig_val_loss:.4f}, " + f"Calibrated val loss: {cal_val_loss:.4f}") + + return cal_val_loss + + + def loss_fn_numpy(self, labels, probs): + """Convert tensor loss to numpy equivalent for calibration evaluation.""" + if self.run_mode == RunMode.classification: + return log_loss(labels, probs) + else: + return mean_squared_error(labels, probs) + + + def predict_calibrated(self, features): + """Get calibrated predictions if calibration is available.""" + if self.calibrated_model is None: + logging.warning("Calibration not set up. Using base model predictions.") + return self.predict_uncalibrated(features) + + # Get base model probabilities + base_probs = self.predict_uncalibrated(features) + + # Apply calibration + if isinstance(base_probs, torch.Tensor): + base_probs = base_probs.cpu().numpy() + + cal_probs = self.calibrated_model.predict_proba(base_probs) + return torch.tensor(cal_probs, device=self.device) + + + def predict_uncalibrated(self, features): + """Get raw model predictions without calibration.""" + self.eval() + with torch.no_grad(): + if isinstance(features, list): + features = [f.float().to(self.device) for f in features] + else: + features = features.float().to(self.device) + + out = self(features) + if len(out) == 2 and isinstance(out, tuple): + out, _ = out + + return torch.softmax(out, dim=1) @gin.configurable("MLWrapper") class MLWrapper(BaseModule, ABC): @@ -376,6 +615,7 @@ def __init__(self, *args, run_mode=RunMode.classification, loss=log_loss, patien self.patience = patience self.mps = mps self.loss_weight = None + self.save_model_outputs = True def set_metrics(self, labels): if self.run_mode == RunMode.classification: @@ -403,6 +643,43 @@ def set_metrics(self, labels): self.label_transform = lambda x: x self.metrics = MLMetrics.REGRESSION + def setup_calibration(self, val_data, val_labels, method='sigmoid'): + """ + Setup model calibration using validation data for better probability estimates. + + Args: + val_data: Validation features for calibration + val_labels: Validation labels for calibration + method: 'isotonic' or 'sigmoid' calibration method + + Returns: + float: Validation loss after calibration + """ + if self.run_mode != RunMode.classification: + logging.warning("Calibration only supported for classification tasks") + return None + + logging.info(f"Applying {method} calibration using validation data") + + # Create calibrated version using validation data as holdout set + self.calibrated_model = CalibratedClassifierCV( + self.model, + method=method, + cv='prefit' # Use prefit model with holdout validation set + ) + self.calibrated_model.fit(val_data, val_labels) + + # Calculate calibrated validation score + cal_val_pred = self.calibrated_model.predict_proba(val_data) + cal_val_loss = self.loss(val_labels, cal_val_pred) + + logging.info(f"Calibration complete. " + f"Original val loss: {self.loss(val_labels, self.model.predict_proba(val_data)):.4f}, " + f"Calibrated val loss: {cal_val_loss:.4f}") + + return cal_val_loss + + def fit(self, train_dataset, val_dataset): """Fit the model to the training data.""" train_rep, train_label, row_indicators = train_dataset.get_data_and_labels() @@ -414,6 +691,17 @@ def fit(self, train_dataset, val_dataset): self.model.set_params(class_weight=self.weight) val_loss = self.fit_model(train_rep, train_label, val_rep, val_label) + calibrate = True + method = "sigmoid" + # Apply calibration if desired + calibrate = True + if calibrate and self.run_mode == RunMode.classification: + cal_val_loss = self.setup_calibration(val_rep, val_label, method=method) + logging.info(f"Model calibrated. val loss {val_loss} Calibrated val loss: {cal_val_loss}") + + if self.explain_features: + self.explainer_values_train = self._explain_model(train_rep, train_label) + train_pred = self.predict(train_rep) @@ -423,12 +711,13 @@ def fit(self, train_dataset, val_dataset): logging.debug(f"Train loss: {self.loss(train_label, train_pred)}") self.log("val/loss", val_loss, sync_dist=True) logging.debug(f"Val loss: {val_loss}") - self.log_metrics(train_label, train_pred, "train") + self.log_metrics(train_label, train_pred, "train", row_indicators) def fit_model(self, train_data, train_labels, val_data, val_labels): """Fit the model to the training data (default SKlearn syntax)""" self.model.fit(train_data, train_labels) - val_loss = 0.0 + val_pred = self.predict(val_data) + val_loss = self.loss(val_labels, val_pred) return val_loss def validation_step(self, val_dataset, _): @@ -442,6 +731,28 @@ def validation_step(self, val_dataset, _): logging.info(f"Val loss: {self.loss(val_label, val_pred)}") self.log_metrics(val_label, val_pred, "val") + def compare_rankings(self, features, labels): + """Compare rankings between base and calibrated models.""" + base_probs = self.model.predict_proba(features)[:, 1] + cal_probs = self.calibrated_model.predict_proba(features)[:, 1] + + # Check if rankings are identical + base_ranking = np.argsort(base_probs) + cal_ranking = np.argsort(cal_probs) + + ranking_identical = np.array_equal(base_ranking, cal_ranking) + + # Calculate metrics for both + from sklearn.metrics import roc_auc_score, average_precision_score + base_auroc = roc_auc_score(labels, base_probs) + cal_auroc = roc_auc_score(labels, cal_probs) + base_auprc = average_precision_score(labels, base_probs) + cal_auprc = average_precision_score(labels, cal_probs) + logging.info(f"ranking identical: {ranking_identical}") + logging.info(f"Base AUROC: {base_auroc}, Calibrated AUROC: {cal_auroc}, ") + logging.info(f"Base AUPRC: {base_auprc}, Calibrated AUPRC: {cal_auprc}, ") + + def test_step(self, dataset, _): test_rep, test_label, pred_indicators = dataset test_rep, test_label, pred_indicators = ( @@ -451,55 +762,115 @@ def test_step(self, dataset, _): ) self.set_metrics(test_label) test_pred = self.predict(test_rep) - if self.debug: + # test_pred_uncalibrated = self.predict(test_rep, use_calibrated=False)False + self.compare_rankings(test_rep, test_label) + self.log_curves(test_label, test_pred, "test", pred_indicators) + # if self.debug: + if self.save_model_outputs: self._save_model_outputs(pred_indicators, test_pred, test_label) if self.explain_features: - self.explain_model(test_rep, test_label) + # self.explain_model(test_rep, test_label) + self.explainer_values_test = self._explain_model(test_rep, test_label) + if self.persist_reps: + self.rep_test = test_rep + self.label_test = test_label if self.mps: self.log("test/loss", np.float32(self.loss(test_label, test_pred)), sync_dist=True) - self.log_metrics(np.float32(test_label), np.float32(test_pred), "test") + self.log_metrics(np.float32(test_label), np.float32(test_pred), "test", pred_indicators) else: self.log("test/loss", self.loss(test_label, test_pred), sync_dist=True) - self.log_metrics(test_label, test_pred, "test") - logging.debug(f"Test loss: {self.loss(test_label, test_pred)}") - - def predict(self, features): - if self.run_mode == RunMode.regression: - return self.model.predict(features) - else: # Classification: return probabilities - return self.model.predict_proba(features) - - def log_metrics(self, label, pred, metric_type): + self.log_metrics(test_label, test_pred, "test", pred_indicators) + # logging.info(f"Test loss: {self.loss(test_label, test_pred)}, " + # f"uncalibrated {self.loss(test_label,test_pred_uncalibrated)}") + + def predict(self, features, use_calibrated=True): + """Predict using calibrated model if available, otherwise use base model.""" + if hasattr(self, 'calibrated_model') and use_calibrated: + if self.run_mode == RunMode.regression: + return self.calibrated_model.predict(features) + else: + return self.calibrated_model.predict_proba(features) + else: + if self.run_mode == RunMode.regression: + return self.model.predict(features) + else: + return self.model.predict_proba(features) + + + # def predict_with_preserved_ranking(self, features): + # """Get calibrated probabilities while preserving base model ranking.""" + # base_probs = self.model.predict_proba(features)[:, 1] + # cal_probs = self.calibrated_model.predict_proba(features)[:, 1] + # + # # Get ranking from base model + # ranking_indices = np.argsort(base_probs) + # + # # Sort calibrated probabilities to match base model ranking + # sorted_cal_probs = np.sort(cal_probs) + # + # # Create output that preserves base ranking but uses calibrated values + # result = np.zeros_like(cal_probs) + # result[ranking_indices] = sorted_cal_probs + # + # # Convert back to two-column format + # return np.column_stack([1 - result, result]) + + def log_curves(self, label, pred, metric_type, pred_indicators): + for name, metric in self.metrics.items(): + result = metric(self.label_transform(label), self.output_transform(pred)) + if isinstance(result, tuple): + # Vertical stacking for saving to file + # result = tuple(arr.reshape(-1, 1) for arr in result) + log_single_metric_to_file( + metric_name=name, + data_points=result, + output_file=Path(self.logger.save_dir) / f"{metric_type}_metrics_{name}.csv", + ) + + def log_metrics(self, label, pred, metric_type, pred_indicators): """Log metrics to the PL logs.""" - if "Confusion_Matrix" in self.metrics: - self.log_dict(confusion_matrix(self.label_transform(label), self.output_transform(pred)), sync_dist=True) - self.log_dict( - { - f"{metric_type}/{name}": (metric(self.label_transform(label), self.output_transform(pred))) - # For every metric - for name, metric in self.metrics.items() - # Filter out metrics that return a tuple (e.g. precision_recall_curve) - if not isinstance(metric(self.label_transform(label), self.output_transform(pred)), tuple) - and name != "Confusion_Matrix" - }, - sync_dist=True, - ) - - def _explain_model(self, test_rep, test_label): - if self.explainer is not None: - self.test_shap_values = self.explainer(test_rep) + if pred_indicators is None: + if "Confusion_Matrix" in self.metrics: + self.log_dict(confusion_matrix(self.label_transform(label), self.output_transform(pred)), sync_dist=True) + self.log_dict( + { + f"{metric_type}/{name}": (metric(self.label_transform(label), self.output_transform(pred))) + # For every metric + for name, metric in self.metrics.items() + # Filter out metrics that return a tuple (e.g. precision_recall_curve) + if not isinstance(metric(self.label_transform(label), self.output_transform(pred)), tuple) + and name != "Confusion_Matrix" + }, + sync_dist=True, + ) else: - logging.warning("No explainer or explain_features values set.") + if ( + len(pred_indicators.shape) > 1 + and len(pred.shape) > 1 + and pred_indicators.shape[1] == pred.shape[1] + and pred_indicators.shape[0] == pred.shape[0] + ): + pred_indicators = np.hstack((pred_indicators, label.reshape(-1, 1))) + pred_indicators = np.hstack((pred_indicators, pred)) + # Format: id, time (hours), ground truth, prediction 0, prediction 1 + + # TODO: Implement alarm metrics using row indicators + if "Confusion_Matrix" in self.metrics: + self.log_dict(confusion_matrix(self.label_transform(label), self.output_transform(pred)), sync_dist=True) + self.log_dict( + { + f"{metric_type}/{name}": (metric(self.label_transform(label), self.output_transform(pred))) + # For every metric + for name, metric in self.metrics.items() + # Filter out metrics that return a tuple (e.g. precision_recall_curve) + if not isinstance(metric(self.label_transform(label), self.output_transform(pred)), tuple) + and name != "Confusion_Matrix" + }, + sync_dist=True, + ) + + - def _save_model_outputs(self, pred_indicators, test_pred, test_label): - if len(pred_indicators.shape) > 1 and len(test_pred.shape) > 1 and pred_indicators.shape[1] == test_pred.shape[1]: - pred_indicators = np.hstack((pred_indicators, test_label.reshape(-1, 1))) - pred_indicators = np.hstack((pred_indicators, test_pred)) - # Save as: id, time (hours), ground truth, prediction 0, prediction 1 - np.savetxt(Path(self.logger.save_dir) / "pred_indicators.csv", pred_indicators, delimiter=",") - logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / 'row_indicators.csv'}") - else: - logging.warning("Could not save row indicators.") def configure_optimizers(self): return None diff --git a/icu_benchmarks/run.py b/icu_benchmarks/run.py index 1c5479a0..3a129d8f 100644 --- a/icu_benchmarks/run.py +++ b/icu_benchmarks/run.py @@ -18,8 +18,9 @@ setup_logging, import_preprocessor, name_datasets, - get_config_files, + get_config_files, append_predictions_foldwise, ) +from icu_benchmarks.utils import parse_dict from icu_benchmarks.constants import RunMode @@ -51,6 +52,8 @@ def main(my_args=tuple(sys.argv[1:])): experiment = args.experiment source_dir = args.source_dir modalities = args.modalities + load_data_vars = args.load_data_vars + if modalities: logging.debug(f"Binding modalities: {modalities}") gin.bind_parameter("preprocess.selected_modalities", modalities) @@ -64,6 +67,16 @@ def main(my_args=tuple(sys.argv[1:])): f"Model: {model} {'not ' if model not in models else ''}found." ) # Load task config + if load_data_vars: + logging.info(f"Loading variables from {task} from {data_dir} configuration") + if (data_dir / "vars.gin").exists(): + # Open the task config file in append mode and add a line + with open(f"configs/tasks/{task}.gin", "a") as config_file: + config_file.write("\n# Added automatically by run.py\n") + config_file.write(f"# Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + config_file.write(f'include "{data_dir}/vars.gin"\n') + else: + logging.warning(f"No vars.gin file found in {data_dir}. Please ensure the file exists.") gin.parse_config_file(f"configs/tasks/{task}.gin") mode = get_mode() @@ -139,9 +152,33 @@ def main(my_args=tuple(sys.argv[1:])): if args.experiment else [model_path, Path(f"configs/tasks/{task}.gin")] ) + gin.parse_config_files_and_bindings(gin_config_files, args.hyperparams, finalize_config=False) log_full_line(f"Data directory: {data_dir.resolve()}", level=logging.INFO) run_dir = create_run_dir(log_dir) + + # manually bind dataset files + if args.file_names: + logging.info(f"Attempting to bind dataset files: {args.file_names}, type: {type(args.file_names)}") + if isinstance(args.file_names, dict): + logging.info(f"Will load data from {args.file_names}") + gin.bind_parameter("preprocess.file_names", args.file_names) + elif isinstance(args.file_names, str): + file_names = parse_dict(args.file_names) + logging.info(f"Will load data from {args.file_names}") + gin.bind_parameter("preprocess.file_names", file_names) + else: + raise ValueError( + f"Please provide a dictionary type for the file names, got {args.file_names}, " + f"type: {type(args.file_names)}" + ) + + update_wandb_config({ + "data_dir": data_dir.resolve(), + "task": task, + "run_dir": run_dir.resolve(), + }) + choose_and_bind_hyperparameters_optuna( do_tune=args.tune, data_dir=data_dir, @@ -184,6 +221,7 @@ def main(my_args=tuple(sys.argv[1:])): cpu=args.cpu, wandb=args.wandb_sweep, complete_train=args.complete_train, + explain_features=args.explain_features, ) log_full_line("FINISHED TRAINING", level=logging.INFO, char="=", num_newlines=3) @@ -191,6 +229,8 @@ def main(my_args=tuple(sys.argv[1:])): log_full_line(f"DURATION: {execution_time}", level=logging.INFO, char="") try: aggregate_results(run_dir, execution_time) + append_predictions_foldwise(run_dir, "pred_indicators.csv") + except Exception as e: logging.error(f"Failed to aggregate results: {e}") logging.debug("Error details:", exc_info=True) diff --git a/icu_benchmarks/run_utils.py b/icu_benchmarks/run_utils.py index ba2e9bd6..58019be5 100644 --- a/icu_benchmarks/run_utils.py +++ b/icu_benchmarks/run_utils.py @@ -1,8 +1,8 @@ import importlib +import math import sys import warnings from math import sqrt - import gin import torch import json @@ -14,8 +14,11 @@ import shutil from statistics import mean, pstdev from icu_benchmarks.models.utils import JsonResultLoggingEncoder -from icu_benchmarks.wandb_utils import wandb_log import polars as pl +import random + +from .utils import parse_dict +from .wandb_utils import wandb_log def build_parser() -> ArgumentParser: @@ -60,6 +63,26 @@ def build_parser() -> ArgumentParser: help="Optional modality selection to use. Specify multiple modalities separated by spaces.", ) parser.add_argument("--label", type=str, help="Label to use for evaluation in case of multiple labels.", default=None) + parser.add_argument( + "--file_names", + type=parse_dict, + help="Dictionary of file names to use in data_dir " + "(e.g., 'DYNAMIC:dyno.parquet,OUTCOME:outco.parquet,STATIC:sta.parquet')", + default=None, + ) + parser.add_argument("--explain_features", default=False, action=BOA, help="Enable feature explanation.") + parser.add_argument( + "--load_data_vars", + default=False, + action=BOA, + help="Load data variables from the dataset directory. Avoids having to manually add the path in the task.gin", + ) + parser.add_argument( + "--reduce_stay_steps", + default=None, + type=int, + help="Reduce the stay length dynamically", + ) return parser @@ -76,9 +99,9 @@ def create_run_dir(log_dir: Path, randomly_searched_params: str = None) -> Path: Returns: Path to the created run log directory. """ - log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) + log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f")) while log_dir_run.exists(): - log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f")) + log_dir_run = log_dir_run.with_name(log_dir_run.name + random.randint(1, 10)) log_dir_run.mkdir(parents=True) if randomly_searched_params: (log_dir_run / randomly_searched_params).touch() @@ -98,7 +121,71 @@ def import_preprocessor(preprocessor_path: str): logging.error(f"Could not import custom preprocessor from {preprocessor_path}: {e}") -def aggregate_results(log_dir: Path, execution_time: timedelta = None): +def append_predictions_foldwise(directory: str, filename: str, max_id: int = 1300) -> pl.DataFrame: + """ + Load all prediction CSV files with the specified filename in the given directory and its subdirectories, + and append them vertically into a single Polars DataFrame. + Files are processed fold by fold across all iterations. + + Parameters: + directory (str): The root directory to search for CSV files. + filename (str): The specific filename to look for. + max_id (int): The maximum ID value to offset the IDs in each fold to prevent clashes. + + Returns: + pl.DataFrame: A single DataFrame containing all the appended CSV files. + + Example usage: + directory = 'home' + filename = 'pred_indicators.csv' + combined_df = load_and_append_csv_files(directory, filename) + """ + dataframes = [] + id_column = "# id" + counter = 0 + + # Get all iteration directories sorted + iterations = sorted([d for d in Path(directory).iterdir() if d.is_dir()]) + + # Get all unique fold names across all iterations + all_folds = set() + for iteration in iterations: + for fold_dir in iteration.iterdir(): + if fold_dir.is_dir(): + all_folds.add(fold_dir.name) + + # Process fold by fold across all iterations + for fold_name in sorted(all_folds): + for iteration in iterations: + fold_path = iteration / fold_name + if fold_path.exists(): + for file_path in sorted(fold_path.rglob(filename)): + print(f"Loading file: {file_path}") + + # Load the CSV file + df = pl.read_csv(file_path) + + # Check original ID range + original_ids = df.select(pl.col(id_column)).to_numpy().flatten() + print(f" Original ID range: {original_ids.min()} - {original_ids.max()}") + + # Apply offset to prevent clashes + df = df.with_columns(pl.col(id_column) + counter * max_id) + + # Check modified ID range + modified_ids = df.select(pl.col(id_column)).to_numpy().flatten() + print(f" Modified ID range: {modified_ids.min()} - {modified_ids.max()}") + print(f" Counter: {counter}") + print() + dataframes.append(df) + counter += 1 + + # Concatenate all DataFrames vertically + combined_df = pl.concat(dataframes, how="vertical") + + return combined_df + +def aggregate_results(log_dir: Path, execution_time: timedelta = None, explain_features: bool = False): """Aggregates results from all folds and writes to JSON file. Args: @@ -106,7 +193,7 @@ def aggregate_results(log_dir: Path, execution_time: timedelta = None): execution_time: Overall execution time. """ aggregated = {} - shap_values_test = [] + explainer_values_test = [] for repetition in log_dir.iterdir(): if repetition.is_dir(): aggregated[repetition.name] = {} @@ -125,24 +212,27 @@ def aggregate_results(log_dir: Path, execution_time: timedelta = None): with open(fold_iter / "durations.json", "r") as f: result = json.load(f) aggregated[repetition.name][fold_iter.name].update(result) - if (fold_iter / "test_shap_values.parquet").is_file(): - shap_values_test.append(pl.read_parquet(fold_iter / "test_shap_values.parquet")) - - if shap_values_test: - shap_values = pl.concat(shap_values_test) - shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet") + if (fold_iter / "explainer_values_test.parquet").is_file(): + explainer_values_test.append(pl.read_parquet(fold_iter / "explainer_values_test.parquet")) - try: - shap_values = pl.concat(shap_values_test) - shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet") - except Exception as e: - logging.error(f"Error aggregating or writing SHAP values: {e}") + if explain_features: + if explainer_values_test: + shap_values = pl.concat(explainer_values_test) + shap_values.write_parquet(log_dir / "aggregated_explainer_values.parquet") + try: + shap_values = pl.concat(explainer_values_test) + shap_values.write_parquet(log_dir / "aggregated_explainer_values.parquet") + except Exception as e: + logging.error(f"Error aggregating or writing SHAP values: {e}") # Aggregate results per metric list_scores = {} for repetition, folds in aggregated.items(): for fold, result in folds.items(): for metric, score in result.items(): if isinstance(score, (float, int)): + if math.isnan(score): + logging.warning(f"Score for metric {metric} is NaN, adding 0 instead.") + score = 0 list_scores[metric] = list_scores.setdefault(metric, []) list_scores[metric].append(score) diff --git a/icu_benchmarks/tuning/hyperparameters.py b/icu_benchmarks/tuning/hyperparameters.py index cc169710..30b220b1 100644 --- a/icu_benchmarks/tuning/hyperparameters.py +++ b/icu_benchmarks/tuning/hyperparameters.py @@ -1,4 +1,6 @@ import json +import shutil + import gin import logging from logging import NOTSET @@ -83,7 +85,7 @@ def choose_and_bind_hyperparameters_scikit_optimize( configuration, evaluation = None, None if checkpoint: checkpoint_path = checkpoint / checkpoint_file - if not checkpoint_path.exists(): + if not checkpoint_path.isfile(): logging.warning(f"Hyperparameter checkpoint {checkpoint_path} does not exist.") logging.info("Attempting to find latest checkpoint file.") checkpoint_path = find_checkpoint(log_dir.parent, checkpoint_file) @@ -188,6 +190,7 @@ def choose_and_bind_hyperparameters_optuna( n_calls: int = 20, sampler=optuna.samplers.GPSampler, folds_to_tune_on: int = None, + repetitions_to_tune_on: int = 1, checkpoint_file: str = "hyperparameter_tuning_logs.db", generate_cache: bool = False, load_cache: bool = False, @@ -199,6 +202,7 @@ def choose_and_bind_hyperparameters_optuna( """Choose hyperparameters to tune and bind them to gin. Uses Optuna for hyperparameter optimization. Args: + repetitions_to_tune_on: Repetitions to tune on. If None, 1 repetitions are trained on. plot: Whether to plot hyperparameter importances. sampler: The sampler to use for hyperparameter optimization. wandb: Whether we use wandb or not. @@ -222,6 +226,9 @@ def choose_and_bind_hyperparameters_optuna( ValueError: If checkpoint is not None and the checkpoint does not exist. """ hyperparams = {} + if n_calls <= 0: + logging.info(f"Initialized with n_calls: {n_calls} , skipping tuning.") + return if len(scopes) == 0 or folds_to_tune_on is None: logging.warning("No scopes and/or folds to tune on, skipping tuning.") @@ -306,7 +313,7 @@ def bind_params_and_train(hyperparams): Path(temp_dir), seed, mode=run_mode, - cv_repetitions_to_train=1, + cv_repetitions_to_train=repetitions_to_tune_on, cv_folds_to_train=folds_to_tune_on, generate_cache=generate_cache, load_cache=load_cache, @@ -314,6 +321,7 @@ def bind_params_and_train(hyperparams): debug=debug, verbose=verbose, wandb=wandb, + explain_features=False, ) logging.info(f"Score: {score}") return score @@ -326,12 +334,17 @@ def bind_params_and_train(hyperparams): # Optuna study # Attempt checkpoint loading if checkpoint and checkpoint.exists(): - logging.warning(f"Hyperparameter checkpoint {checkpoint} does not exist.") + # logging.warning(f"Hyperparameter checkpoint {checkpoint} does not exist.") # logging.info("Attempting to find latest checkpoint file.") # checkpoint_path = find_checkpoint(log_dir.parent, checkpoint_file) - # Check if we found a checkpoint file - logging.info(f"Loading checkpoint at {checkpoint}") - study = optuna.load_study(study_name="tuning", storage="sqlite:///" + str(checkpoint), sampler=sampler, pruner=pruner) + # Check if we found a checkpoint file and copy it. + logging.info(f"Copying checkpoint and loading checkpoint at {checkpoint}") + local_path = log_dir / checkpoint_file + if not str(local_path).endswith(".db"): + local_path = local_path / "hyperparameter_tuning_logs.db" + logging.warning(f"Checkpoint file {checkpoint_file} does not end with .db, trying {local_path} instead.") + shutil.copy(str(checkpoint), local_path) + study = optuna.load_study(study_name="tuning", storage="sqlite:///" + str(local_path), sampler=sampler, pruner=pruner) n_calls = n_calls - len(study.trials) else: if checkpoint: diff --git a/icu_benchmarks/utils.py b/icu_benchmarks/utils.py new file mode 100644 index 00000000..60148d59 --- /dev/null +++ b/icu_benchmarks/utils.py @@ -0,0 +1,22 @@ +import argparse +import json + + +def parse_dict(arg): + """ + Parses a string into a dictionary. Handles both: + - Unquoted format: 'key1:value1,key2:value2' + - JSON-like quoted format: '"key1":"value1","key2":"value2"' + """ + try: + # Check if the input is in JSON-like format + if ":" in arg and '"' in arg: + # Wrap in curly braces to make it valid JSON + json_string = f"{{{arg}}}" + return json.loads(json_string) + else: + # Handle unquoted format + pairs = arg.split(",") + return {key.strip(): value.strip() for key, value in (pair.split(":", 1) for pair in pairs)} + except Exception as e: + raise argparse.ArgumentTypeError(f"Invalid dictionary format: {e}") diff --git a/icu_benchmarks/wandb_utils.py b/icu_benchmarks/wandb_utils.py index 2ea06b57..1e9559a7 100644 --- a/icu_benchmarks/wandb_utils.py +++ b/icu_benchmarks/wandb_utils.py @@ -4,6 +4,8 @@ import wandb +from .utils import parse_dict + def wandb_running() -> bool: """Check if wandb is running.""" @@ -30,7 +32,8 @@ def apply_wandb_sweep(args: Namespace) -> Namespace: Returns: Namespace: arguments with sweep configuration applied (some are applied via hyperparams) """ - wandb.init(allow_val_change=True, dir=args.log_dir) + wandb.init(allow_val_change=True, dir=args.log_dir, config={"allow_val_change": True}) + wandb.config.allow_val_change = True sweep_config = wandb.config args.__dict__.update(sweep_config) if args.hyperparams is None: @@ -72,6 +75,9 @@ def set_wandb_experiment_name(args, mode): run_name += f"_train_size_{args.samples}_samples" elif args.complete_train: run_name += "_complete_training" + elif args.file_names: + file_names = parse_dict(args.file_names) + run_name += f"_outcome_{file_names['OUTCOME'].removesuffix('.parquet')}" if wandb_running(): wandb.config.update({"run-name": run_name}) diff --git a/requirements.txt b/requirements.txt index 88ecc9af..35afd8d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ coverage==7.2.3 flake8>=7.0.0 matplotlib==3.7.1 gin-config==0.5.0 -pytorch-ignite==0.5.0.post2 +pytorch-ignite==0.5.1 # Note: versioning of Pytorch might be dependent on compatible CUDA version. # Please check yourself if your Pytorch installation supports cuda (for gpu acceleration) torch==2.6.0+cu118 @@ -24,12 +24,13 @@ tensorboard==2.12.2 tqdm==4.66.3 einops==0.6.1 hydra-core==1.3 -optuna==4.0.0 +optuna==4.1.0 optuna-integration==4.0.0 -wandb==0.17.3 +wandb==0.18.5 recipies==1.0 #Fixed version because of NumPy incompatibility and stale development status. scikit-optimize-fix==0.9.1 hydra-submitit-launcher==1.2.0 pytest-runner==6.0.1 +shap==0.46.0 diff --git a/setup.py b/setup.py index 9bcb9890..48d6e0b2 100644 --- a/setup.py +++ b/setup.py @@ -77,6 +77,6 @@ def parse_environment_yml(): test_suite="tests", tests_require=[], url="https://github.com/rvandewater/YAIB", - version="0.3.1", + version="1.1.0", zip_safe=False, )