Skip to content

Conversation

@rvandewater
Copy link
Owner

@rvandewater rvandewater commented Feb 27, 2025

Summary by CodeRabbit

  • New Features

    • HPC benchmarking configs & SLURM scripts; GPU XGBoost option; alarm-style prediction conversion; CLI dict-style dataset/file parsing; toggleable feature-explanation with explainer-value saving; per-sample prediction indicators export; calibration APIs for model probabilities.
  • Enhancements

    • Expanded hyperparameter search spaces and tuning repetitions; improved preprocessing (modality selection, infinite-value handling, row-indicators); new classification metrics (sensitivity, specificity, PPV, incidence); aggregation of fold-wise predictions.
  • Documentation

    • Added full Sphinx docs, API reference, and how-to guides.
  • Chores

    • Dependency/version bumps, package version update, removed CI workflow.

@coderabbitai
Copy link

coderabbitai bot commented Feb 27, 2025

Walkthrough

Expanded hyperparameter grids and GPU configs; added explainability export/persistence, alarm-metrics, modality selection and infinite-value sanitation, per-stay row-indicators, calibration & CSV exports, CLI/W&B/tuning enhancements, many experiment and SLURM scripts, Sphinx docs, dependency bumps, and package version update.

Changes

Cohort / File(s) Summary of changes
Configs: Model hyperparameters
configs/prediction_models/XGBClassifier.gin, configs/prediction_models/LGBMClassifier.gin, configs/prediction_models/XGBClassifierGPU.gin, configs/prediction_models/common/MLTuning.gin
Expanded XGB/LGBM hyperparameter grids, added GPU XGB config, increased tuning calls (n_calls) and added repetitions_to_tune_on, adjusted folds defaults.
Configs: Tasks & dataloader
configs/tasks/BinaryClassification.gin, configs/tasks/common/Dataloader.gin
Declared required segments and file_names for BinaryClassification; enabled PredictionPolarsDataset.mps = True and ImputationPandasDataset.ram_cache = True.
Experiments: YAMLs
experiments/*.yml (multiple new files)
Added many benchmark experiment configs (datasets, models, modalities, horizons, explainability, tuning checkpoints).
Experiments: SLURM scripts
experiments/slurm_base_char_sc.sh, experiments/charhpc_wandb_sweep_cpu.sh, experiments/charhpc_wandb_sweep_gpu.sh
New SLURM helper scripts for CPU/GPU WandB sweeps and base submissions (conda init/activate, wandb agent invocation, resource directives).
Data pipeline: utils & sanitation
icu_benchmarks/data/utils.py, icu_benchmarks/data/split_process_data.py
New utilities: infinite_removal, check_sanitize_data, modality_selection; preprocess_data gains exclude_preproc, returns (data, vars); added modality filtering, infinite-value handling, stay-reduction utility, and enriched logging.
Data pipeline: loader & preprocessor
icu_benchmarks/data/loader.py, icu_benchmarks/data/preprocessor.py
Conditional SEQUENCE validation; reworked row-indicator handling; PredictionPolarsDataset now returns per-stay row_inds; added Preprocessor.vars_selection and selector-based scaling.
Model wrappers & training flow
icu_benchmarks/models/wrappers.py, icu_benchmarks/models/train.py, icu_benchmarks/models/utils.py
Explain-features API (set_explain_features), calibration support, persist explainer values (averaged/full), per-sample CSV exports and metric-curve logging; updated train_common and persist_shap_data signatures.
Models: XGBoost / LGBM / imbalanced
icu_benchmarks/models/ml_models/xgboost.py, .../lgbm.py, .../imblearn.py
XGBoost: unified SHAP explainer, added XGBClassifierGPU using xgboost native API; LGBM: suppressed logs, added SHAP explainer; imblearn: replaced RUSBoost with EasyEnsemble wrapping an XGB estimator and added fit_model.
Metrics & alarm utilities
icu_benchmarks/models/constants.py, .../custom_metrics.py, .../alarm_metrics.py
Added sensitivity, specificity, PPV, binary incidence metrics; refactored MAE compute; new alarm utilities (convert_to_alarm, silence_positives, fill_gaps).
Run/CLI, aggregation & W&B
icu_benchmarks/run.py, icu_benchmarks/run_utils.py, icu_benchmarks/wandb_utils.py, icu_benchmarks/utils.py, icu_benchmarks/cross_validation.py, icu_benchmarks/tuning/hyperparameters.py
New CLI args (--file_names, --explain_features, --load_data_vars); parse_dict added; optional include of data vars.gin; execute_repeated_cv gains explain_features and validation; aggregation handles explainer values; Optuna tuning adds repetitions_to_tune_on, early-exit for n_calls<=0, and copies checkpoints locally before loading; append_predictions_foldwise added.
Model metrics persistence helper
icu_benchmarks/models/utils.py
Added log_single_metric_to_file to persist tuple-array metrics to CSV; capitalized a logger message.
Dependencies & versioning
requirements.txt, setup.py
Bumped pytorch-ignite, optuna, wandb; added shap==0.46.0; package version bumped to 1.1.0.
Documentation scaffold
docs/*
Added Sphinx docs: Makefiles, requirements, conf, API and multiple guides (adding models, imputation, development notes).
CI
.github/workflows/ci.yml
Removed Ruff CI workflow (file deleted).

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant CLI as User/CLI
  participant Parser as build_parser
  participant Run as run.py
  participant CV as execute_repeated_cv
  participant Trainer as train_common
  participant Wrapper as model wrappers
  participant FS as Filesystem
  CLI->>Parser: parse args (--file_names, --explain_features, --load_data_vars)
  CLI->>Run: start training
  Run->>Run: optionally include data vars.gin (if --load_data_vars)
  Run->>Run: bind preprocess.file_names (if provided)
  Run->>CV: execute_repeated_cv(..., explain_features)
  CV->>Trainer: train_common(..., explain_features)
  Trainer->>Wrapper: set_explain_features(flag)
  loop per fold/repetition
    Wrapper->>Wrapper: train/test, compute metrics (calibrated/uncalibrated)
    alt explain_features enabled
      Wrapper->>FS: save explainer_values_train/test.parquet (averaged and full)
    end
    Wrapper->>FS: save predictions CSV and pred_indicators.csv
  end
  Run->>Parser: aggregate_results(explain_features)
  Parser->>FS: read/write aggregated metrics and explainer values
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Poem

I hop through configs with a tiny clap,
I stitch new metrics and a SHAP-filled map.
GPUs hum, CSVs stack in a nap,
Data cleaned, alarms set — I give a tap-tap.
YAIB grows wider; the rabbit takes a lap. 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 70.59% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title Check ❓ Inconclusive The title "Improvements from Cass project" is vague and uses non-descriptive terminology that fails to convey meaningful information about the changeset. While the PR does contain multiple changes related to a "Cass project" (evidenced by numerous benchmark_cass*.yml files and related configurations), the term "Improvements" is too generic and broad. A teammate scanning the commit history would not understand what types of changes were made—whether they involve new models, features, refactoring, bug fixes, performance optimizations, or something else. The title lacks the specificity needed to make the primary purpose of this substantial changeset immediately clear. Consider revising the title to be more specific and descriptive. For example, you could highlight the main category of changes such as: "Add Cass classification benchmarks and model improvements", "Integrate Cass project enhancements: calibration, metrics, and data processing", or "Add GPU support and calibration features from Cass project". This would help reviewers and future readers quickly understand the scope and intent of the PR.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch cass

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 17

🔭 Outside diff range comments (2)
icu_benchmarks/models/custom_metrics.py (1)

66-74: ⚠️ Potential issue

Fix reference to moved function mae_with_invert_compute_fn.

The pipeline is failing because the function mae_with_invert_compute_fn has been moved outside the class, but it's still being referenced in the constructor on line 66. This function should either be moved back inside the class as a method, or it needs to be made accessible at the module level.

def __init__(
    self,
    output_transform: Callable = lambda x: x,
    check_compute_fn: bool = False,
    invert_transform: Callable = lambda x: x,
) -> None:
    super(MAE, self).__init__(
-       lambda x, y: mae_with_invert_compute_fn(x, y, invert_transform),
+       lambda x, y: mae_with_invert_compute_fn(x, y, invert_transform),
        output_transform=output_transform,
        check_compute_fn=check_compute_fn,
    )

-def mae_with_invert_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor, invert_fn=Callable) -> float:
+def mae_with_invert_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor, invert_fn=Callable) -> float:
    y_true = invert_fn(y_targets.numpy().reshape(-1, 1))[:, 0]
    y_pred = invert_fn(y_preds.numpy().reshape(-1, 1))[:, 0]
    return mean_absolute_error(y_true, y_pred)

Alternatively, move the function back inside the class:

def __init__(
    self,
    output_transform: Callable = lambda x: x,
    check_compute_fn: bool = False,
    invert_transform: Callable = lambda x: x,
) -> None:
    super(MAE, self).__init__(
-       lambda x, y: mae_with_invert_compute_fn(x, y, invert_transform),
+       lambda x, y: self.mae_with_invert_compute_fn(x, y, invert_transform),
        output_transform=output_transform,
        check_compute_fn=check_compute_fn,
    )

+    def mae_with_invert_compute_fn(self, y_preds: torch.Tensor, y_targets: torch.Tensor, invert_fn=Callable) -> float:
+        y_true = invert_fn(y_targets.numpy().reshape(-1, 1))[:, 0]
+        y_pred = invert_fn(y_preds.numpy().reshape(-1, 1))[:, 0]
+        return mean_absolute_error(y_true, y_pred)

-def mae_with_invert_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor, invert_fn=Callable) -> float:
-   y_true = invert_fn(y_targets.numpy().reshape(-1, 1))[:, 0]
-   y_pred = invert_fn(y_preds.numpy().reshape(-1, 1))[:, 0]
-   return mean_absolute_error(y_true, y_pred)
🧰 Tools
🪛 Ruff (0.8.2)

66-66: Undefined name mae_with_invert_compute_fn

(F821)

🪛 GitHub Actions: CI

[error] 66-66: Flake8: undefined name 'mae_with_invert_compute_fn'.

icu_benchmarks/data/split_process_data.py (1)

315-315: ⚠️ Potential issue

Reconciling train_size parameter type.

train_size is assigned a float (e.g., 0.80) yet is annotated as int in the signature. This mismatch can lead to confusion or TypeErrors. Clarify whether it's intended to be a float ratio or an absolute count.

Also applies to: 322-322

🧹 Nitpick comments (24)
icu_benchmarks/run.py (1)

187-187: Feature explanation enabled by default

Setting explain_features=True enables SHAP-based feature explanation during the cross-validation process, which enhances model interpretability but may increase computational overhead.

Consider making this parameter configurable through command-line arguments rather than hardcoding it to True. This would allow users to disable feature explanation when prioritizing performance or when explanations aren't needed.

- explain_features=True,
+ explain_features=args.explain_features,

Then add this argument to the parser in the build_parser function:

parser.add_argument("--explain-features", action="store_true", default=False, 
                    help="Enable feature explanation during model training")
experiments/charhpc_wandb_sweep.sh (1)

1-14: Add error handling and usage documentation to the SLURM script.

This script provides a good framework for running W&B sweeps on an HPC cluster with GPU resources. However, it lacks error handling and usage documentation that would make it more robust.

Consider adding:

  1. Parameter validation to check if the sweep ID is provided
  2. A usage comment at the top explaining the purpose and how to run the script
  3. Error handling for common failures
 #!/bin/bash
+# Usage: sbatch charhpc_wandb_sweep.sh <sweep_id>
+# Purpose: Run a W&B sweep on the pgpu partition with GPU access
+
 #SBATCH --job-name=yaib_experiment
 #SBATCH --partition=pgpu # -p
 #SBATCH --cpus-per-task=8 # -c
 #SBATCH --mem=200gb
 #SBATCH --output=logs/classification_%a_%j.log # %j is job id
 #SBATCH --gpus=1
 #SBATCH --time=24:00:00

 source /etc/profile.d/conda.sh

 eval "$(conda shell.bash hook)"
 conda activate yaib_req_pl
+
+if [ -z "$1" ]; then
+  echo "Error: No sweep ID provided. Usage: sbatch charhpc_wandb_sweep.sh <sweep_id>"
+  exit 1
+fi
+
 wandb agent --count 1 cassandra_hpi/cassandra/"$1"
+
experiments/charhpc_wandb_sweep_cpu.sh (1)

1-16: Add error handling to CPU-based SLURM script and standardize with GPU version.

This CPU-focused script is well-structured with appropriate resource allocations and includes a helpful debug command. However, it lacks error handling similar to the GPU version.

Consider:

  1. Adding parameter validation and error handling
  2. Including usage documentation
  3. Standardizing the structure with the GPU version for consistency
 #!/bin/bash
+# Usage: sbatch charhpc_wandb_sweep_cpu.sh <sweep_id>
+# Purpose: Run a W&B sweep on the compute partition with CPU resources
+
 #SBATCH --job-name=yaib_experiment
 #SBATCH --partition=compute # -p
 #SBATCH --cpus-per-task=16 # -c
 #SBATCH --mem=250gb
 #SBATCH --output=logs/classification_%a_%j.log # %j is job id
 #SBATCH --time=48:00:00


 source /etc/profile.d/conda.sh

 eval "$(conda shell.bash hook)"
 conda activate yaib_req_pl
+
+if [ -z "$1" ]; then
+  echo "Error: No sweep ID provided. Usage: sbatch charhpc_wandb_sweep_cpu.sh <sweep_id>"
+  exit 1
+fi
+
 wandb agent --count 1 cassandra_hpi/cassandra/"$1"

 # Debug instance: srun -p gpu --pty -t 5:00:00 --gres=gpu:1 --cpus-per-task=16 --mem=100GB bash
icu_benchmarks/models/train.py (1)

216-216: Remove unnecessary f-string prefix.

The string doesn't contain any placeholders, so the f prefix is redundant.

-        logging.info(f"Saving SHAPS")
+        logging.info("Saving SHAPS")
🧰 Tools
🪛 Ruff (0.8.2)

216-216: f-string without any placeholders

Remove extraneous f prefix

(F541)

icu_benchmarks/models/ml_models/xgboost.py (3)

6-6: Remove unused import.

The numpy package is imported but not used in the code.

-import numpy as np
🧰 Tools
🪛 Ruff (0.8.2)

6-6: numpy imported but unused

Remove unused import: numpy

(F401)


51-56: Consider making explain_features configurable.

The feature explanation is currently hardcoded to True. Consider making this configurable via a parameter to the constructor or through a gin configuration option.

- self.explain_features = True
+ self.explain_features = kwargs.get('explain_features', True)

83-86: Clean up commented-out code.

This commented-out method appears to be obsolete since it's been replaced by _explain_model. Consider removing it if it's no longer needed.

-# def explainer(self, reps):
-#     if not hasattr(self.model, "feature_importances_"):
-#         raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.")
-#     return self.model.feature_importances_
icu_benchmarks/run_utils.py (1)

81-81: Explicitly convert random integer to string.

Adding a random integer to the path name requires conversion to string, which is implicit in this operation. For clarity, consider making this conversion explicit.

- log_dir_run = log_dir_run.with_name(log_dir_run.name + random.randint(1, 10))
+ log_dir_run = log_dir_run.with_name(log_dir_run.name + str(random.randint(1, 10)))
icu_benchmarks/data/preprocessor.py (3)

185-185: Remove unused variable old_columns.

The variable old_columns is assigned but never used in the code.

-   old_columns = data[Split.train][Segment.dynamic].columns
🧰 Tools
🪛 Ruff (0.8.2)

185-185: Local variable old_columns is assigned to but never used

Remove assignment to unused variable old_columns

(F841)


193-195: Consider uncommenting logging statements.

These commented-out logging statements could provide useful debug information. Consider either fully removing them or uncommenting them if they're helpful for debugging.

- # logging.info(f"Excluding {len(self.vars_to_exclude)} : {self.vars_to_exclude}")
+ logging.debug(f"Excluding {len(self.vars_to_exclude)} : {self.vars_to_exclude}")
- # logging.info(f"Applying to {len(vars_to_apply)}")
+ logging.debug(f"Applying to {len(vars_to_apply)}")

206-206: Consider uncommenting or removing logging statement.

This commented-out logging statement regarding column changes could be useful for debugging. Consider either fully removing it or uncommenting it at a debug level.

- # logging.info(f"Data columns: {len(data[Split.train][Segment.dynamic].columns)} -> old columns: {len(old_columns)}, added columns: {set(data[Split.train][Segment.dynamic].columns) - set(old_columns)}")
+ logging.debug(f"Data columns: {len(data[Split.train][Segment.dynamic].columns)}, added columns: {set(data[Split.train][Segment.dynamic].columns) - set(old_columns)}")
icu_benchmarks/models/custom_metrics.py (2)

1-1: Remove unused import.

The numpy package is imported but not used in the code.

-import numpy
🧰 Tools
🪛 Ruff (0.8.2)

1-1: numpy imported but unused

Remove unused import: numpy

(F401)


191-209: Consider removing commented-out code.

This large block of commented-out code appears to be an alternative implementation of the Specificity class. If it's no longer needed, consider removing it to improve code readability. If it's being kept for future reference, add a comment explaining why.

icu_benchmarks/models/alarm_metrics.py (2)

13-29: Question the index-based slicing usage.

Inside convert_to_alarm, it's assumed that silence_positives and fill_gaps operate exclusively on integer-indexed arrays. If you intend to handle time-indexed data, consider a more robust approach (e.g., aligning predictions by timestamps).


55-72: Clarify the purpose of filling gaps.

  1. The function fill_gaps currently sets all subsequent values to 1 once it encounters a single 1. Ensure this behavior aligns with your domain needs.
  2. The function modifies both predictions and ground_truth, which can lead to confusion. Consider returning a separate updated copy or clarifying in docstrings.
icu_benchmarks/data/split_process_data.py (2)

138-138: Confirm data sanitization flow.

check_sanitize_data potentially mutates both data and vars. Ensure all necessary post-processing steps occur after this call so that your data transformations remain consistent throughout.


195-207: Ensure consistent handling of infinite values.

  1. max_float64 = 0 might not be representative of true maximum values. Consider retrieving np.finfo(np.float64).max or another large sentinel.
  2. The loop only replaces infinite values in columns with pl.Float64 dtype; ensure columns of other numeric types are also handled if necessary.
- max_float64 = 0
+ import numpy as np
+ max_float64 = np.finfo(np.float64).max
experiments/benchmark_cass.yml (1)

1-84: Add a newline at end of file.

YAML best practices include ending files with a newline to avoid lint warnings and ensure consistent merging in version control.

Apply this diff:

 No changes to the content
+ 
🧰 Tools
🪛 YAMLlint (1.35.1)

[error] 84-84: no new line character at the end of file

(new-line-at-end-of-file)

icu_benchmarks/data/utils.py (3)

9-20: Robust handling of infinite values.

The infinite_removal function effectively identifies and replaces infinite values in numeric columns with a safe maximum value. This prevents downstream issues that could occur with infinity values in models and statistical operations.

Consider using a more descriptive name for the parameter (valdf) and adding return type hints:

-def infinite_removal(val):
+def infinite_removal(df: pl.DataFrame) -> pl.DataFrame:
     for col in val.select(cs.numeric()).columns:
         if val[col].is_infinite().any():
             logging.info(f"Column '{col}' contains infinite values. Datatype: {val[col].dtype}")

23-48: Comprehensive data sanitization function.

The check_sanitize_data function systematically removes duplicates from different data segments while preserving order, which is critical for time series data. It also handles infinite values and maintains the variable list integrity.

The function uses .keys() repeatedly, which static analysis flags as inefficient. Simplify these checks:

-    group = vars[Var.group] if Var.group in vars.keys() else None
-    sequence = vars[Var.sequence] if Var.sequence in vars.keys() else None
+    group = vars[Var.group] if Var.group in vars else None
+    sequence = vars[Var.sequence] if Var.sequence in vars else None

Also apply similar changes to other instances of in dict.keys().

🧰 Tools
🪛 Ruff (0.8.2)

25-25: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


26-26: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


29-29: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


33-33: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


40-40: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


51-77: Well-structured modality selection implementation.

The modality_selection function provides a flexible way to filter data based on modalities, with proper validation, error handling, and detailed logging. It updates both the data dictionary and variable references consistently.

Consider adding type hints for vars parameter and return values, and use dictionary comprehensions for cleaner code:

def modality_selection(
-    data: dict[pl.DataFrame], modality_mapping: dict[str], selected_modalities: list[str], vars
+    data: dict[pl.DataFrame], modality_mapping: dict[str], selected_modalities: list[str], vars: dict
) -> tuple[dict[pl.DataFrame], dict]:

Also, replace the for loop in lines 73-76 with a dictionary comprehension:

data = {
    key: data[key].select([col for col in data[key].columns if col in selected_columns])
    for key in data
}
🧰 Tools
🪛 Ruff (0.8.2)

55-55: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


56-56: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


73-73: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

icu_benchmarks/models/wrappers.py (3)

6-6: Remove unused import.

The ignite.metrics.EpochMetric is imported but never used in the code.

-from ignite.metrics import EpochMetric
🧰 Tools
🪛 Ruff (0.8.2)

6-6: ignite.metrics.EpochMetric imported but unused

Remove unused import: ignite.metrics.EpochMetric

(F401)


482-519: Extended log_metrics to support row indicators.

The log_metrics method now accepts a pred_indicators parameter and handles different cases based on its presence and shape. This enables more detailed analysis of predictions at specific time points.

There is significant code duplication between the two branches (with and without pred_indicators). Consider refactoring:

def log_metrics(self, label, pred, metric_type, pred_indicators=None):
    """Log metrics to the PL logs."""
    if "Confusion_Matrix" in self.metrics:
        self.log_dict(confusion_matrix(self.label_transform(label), self.output_transform(pred)), sync_dist=True)
        
    self.log_dict(
        {
            f"{metric_type}/{name}": (metric(self.label_transform(label), self.output_transform(pred)))
            # For every metric
            for name, metric in self.metrics.items()
            # Filter out metrics that return a tuple (e.g. precision_recall_curve)
            if not isinstance(metric(self.label_transform(label), self.output_transform(pred)), tuple)
            and name != "Confusion_Matrix"
        },
        sync_dist=True,
    )
    
    # Handle pred_indicators if present
    if pred_indicators is not None:
        if (len(pred_indicators.shape) > 1 and len(pred.shape) > 1 and pred_indicators.shape[1] == pred.shape[1]
                and pred_indicators.shape[0] == pred.shape[0]):
            pred_indicators = np.hstack((pred_indicators, label.reshape(-1, 1)))
            pred_indicators = np.hstack((pred_indicators, pred))
            # TODO: Implement alarm metrics using row indicators

527-567: Comprehensive model output saving with improved formatting.

The _save_model_outputs method now handles different shapes of prediction indicators and includes proper headers in the saved CSV files. This significantly improves the usability of saved outputs for downstream analysis.

Remove unnecessary f-string prefixes from string literals that don't contain placeholders:

-            logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / f'row_indicators.csv'}")
+            logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / 'row_indicators.csv'}")

Apply similar fixes to lines 554 and 563.

🧰 Tools
🪛 Ruff (0.8.2)

538-538: f-string without any placeholders

Remove extraneous f prefix

(F541)


554-554: f-string without any placeholders

Remove extraneous f prefix

(F541)


563-563: f-string without any placeholders

Remove extraneous f prefix

(F541)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0bd9791 and 0afa916.

⛔ Files ignored due to path filters (1)
  • demo_data/mortality24/mimic_demo_static/attrition.csv is excluded by !**/*.csv
📒 Files selected for processing (23)
  • configs/prediction_models/XGBClassifier.gin (1 hunks)
  • configs/prediction_models/common/MLTuning.gin (1 hunks)
  • configs/tasks/common/Dataloader.gin (1 hunks)
  • experiments/benchmark_cass.yml (1 hunks)
  • experiments/charhpc_wandb_sweep.sh (1 hunks)
  • experiments/charhpc_wandb_sweep_cpu.sh (1 hunks)
  • experiments/slurm_base_char_sc.sh (1 hunks)
  • icu_benchmarks/cross_validation.py (3 hunks)
  • icu_benchmarks/data/loader.py (3 hunks)
  • icu_benchmarks/data/preprocessor.py (4 hunks)
  • icu_benchmarks/data/split_process_data.py (11 hunks)
  • icu_benchmarks/data/utils.py (1 hunks)
  • icu_benchmarks/models/alarm_metrics.py (1 hunks)
  • icu_benchmarks/models/constants.py (2 hunks)
  • icu_benchmarks/models/custom_metrics.py (4 hunks)
  • icu_benchmarks/models/ml_models/xgboost.py (4 hunks)
  • icu_benchmarks/models/train.py (4 hunks)
  • icu_benchmarks/models/wrappers.py (8 hunks)
  • icu_benchmarks/run.py (1 hunks)
  • icu_benchmarks/run_utils.py (4 hunks)
  • icu_benchmarks/tuning/hyperparameters.py (2 hunks)
  • requirements.txt (2 hunks)
  • setup.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
icu_benchmarks/data/preprocessor.py (1)
Learnt from: rvandewater
PR: rvandewater/YAIB#155
File: icu_benchmarks/data/preprocessor.py:77-143
Timestamp: 2024-11-12T07:39:38.994Z
Learning: In the `apply` method of `PolarsClassificationPreprocessor` in `icu_benchmarks/data/preprocessor.py`, the data is pre-split, necessitating multiple join operations; thus, optimizing by combining joins is not feasible.
🪛 Shellcheck (0.10.0)
experiments/slurm_base_char_sc.sh

[warning] 18-18: Remove space after = if trying to assign a value (for empty string, use var='' ... ).

(SC1007)


[warning] 28-28: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.

(SC2164)

🪛 Ruff (0.8.2)
icu_benchmarks/models/ml_models/xgboost.py

6-6: numpy imported but unused

Remove unused import: numpy

(F401)


78-78: Local variable feature_importances is assigned to but never used

Remove assignment to unused variable feature_importances

(F841)

icu_benchmarks/models/alarm_metrics.py

2-2: numpy.ndarray imported but unused

Remove unused import: numpy.ndarray

(F401)


3-3: torch imported but unused

Remove unused import: torch

(F401)


4-4: sklearn.metrics.precision_score imported but unused

Remove unused import: sklearn.metrics.precision_score

(F401)

icu_benchmarks/data/split_process_data.py

4-4: numpy imported but unused

Remove unused import: numpy

(F401)


45-45: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)

icu_benchmarks/data/preprocessor.py

185-185: Local variable old_columns is assigned to but never used

Remove assignment to unused variable old_columns

(F841)

icu_benchmarks/models/custom_metrics.py

1-1: numpy imported but unused

Remove unused import: numpy

(F401)


165-165: Redefinition of unused Specificity from line 10

(F811)

icu_benchmarks/models/train.py

216-216: f-string without any placeholders

Remove extraneous f prefix

(F541)

icu_benchmarks/data/utils.py

25-25: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


26-26: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


29-29: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


33-33: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


40-40: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


55-55: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


56-56: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


73-73: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

icu_benchmarks/models/wrappers.py

6-6: ignite.metrics.EpochMetric imported but unused

Remove unused import: ignite.metrics.EpochMetric

(F401)


538-538: f-string without any placeholders

Remove extraneous f prefix

(F541)


554-554: f-string without any placeholders

Remove extraneous f prefix

(F541)


563-563: f-string without any placeholders

Remove extraneous f prefix

(F541)

🪛 GitHub Actions: CI
icu_benchmarks/models/custom_metrics.py

[error] 66-66: Flake8: undefined name 'mae_with_invert_compute_fn'.

🪛 YAMLlint (1.35.1)
experiments/benchmark_cass.yml

[error] 84-84: no new line character at the end of file

(new-line-at-end-of-file)

🔇 Additional comments (33)
setup.py (1)

79-79: Significant version bump from 0.3.1 to 1.1.0

The version has been increased from 0.3.1 to 1.1.0, which indicates a major update according to semantic versioning principles. This aligns with the extensive modifications described in the PR summary, including new features like SHAP-based explanations, refined data processing, and extended XGBoost capabilities.

configs/prediction_models/common/MLTuning.gin (1)

4-4: Hyperparameter tuning has been disabled

Setting tune_hyperparameters.n_calls to 0 effectively disables the hyperparameter tuning process. While this may speed up the training process, it means models will use predefined hyperparameters instead of finding optimal ones.

Is this intentional? Consider documenting the reason for disabling tuning, especially if you're relying on manually optimized parameters from previous experiments.

configs/tasks/common/Dataloader.gin (1)

6-6: New MPS parameter enabled for PredictionPolarsDataset

The addition of PredictionPolarsDataset.mps = True enables a new behavior in the data loading process, affecting how data is processed in the to_tensor method.

Please consider documenting what the mps parameter represents in this context and how it impacts data processing. Is this related to Metal Performance Shaders for GPU acceleration on Apple devices, or does it have a different meaning in this codebase?

requirements.txt (1)

7-7:

Details

✅ Verification successful

Verify compatibility of updated dependencies.

The updates to dependency versions and the addition of SHAP are meaningful improvements. The SHAP library (v0.46.0) enables model explainability, which is valuable for understanding feature importance.

Let's check if these versions are current and compatible:

Also applies to: 27-27, 29-29, 35-35


🏁 Script executed:

#!/bin/bash
pip index versions pytorch-ignite | head -n 5
pip index versions optuna | head -n 5
pip index versions wandb | head -n 5
pip index versions shap | head -n 5

# Check for any known compatibility issues between these packages
pip check pytorch-ignite==0.5.1 optuna==4.1.0 wandb==0.18.5 shap==0.46.0

Length of output: 4446


Dependency compatibility verified – no issues found.

The dependency versions stated in the requirements file (including the updated pytorch-ignite==0.5.1 and the addition of shap==0.46.0) have been checked. The output from pip check confirms that there are no compatibility issues with the specified versions. This applies to all indicated lines (7, 27, 29, and 35).

icu_benchmarks/models/constants.py (1)

30-30: Great addition of clinically relevant metrics.

Adding sensitivity, specificity, and positive predictive value (PPV) metrics is valuable for healthcare applications. These metrics are essential for evaluating models where false positives and false negatives have different clinical implications.

The additions properly complement the existing metrics and follow the established structure. Well done!

Also applies to: 42-46

icu_benchmarks/cross_validation.py (3)

41-41: Good addition of feature explanation capability.

The new explain_features parameter provides a clean way to control feature explanation during model training, following the same style as other parameters in this function.


76-78: Excellent defensive programming check.

This validation prevents a common configuration error by catching cases where cv_folds_to_train exceeds cv_folds, which would lead to unexpected behavior or errors later in the execution.


126-126: LGTM - Correctly passes the explain_features parameter.

The parameter is correctly passed to the train_common function, ensuring the feature explanation capability is properly propagated to the underlying training process.

icu_benchmarks/models/train.py (3)

56-56: LGTM - Parameter addition is consistent with cross_validation.py changes.

The explain_features parameter addition maintains a consistent interface between the cross-validation and training modules.


177-178: Good implementation of the explainer configuration.

The explainer is appropriately set before testing begins, ensuring that any test-time feature explanations can be generated if requested.


196-197: LGTM - Conditional persistence of feature explanations.

The function correctly persists SHAP data only when feature explanations are requested, avoiding unnecessary computation and storage.

icu_benchmarks/tuning/hyperparameters.py (2)

86-86: Good improvement in file existence check.

Changing from exists() to isfile() makes the check more precise, ensuring that the checkpoint path specifically refers to a file rather than any filesystem object.


225-227: Excellent safeguard for skipping unnecessary tuning.

This early return when n_calls <= 0 prevents unnecessary execution of hyperparameter tuning, which is particularly useful when tuning is intentionally disabled through configuration.

icu_benchmarks/models/ml_models/xgboost.py (1)

28-28: LGTM: Added missing parameter settings to XGBoost model.

The addition of eval_metric=log_loss, device="cpu", missing='inf' parameters ensures more consistent model behavior.

icu_benchmarks/run_utils.py (3)

79-79: LGTM: Added microseconds to timestamp format.

Adding microseconds to the timestamp format makes the directory names more unique.


109-109: LGTM: Renamed variable to match new terminology.

Renaming shap_values_test to explainer_values_test ensures consistency with the naming changes in the rest of the codebase.


128-133: LGTM: Updated file paths for explainer values.

Consistently renamed file paths from test_shap_values.parquet to explainer_values_test.parquet to match the variable name changes.

icu_benchmarks/data/preprocessor.py (2)

131-141: LGTM: Improved conditional logic for sequence validation.

Restructuring the sequence check to first verify if "SEQUENCE" exists in vars makes the code more explicit and easier to understand.


188-188: LGTM: Added backend parameter to StepScale.

Specifying the backend for the StepScale operation ensures proper integration with the polars data structure.

icu_benchmarks/models/alarm_metrics.py (2)

13-29: Validate parameter types and thresholds.

  1. The function convert_to_alarm expects ground_truth and predictions to be arrays. Consider validating that numeric values are provided to avoid runtime errors.
  2. For the threshold of 0.5, be prepared to parameterize it or make it easily configurable if your use case evolves.

31-53: Check boundary conditions when silencing positives.

  1. positive_indices = positive_indices[positive_indices < len(ground_truth) - grace_horizon]: This logic might skip positives near file end. Confirm this is desired.
  2. Repeatedly slicing the array in a while loop (silence_array[positive_index+1: ... ]) could be susceptible to edge off-by-one errors.
icu_benchmarks/data/split_process_data.py (2)

99-107: Confirm correct variable exclusions.

The newly introduced logic for exclude_preproc merges modality-based variables. Ensure you verify that only the intended variables are excluded. Edge cases:

  • Overlapping modalities.
  • Missing keys in modality_mapping.

383-386: Validate multi-fold strategy logic.

Switching between StratifiedKFold and StratifiedShuffleSplit based on cv_folds can produce different distributions. Confirm that randomization is acceptable and that test coverage is thorough in both code paths.

configs/prediction_models/XGBClassifier.gin (3)

14-14: Enhanced min_child_weight granularity and range for better regularization control.

The values for min_child_weight now include both finer granularity and a wider range (0.1-10 instead of just 0.5-1), allowing for more precise control over overfitting.


18-21: Added regularization parameters for improved model tuning.

Four important hyperparameters have been added:

  1. gamma - Controls minimum loss reduction for creating a new tree split
  2. early_stopping_rounds - Prevents overfitting by stopping training when validation metric stops improving
  3. reg_lambda/reg_alpha - L2/L1 regularization parameters to control model complexity

These additions provide better control over the model's complexity and training behavior.


17-17:

Details

❓ Verification inconclusive

Verify impact of removing eval_metric parameter.

The eval_metric parameter has been commented out, which means the model might use a default evaluation metric instead of the previously specified "aucpr".

Run the following to check what evaluation metric is being used now:


🏁 Script executed:

#!/bin/bash
# Find where eval_metric is defined or used elsewhere
rg "eval_metric" --type=py icu_benchmarks/models/

Length of output: 220


Action: Confirm intended evaluation metric change

After verification, it appears that the XGBoost model’s evaluation metric is now explicitly set to log_loss in icu_benchmarks/models/ml_models/xgboost.py:

  • The gin config’s eval_metric parameter (previously set to "aucpr") is commented out.
  • The model instantiation explicitly passes eval_metric=log_loss, meaning it will use log_loss instead of "aucpr".

Please confirm if switching to log_loss is an intentional change. If this update is desired, the removal of the aucpr parameter is acceptable; otherwise, consider reverting or updating the evaluation metric accordingly.

icu_benchmarks/data/utils.py (1)

37-38:

Details

✅ Verification successful

Verify the impact of removing the wearable_ppgfeature_HRV_SampEn columns.

Line 37 removes specific wearable features from dynamic data, and line 38 updates the variable list accordingly. This could affect downstream analysis if these features were previously used.

Let's check where these features were used:


🏁 Script executed:

#!/bin/bash
# Search for references to the removed feature
rg "wearable_ppgfeature_HRV_SampEn" --type=py

Length of output: 186


Impact Verification Complete – Removal Appears Safe

After executing the grep search, the only reference to "wearable_ppgfeature_HRV_SampEn" is within its removal in icu_benchmarks/data/utils.py. No downstream Python references were found, suggesting that the change won’t inadvertently affect further analysis. Please double-check any external or non-Python integrations if applicable, but within the codebase, the removal appears to be safe.

icu_benchmarks/data/loader.py (2)

45-46: Added explicit sorting for row indicators to ensure consistent order.

Sorting the row indicators ensures that data is processed in a consistent order across different components, which is especially important for time series data where sequence matters.

Also applies to: 53-53


176-179: Group column removal from representation matrix.

Line 179 removes the first column (group column) from the data representation. This prevents the group identifier from being used as a feature in the model, which is typically correct behavior as IDs shouldn't be treated as features.

icu_benchmarks/models/wrappers.py (4)

50-50: Debug mode enabled by default.

Debug mode has been changed from False to True, which will result in more verbose logging and output saving. This may impact performance in production environments.

Check if this is intentional or if it should only be enabled in development environments.


109-110: Added support for feature explanation.

The new set_explain_features method allows enabling/disabling feature explanation, which can help with model interpretability.


461-467: Enhanced model output saving and feature explanation.

The error handling for saving model outputs has been commented out, and feature explanations are now stored in explainer_values_test. This change enables more robust feature importance analysis.

Consider whether removing the error handling is intentional:

-            # try:
            self._save_model_outputs(pred_indicators, test_pred, test_label)
-            # except Exception as e:
-            #     logging.warning(f"Could not save model outputs: {e}")

436-437:

Details

❓ Verification inconclusive

Removed default validation loss initialization.

Previously, val_loss was initialized to 0.0 as a fallback, but now it directly uses the return value from model.fit(). This assumes that model.fit() always returns a valid validation loss.

Check if all model implementations properly return a validation loss:


🏁 Script executed:

#!/bin/bash
# Find model.fit implementations
rg "def fit\(" --type=py icu_benchmarks/models/ml_models/

Length of output: 59


Attention: Need Manual Verification of model.fit() Return Values

The removal of the default val_loss = 0.0 fallback now assumes that every model.fit() invocation reliably returns a valid validation loss. However, our automated search in the icu_benchmarks/models/ml_models/ directory (using rg "def fit\(" --type=py) did not yield any matching implementations. This could mean one of the following:

  • The fit() methods may be implemented elsewhere (or inherited from base classes), or
  • They might not explicitly return a validation loss as expected.

Action Items:

  • Manually verify that all model implementations (including those possibly inherited or defined in other directories) indeed return a meaningful validation loss.
  • If any model does not meet this assumption, consider either restoring a fallback mechanism or updating the implementation accordingly.


echo "This is a SLURM job named" $SLURM_JOB_NAME "with array id" $SLURM_ARRAY_TASK_ID "and job id" $SLURM_JOB_ID
echo "Resources allocated: " $SLURM_CPUS_PER_TASK "CPUs, " $SLURM_MEM_PER_NODE "GB RAM, " $SLURM_GPUS_PER_NODE "GPUs"
echi "Task type:" ${TASK}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix typo in echo command.

There's a typo in the command: echi should be echo.

-echi "Task type:" ${TASK}
+echo "Task type:" ${TASK}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
echi "Task type:" ${TASK}
echo "Task type:" ${TASK}

TASK=BinaryClassification # BinaryClassification
YAIB_PATH=/home/vandewrp/projects/YAIB #/dhc/home/robin.vandewater/projects/yaib
EXPERIMENT_PATH=../yaib_logs/${TASK_NAME}_experiment
DATASET_ROOT_PATH= /sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format #data/YAIB_Datasets/data
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix Bash variable assignment.

Bash assignments should not have spaces around the = sign, as it affects variable assignment.

-DATASET_ROOT_PATH= /sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format #data/YAIB_Datasets/data
+DATASET_ROOT_PATH=/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format #data/YAIB_Datasets/data
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
DATASET_ROOT_PATH= /sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format #data/YAIB_Datasets/data
DATASET_ROOT_PATH=/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format #data/YAIB_Datasets/data
🧰 Tools
🪛 Shellcheck (0.10.0)

[warning] 18-18: Remove space after = if trying to assign a value (for empty string, use var='' ... ).

(SC1007)

echo "Dataset: "${DATASETS[$SLURM_ARRAY_TASK_ID]}
echo "Experiment path: "${EXPERIMENT_PATH}

cd ${YAIB_PATH}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for directory change.

The script should exit if the directory change fails, which can happen if the path is incorrect or inaccessible.

-cd ${YAIB_PATH}
+cd ${YAIB_PATH} || { echo "Failed to change to directory ${YAIB_PATH}"; exit 1; }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
cd ${YAIB_PATH}
cd ${YAIB_PATH} || { echo "Failed to change to directory ${YAIB_PATH}"; exit 1; }
🧰 Tools
🪛 Shellcheck (0.10.0)

[warning] 28-28: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.

(SC2164)


icu-benchmarks train \
-d ${DATASET_ROOT_PATH} \
-n ${DATASETS} \
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Variable expansion may be incorrect.

The script passes ${DATASETS} to the -n flag, but this appears to be an array from the usage on line 25. You likely need to reference a specific element of the array or convert it appropriately.

-  -n ${DATASETS} \
+  -n ${DATASETS[$SLURM_ARRAY_TASK_ID]} \
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
-n ${DATASETS} \
-n ${DATASETS[$SLURM_ARRAY_TASK_ID]} \

model/hyperparameter.class_to_tune = @XGBClassifier
model/hyperparameter.learning_rate = (0.01, 0.1, "log")
model/hyperparameter.n_estimators = [50, 100, 250, 500, 750, 1000,1500,2000]
model/hyperparameter.n_estimators = [50, 100, 250, 500, 750, 1000,1500,2000, 2500, 3000, 3500, 4000, 4500, 5000]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Expanded n_estimators range for potentially deeper ensemble models.

The range of n_estimators has been significantly increased from a previous maximum of 2000 to 5000, allowing for much deeper ensemble models. This can lead to better performance but may also increase training time substantially.

Make sure this expanded range is compatible with your computational resources. Run the following to check if the tuning is actually enabled:


🏁 Script executed:

#!/bin/bash
# Check if hyperparameter tuning is enabled in MLTuning.gin
rg "tune_hyperparameters.n_calls" configs/prediction_models/common/MLTuning.gin

Length of output: 112


Action Required: Verify Hyperparameter Tuning Configuration

  • The expanded n_estimators range in configs/prediction_models/XGBClassifier.gin now goes up to 5000 (previous max was 2000). This change could improve model performance but also significantly increase training time.
  • However, the output from configs/prediction_models/common/MLTuning.gin shows tune_hyperparameters.n_calls = 0, indicating that hyperparameter tuning is currently disabled.
  • Please confirm whether hyperparameter tuning is intended to be enabled. If not, consider whether expanding the estimator range is still necessary. If tuning should be enabled to take advantage of the broader range, update the tuning configuration accordingly and verify that your computational resources can handle the increased training time.


def get_feature_names(self) -> List[str]:
return self.features_df.columns
return [col for col in self.features_df.columns ]#if col != self.vars["GROUP"] and col != self.vars["SEQUENCE"]]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Commented code may affect feature name filtering.

The commented portion would filter out GROUP and SEQUENCE columns from feature names. Currently, these columns will be included, which might affect downstream models that expect only feature columns.

Let's check if these columns are used as features elsewhere:


🏁 Script executed:

#!/bin/bash
# Find where get_feature_names() is used
rg "get_feature_names\(\)" --type=py icu_benchmarks/models/

Length of output: 157


Attention: Inconsistent Feature Filtering – Investigate GROUP and SEQUENCE Exclusion

The current implementation in icu_benchmarks/data/loader.py (line 75) returns all columns from features_df, including GROUP and SEQUENCE, because the filtering logic is commented out. Notably, in icu_benchmarks/models/train.py, the return value of get_feature_names() is used to set the model’s trained columns. This means that the extra columns are now treated as features, which may lead to downstream issues if these columns are not intended for model training.

  • Location Affected:

    • icu_benchmarks/data/loader.py (line 75): The filtering condition for GROUP and SEQUENCE is commented out.
    • icu_benchmarks/models/train.py: The output of get_feature_names() is passed to model.set_trained_columns().
  • Action Required:
    Verify whether excluding the GROUP and SEQUENCE columns is necessary. If these columns are not meant to be part of the feature set, please re-enable the filtering logic. Otherwise, update the documentation to clarify that these columns are intentionally included during training.

data, labels, row_indicators = self.get_data_and_labels()
if self.mps:
return from_numpy(data).to(float32), from_numpy(labels).to(float32)
return from_numpy(data).to(float32), from_numpy(labels).to(float32), row_indicators
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Return type extended to include row indicators.

The to_tensor method now returns row_indicators along with data and labels when mps is true. Ensure all calling code handles this additional return value properly.

Check if all callers of this method handle the additional return value:


🏁 Script executed:

#!/bin/bash
# Find where to_tensor is called
rg "\bto_tensor\(\)" --type=py

Length of output: 129


Attention: Update DataLoader Usage to Properly Handle Triple Return Value

The recent change to extend the return type of to_tensor now includes row_indicators along with data and labels. However, the call in icu_benchmarks/models/train.py:

else DataLoader([test_dataset.to_tensor()], batch_size=1)

currently does not unpack or handle the additional row_indicators. Please update the caller to either unpack the tuple (e.g., (data, labels, row_indicators) = test_dataset.to_tensor()) or adjust downstream code to accommodate the extra value accordingly.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
icu_benchmarks/models/train.py (1)

209-235: ⚠️ Potential issue

Inconsistency in explainer values handling between test and train.

There's an inconsistency in how explainer values are processed:

  • For test values (line 222): explainer_values is used directly
  • For train values (line 230): explainer_values.values is used

This suggests different object types are being handled, which could cause errors.

-            shaps_train = pl.DataFrame(schema=trained_columns, data=explainer_values.values)
+            shaps_train = pl.DataFrame(schema=trained_columns, data=explainer_values)

Alternatively, if the objects truly have different structures:

-            shaps_train = pl.DataFrame(schema=trained_columns, data=explainer_values.values)
+            # Use consistent structure based on the explainer_values_train object type
+            if hasattr(explainer_values, 'values'):
+                shaps_train = pl.DataFrame(schema=trained_columns, data=explainer_values.values)
+            else:
+                shaps_train = pl.DataFrame(schema=trained_columns, data=explainer_values)
🧰 Tools
🪛 Ruff (0.8.2)

216-216: f-string without any placeholders

Remove extraneous f prefix

(F541)

🧹 Nitpick comments (20)
experiments/benchmark_cass.yml (9)

14-17: Hyperparameter Tuning Naming and Checkpoint Options
The parameters -tn and SSI are configured for naming conventions within hyperparameter tuning. Additionally, several hyperparameter checkpoint options remain commented out. If these are no longer required, consider removing them to improve clarity.


18-23: Historical Configuration Comments
The commented-out lines for alternate hyperparameter tuning log paths and verbose flags provide useful historical context. Once the configuration is finalized, consider cleaning up unnecessary commented lines to reduce clutter.


24-27: Modalities Flag in Command
The --modalities flag is introduced correctly with the active value "all". The commented-out verbose flag is acceptable during experimentation—if it’s no longer needed, its removal might simplify the configuration.


28-30: Optional Log Configuration Comments
Additional commented-out options (for example, alternate transformer log paths and flags like -gc or -lc) remain present. Keeping these can be useful during development, but you may want to remove them once the configuration stabilizes.


33-49: Data Directory Configuration
The data_dir parameter is configured with one active data path and several alternative paths commented out. This flexible setup allows for quick switching between data sources during development. Ensure that the selected active path is valid and consider cleaning up the commented entries once you've finalized the configuration.


68-81: Model Parameter Selection
The model parameter explicitly selects XGBClassifier as the active model, while alternative model options are provided as comments. This configuration aligns with the enhanced hyperparameter settings. If these alternatives will no longer be used, removing them might improve readability.


82-89: Modalities Configuration Consistency
The modalities parameter lists two active values: a string "all" and an array of modality identifiers. Please verify that the system consuming this config can handle mixed type values. If a consistent format is needed, consider normalizing these entries.


93-95: Pretrained Imputation Setting
The use_pretrained_imputation parameter is set with the value None, indicating that no pretrained imputation is applied. Confirm that downstream processes correctly interpret this setting (e.g., as a null value) or adjust if a different behavior is required.


101-101: Newline at End-of-File
The file is missing a newline character at the end. Please add a newline to adhere to YAML style guidelines and avoid potential issues with certain tools or environments.

🧰 Tools
🪛 YAMLlint (1.35.1)

[error] 101-101: no new line character at the end of file

(new-line-at-end-of-file)

icu_benchmarks/models/train.py (1)

216-216: Remove extraneous f-string prefix.

The string doesn't contain any placeholders, so the f prefix is unnecessary.

-        logging.info(f"Saving SHAPS")
+        logging.info("Saving SHAPS")
🧰 Tools
🪛 Ruff (0.8.2)

216-216: f-string without any placeholders

Remove extraneous f prefix

(F541)

icu_benchmarks/models/ml_models/xgboost.py (2)

6-6: Remove unused import.

The numpy module is imported but never used in this file.

-import numpy as np
🧰 Tools
🪛 Ruff (0.8.2)

6-6: numpy imported but unused

Remove unused import: numpy

(F401)


84-87: Consider removing commented-out code.

The commented-out explainer method seems to be replaced by the new SHAP-based implementation. Unless this is intentionally being kept for reference, consider removing it to keep the codebase clean.

-    # def explainer(self, reps):
-    #     if not hasattr(self.model, "feature_importances_"):
-    #         raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.")
-    #     return self.model.feature_importances_
icu_benchmarks/models/wrappers.py (8)

6-6: Remove unused import.

The import EpochMetric is never referenced in the code and can safely be removed to reduce clutter and improve maintainability.

-from ignite.metrics import EpochMetric
🧰 Tools
🪛 Ruff (0.8.2)

6-6: ignite.metrics.EpochMetric imported but unused

Remove unused import: ignite.metrics.EpochMetric

(F401)


109-110: Add docstring for new method.

The set_explain_features method would benefit from a short docstring describing how and when to call it, ensuring maintainers and users understand its purpose.

 def set_explain_features(self, set_explain_features: bool):
+    """
+    Enable or disable feature explanations for downstream tasks.
+    """
     self.explain_features = set_explain_features

436-437: Clarify return value from model.fit.

The call to self.model.fit(train_data, train_labels) is stored in val_loss, but it's unclear if the model’s fit method actually returns a loss value. Consider renaming the variable or verifying the returned object to avoid confusion.


461-462: Restore error handling if necessary.

Removing the try block around _save_model_outputs means the method could raise exceptions that go unhandled during testing. Ensure that this is intentional or reintroduce suitable error handling.


463-464: Reevaluate commented-out exception handler.

The exception handling has been fully commented out, so any file I/O issues will now propagate. If the intention is to log such failures, reintroduce the logging statement in an except block.


473-473: Using raw label/pred in log metrics.

Unlike line 470, here you leave the data as-is. Consider standardizing your approach (either always cast or ensure consistent types) for consistent metric logging.


482-520: Unify metric logging logic.

Within log_metrics, both the if pred_indicators is None and else blocks contain largely similar metric logging code (confusion matrix, dictionary creation). Consider unifying these sections to reduce duplication and complexity.


528-568: Clean up duplicate or commented-out lines & fix extraneous f-strings.

  1. Several commented-out lines in _save_model_outputs add noise. Remove or refactor them for clarity.
  2. The static analysis flags f-strings without placeholders (e.g., f'row_indicators.csv'). Remove unused f prefixes to avoid confusion.
-logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / f'row_indicators.csv'}")
+logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / 'row_indicators.csv'}")
🧰 Tools
🪛 Ruff (0.8.2)

538-538: f-string without any placeholders

Remove extraneous f prefix

(F541)


554-554: f-string without any placeholders

Remove extraneous f prefix

(F541)


563-563: f-string without any placeholders

Remove extraneous f prefix

(F541)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0afa916 and 320007a.

📒 Files selected for processing (9)
  • configs/prediction_models/common/MLTuning.gin (1 hunks)
  • configs/tasks/BinaryClassification.gin (1 hunks)
  • experiments/benchmark_cass.yml (1 hunks)
  • icu_benchmarks/data/split_process_data.py (12 hunks)
  • icu_benchmarks/models/ml_models/xgboost.py (4 hunks)
  • icu_benchmarks/models/train.py (4 hunks)
  • icu_benchmarks/models/wrappers.py (8 hunks)
  • icu_benchmarks/run_utils.py (6 hunks)
  • icu_benchmarks/tuning/hyperparameters.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • configs/prediction_models/common/MLTuning.gin
  • icu_benchmarks/run_utils.py
  • icu_benchmarks/tuning/hyperparameters.py
🧰 Additional context used
🧬 Code Definitions (2)
icu_benchmarks/models/wrappers.py (2)
icu_benchmarks/models/utils.py (2)
  • log_metrics (162-183)
  • name (158-159)
icu_benchmarks/models/custom_metrics.py (1)
  • confusion_matrix (137-146)
icu_benchmarks/models/train.py (1)
icu_benchmarks/models/wrappers.py (1)
  • set_explain_features (109-110)
🪛 Ruff (0.8.2)
icu_benchmarks/data/split_process_data.py

4-4: numpy imported but unused

Remove unused import: numpy

(F401)


45-45: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)

icu_benchmarks/models/ml_models/xgboost.py

6-6: numpy imported but unused

Remove unused import: numpy

(F401)

icu_benchmarks/models/wrappers.py

6-6: ignite.metrics.EpochMetric imported but unused

Remove unused import: ignite.metrics.EpochMetric

(F401)


538-538: f-string without any placeholders

Remove extraneous f prefix

(F541)


554-554: f-string without any placeholders

Remove extraneous f prefix

(F541)


563-563: f-string without any placeholders

Remove extraneous f prefix

(F541)

icu_benchmarks/models/train.py

216-216: f-string without any placeholders

Remove extraneous f prefix

(F541)

🪛 YAMLlint (1.35.1)
experiments/benchmark_cass.yml

[error] 101-101: no new line character at the end of file

(new-line-at-end-of-file)

🔇 Additional comments (23)
experiments/benchmark_cass.yml (5)

1-9: Command Configuration Validation
The command list is structured properly with the appropriate use of environment variables (${env}) and program parameters (${program}). The change from the commented-out "BinaryClassification" to the active "CassClassification" appears intentional and aligns with the updated benchmarking focus.


10-13: Logging and Tuning Parameters
The logging flag (--log-dir) and tuning flags (--tune, --wandb-sweep) are clearly specified. Please ensure that the log directory (/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs) is accessible and that appropriate permissions are in place in the target environment.


31-32: Benchmark Method and Name
Setting the benchmarking method to grid and naming the benchmark yaib_classification_benchmark clearly defines the experiment’s search strategy and identification.


90-92: Seed for Reproducibility
The random seed is set to 1111, which helps in ensuring reproducible results across experiment runs. This is a standard and effective configuration.


96-96: Program Specification
The program is set to icu-benchmarks, which should correspond with the intended execution context of the benchmark. Verify that all associated scripts and modules recognize this program identifier appropriately.

configs/tasks/BinaryClassification.gin (1)

25-30: Configuration enhancement for required data segments and filenames.

The addition of explicit required segments and file name mappings improves configuration clarity and makes dependencies more explicit. This change aligns well with the overall improvements in data processing from the PR.

icu_benchmarks/models/train.py (3)

56-56: Well-designed addition of feature explanation flag.

Adding the explain_features parameter with a default of False provides a clean way to optionally enable feature explanations during testing without breaking existing functionality.


177-178: Good implementation of feature explanation configuration.

The model is properly configured for feature explanation before testing, which aligns with the PR's goal of enabling feature explanation throughout training.


196-197: Conditionally persisting explainer data is a good approach.

The conditional call to persist_shap_data only when explain_features is True ensures that explainer values are only computed and saved when needed.

icu_benchmarks/models/ml_models/xgboost.py (3)

28-28: Enhanced model initialization with improved parameters.

Adding eval_metric=log_loss, missing='inf', and verbosity=0 parameters provides better default behavior for the XGBoost classifier.


51-57: Good implementation of SHAP-based feature explanation.

The implementation of TreeExplainer with feature_perturbation="interventional" and model_output="probability" parameters provides an effective way to explain feature importance. The conditional computation of SHAP values only when feature explanation is enabled is also a good approach.


76-83: Well-designed feature explanation method.

The _explain_model method properly returns SHAP values from the explainer, which is a more interpretable way to understand feature importance than the built-in feature importances.

icu_benchmarks/data/split_process_data.py (5)

4-4: Remove unused import.

The numpy module is imported but never used in this file.

-import numpy as np
🧰 Tools
🪛 Ruff (0.8.2)

4-4: numpy imported but unused

Remove unused import: numpy

(F401)


45-45: Avoid mutable default arguments.

Using ["GROUP", "LABEL"] as a default for required_var_types can introduce subtle bugs if the list is mutated. Replace with None or a tuple and initialize inside the function as needed.

-required_var_types=["GROUP", "LABEL"],
+required_var_types=None,

And inside the function:

if required_var_types is None:
    required_var_types = ["GROUP", "LABEL"]
🧰 Tools
🪛 Ruff (0.8.2)

45-45: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


99-111: Good implementation of modality exclusion logic.

The added logic for excluding variables from preprocessing based on modality mapping is well-structured and includes proper error checking and informative logging messages.


322-322: Improved default train_size parameter.

Setting a default value of 0.80 for train_size provides a reasonable split ratio and makes the function more user-friendly by not requiring explicit specification of this parameter.


383-386: Enhanced cross-validation strategy for small fold counts.

Using StratifiedShuffleSplit when fold count is ≤ 2 is a good approach for handling smaller cross-validation scenarios, as it ensures proper class distribution in the splits.

icu_benchmarks/models/wrappers.py (6)

50-50: Confirm the impact of always enabling debug mode.

Setting debug = True might produce verbose logs or additional overhead. Ensure that this setting is intentional for production environments, as debug logs can expose sensitive details and affect performance.


66-66: Logging column count is useful.

This logging addition provides helpful visibility into how many columns are being set for training. No issues here.


228-229: No functional changes.

These lines appear to be blank space insertions or formatting tweaks. No action needed.


432-432: Confirm non-null usage of row_indicators.

You are now passing row_indicators to log_metrics for the training phase. If row_indicators can be None for the training set, ensure that log_metrics can handle it gracefully.


466-467: Use _explain_model consistently.

You replaced self.explain_model with self.explainer_values_test = self._explain_model(...). If _explain_model provides all needed functionality, confirm that explain_model is no longer required or reference it for clarity.


470-470: Casting labels and predictions to float32.

Confirm that explicitly casting test_label and test_pred to float32 does not introduce precision issues for downstream calculations.

Comment on lines 194 to 207
logging.info("Checking for infinite values.")
for col in val.select(cs.numeric()).columns:
if val[col].is_infinite().any():
logging.info(f"Column '{col}' contains infinite values. Datatype: {val[col].dtype}")

max_float64 = 0
# Replace infinite values with the maximum value for float64
val = val.with_columns([
pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col)
for col in val.columns if val[col].dtype == pl.Float64
])
dict[key] = val
logging.info(f"Amount of columns: {len(val.columns)}")

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

The max_float64 value is incorrectly set to 0.

Setting max_float64 = 0 before replacing infinite values doesn't seem correct. To properly handle infinite values, you should use the maximum finite value for float64 from numpy.

-            max_float64 = 0
+            # Use numpy's maximum finite value for float64
+            import numpy as np
+            max_float64 = np.finfo(np.float64).max

This ensures that infinite values are replaced with the largest valid float64 value instead of 0, which could significantly affect model training.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
logging.info("Checking for infinite values.")
for col in val.select(cs.numeric()).columns:
if val[col].is_infinite().any():
logging.info(f"Column '{col}' contains infinite values. Datatype: {val[col].dtype}")
max_float64 = 0
# Replace infinite values with the maximum value for float64
val = val.with_columns([
pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col)
for col in val.columns if val[col].dtype == pl.Float64
])
dict[key] = val
logging.info(f"Amount of columns: {len(val.columns)}")
logging.info("Checking for infinite values.")
for col in val.select(cs.numeric()).columns:
if val[col].is_infinite().any():
logging.info(f"Column '{col}' contains infinite values. Datatype: {val[col].dtype}")
# Use numpy's maximum finite value for float64
import numpy as np
max_float64 = np.finfo(np.float64).max
# Replace infinite values with the maximum value for float64
val = val.with_columns([
pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col)
for col in val.columns if val[col].dtype == pl.Float64
])
dict[key] = val
logging.info(f"Amount of columns: {len(val.columns)}")

# Conflicts:
#	icu_benchmarks/data/loader.py
#	icu_benchmarks/data/split_process_data.py
#	icu_benchmarks/models/ml_models/xgboost.py
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
icu_benchmarks/data/split_process_data.py (1)

220-244: Update function documentation to match new return signature.

The check_sanitize_data function now returns both data and vars, but the function docstring doesn't reflect this change. Update the documentation to match the actual implementation.

def check_sanitize_data(data, vars):
-    """Check for duplicates in the loaded data and remove them."""
+    """Check for duplicates in the loaded data and remove them.
+    
+    Args:
+        data: Dictionary containing data divided into segments
+        vars: Dictionary containing variable metadata
+        
+    Returns:
+        Tuple of (sanitized_data, updated_vars)
+    """
🧰 Tools
🪛 Ruff (0.8.2)

220-220: Redefinition of unused check_sanitize_data from line 19

(F811)


222-222: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


223-223: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


225-225: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


230-230: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


235-235: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 22c8741 and de09ec8.

📒 Files selected for processing (1)
  • icu_benchmarks/data/split_process_data.py (12 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
icu_benchmarks/data/split_process_data.py

4-4: numpy imported but unused

Remove unused import: numpy

(F401)


45-45: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)

🔇 Additional comments (7)
icu_benchmarks/data/split_process_data.py (7)

4-4: The numpy import is now being used in the code.

Previously flagged as unused, now it appears to be used for handling infinite values (in concept at least). However, the numpy import is still not directly referenced in the implementation. It could be better utilized for setting the maximum float64 value.

🧰 Tools
🪛 Ruff (0.8.2)

4-4: numpy imported but unused

Remove unused import: numpy

(F401)


45-45: Avoid mutable default arguments.

Using ["GROUP", "LABEL"] as a default for required_var_types can introduce subtle bugs if the list is mutated. Replace with None or a tuple and initialize inside the function as needed.

- required_var_types=["GROUP", "LABEL"],
+ required_var_types=None,

Then add inside the function:

if required_var_types is None:
    required_var_types = ["GROUP", "LABEL"]
🧰 Tools
🪛 Ruff (0.8.2)

45-45: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


99-111: LGTM! Good implementation of modality exclusion.

The implementation of modality-based variable exclusion is well-structured with appropriate error handling and logging. This provides a flexible way to exclude variables that have already undergone preprocessing.


138-138: Function call updated to match new return signature.

The code correctly handles the updated return signature of check_sanitize_data, which now returns both the sanitized data and updated vars.


199-207: The max_float64 value is incorrectly set to 0.

Setting max_float64 = 0 before replacing infinite values doesn't seem correct. To properly handle infinite values, you should use the maximum finite value for float64 from numpy.

-            max_float64 = 0
+            # Use numpy's maximum finite value for float64
+            max_float64 = np.finfo(np.float64).max

This ensures that infinite values are replaced with the largest valid float64 value instead of 0, which could significantly affect model training.


439-442: Good enhancement for flexibility in cross-validation.

The conditional logic for choosing between StratifiedKFold and StratifiedShuffleSplit based on the number of folds is a good enhancement that improves flexibility, especially for smaller datasets.


378-378: Good default value for train_size parameter.

Setting a default value of 0.80 for train_size is a sensible choice and aligns with common machine learning practices.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (1)
icu_benchmarks/data/utils.py (1)

26-51: Apply static analysis suggestions and improve code clarity

The function works correctly but has some style issues flagged by static analysis tools.

Apply this diff to address the static analysis hints:

def check_sanitize_data(data, vars):
    """Check for duplicates in the loaded data and remove them."""
-    group = vars[Var.group] if Var.group in vars.keys() else None
-    sequence = vars[Var.sequence] if Var.sequence in vars.keys() else None
+    group = vars[Var.group] if Var.group in vars else None
+    sequence = vars[Var.sequence] if Var.sequence in vars else None
     keep = "last"
     logging.info(data.keys())
-    if Segment.static in data.keys():
+    if Segment.static in data:
         old_len = len(data[Segment.static])
         data[Segment.static] = data[Segment.static].unique(subset=group, keep=keep, maintain_order=True)
         logging.warning(f"Removed {old_len - len(data[Segment.static])} duplicates from static data.")
-    if Segment.dynamic in data.keys():
+    if Segment.dynamic in data:
         old_len = len(data[Segment.dynamic])
         data[Segment.dynamic] = data[Segment.dynamic].unique(subset=[group, sequence], keep=keep, maintain_order=True)
         data[Segment.dynamic] = infinite_removal(data[Segment.dynamic])
         data[Segment.dynamic] = data[Segment.dynamic].select(~cs.starts_with("wearable_ppgfeature_HRV_SampEn"))
         vars[Segment.dynamic] = [col for col in vars[Segment.dynamic] if col in data[Segment.dynamic].columns]
         logging.warning(f"Removed {old_len - len(data[Segment.dynamic])} duplicates from dynamic data.")
-    if Segment.outcome in data.keys():
+    if Segment.outcome in data:
         old_len = len(data[Segment.outcome])
         if sequence in data[Segment.outcome].columns:
             # We have a dynamic outcome with group and sequence
             data[Segment.outcome] = data[Segment.outcome].unique(subset=[group, sequence], keep=keep, maintain_order=True)
         else:
             data[Segment.outcome] = data[Segment.outcome].unique(subset=[group], keep=keep, maintain_order=True)
         logging.warning(f"Removed {old_len - len(data[Segment.outcome])} duplicates from outcome data.")
     return data, vars
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 95ff49d and 5b71c64.

📒 Files selected for processing (4)
  • icu_benchmarks/data/loader.py (3 hunks)
  • icu_benchmarks/data/preprocessor.py (6 hunks)
  • icu_benchmarks/data/utils.py (1 hunks)
  • icu_benchmarks/tuning/hyperparameters.py (7 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • icu_benchmarks/data/loader.py
  • icu_benchmarks/tuning/hyperparameters.py
  • icu_benchmarks/data/preprocessor.py
🧰 Additional context used
🪛 Ruff (0.12.2)
icu_benchmarks/data/utils.py

28-28: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


29-29: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


32-32: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


36-36: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


43-43: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


58-58: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


59-59: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


77-77: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

🔇 Additional comments (1)
icu_benchmarks/data/utils.py (1)

1-6: LGTM: Clean imports and dependencies

The imports are well-organized and appropriate for the data processing utilities implemented in this module.

logging.info("No columns selected. Using all columns.")
return data, vars
selected_columns = sum(selected_columns, [])
selected_columns.extend([vars[Var.group], vars[Var.label], vars[Var.sequence]])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add defensive programming for missing standard columns

The function assumes that Var.group, Var.label, and Var.sequence always exist in the vars dictionary. Add validation to prevent KeyError exceptions.

Apply this diff to add defensive checks:

-    selected_columns.extend([vars[Var.group], vars[Var.label], vars[Var.sequence]])
+    standard_cols = []
+    for var_type in [Var.group, Var.label, Var.sequence]:
+        if var_type in vars:
+            standard_cols.append(vars[var_type])
+        else:
+            logging.warning(f"Standard column {var_type} not found in vars")
+    selected_columns.extend(standard_cols)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
selected_columns.extend([vars[Var.group], vars[Var.label], vars[Var.sequence]])
standard_cols = []
for var_type in [Var.group, Var.label, Var.sequence]:
if var_type in vars:
standard_cols.append(vars[var_type])
else:
logging.warning(f"Standard column {var_type} not found in vars")
selected_columns.extend(standard_cols)
🤖 Prompt for AI Agents
In icu_benchmarks/data/utils.py at line 65, the code assumes that keys
Var.group, Var.label, and Var.sequence exist in the vars dictionary, which can
cause KeyError if any are missing. Modify the code to check if each of these
keys exists in vars before extending selected_columns. Only add the columns that
are present to avoid exceptions and ensure robustness.

Comment on lines +9 to +23
def infinite_removal(val):
for col in val.select(cs.numeric()).columns:
if val[col].is_infinite().any():
logging.info(f"Column '{col}' contains infinite values. Datatype: {val[col].dtype}")

max_float64 = np.finfo(np.float64).max / 100
# Replace infinite values with the maximum value for float64
val = val.with_columns(
[
pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col)
for col in val.columns
if val[col].dtype == pl.Float64
]
)
return val
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix inconsistent infinite value handling across numeric types

The function checks for infinite values in all numeric columns but only replaces them in Float64 columns. This could leave infinite values in other numeric types (Float32, Int64, etc.) unhandled.

Apply this diff to handle all numeric types consistently:

def infinite_removal(val):
-    for col in val.select(cs.numeric()).columns:
-        if val[col].is_infinite().any():
-            logging.info(f"Column '{col}' contains infinite values. Datatype: {val[col].dtype}")
-
-    max_float64 = np.finfo(np.float64).max / 100
-    # Replace infinite values with the maximum value for float64
-    val = val.with_columns(
-        [
-            pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col)
-            for col in val.columns
-            if val[col].dtype == pl.Float64
-        ]
-    )
+    numeric_cols = val.select(cs.numeric()).columns
+    cols_with_infinite = []
+    
+    for col in numeric_cols:
+        if val[col].is_infinite().any():
+            logging.info(f"Column '{col}' contains infinite values. Datatype: {val[col].dtype}")
+            cols_with_infinite.append(col)
+    
+    if cols_with_infinite:
+        max_float64 = np.finfo(np.float64).max / 100
+        # Replace infinite values with the maximum value for numeric columns
+        val = val.with_columns(
+            [
+                pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col)
+                for col in cols_with_infinite
+            ]
+        )
+    
     return val
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def infinite_removal(val):
for col in val.select(cs.numeric()).columns:
if val[col].is_infinite().any():
logging.info(f"Column '{col}' contains infinite values. Datatype: {val[col].dtype}")
max_float64 = np.finfo(np.float64).max / 100
# Replace infinite values with the maximum value for float64
val = val.with_columns(
[
pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col)
for col in val.columns
if val[col].dtype == pl.Float64
]
)
return val
def infinite_removal(val):
numeric_cols = val.select(cs.numeric()).columns
cols_with_infinite = []
for col in numeric_cols:
if val[col].is_infinite().any():
logging.info(f"Column '{col}' contains infinite values. Datatype: {val[col].dtype}")
cols_with_infinite.append(col)
if cols_with_infinite:
max_float64 = np.finfo(np.float64).max / 100
# Replace infinite values with the maximum value for numeric columns
val = val.with_columns(
[
pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col)
for col in cols_with_infinite
]
)
return val
🤖 Prompt for AI Agents
In icu_benchmarks/data/utils.py around lines 9 to 23, the function
infinite_removal only replaces infinite values in Float64 columns but checks all
numeric columns, causing inconsistent handling. Modify the replacement logic to
include all numeric columns with floating point types (e.g., Float32 and
Float64) by adjusting the condition to cover these types, and ensure the
replacement value is appropriate for each type. This will consistently handle
infinite values across all relevant numeric columns.

Comment on lines +54 to +81
def modality_selection(
data: dict[pl.DataFrame], modality_mapping: dict[str], selected_modalities: list[str], vars
) -> dict[pl.DataFrame]:
logging.info(f"Selected modalities: {selected_modalities}")
selected_columns = [modality_mapping[cols] for cols in selected_modalities if cols in modality_mapping.keys()]
if not any(col in modality_mapping.keys() for col in selected_modalities):
raise ValueError("None of the selected modalities found in modality mapping.")
if selected_columns == []:
logging.info("No columns selected. Using all columns.")
return data, vars
selected_columns = sum(selected_columns, [])
selected_columns.extend([vars[Var.group], vars[Var.label], vars[Var.sequence]])
old_columns = []
# Update vars dict
for key, value in vars.items():
if key not in [Var.group, Var.label, Var.sequence]:
old_columns.extend(value)
vars[key] = [col for col in value if col in selected_columns]
# -3 because of standard columns
logging.info(f"Selected columns: {len(selected_columns) - 3}, original columns: {len(old_columns)}, "
f"not using: {len(set(old_columns) - set(selected_columns))} columns")
logging.debug(f"Not using columns: {set(old_columns) - set(selected_columns)}")
# Update data dict
for key in data.keys():
sel_col = [col for col in data[key].columns if col in selected_columns]
data[key] = data[key].select(sel_col)
logging.debug(f"Selected columns in {key}: {len(data[key].columns)}")
return data, vars
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix return type annotation and apply static analysis suggestions

The function has an incorrect return type annotation and contains static analysis issues.

Apply this diff to fix the return type and style issues:

def modality_selection(
-    data: dict[pl.DataFrame], modality_mapping: dict[str], selected_modalities: list[str], vars
-) -> dict[pl.DataFrame]:
+    data: dict[pl.DataFrame], modality_mapping: dict[str], selected_modalities: list[str], vars
+) -> tuple[dict[pl.DataFrame], dict]:
     logging.info(f"Selected modalities: {selected_modalities}")
-    selected_columns = [modality_mapping[cols] for cols in selected_modalities if cols in modality_mapping.keys()]
-    if not any(col in modality_mapping.keys() for col in selected_modalities):
+    selected_columns = [modality_mapping[cols] for cols in selected_modalities if cols in modality_mapping]
+    if not any(col in modality_mapping for col in selected_modalities):
         raise ValueError("None of the selected modalities found in modality mapping.")
     if selected_columns == []:
         logging.info("No columns selected. Using all columns.")
         return data, vars
     selected_columns = sum(selected_columns, [])
     selected_columns.extend([vars[Var.group], vars[Var.label], vars[Var.sequence]])
     old_columns = []
     # Update vars dict
     for key, value in vars.items():
         if key not in [Var.group, Var.label, Var.sequence]:
             old_columns.extend(value)
             vars[key] = [col for col in value if col in selected_columns]
     # -3 because of standard columns
     logging.info(f"Selected columns: {len(selected_columns) - 3}, original columns: {len(old_columns)}, "
                  f"not using: {len(set(old_columns) - set(selected_columns))} columns")
     logging.debug(f"Not using columns: {set(old_columns) - set(selected_columns)}")
     # Update data dict
-    for key in data.keys():
+    for key in data:
         sel_col = [col for col in data[key].columns if col in selected_columns]
         data[key] = data[key].select(sel_col)
         logging.debug(f"Selected columns in {key}: {len(data[key].columns)}")
     return data, vars
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def modality_selection(
data: dict[pl.DataFrame], modality_mapping: dict[str], selected_modalities: list[str], vars
) -> dict[pl.DataFrame]:
logging.info(f"Selected modalities: {selected_modalities}")
selected_columns = [modality_mapping[cols] for cols in selected_modalities if cols in modality_mapping.keys()]
if not any(col in modality_mapping.keys() for col in selected_modalities):
raise ValueError("None of the selected modalities found in modality mapping.")
if selected_columns == []:
logging.info("No columns selected. Using all columns.")
return data, vars
selected_columns = sum(selected_columns, [])
selected_columns.extend([vars[Var.group], vars[Var.label], vars[Var.sequence]])
old_columns = []
# Update vars dict
for key, value in vars.items():
if key not in [Var.group, Var.label, Var.sequence]:
old_columns.extend(value)
vars[key] = [col for col in value if col in selected_columns]
# -3 because of standard columns
logging.info(f"Selected columns: {len(selected_columns) - 3}, original columns: {len(old_columns)}, "
f"not using: {len(set(old_columns) - set(selected_columns))} columns")
logging.debug(f"Not using columns: {set(old_columns) - set(selected_columns)}")
# Update data dict
for key in data.keys():
sel_col = [col for col in data[key].columns if col in selected_columns]
data[key] = data[key].select(sel_col)
logging.debug(f"Selected columns in {key}: {len(data[key].columns)}")
return data, vars
def modality_selection(
data: dict[pl.DataFrame], modality_mapping: dict[str], selected_modalities: list[str], vars
) -> tuple[dict[pl.DataFrame], dict]:
logging.info(f"Selected modalities: {selected_modalities}")
selected_columns = [
modality_mapping[cols]
for cols in selected_modalities
if cols in modality_mapping
]
if not any(col in modality_mapping for col in selected_modalities):
raise ValueError("None of the selected modalities found in modality mapping.")
if selected_columns == []:
logging.info("No columns selected. Using all columns.")
return data, vars
selected_columns = sum(selected_columns, [])
selected_columns.extend([vars[Var.group], vars[Var.label], vars[Var.sequence]])
old_columns = []
# Update vars dict
for key, value in vars.items():
if key not in [Var.group, Var.label, Var.sequence]:
old_columns.extend(value)
vars[key] = [col for col in value if col in selected_columns]
# -3 because of standard columns
logging.info(
f"Selected columns: {len(selected_columns) - 3}, "
f"original columns: {len(old_columns)}, "
f"not using: {len(set(old_columns) - set(selected_columns))} columns"
)
logging.debug(f"Not using columns: {set(old_columns) - set(selected_columns)}")
# Update data dict
for key in data:
sel_col = [col for col in data[key].columns if col in selected_columns]
data[key] = data[key].select(sel_col)
logging.debug(f"Selected columns in {key}: {len(data[key].columns)}")
return data, vars
🧰 Tools
🪛 Ruff (0.12.2)

58-58: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


59-59: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


77-77: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

🤖 Prompt for AI Agents
In icu_benchmarks/data/utils.py lines 54 to 81, the function modality_selection
has an incorrect return type annotation and static analysis issues. Change the
return type annotation to a tuple of two dictionaries with keys as strings and
values as pl.DataFrame for the first and dict for the second. Also, ensure the
function consistently returns a tuple (data, vars) in all cases, including when
selected_columns is empty. Adjust the code style to fix any static analysis
warnings, such as using consistent key checks and avoiding redundant calls to
keys().

Comment on lines +61 to +63
if selected_columns == []:
logging.info("No columns selected. Using all columns.")
return data, vars
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify empty list check logic

The condition selected_columns == [] might not work as intended since selected_columns is a list of lists at this point. Consider checking if the flattened list is empty.

Please verify this logic handles the edge case correctly. The check should occur after flattening the list or use a different approach:


🏁 Script executed:

#!/bin/bash
# Description: Search for similar patterns in the codebase to understand the intended behavior
# Expected: Find usage patterns of modality_selection function

rg -A 10 -B 5 "modality_selection" --type py

Length of output: 5378


🏁 Script executed:

#!/bin/bash
# Extract the modality_selection implementation from utils.py to inspect selected_columns usage
rg -A 20 -B 10 "def modality_selection" --type py icu_benchmarks/data/utils.py

Length of output: 1849


Update empty‐check to use flattened columns

The current check if selected_columns == [] only detects when no modalities matched, but doesn’t catch cases where the mapping exists but yields empty lists (e.g. [[]]). Flatten first, then guard on the resulting column list:

  • File: icu_benchmarks/data/utils.py, around lines 61–63

  • Replace:

  • if selected_columns == []:

  •    logging.info("No columns selected. Using all columns.")
    
  •    return data, vars
    
  • selected_columns = sum(selected_columns, [])

with:
```diff
+    # Flatten and catch “no actual columns” after mapping
+    flattened = sum(selected_columns, [])
+    if not flattened:
+        logging.info("No columns selected. Using all columns.")
+        return data, vars
+    selected_columns = flattened
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if selected_columns == []:
logging.info("No columns selected. Using all columns.")
return data, vars
# Flatten and catch “no actual columns” after mapping
flattened = sum(selected_columns, [])
if not flattened:
logging.info("No columns selected. Using all columns.")
return data, vars
selected_columns = flattened
🤖 Prompt for AI Agents
In icu_benchmarks/data/utils.py around lines 61 to 63, the check for empty
selected_columns only detects an empty list but misses cases where
selected_columns contains empty lists like [[]]. To fix this, flatten the
selected_columns list first and then check if the flattened list is empty. If it
is empty, log the info message and return data and vars as before.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (4)
icu_benchmarks/run_utils.py (4)

98-98: Fix collision handling implementation.

This is the same issue identified in previous reviews - random.randint(1, 10) returns an integer but with_name() expects a string, and the limited range could cause issues.


190-190: Variable naming is consistent.

This addresses the previous review comment about variable naming consistency - the variable is now properly named explainer_values_test throughout.


209-220: Remove duplicate code block.

Lines 216-220 duplicate the functionality of lines 213-215. This redundant try-except block should be removed as it performs the exact same operation.


227-229: Fix TypeError risk with math.isnan().

The code uses math.isnan() on a variable that can be an integer, which will raise a TypeError. This needs to be fixed by either restricting the type check to floats only or casting to float first.

🧹 Nitpick comments (1)
icu_benchmarks/run_utils.py (1)

118-180: Well-structured function for prediction aggregation.

The append_predictions_foldwise function is well-documented and implements a systematic approach to aggregate prediction files across folds and iterations. The ID offsetting mechanism prevents clashes, and the fold-by-fold processing ensures consistent ordering.

Consider using enumerate() for the counter as suggested by static analysis:

-    counter = 0
     # Process fold by fold across all iterations
-    for fold_name in sorted(all_folds):
+    for counter, fold_name in enumerate(sorted(all_folds)):
         for iteration in iterations:
             # ... existing code ...
-        counter += 1
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5b71c64 and 187cb2c.

📒 Files selected for processing (2)
  • icu_benchmarks/run.py (5 hunks)
  • icu_benchmarks/run_utils.py (6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • icu_benchmarks/run.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
icu_benchmarks/run_utils.py (4)
icu_benchmarks/utils.py (1)
  • parse_dict (5-22)
icu_benchmarks/wandb_utils.py (1)
  • wandb_log (46-53)
icu_benchmarks/models/utils.py (2)
  • default (95-109)
  • name (159-160)
scripts/evaluate_results/aggregrate_experiment.py (1)
  • aggregate_results (7-85)
🪛 Ruff (0.12.2)
icu_benchmarks/run_utils.py

175-175: Use enumerate() for index variable counter in for loop

(SIM113)


223-223: Loop control variable repetition not used within loop body

(B007)


224-224: Loop control variable fold not used within loop body

Rename unused fold to _fold

(B007)

🔇 Additional comments (5)
icu_benchmarks/run_utils.py (5)

2-2: New imports look good.

The addition of math and random imports are appropriately used within the file for NaN checking and collision handling respectively.

Also applies to: 18-18


20-21: Imported utilities are properly used.

The parse_dict function from utils and wandb_log from wandb_utils are correctly imported and used in their respective contexts.


66-79: New command line arguments enhance flexibility.

The addition of --file_names, --explain_features, and --load_data_vars arguments provides good configurability for dataset handling and feature explanation functionality.


96-96: Improved timestamp precision.

Adding microsecond precision (%f) to the timestamp format reduces the likelihood of directory name collisions, which is a good improvement.


182-182: Good addition of explain_features parameter.

Adding the explain_features parameter to control SHAP value aggregation provides appropriate conditional behavior for feature explanation functionality.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🔭 Outside diff range comments (2)
icu_benchmarks/run.py (2)

150-154: Include vars.gin in gin_config_files (first), instead of editing task file.

To ensure variables are available when parsing the task, include vars.gin (if requested and present) as the first file in gin_config_files. This keeps all parsing consolidated in one place and avoids double-parsing.

Apply this diff:

-        gin_config_files = (
-            [Path(f"configs/experiments/{args.experiment}.gin")]
-            if args.experiment
-            else [model_path, Path(f"configs/tasks/{task}.gin")]
-        )
+        # Build config file list. Ensure vars.gin (if present) is parsed first.
+        if args.experiment:
+            gin_config_files = [Path(f"configs/experiments/{args.experiment}.gin")]
+        else:
+            gin_config_files = [model_path, Path(f"configs/tasks/{task}.gin")]
+        if load_data_vars and (data_dir / "vars.gin").exists():
+            gin_config_files.insert(0, data_dir / "vars.gin")

137-143: Bug: Directory name uses fine_tune when branching on samples.

This branch checks args.samples but names the directory with args.fine_tune. This can mislabel runs and cause directory collisions.

Apply this diff:

-        log_dir /= f"samples_{args.fine_tune}"
+        log_dir /= f"samples_{args.samples}"
♻️ Duplicate comments (2)
icu_benchmarks/run.py (2)

160-175: Use lazy logging and correct the logged variable; raising is fixed.

Good: you now raise ValueError. Still pending: switch logging to lazy formatting and log the parsed file_names when applicable to avoid duplicate parsing and to improve performance.

Apply this diff:

-        logging.info(f"Attempting to bind dataset files: {args.file_names}, type: {type(args.file_names)}")
+        logging.info("Attempting to bind dataset files: %s, type: %s", args.file_names, type(args.file_names))
         if isinstance(args.file_names, dict):
-            logging.info(f"Will load data from {args.file_names}")
+            logging.info("Will load data from %s", args.file_names)
             gin.bind_parameter("preprocess.file_names", args.file_names)
         elif isinstance(args.file_names, str):
             file_names = parse_dict(args.file_names)
-            logging.info(f"Will load data from {args.file_names}")
+            logging.info("Will load data from %s", file_names)
             gin.bind_parameter("preprocess.file_names", file_names)

147-149: Model path continuation indentation looks corrected.

The previous E122 issue appears resolved with consistent visual indentation in the multi-line expression.

🧹 Nitpick comments (2)
icu_benchmarks/run.py (2)

176-180: Cast Path objects to str before updating WandB config.

W&B config values should be JSON-serializable; Path objects may not serialize cleanly.

Apply this diff:

-        update_wandb_config({
-            "data_dir": data_dir.resolve(),
-            "task": task,
-            "run_dir": run_dir.resolve(),
-        })
+        update_wandb_config({
+            "data_dir": str(data_dir.resolve()),
+            "task": task,
+            "run_dir": str(run_dir.resolve()),
+        })

232-233: Persist fold-wise predictions returned by append_predictions_foldwise.

The function returns a Polars DataFrame; currently it’s discarded. Save it to disk to make the aggregation usable downstream.

Apply this diff:

-        append_predictions_foldwise(run_dir, "pred_indicators.csv")
+        pred_df = append_predictions_foldwise(run_dir, "pred_indicators.csv")
+        pred_df.write_parquet(run_dir / "pred_indicators_aggregated.parquet")

Add this import at the top of the file if Polars isn’t already imported elsewhere:

import polars as pl  # at top-level
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 187cb2c and aac4d18.

📒 Files selected for processing (1)
  • icu_benchmarks/run.py (5 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
icu_benchmarks/run.py (5)
icu_benchmarks/run_utils.py (6)
  • get_config_files (343-367)
  • append_predictions_foldwise (118-180)
  • aggregate_results (182-260)
  • aggregate_results (101-176)
  • build_parser (21-63)
  • name_datasets (179-181)
icu_benchmarks/utils.py (1)
  • parse_dict (5-22)
icu_benchmarks/wandb_utils.py (1)
  • update_wandb_config (15-23)
scripts/evaluate_results/aggregrate_experiment.py (1)
  • aggregate_results (7-85)
icu_benchmarks/models/train.py (1)
  • train_common (34-204)
🔇 Additional comments (1)
icu_benchmarks/run.py (1)

23-23: CLI parser flags are correctly defined

I’ve confirmed in icu_benchmarks/run_utils.py’s build_parser that all three flags are present with the intended types and defaults:

  • --file_names (line 67) uses type=parse_dict
  • --explain_features (line 73) has default=False, action=BOA
  • --load_data_vars (line 75) has default=False, action=BOA

No changes required.

Comment on lines +70 to +79
if load_data_vars:
logging.info(f"Loading variables from {task} from {data_dir} configuration")
if (data_dir / "vars.gin").exists():
# Open the task config file in append mode and add a line
with open(f"configs/tasks/{task}.gin", "a") as config_file:
config_file.write("\n# Added automatically by run.py\n")
config_file.write(f"# Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
config_file.write(f'include "{data_dir}/vars.gin"\n')
else:
logging.warning(f"No vars.gin file found in {data_dir}. Please ensure the file exists.")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Do not mutate repo configs at runtime; parse vars.gin instead (race condition and drift risk).

Appending an include to configs/tasks/{task}.gin is unsafe:

  • Concurrent runs can corrupt the file.
  • Repeated runs will keep appending, causing duplicates.
  • It pollutes the repo/installed package state.

Parse vars.gin directly instead of editing the task config file.

Apply this diff:

-    if load_data_vars:
-        logging.info(f"Loading variables from {task} from {data_dir} configuration")
-        if (data_dir / "vars.gin").exists():
-            # Open the task config file in append mode and add a line
-            with open(f"configs/tasks/{task}.gin", "a") as config_file:
-                config_file.write("\n# Added automatically by run.py\n")
-                config_file.write(f"# Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
-                config_file.write(f'include "{data_dir}/vars.gin"\n')
-        else:
-            logging.warning(f"No vars.gin file found in {data_dir}. Please ensure the file exists.")
+    if load_data_vars:
+        vars_path = data_dir / "vars.gin"
+        logging.info("Loading variables for task %s from %s configuration", task, data_dir)
+        if vars_path.exists():
+            gin.parse_config_file(vars_path)
+            logging.info("Loaded %s", vars_path)
+        else:
+            logging.warning("No vars.gin file found in %s. Please ensure the file exists.", data_dir)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if load_data_vars:
logging.info(f"Loading variables from {task} from {data_dir} configuration")
if (data_dir / "vars.gin").exists():
# Open the task config file in append mode and add a line
with open(f"configs/tasks/{task}.gin", "a") as config_file:
config_file.write("\n# Added automatically by run.py\n")
config_file.write(f"# Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
config_file.write(f'include "{data_dir}/vars.gin"\n')
else:
logging.warning(f"No vars.gin file found in {data_dir}. Please ensure the file exists.")
if load_data_vars:
vars_path = data_dir / "vars.gin"
logging.info("Loading variables for task %s from %s configuration", task, data_dir)
if vars_path.exists():
gin.parse_config_file(vars_path)
logging.info("Loaded %s", vars_path)
else:
logging.warning("No vars.gin file found in %s. Please ensure the file exists.", data_dir)

config_file.write(f'include "{data_dir}/vars.gin"\n')
else:
logging.warning(f"No vars.gin file found in {data_dir}. Please ensure the file exists.")
gin.parse_config_file(f"configs/tasks/{task}.gin")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid parsing task config before branching; this double-parses and interferes with load_weights.

You parse configs/tasks/{task}.gin here, and then again later in the normal training path. When evaluating/fine-tuning (load_weights=True), this early parse can conflict with the config loaded from source_dir/train_config.gin.

Remove the unconditional parse here and handle task/var configs only inside the normal training branch.

Apply this diff:

-    gin.parse_config_file(f"configs/tasks/{task}.gin")
+    # NOTE: Task config is parsed in the normal training branch below.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
gin.parse_config_file(f"configs/tasks/{task}.gin")
# NOTE: Task config is parsed in the normal training branch below.
🤖 Prompt for AI Agents
In icu_benchmarks/run.py around line 80, the unconditional
gin.parse_config_file(f"configs/tasks/{task}.gin") is parsed before branching
which causes the task config to be double-parsed and interferes with
load_weights when evaluating/fine-tuning; remove this unconditional parse and
instead parse task and variable configs only inside the normal training branch
(where train_config.gin and task-specific configs are meant to be applied),
ensuring the evaluation/fine-tune path does not re-parse the task config so
load_weights can load source_dir/train_config.gin without conflict.

execution_time = datetime.now() - start_time
log_full_line(f"DURATION: {execution_time}", level=logging.INFO, char="")
try:
aggregate_results(run_dir, execution_time)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Pass explain_features into aggregate_results to write SHAP aggregates.

aggregate_results has an explain_features flag; without passing it through, SHAP values won’t be aggregated even if generated.

Apply this diff:

-        aggregate_results(run_dir, execution_time)
+        aggregate_results(run_dir, execution_time, explain_features=args.explain_features)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
aggregate_results(run_dir, execution_time)
- aggregate_results(run_dir, execution_time)
+ aggregate_results(run_dir, execution_time, explain_features=args.explain_features)
🤖 Prompt for AI Agents
In icu_benchmarks/run.py around line 231, the call to aggregate_results(run_dir,
execution_time) omits the explain_features argument so SHAP aggregates are never
written; update the call to pass the explain_features boolean (e.g.
aggregate_results(run_dir, execution_time, explain_features)) ensuring the
explain_features variable is in scope and matches the aggregate_results
parameter order.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (6)
icu_benchmarks/data/loader.py (4)

53-54: Duplicate comment: Static dataset mis-alignment risk exists.

The sorting of row_indicators without corresponding sorting of features_df creates the same alignment risk mentioned in previous reviews. This issue persists in the static dataset path.


75-75: Duplicate comment: Commented code affects feature filtering.

The feature filtering for GROUP and SEQUENCE columns remains commented out, which means these columns will be included in the feature names returned to downstream components. This issue was flagged in previous reviews.


182-199: Duplicate comment: Avoid mutating self.row_indicators in accessor method.

The method still mutates self.row_indicators directly on lines 198, which creates the same thread-safety and reproducibility concerns flagged in previous reviews. The suggested fix to use a local copy should be applied.


46-47: Fix data alignment issue between features_df and row_indicators.

The current implementation sorts row_indicators after already sorting features_df by the same columns. However, this creates a potential misalignment since features_df has the SEQUENCE column dropped on Line 45, which could affect the sorting order.

Apply this diff to ensure both DataFrames remain aligned:

 self.features_df = data[split][DataSegment.features]
 self.features_df = self.features_df.sort([self.vars["GROUP"], self.vars["SEQUENCE"]])
-self.features_df = self.features_df.drop(self.vars["SEQUENCE"])
-self.row_indicators = self.row_indicators.sort([self.vars["GROUP"], self.vars["SEQUENCE"]])
+self.row_indicators = self.row_indicators.sort([self.vars["GROUP"], self.vars["SEQUENCE"]])
+self.features_df = self.features_df.drop(self.vars["SEQUENCE"])
icu_benchmarks/models/wrappers.py (2)

540-544: Error handling removal needs justification.

The try/catch block for _save_model_outputs has been commented out, which means errors will now propagate and potentially crash the test step. This was flagged in previous reviews.


574-615: Complex conditional logic in updated log_metrics method.

This was flagged in previous reviews for lacking comprehensive unit tests. The method now has two distinct branches based on whether pred_indicators is None, with complex shape matching logic for the non-None case.

🧹 Nitpick comments (2)
icu_benchmarks/data/loader.py (1)

97-104: Improve conditional logic for SEQUENCE handling.

The nested if-else structure could be simplified and made more readable. The logic correctly handles both temporal and non-temporal SEQUENCE columns, but the structure is unnecessarily complex.

Apply this diff to simplify the conditional logic:

-if "SEQUENCE" in self.vars:
-    if self.row_indicators[self.vars["SEQUENCE"]].dtype.is_temporal():
-        self.row_indicators = self.row_indicators.with_columns(pl.col(self.vars["SEQUENCE"]).dt.total_hours())
-    else:
-        self.row_indicators = self.row_indicators.with_columns(pl.col(self.vars["SEQUENCE"]))
-else:
-    logging.info("Using static dataset")
-    self.row_indicators = self.grouping_df[self.vars["GROUP"]].to_frame(self.vars["GROUP"])
+if "SEQUENCE" in self.vars:
+    sequence_col = pl.col(self.vars["SEQUENCE"])
+    if self.row_indicators[self.vars["SEQUENCE"]].dtype.is_temporal():
+        sequence_col = sequence_col.dt.total_hours()
+    self.row_indicators = self.row_indicators.with_columns(sequence_col)
+else:
+    logging.info("Using static dataset")
+    self.row_indicators = self.grouping_df[self.vars["GROUP"]].to_frame(self.vars["GROUP"])
icu_benchmarks/models/wrappers.py (1)

169-169: Remove unnecessary f-string prefix.

The f-string has no placeholders and should use a regular string.

-            logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / f'row_indicators.csv'}")
+            logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / 'row_indicators.csv'}")
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between aac4d18 and 5e93636.

📒 Files selected for processing (2)
  • icu_benchmarks/data/loader.py (7 hunks)
  • icu_benchmarks/models/wrappers.py (11 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (2)
icu_benchmarks/data/loader.py (1)
icu_benchmarks/data/constants.py (1)
  • DataSegment (7-11)
icu_benchmarks/models/wrappers.py (3)
icu_benchmarks/models/utils.py (2)
  • log_single_metric_to_file (290-315)
  • name (159-160)
icu_benchmarks/models/ml_models/xgboost.py (5)
  • _explain_model (76-82)
  • fit_model (41-62)
  • fit_model (121-154)
  • predict (29-39)
  • predict (109-119)
icu_benchmarks/models/custom_metrics.py (1)
  • confusion_matrix (141-150)
🪛 Ruff (0.12.2)
icu_benchmarks/models/wrappers.py

169-169: f-string without any placeholders

Remove extraneous f prefix

(F541)

🔇 Additional comments (18)
icu_benchmarks/data/loader.py (5)

36-39: SEQUENCE validation properly gated by feature availability.

The conditional validation ensures that SEQUENCE type checking only occurs when SEQUENCE is both defined in vars and present in the data columns. This prevents unnecessary errors when SEQUENCE is not needed.


132-132: Extract row indicators properly for current stay.

The extraction logic correctly filters row indicators for the current stay_id, ensuring proper alignment with the window data.


143-145: Row indicator padding implementation is correct.

The padding logic properly handles row indicators by extending them with pad_value and converting to float32 for consistency with other tensors.


194-195: Data processing logic is correct.

The conversion to NumPy and removal of the first column (group column) is appropriately handled. The logic correctly prepares the feature data for ML models.


154-155: PredictionPolarsDataset now returns (data, labels, pad_mask, row_inds) — callers verified

Short summary: I inspected the repo and confirmed internal callers support the 4-tuple return.

Relevant locations:

  • icu_benchmarks/data/loader.py — PredictionPolarsDataset.getitem: returns 4-tuple at line 155:
    return from_numpy(data), from_numpy(labels), from_numpy(pad_mask), from_numpy(row_inds)
  • icu_benchmarks/models/wrappers.py — step_fn explicitly handles both shapes:
    • len(element) == 3 branch (line ~376) -> data, labels, indicators
    • len(element) == 4 branch (line ~385) -> data, labels, mask, indicators
  • icu_benchmarks/data/loader.py — PredictionPandasDataset.getitem still returns 3-tuple (line ~327); wrappers' 3-tuple branch covers this.
  • Imputation-related dataset getitem implementations already return 4-tuples where applicable.

Optional suggestion: update PredictionPolarsDataset.getitem type annotation/docstring to reflect the 4-tuple return for clarity (not required for correctness).

icu_benchmarks/models/wrappers.py (13)

20-20: Verify import addition aligns with actual usage.

The log_single_metric_to_file import has been added, which aligns with its usage in the new log_curves methods in both ML and DL wrappers.


66-67: LGTM: Useful logging enhancement.

The addition of column count logging in set_trained_columns provides helpful debugging information.


109-111: Simple and effective feature toggle method.

The set_explain_features method provides a clean interface to toggle explainability features.


375-393: Enhanced step function correctly handles indicators parameter.

The extension to support 3 or 4 element batches (with indicators and optional mask) is well implemented and maintains backward compatibility.


404-415: Good extraction of prediction data for output saving.

The code properly extracts indicators, predictions, and targets using the mask for valid indices, and correctly applies softmax transformation.


427-429: Calling _save_model_outputs in the step function.

This call correctly integrates the new output saving functionality into the training loop.


499-501: Conditional explainer computation added.

The explainer is now properly gated behind the explain_features flag, which helps avoid expensive computations during hyperparameter tuning.


510-510: Method signature updated to include pred_indicators.

The log_metrics call now includes the row_indicators parameter to support the enhanced logging capabilities.


545-547: Explainer logic updated correctly.

The explainer computation is now properly conditional on explain_features flag and stores results in explainer_values_test.


547-547: New log_curves method call added.

This integrates the curve logging functionality into the test workflow.


550-554: Updated log_metrics calls with pred_indicators parameter.

Both MPS and standard paths now correctly pass the pred_indicators parameter to enable enhanced metric logging.


562-573: New log_curves method implementation.

The method correctly handles metrics that return tuples (like precision-recall curves) and saves them to CSV files using the imported log_single_metric_to_file utility.


616-622: Commented-out _explain_model method.

This method definition is commented out but this doesn't cause the runtime error mentioned in previous reviews since the actual _explain_model implementations are in the specific model classes (as seen in the relevant code snippets).

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

🧹 Nitpick comments (14)
experiments/charhpc_wandb_sweep_gpu.sh (4)

3-3: Verify GPU request flags for your Slurm version/cluster policy.

Some clusters require --gres=gpu:1 (and matching srun flags) instead of --gpus=1. Using the wrong flag can yield zero GPUs at runtime.

Suggested alternative if your site uses GRES:

-#SBATCH --partition=gpu # -p
-#SBATCH --gpus=1
+#SBATCH --partition=gpu # -p
+#SBATCH --gres=gpu:1

If you keep --gpus=1, prefer matching it in the srun invocation (see later comment) to propagate GPU binding to the job step.

Also applies to: 7-7


1-1: Minor: Prefer env-based shebang for portability.

This avoids hard-coding /bin path differences across clusters.

-#!/bin/bash
+#!/usr/bin/env bash

9-9: Enable strict bash mode for safer failures.

Prevents silent errors and unbound variable usage from derailing long-running jobs.

 #SBATCH --time=48:00:00
+
+set -euo pipefail
+IFS=$'\n\t'

10-13: Conda init is redundant; use one method consistently.

You’re both sourcing conda.sh and using the shell hook. One is enough; keeping both can be brittle across sites.

Option A — rely on conda.sh (common on HPC images):

-source /etc/profile.d/conda.sh
-
-eval "$(conda shell.bash hook)"
-conda activate yaib_req_pl
+source /etc/profile.d/conda.sh
+conda activate yaib_req_pl

Option B — rely on the hook (when conda is already on PATH):

-source /etc/profile.d/conda.sh
-
-eval "$(conda shell.bash hook)"
-conda activate yaib_req_pl
+eval "$(conda shell.bash hook)"
+conda activate yaib_req_pl

Optional: fail fast if the env is missing.

+conda env list | grep -q '^yaib_req_pl\s' || { echo "Conda env 'yaib_req_pl' not found"; exit 3; }
icu_benchmarks/models/ml_models/imblearn.py (8)

1-2: Prefer a module-scoped logger over the root logger.

Using the root logger (logging.info, logging.warning) makes it harder to control log levels per module. Define a module logger and use it.

Apply:

 import logging
+logger = logging.getLogger(__name__)

And update usages later in this file:

-logging.info(f"XGBoost model trained successfully. Validation loss: {val_loss:.4f}")
+logger.info(f"XGBoost model trained successfully. Validation loss: {val_loss:.4f}")

-logging.warning(f"XGBoost failed: {e}")
+logger.warning(f"XGBoost failed: {e}")

7-7: Remove unused import (Ruff F401).

DecisionTreeClassifier isn’t used in this file.

-from sklearn.tree import DecisionTreeClassifier

8-8: Clean up duplicate/unused imports (Ruff F401/F811).

RUSBoostClassifier and BalancedRandomForestClassifier are already imported on Line 3 and are unused here. Only EasyEnsembleClassifier is needed.

Within this hunk:

-from imblearn.ensemble import RUSBoostClassifier, BalancedRandomForestClassifier, EasyEnsembleClassifier
+from imblearn.ensemble import EasyEnsembleClassifier

Also update the earlier import to drop the unused RUSBoostClassifier:

-from imblearn.ensemble import BalancedRandomForestClassifier, RUSBoostClassifier
+from imblearn.ensemble import BalancedRandomForestClassifier

21-36: Base estimator defined as a class attribute may surprise; consider constructing per-instance and make it gin-configurable.

  • Defining individual_model as a class attribute creates a single estimator object shared by all instances. While scikit-learn meta-estimators clone base estimators, sharing a mutable instance at class scope is a common footgun and reduces configurability.
  • The current hyperparameters are very heavy for an ensemble wrapper (n_estimators=5000) and combined with EasyEnsembleClassifier can explode training time/memory. Also, with EasyEnsemble balancing, scale_pos_weight=30 likely double-weights the minority class.

Suggested refactor: build the XGB estimator in __init__ and allow passing hyperparameters via gin.

-@gin.configurable
-class XGBEnsembleClassifier(MLWrapper):
+@gin.configurable
+class XGBEnsembleClassifier(MLWrapper):
     _supported_run_modes = [RunMode.classification]
-    individual_model = xgb.XGBClassifier(
-        learning_rate=0.1,
-        n_estimators=5000,
-        max_depth=10,
-        scale_pos_weight=30,
-        min_child_weight=1,
-        max_delta_step=3,
-        colsample_bytree=0.25,
-        gamma=0.9,
-        reg_lambda=0.1,
-        reg_alpha=100,
-        random_state=42,
-        eval_metric='logloss'
-    )
-    def __init__(self, *args, **kwargs):
-        self.model = self.set_model_args(EasyEnsembleClassifier, *args, **kwargs, estimator=self.individual_model)
+    def __init__(self, *args, xgb_params=None, **kwargs):
+        xgb_params = xgb_params or {
+            "learning_rate": 0.1,
+            "n_estimators": 1000,        # Lower by default; tune via gin
+            "max_depth": 10,
+            "min_child_weight": 1,
+            "max_delta_step": 3,
+            "colsample_bytree": 0.25,
+            "gamma": 0.9,
+            "reg_lambda": 0.1,
+            "reg_alpha": 100,
+            "random_state": 42,
+            "eval_metric": "logloss",
+            "verbosity": 0,
+            # Consider dropping scale_pos_weight when using EasyEnsemble
+            # "scale_pos_weight": 30,
+        }
+        individual_model = xgb.XGBClassifier(**xgb_params)
+        self.model = self.set_model_args(EasyEnsembleClassifier, *args, **kwargs, estimator=individual_model)
         super().__init__(*args, **kwargs)

Verification checklist:

  • Confirm we don’t rely on class-level access to individual_model elsewhere.
  • Validate whether scale_pos_weight is still desired when using EasyEnsembleClassifier (it already balances the classes). If you keep it, document the intended effect.

38-39: Consider wiring through EasyEnsembleClassifier’s own random_state and n_jobs for reproducibility/throughput.

EasyEnsembleClassifier has random_state and n_jobs. Using them helps ensure deterministic behavior and better CPU utilization.

-        self.model = self.set_model_args(EasyEnsembleClassifier, *args, **kwargs, estimator=self.individual_model)
+        self.model = self.set_model_args(
+            EasyEnsembleClassifier, *args, **kwargs,
+            estimator=self.individual_model,
+            random_state=42,
+            n_jobs=-1,
+        )

23-36: Training-time/performance caveat with current hyperparameters.

  • n_estimators=5000 inside an EasyEnsemble can be prohibitively slow and memory-intensive.
  • With balancing done by EasyEnsemble, scale_pos_weight=30 likely over-penalizes negatives. Consider removing or tuning it based on post-resampling class ratios.
  • Consider adding tree_method="hist" (CPU) and n_jobs=-1 to speed up XGBoost training.

If helpful, I can propose a gin-tunable grid with sensible defaults for n_estimators and scale_pos_weight. Would you like me to open a follow-up PR for that?


7-8: Remove unused imports flagged by Ruff (F401).

  • Line 7: DecisionTreeClassifier is unused.
  • Line 42: Inner import xgboost as xgb is unused.

Apply:

-        import xgboost as xgb

(Line 7 fix shown in another comment.)

Also applies to: 41-42


21-22: Naming alignment and docstring suggestion.

The class name XGBEnsembleClassifier is clear, but a short class docstring would help future readers understand that this is an EasyEnsemble over XGBoost base learners and why this is preferred over RUSBoost in this project.

Example:

class XGBEnsembleClassifier(MLWrapper):
    """Imbalanced-class classifier using imblearn's EasyEnsemble with XGBClassifier base learners.

    Rationale: undersample majority per ensemble member; leverage XGBoost capacity;
    tune XGB via gin; prefer over RUSBoost due to <project-specific reasons>.
    """
configs/prediction_models/LGBMClassifier.gin (2)

34-37: DART ranges are quite aggressive; consider narrowing to reduce instability.

High drop_rate/skip_drop can make convergence noisy. Typical starting ranges are tighter.

Apply this diff if you observe unstable validation curves with DART:

-model/hyperparameter.drop_rate = (0.1, 0.5)
+model/hyperparameter.drop_rate = (0.05, 0.2)
@@
-model/hyperparameter.skip_drop = (0.1, 0.9)
+model/hyperparameter.skip_drop = (0.0, 0.5)

16-17: Sanity-check upper bounds for min_child_samples and min_child_weight.

Including very large min_child_samples (e.g., 1000) can eliminate splits on small/medium datasets. Keep if your datasets are large; otherwise, narrow to something like (5, 200) and (1e-3, 10.0, "log") respectively.

Apply if your typical training datasets are <100k rows:

-model/hyperparameter.min_child_samples = (10, 20, 50, 100, 500, 1000)
+model/hyperparameter.min_child_samples = (5, 200)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 5e93636 and 8af76b1.

📒 Files selected for processing (7)
  • configs/prediction_models/LGBMClassifier.gin (1 hunks)
  • docs/source/conf.py (1 hunks)
  • experiments/charhpc_wandb_sweep_cpu.sh (1 hunks)
  • experiments/charhpc_wandb_sweep_gpu.sh (1 hunks)
  • icu_benchmarks/models/ml_models/imblearn.py (2 hunks)
  • icu_benchmarks/models/ml_models/lgbm.py (0 hunks)
  • icu_benchmarks/wandb_utils.py (3 hunks)
💤 Files with no reviewable changes (1)
  • icu_benchmarks/models/ml_models/lgbm.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • docs/source/conf.py
  • experiments/charhpc_wandb_sweep_cpu.sh
  • icu_benchmarks/wandb_utils.py
🧰 Additional context used
🧬 Code graph analysis (1)
icu_benchmarks/models/ml_models/imblearn.py (3)
icu_benchmarks/constants.py (1)
  • RunMode (4-7)
icu_benchmarks/models/wrappers.py (5)
  • MLWrapper (444-652)
  • set_model_args (642-652)
  • fit_model (512-517)
  • fit (487-510)
  • fit (744-745)
icu_benchmarks/models/ml_models/xgboost.py (6)
  • XGBClassifier (21-87)
  • set_model_args (64-74)
  • fit_model (41-62)
  • fit_model (121-154)
  • XGBClassifier (21-74)
  • __init__ (25-27)
🪛 Ruff (0.12.2)
icu_benchmarks/models/ml_models/imblearn.py

7-7: sklearn.tree.DecisionTreeClassifier imported but unused

Remove unused import: sklearn.tree.DecisionTreeClassifier

(F401)


8-8: Redefinition of unused RUSBoostClassifier from line 3

Remove definition: RUSBoostClassifier

(F811)


8-8: imblearn.ensemble.RUSBoostClassifier imported but unused

Remove unused import: imblearn.ensemble.RUSBoostClassifier

(F401)


8-8: Redefinition of unused BalancedRandomForestClassifier from line 3

Remove definition: BalancedRandomForestClassifier

(F811)


42-42: xgboost imported but unused

Remove unused import: xgboost

(F401)

🔇 Additional comments (3)
icu_benchmarks/models/ml_models/imblearn.py (1)

9-9: LGTM: xgboost import is used below.

The module-level xgboost import is required for defining the base estimator.

configs/prediction_models/LGBMClassifier.gin (2)

9-11: LGTM: Explicitly scoping the tuner target.

Declaring model/hyperparameter.class_to_tune = @LGBMClassifier is clear and aligns with train_common.model = @LGBMClassifier.


13-17: Guard cross-parameter validity: num_leaves vs max_depth.

Some sampled combos (e.g., shallow max_depth with very large num_leaves) are wasteful or incoherent. If your tuner supports constraints, enforce num_leaves <= 2**max(1, max_depth) when max_depth > 0. Otherwise, consider capping num_leaves to 2^(max_depth) in post-processing before estimator instantiation.

If your tuner doesn’t support relational constraints, I can draft a small adapter to clamp num_leaves based on the sampled max_depth. Want me to open an issue and provide a PR?

Comment on lines +13 to +18
model/hyperparameter.n_estimators = (500, 2000, 5000, 10000, 100000)
model/hyperparameter.max_depth = (3, 5, 7, 10, 15)
model/hyperparameter.num_leaves = (8, 16, 31, 64, 128, 256, "log", 2)
model/hyperparameter.min_child_samples = (10, 20, 50, 100, 500, 1000)
model/hyperparameter.min_child_weight = (1e-3, 10.0, "log")

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Fix malformed num_leaves spec and consider simplifying n_estimators.

  • The num_leaves tuple mixes discrete values with "log", 2. This is likely to break your parser (or yield invalid values). Prefer either a discrete list or a proper range tuple.
  • Optional: Given early stopping is enabled below, a very large discrete set for n_estimators adds search noise/cost. A log-range or a single high ceiling with early stopping is more efficient.

Apply this diff to make both entries unambiguous:

-model/hyperparameter.n_estimators = (500, 2000, 5000, 10000, 100000)
+model/hyperparameter.n_estimators = (500, 20000, "log")
@@
-model/hyperparameter.num_leaves = (8, 16, 31, 64, 128, 256, "log", 2)
+model/hyperparameter.num_leaves = (8, 256, "log")

Notes:

  • If you prefer discrete leaves, use e.g. [15, 31, 63, 127, 255] instead.
  • Consider adding max_depth = [-1, 3, 5, 7, 10, 15] so the unconstrained depth case is explored.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
model/hyperparameter.n_estimators = (500, 2000, 5000, 10000, 100000)
model/hyperparameter.max_depth = (3, 5, 7, 10, 15)
model/hyperparameter.num_leaves = (8, 16, 31, 64, 128, 256, "log", 2)
model/hyperparameter.min_child_samples = (10, 20, 50, 100, 500, 1000)
model/hyperparameter.min_child_weight = (1e-3, 10.0, "log")
# configs/prediction_models/LGBMClassifier.gin
model/hyperparameter.n_estimators = (500, 20000, "log")
model/hyperparameter.max_depth = (3, 5, 7, 10, 15)
model/hyperparameter.num_leaves = (8, 256, "log")
model/hyperparameter.min_child_samples = (10, 20, 50, 100, 500, 1000)
model/hyperparameter.min_child_weight = (1e-3, 10.0, "log")
🤖 Prompt for AI Agents
In configs/prediction_models/LGBMClassifier.gin around lines 13 to 18, the
num_leaves hyperparameter is malformed (mixes discrete values with `"log", 2`)
and n_estimators is an overly large discrete set; fix num_leaves to be a clear
discrete list (e.g. 15, 31, 63, 127, 255) or a valid range spec so the parser
won’t break, and simplify n_estimators to either a log-range (e.g. 100, 500,
1000, 5000) or a single high ceiling relying on early stopping to control
training; optionally add max_depth = (-1, 3, 5, 7, 10, 15) so the unconstrained
depth case is included.

Comment on lines +24 to +29
# Sampling parameters
model/hyperparameter.subsample = (0.4, 1.0)
model/hyperparameter.subsample_freq = (0, 1, 5, 10)
model/hyperparameter.colsample_bytree = (0.4, 1.0)
model/hyperparameter.colsample_bynode = (0.4, 1.0)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Remove LightGBM/XGBoost alias duplication and an invalid parameter name.

  • subsample/subsample_freq are LightGBM aliases of bagging_fraction/bagging_freq. Keeping both creates conflicting grids and harder-to-debug overrides.
  • colsample_bytree is an alias of feature_fraction in LightGBM; again, duplicate knobs.
  • colsample_bynode is not a LightGBM parameter (XGBoost-ism). If you want per-node feature sampling, use feature_fraction_bynode.

Apply this diff to remove the aliases here and manage sampling via the canonical keys in the Performance block (Lines 42-47):

-# Sampling parameters
-model/hyperparameter.subsample = (0.4, 1.0)
-model/hyperparameter.subsample_freq = (0, 1, 5, 10)
-model/hyperparameter.colsample_bytree = (0.4, 1.0)
-model/hyperparameter.colsample_bynode = (0.4, 1.0)

Then, add the per-node variant where the canonical feature/bagging params are declared (see the suggested change on Lines 42-47).

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Sampling parameters
model/hyperparameter.subsample = (0.4, 1.0)
model/hyperparameter.subsample_freq = (0, 1, 5, 10)
model/hyperparameter.colsample_bytree = (0.4, 1.0)
model/hyperparameter.colsample_bynode = (0.4, 1.0)
🤖 Prompt for AI Agents
In configs/prediction_models/LGBMClassifier.gin around lines 24 to 29, remove
the LightGBM/XGBoost alias and invalid parameters — delete
model/hyperparameter.subsample, model/hyperparameter.subsample_freq,
model/hyperparameter.colsample_bytree and model/hyperparameter.colsample_bynode
— and instead manage sampling via the canonical LightGBM keys in the Performance
block (lines 42-47); then add the per-node variant
model/hyperparameter.feature_fraction_bynode alongside feature_fraction and
bagging_fraction in that Performance block so per-node sampling is available
under the canonical LightGBM name.

Comment on lines +30 to +41
# Boosting parameters
model/hyperparameter.boosting_type = ["gbdt", "dart"]

# Advanced DART parameters (active when boosting_type="dart")
model/hyperparameter.drop_rate = (0.1, 0.5)
model/hyperparameter.max_drop = (10, 50)
model/hyperparameter.skip_drop = (0.1, 0.9)

# GOSS parameters (active when boosting_type="goss")
model/hyperparameter.top_rate = (0.1, 0.5)
model/hyperparameter.other_rate = (0.05, 0.2)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

GOSS parameters present without enabling GOSS; make the search space consistent.

You’re tuning top_rate/other_rate for GOSS but boosting_type doesn’t include "goss". Either add "goss" to boosting_type or drop those knobs to avoid dead/ignored params.

Apply one of the following:

Option A — enable GOSS:

-model/hyperparameter.boosting_type = ["gbdt", "dart"]
+model/hyperparameter.boosting_type = ["gbdt", "dart", "goss"]

Option B — keep only GBDT/DART (remove GOSS-only params):

-# GOSS parameters (active when boosting_type="goss")
-model/hyperparameter.top_rate = (0.1, 0.5)
-model/hyperparameter.other_rate = (0.05, 0.2)

Note: When using GOSS, ensure top_rate + other_rate < 1. If your tuner can’t enforce relational constraints, narrow ranges to, e.g., top_rate=(0.1,0.3) and other_rate=(0.05,0.15).

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Boosting parameters
model/hyperparameter.boosting_type = ["gbdt", "dart"]
# Advanced DART parameters (active when boosting_type="dart")
model/hyperparameter.drop_rate = (0.1, 0.5)
model/hyperparameter.max_drop = (10, 50)
model/hyperparameter.skip_drop = (0.1, 0.9)
# GOSS parameters (active when boosting_type="goss")
model/hyperparameter.top_rate = (0.1, 0.5)
model/hyperparameter.other_rate = (0.05, 0.2)
# Boosting parameters
model/hyperparameter.boosting_type = ["gbdt", "dart", "goss"]
# Advanced DART parameters (active when boosting_type="dart")
model/hyperparameter.drop_rate = (0.1, 0.5)
model/hyperparameter.max_drop = (10, 50)
model/hyperparameter.skip_drop = (0.1, 0.9)
# GOSS parameters (active when boosting_type="goss")
model/hyperparameter.top_rate = (0.1, 0.5)
model/hyperparameter.other_rate = (0.05, 0.2)
Suggested change
# Boosting parameters
model/hyperparameter.boosting_type = ["gbdt", "dart"]
# Advanced DART parameters (active when boosting_type="dart")
model/hyperparameter.drop_rate = (0.1, 0.5)
model/hyperparameter.max_drop = (10, 50)
model/hyperparameter.skip_drop = (0.1, 0.9)
# GOSS parameters (active when boosting_type="goss")
model/hyperparameter.top_rate = (0.1, 0.5)
model/hyperparameter.other_rate = (0.05, 0.2)
# Boosting parameters
model/hyperparameter.boosting_type = ["gbdt", "dart"]
# Advanced DART parameters (active when boosting_type="dart")
model/hyperparameter.drop_rate = (0.1, 0.5)
model/hyperparameter.max_drop = (10, 50)
model/hyperparameter.skip_drop = (0.1, 0.9)
🤖 Prompt for AI Agents
In configs/prediction_models/LGBMClassifier.gin around lines 30 to 41, the file
defines GOSS-specific hyperparameters (top_rate, other_rate) but the
boosting_type list does not include "goss", so those knobs will be ignored; fix
by either (A) add "goss" to model/hyperparameter.boosting_type so the GOSS
params become active and, if enabling GOSS, ensure top_rate + other_rate < 1 (or
tighten ranges, e.g., top_rate=(0.1,0.3) and other_rate=(0.05,0.15) if the tuner
can’t enforce relational constraints), or (B) remove the GOSS-only params
(top_rate and other_rate) to keep the search space consistent with only "gbdt"
and "dart".

Comment on lines +42 to +47
# Performance and stability
model/hyperparameter.feature_fraction = (0.4, 1.0)
model/hyperparameter.bagging_fraction = (0.4, 1.0)
model/hyperparameter.bagging_freq = (0, 1, 5, 10)
model/hyperparameter.min_split_gain = (1e-6, 1.0, "log")

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consolidate sampling on canonical LightGBM params and add by-node feature sampling.

Given aliases were removed above, keep tuning here via canonical names. If you want per-node feature sampling, add feature_fraction_bynode.

Apply this diff to extend the canonical block:

 model/hyperparameter.feature_fraction = (0.4, 1.0)
 model/hyperparameter.bagging_fraction = (0.4, 1.0)
 model/hyperparameter.bagging_freq = (0, 1, 5, 10)
 model/hyperparameter.min_split_gain = (1e-6, 1.0, "log")
+model/hyperparameter.feature_fraction_bynode = (0.4, 1.0)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Performance and stability
model/hyperparameter.feature_fraction = (0.4, 1.0)
model/hyperparameter.bagging_fraction = (0.4, 1.0)
model/hyperparameter.bagging_freq = (0, 1, 5, 10)
model/hyperparameter.min_split_gain = (1e-6, 1.0, "log")
# Performance and stability
model/hyperparameter.feature_fraction = (0.4, 1.0)
model/hyperparameter.bagging_fraction = (0.4, 1.0)
model/hyperparameter.bagging_freq = (0, 1, 5, 10)
model/hyperparameter.min_split_gain = (1e-6, 1.0, "log")
model/hyperparameter.feature_fraction_bynode = (0.4, 1.0)
🤖 Prompt for AI Agents
In configs/prediction_models/LGBMClassifier.gin around lines 42 to 47, the
tuning block uses LightGBM aliases but should use canonical parameter names and
include per-node feature sampling; replace/extend the current entries so they
use canonical names (feature_fraction, bagging_fraction, bagging_freq,
min_split_gain) and add a new feature_fraction_bynode hyperparameter with an
appropriate search range (e.g., (0.4, 1.0)) to enable by-node feature sampling;
ensure the value formats match the file's existing hyperparameter tuple
conventions.

Comment on lines +52 to +55
# Early stopping and validation
model/hyperparameter.early_stopping_rounds = 100
model/hyperparameter.eval_metric = ["binary_logloss", "auc", "binary_error"]

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Confirm fit-time plumbing for early stopping and evaluation metric.

early_stopping_rounds and eval_metric are fit-time arguments in the LGBMClassifier sklearn API. If your trainer only instantiates the estimator and never passes these to fit(), these settings won’t take effect. Please confirm they’re forwarded at training time.

Run the following script to verify that your training pipeline passes these through to LightGBM’s fit() (or wraps them in callbacks):

Optional: For stability, consider dropping binary_error from the metric grid for early-stopping-driven tuning and rely on auc or binary_logloss.

Apply if desired:

-model/hyperparameter.eval_metric = ["binary_logloss", "auc", "binary_error"]
+model/hyperparameter.eval_metric = ["binary_logloss", "auc"]

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Inspect how tuning params are materialized and passed to LGBM
rg -n -C3 -e 'early_stopping_rounds|eval_metric|class_to_tune|LGBMClassifier|fit\('

# Common places:
# - Tuning utilities that map gin "model/hyperparameter.*" into estimator kwargs
# - Trainer code that calls estimator.fit(..., eval_set=..., eval_metric=..., early_stopping_rounds=...)

Length of output: 45667


🏁 Script executed:

#!/bin/bash
# Display the beginning of the LGBM wrapper to inspect how fit() is implemented
sed -n '1,200p' icu_benchmarks/models/ml_models/lgbm.py

Length of output: 1753


🏁 Script executed:

#!/bin/bash
# Locate the MLWrapper class to inspect how hyperparameters are mapped
rg -C5 -n 'class MLWrapper' icu_benchmarks/models/wrappers.py

Length of output: 485


🏁 Script executed:

#!/bin/bash
# Inspect how model hyperparameters (including eval_metric, early_stopping_rounds) are applied during model instantiation
rg -C5 -n 'def set_model_args' icu_benchmarks/models/wrappers.py

Length of output: 654


🏁 Script executed:

#!/bin/bash
python3 - << 'EOF'
import lightgbm as lgb
import inspect
print(inspect.signature(lgb.LGBMClassifier.__init__))
EOF

Length of output: 132


🏁 Script executed:

#!/bin/bash
sed -n '630,730p' icu_benchmarks/models/wrappers.py

Length of output: 4091


Update LGBMWrapper.fit_model to wire through Gin-configured early stopping and eval metrics

Your current implementation in icu_benchmarks/models/ml_models/lgbm.py hard-codes:

  • an early-stopping callback using self.hparams.patience (not your Gin model/hyperparameter.early_stopping_rounds)
  • a generic log_evaluation callback (ignoring your Gin model/hyperparameter.eval_metric)

As a result, neither early_stopping_rounds nor eval_metric from your Gin config ever reach the LightGBM fit() call. You’ll need to refactor the fit_model method to read from self.hparams and pass those values into the model, for example:

--- a/icu_benchmarks/models/ml_models/lgbm.py
+++ b/icu_benchmarks/models/ml_models/lgbm.py
@@ class LGBMWrapper(MLWrapper):
-        # Early stopping and logging (hard-coded)
-        callbacks = [
-            lgbm.early_stopping(self.hparams.patience, verbose=True),
-            lgbm.log_evaluation(period=-1)
-        ]
-
-        self.model = self.model.fit(
-            train_data,
-            train_labels,
-            eval_set=(val_data, val_labels),
-            callbacks=callbacks,
-        )
+        # Wire through Gin hyperparameters for early stopping and evaluation metrics
+        callbacks = []
+        # Use Gin’s early_stopping_rounds if set
+        if getattr(self.hparams, "early_stopping_rounds", None):
+            callbacks.append(lgbm.early_stopping(self.hparams.early_stopping_rounds, verbose=True))
+
+        # Override the eval_metric on the model if configured
+        if getattr(self.hparams, "eval_metric", None):
+            # scikit-learn wrapper will pass this metric to the Booster
+            self.model.set_params(eval_metric=self.hparams.eval_metric)
+
+        self.model = self.model.fit(
+            train_data,
+            train_labels,
+            eval_set=(val_data, val_labels),
+            callbacks=callbacks,
+        )

This change ensures that:

  • model/hyperparameter.early_stopping_rounds drives your early-stopping callback
  • model/hyperparameter.eval_metric is passed to LightGBM and used for evaluation logging and stopping

References:
LightGBM scikit-learn API expects callbacks for early stopping and log evaluation; early_stopping_rounds as a fit arg is deprecated in favor of early_stopping() callback (blog.csdn.net).
The wrapper’s fit(...) signature supports eval_metric and custom callbacks for metric tracking (lightgbm.readthedocs.io).

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Early stopping and validation
model/hyperparameter.early_stopping_rounds = 100
model/hyperparameter.eval_metric = ["binary_logloss", "auc", "binary_error"]
# Wire through Gin hyperparameters for early stopping and evaluation metrics
callbacks = []
# Use Gin’s early_stopping_rounds if set
if getattr(self.hparams, "early_stopping_rounds", None):
callbacks.append(lgbm.early_stopping(self.hparams.early_stopping_rounds, verbose=True))
# Override the eval_metric on the model if configured
if getattr(self.hparams, "eval_metric", None):
# scikit-learn wrapper will pass this metric to the Booster
self.model.set_params(eval_metric=self.hparams.eval_metric)
self.model = self.model.fit(
train_data,
train_labels,
eval_set=(val_data, val_labels),
callbacks=callbacks,
)
🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/lgbm.py around the fit_model method (where
callbacks and eval metrics are assembled), the code currently hardcodes an
early-stopping callback using self.hparams.patience and always uses a generic
log_evaluation callback, so Gin-configured
model/hyperparameter.early_stopping_rounds and model/hyperparameter.eval_metric
never reach LightGBM; update fit_model to read early_stopping_rounds and
eval_metric from self.hparams, construct callbacks using
lightgbm.callback.early_stopping(self.hparams.early_stopping_rounds) and a
log_evaluation configured appropriately (or omit if not desired), and pass
eval_metric=self.hparams.eval_metric into the LGBMClassifier.fit(...) call along
with the validation set (eval_set/eval_names) and the assembled callbacks so
LightGBM uses the Gin-provided values for evaluation and stopping.

Comment on lines +56 to +59
# Class imbalance handling
model/hyperparameter.is_unbalance = [True, False]
model/hyperparameter.scale_pos_weight = (0.1, 10.0, "log")

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid combining is_unbalance with scale_pos_weight; prefer one strategy.

LightGBM treats these as alternative approaches to class imbalance. Searching both simultaneously can produce ambiguous or counteracting settings. Prefer tuning scale_pos_weight and fixing is_unbalance=False, or vice versa.

Apply this diff to prefer scale_pos_weight only:

-# Class imbalance handling
-model/hyperparameter.is_unbalance = [True, False]
-model/hyperparameter.scale_pos_weight = (0.1, 10.0, "log")
+# Class imbalance handling (prefer explicit weighting)
+model/hyperparameter.is_unbalance = [False]
+model/hyperparameter.scale_pos_weight = (0.1, 10.0, "log")

If you want to compare strategies, introduce a higher-level toggle (e.g., imbalance_strategy in {"none","is_unbalance","scale_pos_weight"}) and conditionally set params accordingly to avoid invalid combos.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Class imbalance handling
model/hyperparameter.is_unbalance = [True, False]
model/hyperparameter.scale_pos_weight = (0.1, 10.0, "log")
# Class imbalance handling (prefer explicit weighting)
model/hyperparameter.is_unbalance = [False]
model/hyperparameter.scale_pos_weight = (0.1, 10.0, "log")
🤖 Prompt for AI Agents
In configs/prediction_models/LGBMClassifier.gin around lines 56 to 59, the
current config exposes both model/hyperparameter.is_unbalance and
model/hyperparameter.scale_pos_weight which are mutually exclusive; update the
file to prefer scale_pos_weight only by fixing model/hyperparameter.is_unbalance
to False and keeping model/hyperparameter.scale_pos_weight as the tunable range
(0.1, 10.0, "log"); alternatively, if you need to compare strategies, introduce
a single high-level toggle (e.g., imbalance_strategy in
{"none","is_unbalance","scale_pos_weight"}) and conditionally set either
is_unbalance or scale_pos_weight based on that toggle to avoid invalid
combinations.

#SBATCH --partition=gpu # -p
#SBATCH --cpus-per-task=16 # -c
#SBATCH --mem=200gb
#SBATCH --output=logs/classification_%a_%j.log # %j is job id
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix SBATCH output path: logs/ directory may not exist and %a is unused.

  • If the logs/ directory does not pre-exist at submission time, Slurm will fail to open the output file and your job won’t start. Also, %a resolves to 0 for non-array jobs, which is likely unintended noise.

Apply one of the following:

Option A — drop the subdir and unused array token:

-#SBATCH --output=logs/classification_%a_%j.log # %j is job id
+#SBATCH --output=classification_%j.log # %j is job id

Option B — keep the subdir but remove %a, and ensure the directory exists before submission:

-#SBATCH --output=logs/classification_%a_%j.log # %j is job id
+#SBATCH --output=logs/classification_%j.log # %j is job id

And document/ensure: mkdir -p logs before running sbatch.

🤖 Prompt for AI Agents
In experiments/charhpc_wandb_sweep_gpu.sh around line 6, the SBATCH output path
uses a non-existent logs/ directory and the unused %a array token; change the
output directive to remove %a (e.g. --output=logs/classification_%j.log or
--output=classification_%j.log) and ensure the logs directory exists before
submission by running mkdir -p logs (or drop the subdir and write to the current
dir if you don’t want to create directories).


eval "$(conda shell.bash hook)"
conda activate yaib_req_pl
wandb agent --count 1 cassandra_hpi/cassandra/"$1" No newline at end of file
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Validate required arg and run via srun to inherit CPU/GPU binding.

  • Without an argument, the W&B agent path becomes invalid and the job will fail.
  • Launching the agent with srun ensures CUDA-visible device assignment and CPU binding honor the job allocation across Slurm versions.
-wandb agent --count 1 cassandra_hpi/cassandra/"$1"
+[[ -n "${1:-}" ]] || { echo "Usage: $0 <SWEEP_ID>"; exit 2; }
+# If cluster uses --gpus=1, mirror that here; if it uses GRES, use --gres=gpu:1 instead.
+# CPU binding can be adjusted per site (e.g., --cpu-bind=cores).
+srun --cpus-per-task="${SLURM_CPUS_PER_TASK:-16}" --gpus="${SLURM_GPUS:-1}" \
+  wandb agent --count 1 "cassandra_hpi/cassandra/${1}"

Optional, for better provenance in logs:

+echo "Job: ${SLURM_JOB_ID:-na} | Node: ${SLURM_NODELIST:-na} | GPUs: ${SLURM_GPUS:-na} | CPUs/Task: ${SLURM_CPUS_PER_TASK:-na}"
+command -v nvidia-smi >/dev/null 2>&1 && nvidia-smi || true
+conda info && conda list | grep -E '^(pytorch|torch|cuda|cudnn|wandb)\s' || true
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
wandb agent --count 1 cassandra_hpi/cassandra/"$1"
[[ -n "${1:-}" ]] || { echo "Usage: $0 <SWEEP_ID>"; exit 2; }
# If cluster uses --gpus=1, mirror that here; if it uses GRES, use --gres=gpu:1 instead.
# CPU binding can be adjusted per site (e.g., --cpu-bind=cores).
srun --cpus-per-task="${SLURM_CPUS_PER_TASK:-16}" --gpus="${SLURM_GPUS:-1}" \
wandb agent --count 1 "cassandra_hpi/cassandra/${1}"
# Optional provenance/debug logging
echo "Job: ${SLURM_JOB_ID:-na} | Node: ${SLURM_NODELIST:-na} | GPUs: ${SLURM_GPUS:-na} | CPUs/Task: ${SLURM_CPUS_PER_TASK:-na}"
command -v nvidia-smi >/dev/null 2>&1 && nvidia-smi || true
conda info && conda list | grep -E '^(pytorch|torch|cuda|cudnn|wandb)\s' || true

Comment on lines +21 to 39
class XGBEnsembleClassifier(MLWrapper):
_supported_run_modes = [RunMode.classification]

individual_model = xgb.XGBClassifier(
learning_rate=0.1,
n_estimators=5000,
max_depth=10,
scale_pos_weight=30,
min_child_weight=1,
max_delta_step=3,
colsample_bytree=0.25,
gamma=0.9,
reg_lambda=0.1,
reg_alpha=100,
random_state=42,
eval_metric='logloss'
)
def __init__(self, *args, **kwargs):
self.model = self.set_model_args(RUSBoostClassifier, *args, **kwargs)
self.model = self.set_model_args(EasyEnsembleClassifier, *args, **kwargs, estimator=self.individual_model)
super().__init__(*args, **kwargs)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Check for references to removed/renamed classifiers.

If RUSBClassifier was used in configs/scripts, renaming to XGBEnsembleClassifier may break gin configs or imports.

Run the following to find lingering references:


🏁 Script executed:

#!/bin/bash
# Find references to the old class and to RUSBoost usage in the repo
rg -nP -C2 '\bRUSBClassifier\b|\bRUSBoostClassifier\b'

Length of output: 3543


Update Remaining RUSBClassifier References

The renaming of RUSBClassifier to XGBEnsembleClassifier has left several broken imports and config entries. Please update or remove all occurrences to ensure the new class is correctly wired up.

• icu_benchmarks/models/init.py
– Replace the RUSBClassifier import and its entry in __all__ (around line 7 and the list at lines 34–62) with XGBEnsembleClassifier.

• configs/prediction_models/RUSBClassifier.gin
– Rename this file (and its references) to match XGBEnsembleClassifier.gin.
– Update train_common.model = @RUSBClassifier and model/hyperparameter.class_to_tune = @RUSBClassifier to use @XGBEnsembleClassifier.

• experiments/benchmark_*.yml
– Although currently commented out, there are lines referencing RUSBClassifier in several experiment YAMLs. Either remove or update these for consistency.

Addressing these spots is required to prevent import errors and broken gin configurations.

Comment on lines +41 to +57
def fit_model(self, train_data, train_labels, val_data, val_labels):
import xgboost as xgb
from sklearn.metrics import log_loss

# Try XGBoost first with provided config
try:
self.model.fit(train_data, train_labels,)

val_pred_proba = self.model.predict_proba(val_data)
val_loss = log_loss(val_labels, val_pred_proba)
logging.info(f"XGBoost model trained successfully. Validation loss: {val_loss:.4f}")
return val_loss

except Exception as e:
logging.warning(f"XGBoost failed: {e}")

return val_loss
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Bug: val_loss may be referenced before assignment on failure; also prefer self.loss/self.predict and drop unused imports.

If self.model.fit raises, the function falls through to return val_loss without it being defined, causing UnboundLocalError. In-function imports are unnecessary, and using self.loss/self.predict preserves consistency with the wrapper’s configured loss.

Apply:

-    def fit_model(self, train_data, train_labels, val_data, val_labels):
-        import xgboost as xgb
-        from sklearn.metrics import log_loss
-
-        # Try XGBoost first with provided config
-        try:
-            self.model.fit(train_data, train_labels,)
-
-            val_pred_proba = self.model.predict_proba(val_data)
-            val_loss = log_loss(val_labels, val_pred_proba)
-            logging.info(f"XGBoost model trained successfully. Validation loss: {val_loss:.4f}")
-            return val_loss
-
-        except Exception as e:
-            logging.warning(f"XGBoost failed: {e}")
-
-        return val_loss
+    def fit_model(self, train_data, train_labels, val_data, val_labels):
+        # Ensure a defined loss even if training fails
+        val_loss = float("inf")
+        try:
+            self.model.fit(train_data, train_labels)
+            # Use the wrapper's predict() and configured loss for consistency
+            val_pred = self.predict(val_data)
+            val_loss = self.loss(val_labels, val_pred)
+            logging.info(f"XGBoost model trained successfully. Validation loss: {val_loss:.4f}")
+        except Exception:
+            logging.exception("XGBoost training failed")
+        return val_loss
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def fit_model(self, train_data, train_labels, val_data, val_labels):
import xgboost as xgb
from sklearn.metrics import log_loss
# Try XGBoost first with provided config
try:
self.model.fit(train_data, train_labels,)
val_pred_proba = self.model.predict_proba(val_data)
val_loss = log_loss(val_labels, val_pred_proba)
logging.info(f"XGBoost model trained successfully. Validation loss: {val_loss:.4f}")
return val_loss
except Exception as e:
logging.warning(f"XGBoost failed: {e}")
return val_loss
def fit_model(self, train_data, train_labels, val_data, val_labels):
# Ensure a defined loss even if training fails
val_loss = float("inf")
try:
self.model.fit(train_data, train_labels)
# Use the wrapper's predict() and configured loss for consistency
val_pred = self.predict(val_data)
val_loss = self.loss(val_labels, val_pred)
logging.info(f"XGBoost model trained successfully. Validation loss: {val_loss:.4f}")
except Exception:
logging.exception("XGBoost training failed")
return val_loss
🧰 Tools
🪛 Ruff (0.12.2)

42-42: xgboost imported but unused

Remove unused import: xgboost

(F401)

🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/imblearn.py around lines 41 to 57, the
current fit_model may return val_loss when an exception occurs before it is
defined and contains unnecessary in-function imports; update the method to
remove the in-function imports, call self.model.fit inside try, compute
validation predictions using self.predict(val_data) and compute loss with
self.loss(val_labels, val_pred_proba) (or assign to self.loss if a callable
exists), ensure val_loss is always assigned before any return by returning a
default failure value (e.g., float("inf")) or re-raising after logging, and log
the caught exception details; keep the success path returning val_loss and the
failure path returning the safe default.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

♻️ Duplicate comments (5)
icu_benchmarks/models/train.py (1)

253-254: Ensure consistent data source for train explainer values.

There's an inconsistency that was previously flagged - for test values (line 230), explainer_values is used directly, but for train values (line 255), the variable name suggests the same structure. Verify both have the same data structure.

icu_benchmarks/models/wrappers.py (3)

49-51: Debug flag should default to False in production.

Setting debug = True by default means all models will save outputs to disk by default, which may not be intended for production use. This is inconsistent with the CLI default (args.debug defaults to False).


116-133: Fix inconsistent logging message and improve code maintainability.

  1. Line 133 logs "row_indicators.csv" but saves to "pred_indicators.csv"
  2. Remove large blocks of commented code (lines 134-140)

147-153: Fix format string mismatch in static dataset case.

The format string "%d,%d,%0.3f,%0.3f" has 4 format specifiers but the header "id,ground_truth,prediction_0,prediction_1" suggests 4 columns with different types. The second %d should likely be a float format.

                 header="id,ground_truth,prediction_0,prediction_1",
-                fmt="%d,%d,%0.3f,%0.3f",
+                fmt="%d,%.3f,%.3f,%.3f",
icu_benchmarks/models/ml_models/xgboost.py (1)

76-82: Remove unused labels parameter from _explain_model.

The labels parameter is passed to the method but not used. The parent class signature might require it, but shap_values method doesn't need labels.

🧹 Nitpick comments (9)
icu_benchmarks/models/train.py (3)

204-205: Consider renaming persist_shap_data to align with the new explainer terminology.

The function now persists "explainer values" rather than "shap values" as indicated by the logging message and variable names. Consider renaming to persist_explainer_data for consistency.

-    if explain_features:
-        persist_shap_data(trainer, log_dir)
+    if explain_features:
+        persist_explainer_data(trainer, log_dir)

And update the function definition:

-def persist_shap_data(trainer: Trainer, log_dir: Path, save_full_valuesets=True) -> None:
+def persist_explainer_data(trainer: Trainer, log_dir: Path, save_full_valuesets=True) -> None:

221-225: Simplify the column filtering logic for better readability.

The current implementation could be more concise and easier to understand.

-        if (
-            any(x in ["stay_id", "id"] for x in trained_columns)
-            and len(trained_columns) != trainer.lightning_module.explainer_values_test.shape[1]
-        ):
-            trained_columns.remove("stay_id" if "stay_id" in trained_columns else "id")
+        # Remove ID columns if shape mismatch indicates they're not in the explainer values
+        if len(trained_columns) != trainer.lightning_module.explainer_values_test.shape[1]:
+            id_cols = [col for col in ["stay_id", "id"] if col in trained_columns]
+            if id_cols:
+                trained_columns.remove(id_cols[0])

237-243: Clarify and adjust rep_test/label_test saving conditions

Currently, within persist_shap_data(trainer, log_dir, save_full_valuesets=True) (icu_benchmarks/models/train.py:232–242), both rep_test and label_test are written to disk only when save_full_valuesets is true (which it always is, because no call ever overrides the default) and when those attributes exist on trainer.lightning_module.

  • Because persist_shap_data is only ever invoked as persist_shap_data(trainer, log_dir) (with no explicit save_full_valuesets=False), these files will always be saved whenever explainer_values_test and rep_test exist.
  • If the intention is to control representation persistence via the module’s own persist_reps flag (set in wrappers.py), consider replacing if save_full_valuesets: with if trainer.lightning_module.persist_reps:.
  • If instead you want to save rep_test and label_test independently of “full” SHAP values, move their write logic out of the save_full_valuesets block so they run whenever those attributes exist.

Please review and adjust the conditional to match your desired behavior.

icu_benchmarks/models/wrappers.py (4)

113-114: Consider making _explain_model abstract or providing a default implementation.

Currently, the base implementation raises NotImplementedError. Consider making this an abstract method using @abstractmethod decorator, or provide a default implementation that returns None with a warning, to make the interface clearer.

+    from abc import abstractmethod
+
+    @abstractmethod
     def _explain_model(self, reps, labels):
-        raise NotImplementedError(f"Model {self.__class__.__name__} does not currently support feature explanation.")
+        """Abstract method for model explanation. Must be implemented by subclasses that support feature explanation."""
+        pass

Or alternatively, provide a default implementation:

     def _explain_model(self, reps, labels):
-        raise NotImplementedError(f"Model {self.__class__.__name__} does not currently support feature explanation.")
+        logging.warning(f"Model {self.__class__.__name__} does not currently support feature explanation.")
+        return None

174-174: Remove unnecessary f-string prefix.

Static analysis correctly identified an f-string without any placeholders.

-            logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / f'row_indicators.csv'}")
+            logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / 'row_indicators.csv'}")

545-560: Consider consolidating the log_curves and log_metrics calls.

The test_step method now calls both log_curves and log_metrics with the same parameters. Consider whether these could be unified or if the separation is intentional for clarity.

If the separation is not required, consider creating a single method that handles both:

-        self.log_curves(test_label, test_pred, "test", pred_indicators)
         if self.debug:
             self._save_model_outputs(pred_indicators, test_pred, test_label)
         if self.explain_features:
-            # self.explain_model(test_rep, test_label)
             self.explainer_values_test = self._explain_model(test_rep, test_label)
         if self.persist_reps:
             self.rep_test = test_rep
             self.label_test = test_label
+        
+        # Log both curves and metrics
+        self.log_curves(test_label, test_pred, "test", pred_indicators)
         if self.mps:
             self.log("test/loss", np.float32(self.loss(test_label, test_pred)), sync_dist=True)
             self.log_metrics(np.float32(test_label), np.float32(test_pred), "test", pred_indicators)

597-621: Remove TODO comment or create an issue to track it.

Line 607 contains a TODO comment about implementing alarm metrics. Either implement this functionality or create a GitHub issue to track this enhancement.

Would you like me to help implement the alarm metrics functionality using the row indicators, or should I create a GitHub issue to track this enhancement?

icu_benchmarks/models/ml_models/xgboost.py (2)

47-49: Consider using parameterized logging level.

The logging level changed from debug to info for shape information. Consider making this configurable or consistent with the verbosity setting.

-        logging.info(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}")
-        logging.info(train_labels)
+        logging.debug(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}")
+        if self.model.verbosity > 0:
+            logging.info(f"Training labels distribution: {train_labels}")

50-56: Clean up commented code or document why it's retained.

There's commented code for conditional SHAP value computation. Either remove it or add a comment explaining why it's kept for reference.

-        # if self.explain_features:
         self.explainer = shap.TreeExplainer(
             self.model, train_data, feature_perturbation="interventional", model_output="probability"
         )
-        # if self.explain_features:
-        #     logging.info("Explaining features")
-        #     self.train_shap_values = self.explainer.shap_values(train_data)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 8af76b1 and 9a2aa2a.

📒 Files selected for processing (3)
  • icu_benchmarks/models/ml_models/xgboost.py (4 hunks)
  • icu_benchmarks/models/train.py (3 hunks)
  • icu_benchmarks/models/wrappers.py (11 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
icu_benchmarks/models/train.py (2)
icu_benchmarks/models/wrappers.py (1)
  • set_explain_features (109-111)
icu_benchmarks/models/utils.py (1)
  • save_config_file (22-25)
icu_benchmarks/models/ml_models/xgboost.py (3)
icu_benchmarks/models/wrappers.py (8)
  • set_model_args (642-652)
  • fit (492-516)
  • fit (744-745)
  • _explain_model (113-114)
  • MLWrapper (449-652)
  • predict (562-566)
  • predict (750-756)
  • fit_model (518-523)
icu_benchmarks/constants.py (1)
  • RunMode (4-7)
icu_benchmarks/models/ml_models/imblearn.py (1)
  • fit_model (41-57)
icu_benchmarks/models/wrappers.py (3)
icu_benchmarks/models/utils.py (2)
  • log_single_metric_to_file (290-315)
  • name (159-160)
icu_benchmarks/models/ml_models/xgboost.py (3)
  • _explain_model (76-82)
  • fit_model (41-62)
  • fit_model (115-148)
icu_benchmarks/models/custom_metrics.py (1)
  • confusion_matrix (141-150)
🪛 Ruff (0.12.2)
icu_benchmarks/models/wrappers.py

174-174: f-string without any placeholders

Remove extraneous f prefix

(F541)

🔇 Additional comments (6)
icu_benchmarks/models/train.py (1)

185-186: LGTM! Good separation of concerns for explainer setup.

Setting the explainer feature flag just before testing cleanly separates training from explanation, which is good for performance.

icu_benchmarks/models/wrappers.py (3)

109-114: LGTM! Clean implementation of feature explanation toggle.

The method properly sets both explain_features and persist_reps flags together, ensuring consistent behavior.


504-506: LGTM! Clean integration of feature explanation during training.

The conditional execution of _explain_model when explain_features is enabled is well-placed and properly stores the results.


568-579: LGTM! Clean implementation of curve logging.

The log_curves method properly identifies metrics that return tuples and saves them to CSV files for later analysis.

icu_benchmarks/models/ml_models/xgboost.py (2)

26-26: LGTM! Good addition of explicit XGBoost parameters.

Setting eval_metric=log_loss, device="cpu", and verbosity=0 explicitly ensures consistent behavior across different environments.


66-72: LGTM! Better approach to get valid XGBoost parameters.

Using model().get_params().keys() is more reliable than inspecting the constructor signature for XGBoost models.

Comment on lines +84 to +101
@gin.configurable
class XGBClassifierGPU(MLWrapper):
_supported_run_modes = [RunMode.classification]
_explain_values = False

def __init__(self, *args, **kwargs):
# self.model = self.set_model_args(
# xgb, *args, **kwargs, eval_metric="logloss", tree_method="hist", device="cuda", verbosity=0
# )
self.model = xgb
self.params = {
"eval_metric": "logloss",
"tree_method": "hist",
"device": "cuda",
"verbosity": 0,
**kwargs,
}
super().__init__(*args, **kwargs)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Document the GPU implementation status and clean up initialization.

The GPU class has commented-out model initialization code. Either complete the implementation or document its current status.

 @gin.configurable
 class XGBClassifierGPU(MLWrapper):
+    """GPU-accelerated XGBoost classifier using native XGBoost API.
+    
+    Note: This implementation uses the native XGBoost train API rather than 
+    the scikit-learn compatible API for better GPU performance control.
+    """
     _supported_run_modes = [RunMode.classification]
     _explain_values = False
 
     def __init__(self, *args, **kwargs):
-        # self.model = self.set_model_args(
-        #     xgb, *args, **kwargs, eval_metric="logloss", tree_method="hist", device="cuda", verbosity=0
-        # )
         self.model = xgb
         self.params = {
             "eval_metric": "logloss",
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@gin.configurable
class XGBClassifierGPU(MLWrapper):
_supported_run_modes = [RunMode.classification]
_explain_values = False
def __init__(self, *args, **kwargs):
# self.model = self.set_model_args(
# xgb, *args, **kwargs, eval_metric="logloss", tree_method="hist", device="cuda", verbosity=0
# )
self.model = xgb
self.params = {
"eval_metric": "logloss",
"tree_method": "hist",
"device": "cuda",
"verbosity": 0,
**kwargs,
}
super().__init__(*args, **kwargs)
@gin.configurable
class XGBClassifierGPU(MLWrapper):
"""GPU-accelerated XGBoost classifier using native XGBoost API.
Note: This implementation uses the native XGBoost train API rather than
the scikit-learn compatible API for better GPU performance control.
"""
_supported_run_modes = [RunMode.classification]
_explain_values = False
def __init__(self, *args, **kwargs):
self.model = xgb
self.params = {
"eval_metric": "logloss",
"tree_method": "hist",
"device": "cuda",
"verbosity": 0,
**kwargs,
}
super().__init__(*args, **kwargs)
🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/xgboost.py around lines 84 to 101, the GPU
classifier contains commented-out initialization and leaves self.model set to
the xgb module while params are stored separately; either restore a proper
GPU-ready model construction or explicitly mark the GPU implementation as
not-initialized. Fix by replacing the commented block with a concrete
initialization that constructs an XGBoost estimator (e.g., call
xgb.XGBClassifier or xgb.XGBModel with the params merged in) and assign that
instance to self.model, or if GPU support is not yet available, remove the
commented code, add a clear docstring or comment indicating GPU mode is
unimplemented, and raise NotImplementedError in __init__ (or fall back to CPU)
so behavior is explicit and the codebase has no stale commented-out
initialization.

Returns:
numpy.ndarray: Predicted probabilities for each class.
"""
return self.model.predict(xgb.DMatrix(features))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for GPU-specific operations.

The predict method should handle potential GPU-related errors gracefully.

     def predict(self, features):
         """
         Predicts class probabilities for the given features.
 
         Args:
             features: Input features for prediction.
 
         Returns:
             numpy.ndarray: Predicted probabilities for each class.
         """
-        return self.model.predict(xgb.DMatrix(features))
+        try:
+            dmatrix = xgb.DMatrix(features)
+            return self.model.predict(dmatrix)
+        except Exception as e:
+            logging.error(f"GPU prediction failed: {e}")
+            raise
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
return self.model.predict(xgb.DMatrix(features))
def predict(self, features):
"""
Predicts class probabilities for the given features.
Args:
features: Input features for prediction.
Returns:
numpy.ndarray: Predicted probabilities for each class.
"""
try:
dmatrix = xgb.DMatrix(features)
return self.model.predict(dmatrix)
except Exception as e:
logging.error(f"GPU prediction failed: {e}")
raise
🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/xgboost.py around line 113, the return of
self.model.predict(xgb.DMatrix(features)) lacks GPU-specific error handling;
wrap the predict call in a try/except that catches XGBoost GPU-related errors
(e.g., xgboost.core.XGBoostError) and general Exception, log a clear warning
about falling back to CPU, then retry prediction using a CPU fallback (for
example recreate the DMatrix for CPU and call model.predict with
predictor='cpu_predictor' or reset model params to use CPU) and if that also
fails re-raise the exception so callers can handle it.


if wandb.run is not None:
callbacks.append(wandb_xgb())
self.model.train(self.params, train_data=dtrain, evals=evals, callbacks=callbacks)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix incorrect API usage in GPU fit_model.

The self.model.train() call is incorrect - xgb.train() is a module-level function, not a method. The model should be assigned from the train function's return value.

-        self.model.train(self.params, train_data=dtrain, evals=evals, callbacks=callbacks)
+        self.model = xgb.train(self.params, dtrain=dtrain, evals=evals, callbacks=callbacks)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.model.train(self.params, train_data=dtrain, evals=evals, callbacks=callbacks)
self.model = xgb.train(self.params, dtrain=dtrain, evals=evals, callbacks=callbacks)
🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/xgboost.py around line 136, the code calls
self.model.train(...), but xgboost's train is a module-level function and
returns the trained Booster; replace the method call with an assignment using
xgb.train(...) and set self.model = xgb.train(self.params, train_data=dtrain,
evals=evals, callbacks=callbacks) (preserve existing args/kwargs), removing the
erroneous method invocation so self.model holds the returned trained model.

# logging.info("Explaining features")
# self.train_shap_values = self.explainer.shap_values(dtrain)

eval_score = mean(next(iter(self.model.evals_result()["validation"].values())))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix method call for getting evaluation results.

The evals_result() should be called on the trained booster, not as shown.

-        eval_score = mean(next(iter(self.model.evals_result()["validation"].values())))
+        # Get the evaluation history from the last callback or the model's evals_result
+        evals_result = self.model.evals_result() if hasattr(self.model, 'evals_result') else {}
+        if "validation" in evals_result and evals_result["validation"]:
+            eval_score = mean(next(iter(evals_result["validation"].values())))
+        else:
+            logging.warning("No validation metrics found")
+            eval_score = float('inf')
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
eval_score = mean(next(iter(self.model.evals_result()["validation"].values())))
# Get the evaluation history from the last callback or the model's evals_result
evals_result = self.model.evals_result() if hasattr(self.model, 'evals_result') else {}
if "validation" in evals_result and evals_result["validation"]:
eval_score = mean(next(iter(evals_result["validation"].values())))
else:
logging.warning("No validation metrics found")
eval_score = float('inf')
🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/xgboost.py around line 147, the code calls
evals_result() on the estimator object; instead call it on the trained booster.
Replace self.model.evals_result() with self.model.get_booster().evals_result(),
and keep the existing extraction of the validation metric (e.g.,
mean(next(iter(...))) ), ensuring get_booster() is used to access
evals_result().

Comment on lines +210 to +217
def persist_shap_data(trainer: Trainer, log_dir: Path, save_full_valuesets=True) -> None:
"""
Persist shap values to disk.
Args:
trainer: Pytorch lightning trainer object
log_dir: Log directory
"""
logging.info("Persisting explainer values to disk.")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Update docstring to reflect the new function behavior.

The docstring still mentions "shap values" but the function now handles "explainer values" as shown in the logging and variable names.

 def persist_shap_data(trainer: Trainer, log_dir: Path, save_full_valuesets=True) -> None:
     """
-    Persist shap values to disk.
+    Persist explainer values to disk.
     Args:
         trainer: Pytorch lightning trainer object
         log_dir: Log directory
+        save_full_valuesets: Whether to save full explainer values (default: True)
     """
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def persist_shap_data(trainer: Trainer, log_dir: Path, save_full_valuesets=True) -> None:
"""
Persist shap values to disk.
Args:
trainer: Pytorch lightning trainer object
log_dir: Log directory
"""
logging.info("Persisting explainer values to disk.")
def persist_shap_data(trainer: Trainer, log_dir: Path, save_full_valuesets=True) -> None:
"""
Persist explainer values to disk.
Args:
trainer: Pytorch lightning trainer object
log_dir: Log directory
save_full_valuesets: Whether to save full explainer values (default: True)
"""
logging.info("Persisting explainer values to disk.")
🤖 Prompt for AI Agents
In icu_benchmarks/models/train.py around lines 210 to 217, the function
docstring is outdated — it refers to "shap values" while the function now
persists "explainer values" and adds the parameter save_full_valuesets; update
the docstring to accurately describe that the function persists explainer (e.g.,
SHAP or other explainer) values to disk, list and describe the parameters
trainer: Trainer, log_dir: Path, save_full_valuesets: bool (default True) and
their meanings, and mention the function returns None; keep the brief summary
and Args/Returns sections consistent with the new logging and variable names.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (4)
icu_benchmarks/models/ml_models/xgboost.py (4)

86-104: GPU classifier not constructing a model; add docstring and initialize params; set objective.

self.model = xgb is a module reference, not a model. Default an objective; Booster gets created in fit.

 @gin.configurable
 class XGBClassifierGPU(MLWrapper):
+    """GPU-accelerated XGBoost classifier using native xgboost.train API."""
@@
-        self.model = xgb
-        self.params = {
+        self.model = None  # Will hold the trained Booster
+        self.params = {
+            "objective": "binary:logistic",
             "eval_metric": "logloss",
             "tree_method": "hist",
             "device": "cuda",
             "verbosity": 0,
             **kwargs,
         }

105-115: Predict uses module instead of trained Booster; add error handling and return NxK probs.

Return shape should align with predict_proba; add GPU error handling.

-        return self.model.predict(xgb.DMatrix(features))
+        if self.model is None:
+            raise ValueError("Model not trained. Call fit() first.")
+        try:
+            dmatrix = xgb.DMatrix(features)
+            preds = self.model.predict(dmatrix)
+        except xgb.core.XGBoostError as e:
+            logging.error(f"GPU prediction failed: {e}")
+            raise
+        # Ensure 2D probabilities for binary classification
+        if preds.ndim == 1:
+            preds = np.column_stack([1.0 - preds, preds])
+        return preds

134-138: Train via xgb.train, set objective/num_boost_round, and assign returned Booster.

Current call is wrong and ignores rounds; set objective based on class count.

-        self.model.train(self.params, train_data=dtrain, evals=evals, callbacks=callbacks)
+        # Configure objective / classes
+        num_classes = int(len(np.unique(train_labels)))
+        params = dict(self.params)
+        if num_classes > 2:
+            params.update({"objective": "multi:softprob", "num_class": num_classes})
+        else:
+            params.update({"objective": "binary:logistic"})
+        num_boost_round = int(params.pop("num_boost_round", params.pop("n_estimators", 100)))
+        self.model = xgb.train(params, dtrain, num_boost_round=num_boost_round, evals=evals, callbacks=callbacks)

149-150: Fix evals_result access; take the final validation metric.

Call on Booster and read the last value.

-        eval_score = mean(next(iter(self.model.evals_result()["validation"].values())))
-        return eval_score
+        evals_result = self.model.evals_result()
+        if "validation" in evals_result and evals_result["validation"]:
+            metric = next(iter(evals_result["validation"]))
+            eval_score = evals_result["validation"][metric][-1]
+        else:
+            logging.warning("No validation metrics found")
+            eval_score = float("inf")
+        return eval_score
🧹 Nitpick comments (3)
icu_benchmarks/models/ml_models/xgboost.py (3)

50-55: Gate SHAP explainer creation behind explain_features and make sampling reproducible.

Avoid unnecessary overhead and non-determinism.

-        n_samples = min(1000, len(train_data))
-        indices = np.random.choice(len(train_data), size=n_samples, replace=False)
-        background_sample = train_data[indices]
-        self.explainer = shap.TreeExplainer(
-            self.model, background_sample, feature_perturbation="interventional", model_output="probability"
-        )
-        # if self.explain_features:
-        #     logging.info("Explaining features")
-        #     self.train_shap_values = self.explainer.shap_values(train_data)
+        if self.explain_features:
+            n_samples = min(1000, len(train_data))
+            rng = np.random.default_rng(getattr(self.hparams, "random_state", None))
+            indices = rng.choice(len(train_data), size=n_samples, replace=False)
+            background_sample = train_data[indices]
+            self.explainer = shap.TreeExplainer(
+                self.model, background_sample, feature_perturbation="interventional", model_output="probability"
+            )

Also applies to: 56-58


63-64: Report the final/best validation metric, not the mean across rounds.

Mean hides the true end-of-training performance.

-        eval_score = mean(next(iter(self.model.evals_result_["validation_0"].values())))
-        return eval_score  # , callbacks=callbacks)
+        evals = self.model.evals_result_.get("validation_0", {})
+        metric = next(iter(evals)) if evals else None
+        eval_score = evals[metric][-1] if metric else float("inf")
+        return eval_score

73-75: Clarify warning about dropped kwargs.

Minor wording tweak to make logs clearer.

-        if len(valid_kwargs) == 0:
-            logging.warning("No valid arguments passed to XGBoost")
+        if len(valid_kwargs) == 0:
+            logging.warning("No valid XGBoost hyperparameters found in kwargs")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 9a2aa2a and 2632228.

📒 Files selected for processing (1)
  • icu_benchmarks/models/ml_models/xgboost.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
icu_benchmarks/models/ml_models/xgboost.py (1)
icu_benchmarks/models/wrappers.py (9)
  • set_model_args (642-652)
  • fit (492-516)
  • fit (744-745)
  • _explain_model (113-114)
  • MLWrapper (449-652)
  • predict (562-566)
  • predict (750-756)
  • fit_model (518-523)
  • _explain_model (488-492)
🔇 Additional comments (2)
icu_benchmarks/models/ml_models/xgboost.py (2)

68-70: LGTM: deriving valid kwargs from get_params() is the right call.


117-131: DMatrix construction LGTM.

Inputs are correctly wrapped for native training.

Comment on lines +47 to +49
logging.info(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}")
logging.info(train_labels)
self.model.fit(train_data, train_labels, eval_set=[(val_data, val_labels)], verbose=0)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enable early stopping (callbacks currently unused) and avoid logging raw labels.

Wire callbacks or early_stopping_rounds, and don’t log full label arrays.

-        logging.info(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}")
-        logging.info(train_labels)
-        self.model.fit(train_data, train_labels, eval_set=[(val_data, val_labels)], verbose=0)
+        logging.debug(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}")
+        self.model.fit(
+            train_data,
+            train_labels,
+            eval_set=[(val_data, val_labels)],
+            early_stopping_rounds=self.hparams.patience,
+            callbacks=callbacks,
+            verbose=False,
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
logging.info(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}")
logging.info(train_labels)
self.model.fit(train_data, train_labels, eval_set=[(val_data, val_labels)], verbose=0)
logging.debug(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}")
self.model.fit(
train_data,
train_labels,
eval_set=[(val_data, val_labels)],
early_stopping_rounds=self.hparams.patience,
callbacks=callbacks,
verbose=False,
)

Comment on lines 78 to 84
def _explain_model(self, reps, labels):
if not hasattr(self.model, "feature_importances_"):
raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.")
return self.model.feature_importances_
# feature_importances = self.model.feature_importances_
shap_values = self.explainer.shap_values(reps, labels)
# feature_importances = np.abs(shap_values).mean(axis=1)
return shap_values
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Fix explainer precondition and SHAP API usage.

Check for self.explainer instead of feature_importances_; avoid passing labels to shap_values (API-compat issues).

     def _explain_model(self, reps, labels):
-        if not hasattr(self.model, "feature_importances_"):
-            raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.")
-        # feature_importances = self.model.feature_importances_
-        shap_values = self.explainer.shap_values(reps, labels)
-        # feature_importances = np.abs(shap_values).mean(axis=1)
-        return shap_values
+        if not hasattr(self, "explainer"):
+            raise ValueError("Explainer not initialized. Call fit_model() with explain_features=True.")
+        return self.explainer.shap_values(reps)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _explain_model(self, reps, labels):
if not hasattr(self.model, "feature_importances_"):
raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.")
return self.model.feature_importances_
# feature_importances = self.model.feature_importances_
shap_values = self.explainer.shap_values(reps, labels)
# feature_importances = np.abs(shap_values).mean(axis=1)
return shap_values
def _explain_model(self, reps, labels):
if not hasattr(self, "explainer"):
raise ValueError("Explainer not initialized. Call fit_model() with explain_features=True.")
return self.explainer.shap_values(reps)
🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/xgboost.py around lines 78 to 84, the
precondition and SHAP call are wrong: check for the explainer (e.g., if not
hasattr(self, "explainer") or self.explainer is None) and raise a clear
ValueError instructing to call the explainer creation step before explaining;
then call self.explainer.shap_values(reps) without passing labels (remove the
labels argument) to match the SHAP API.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 24

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
icu_benchmarks/models/wrappers.py (1)

569-579: Bug: validation_step miscalls log_metrics and omits pred_indicators.

This will raise a TypeError and skip proper metric logging.

-        self.log_metrics("val/loss", self.loss(val_label, val_pred), sync_dist=True)
+        self.log("val/loss", self.loss(val_label, val_pred), sync_dist=True)
         logging.info(f"Val loss: {self.loss(val_label, val_pred)}")
-        self.log_metrics(val_label, val_pred, "val")
+        self.log_metrics(val_label, val_pred, "val", row_indicators)
♻️ Duplicate comments (6)
icu_benchmarks/models/ml_models/xgboost.py (3)

79-85: Fix explainer precondition and SHAP API usage.

Check for explainer, not feature_importances_; don’t pass labels to shap_values.

-    def _explain_model(self, reps, labels):
-        if not hasattr(self.model, "feature_importances_"):
-                raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.")
-        # feature_importances = self.model.feature_importances_
-        shap_values = self.explainer.shap_values(reps, labels)
-        # feature_importances = np.abs(shap_values).mean(axis=1)
-        return shap_values
+    def _explain_model(self, reps, labels):
+        if not hasattr(self, "explainer") or self.explainer is None:
+            raise ValueError("Explainer not initialized. Set explain_features=True when fitting.")
+        return self.explainer.shap_values(reps)

45-50: Don’t log full label arrays; wire early stopping callbacks.

Avoid printing labels (PII/verbosity) and pass callbacks to fit.

-        logging.info(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}")
-        logging.info(train_labels)
-        self.model.fit(train_data, train_labels, eval_set=[(val_data, val_labels)], verbose=0)
+        logging.debug(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}")
+        self.model.fit(
+            train_data,
+            train_labels,
+            eval_set=[(val_data, val_labels)],
+            callbacks=callbacks,
+            verbose=False,
+        )

118-151: Fix xgb.train usage and eval extraction; assign the Booster.

Use module-level xgb.train, assign to self.model, and safely read evals_result.

-        callbacks = [EarlyStopping(self.hparams.patience)]
+        callbacks = [EarlyStopping(self.hparams.patience)]
@@
-        if wandb.run is not None:
-            callbacks.append(wandb_xgb())
-        self.model.train(self.params, train_data=dtrain, evals=evals, callbacks=callbacks)
+        if wandb.run is not None:
+            callbacks.append(wandb_xgb())
+        self.model = xgb.train(self.params, dtrain=dtrain, evals=evals, callbacks=callbacks)
@@
-        eval_score = mean(next(iter(self.model.evals_result()["validation"].values())))
+        evals_result = self.model.evals_result() if hasattr(self.model, "evals_result") else {}
+        if "validation" in evals_result and evals_result["validation"]:
+            eval_score = mean(next(iter(evals_result["validation"].values())))
+        else:
+            logging.warning("No validation metrics found for GPU model.")
+            eval_score = float("inf")
icu_benchmarks/models/wrappers.py (3)

410-414: Remove duplicate valid_idx computation.

Identical assignment appears twice back-to-back.

-        valid_idx = mask.detach().cpu().numpy().astype(bool)
-
-        # valid_idx is (batch, seq_len) boolean mask
-        valid_idx = mask.detach().cpu().numpy().astype(bool)
+        # valid_idx is (batch, seq_len) boolean mask
+        valid_idx = mask.detach().cpu().numpy().astype(bool)

632-673: Tests needed for new pred_indicators branches in log_metrics.

Edge cases (temporal vs. flat, shape mismatches) aren’t covered.

I can add unit tests covering:

  • pred_indicators=None
  • 2-D indicators matching pred (temporal)
  • Non-matching shapes path
  • Classification vs regression

Say the word and I’ll open a PR-ready test module.


50-52: Default debug=True causes unintended disk I/O; gate saves behind debug and revert default.

Persisting outputs by default is risky (large files, slower runs) and inconsistent with CLI defaults. Also, _save_model_outputs is called unconditionally in DL, bypassing debug.

Apply:

-    debug = True
+    debug = False
-        # Save predictions to file
-        self._save_model_outputs(indicators_np, transformed_predictors, target_np)
+        # Save predictions to file
+        if self.debug:
+            self._save_model_outputs(indicators_np, transformed_predictors, target_np)

Also applies to: 433-435

🧹 Nitpick comments (18)
experiments/benchmark_cass_baselines.yml (2)

50-53: Normalize file_names to a proper dict string.

Safer to pass a JSON-like dict (what parse_dict typically expects).

-      - '"OUTCOME":"outc.parquet","STATIC":"sta.parquet"'
+      - '{"OUTCOME":"outc.parquet","STATIC":"sta.parquet"}'

If the parser intentionally accepts the current format, ignore this change; otherwise adopt braces consistently across all experiment files.


64-64: Add trailing newline.

Fixes YAML lint warning.

experiments/benchmark_sex.yml (2)

50-53: Use braces for file_names dict string.

-      - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"'
+      - '{"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"}'

60-60: Add trailing newline.

experiments/benchmark_cass_modalities.yml (2)

60-61: Wrap file_names mapping in braces for consistency.

-      - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"'
+      - '{"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"}'

68-68: Add trailing newline.

experiments/benchmark_cass_wearable_modalities.yml (2)

66-68: Consistent dict format for file_names.

-      - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"'
+      - '{"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"}'

73-73: Add trailing newline.

experiments/benchmark_cass_segment_length.yml (2)

66-69: Use braces around file_names mapping.

-      - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"'
+      - '{"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"}'

79-79: Add trailing newline.

experiments/benchmark_cass_model_architecture.yml (2)

59-59: Fix YAML indentation (yamllint).

Indentation under values should be 6 spaces, not 8.

-        - [ top_500_features ]
+      - [ top_500_features ]

74-74: Add trailing newline (yamllint).

-program: icu-benchmarks
+program: icu-benchmarks
+
experiments/benchmark_cass_explain_modalities.yml (1)

69-69: Add trailing newline (yamllint).

-program: icu-benchmarks
+program: icu-benchmarks
+
experiments/top_features_benchmark_cass.yml (1)

88-89: Add trailing newline.

-program: icu-benchmarks
+program: icu-benchmarks
+
icu_benchmarks/models/ml_models/xgboost.py (1)

90-91: Annotate class var per Ruff (RUF012).

Mark _explain_values as ClassVar[bool] or remove if unused.

-from icu_benchmarks.models.wrappers import MLWrapper
+from icu_benchmarks.models.wrappers import MLWrapper
+from typing import ClassVar
@@
-    _explain_values = False
+    _explain_values: ClassVar[bool] = False
icu_benchmarks/models/wrappers.py (3)

539-547: Clean up calibration toggles; use the method variable and remove duplicate assignment.

Avoid unused local (method) and duplicate calibrate assignment.

-        val_loss = self.fit_model(train_rep, train_label, val_rep, val_label)
-        calibrate = True
-        method = "isotonic"
-        # Apply calibration if desired
-        calibrate = True
-        if calibrate and self.run_mode == RunMode.classification:
-            cal_val_loss = self.setup_calibration(val_rep, val_label, method='isotonic')
+        val_loss = self.fit_model(train_rep, train_label, val_rep, val_label)
+        calibrate = True
+        method = "isotonic"
+        # Apply calibration if desired
+        if calibrate and self.run_mode == RunMode.classification:
+            cal_val_loss = self.setup_calibration(val_rep, val_label, method=method)

620-631: Silence ARG002: pred_indicators unused in log_curves.

Rename the arg to underscore to satisfy linters without changing behavior.

-    def log_curves(self, label, pred, metric_type, pred_indicators):
+    def log_curves(self, label, pred, metric_type, _pred_indicators=None):

381-399: Clarify/align DataLoader contract in DL step_fn.

Currently enforces 3 or 4 elements, but the exception message mentions older 2-tuple signatures. Either re-support the 2-tuple path (with pred_indicators=None) or update the message.

If you only intend to support indicators now, change:

-            raise Exception("Loader should return either (data, label) or (data, label, mask)")
+            raise Exception("Loader should return (data, label, indicators) or (data, label, mask, indicators)")

Otherwise, I can send a patch to handle 2-tuples by setting indicators=None and skipping saves/indicator-based logic.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 2632228 and 52895ac.

📒 Files selected for processing (11)
  • experiments/benchmark_cass.yml (1 hunks)
  • experiments/benchmark_cass_baselines.yml (1 hunks)
  • experiments/benchmark_cass_explain_modalities.yml (1 hunks)
  • experiments/benchmark_cass_modalities.yml (1 hunks)
  • experiments/benchmark_cass_model_architecture.yml (1 hunks)
  • experiments/benchmark_cass_segment_length.yml (1 hunks)
  • experiments/benchmark_cass_wearable_modalities.yml (1 hunks)
  • experiments/benchmark_sex.yml (1 hunks)
  • experiments/top_features_benchmark_cass.yml (1 hunks)
  • icu_benchmarks/models/ml_models/xgboost.py (3 hunks)
  • icu_benchmarks/models/wrappers.py (12 hunks)
🧰 Additional context used
🧬 Code graph analysis (10)
experiments/benchmark_cass_wearable_modalities.yml (1)
icu_benchmarks/run.py (1)
  • main (33-198)
experiments/benchmark_sex.yml (1)
icu_benchmarks/run.py (1)
  • main (33-198)
experiments/top_features_benchmark_cass.yml (1)
icu_benchmarks/run.py (1)
  • main (33-198)
experiments/benchmark_cass_baselines.yml (1)
icu_benchmarks/run.py (1)
  • main (33-198)
experiments/benchmark_cass_model_architecture.yml (1)
icu_benchmarks/run.py (1)
  • main (33-198)
experiments/benchmark_cass_segment_length.yml (2)
icu_benchmarks/run.py (1)
  • main (33-198)
icu_benchmarks/wandb_utils.py (1)
  • set_wandb_experiment_name (54-78)
experiments/benchmark_cass_explain_modalities.yml (1)
icu_benchmarks/run.py (1)
  • main (33-198)
experiments/benchmark_cass.yml (1)
icu_benchmarks/run.py (1)
  • main (33-198)
icu_benchmarks/models/ml_models/xgboost.py (1)
icu_benchmarks/models/wrappers.py (7)
  • set_model_args (694-704)
  • fit_model (562-567)
  • fit (529-560)
  • fit (796-797)
  • _explain_model (114-115)
  • predict (608-619)
  • predict (802-808)
icu_benchmarks/models/wrappers.py (6)
icu_benchmarks/models/utils.py (2)
  • log_single_metric_to_file (290-315)
  • name (159-160)
icu_benchmarks/models/ml_models/xgboost.py (1)
  • predict (106-116)
icu_benchmarks/constants.py (1)
  • RunMode (4-7)
icu_benchmarks/models/ml_models/lgbm.py (1)
  • predict (36-46)
icu_benchmarks/models/ml_models/catboost.py (1)
  • predict (16-26)
icu_benchmarks/models/custom_metrics.py (1)
  • confusion_matrix (141-150)
🪛 YAMLlint (1.37.1)
experiments/benchmark_cass_wearable_modalities.yml

[error] 73-73: no new line character at the end of file

(new-line-at-end-of-file)

experiments/benchmark_sex.yml

[error] 60-60: no new line character at the end of file

(new-line-at-end-of-file)

experiments/benchmark_cass_baselines.yml

[error] 64-64: no new line character at the end of file

(new-line-at-end-of-file)

experiments/benchmark_cass_model_architecture.yml

[warning] 59-59: wrong indentation: expected 6 but found 8

(indentation)


[error] 74-74: no new line character at the end of file

(new-line-at-end-of-file)

experiments/benchmark_cass_segment_length.yml

[error] 79-79: no new line character at the end of file

(new-line-at-end-of-file)

experiments/benchmark_cass_explain_modalities.yml

[error] 69-69: no new line character at the end of file

(new-line-at-end-of-file)

experiments/benchmark_cass_modalities.yml

[error] 68-68: no new line character at the end of file

(new-line-at-end-of-file)

🪛 Ruff (0.12.2)
icu_benchmarks/models/ml_models/xgboost.py

81-81: Avoid specifying long messages outside the exception class

(TRY003)


89-89: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)

icu_benchmarks/models/wrappers.py

175-175: f-string without any placeholders

Remove extraneous f prefix

(F541)


541-541: Local variable method is assigned to but never used

Remove assignment to unused variable method

(F841)


620-620: Unused method argument: pred_indicators

(ARG002)

Comment on lines +1 to +24
command:
- ${env}
- ${program}
- train
- -d
- ../data/
- -t
# - BinaryClassification
- CassClassification
- --log-dir
- /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs
- --tune
- --wandb-sweep
- -tn
- SSI
- --hp-checkpoint
- "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/dataset_baseline_cohort_post-operative_static_2025-08-01T11:00:26/SSI/XGBClassifier/2025-08-01T11-08-08.547141/hyperparameter_tuning_logs.db"
- --modalities
- "all"
- --file-names
- '"DYNAMIC":"dyn.parquet","STATIC":"sta.parquet"'
- --explain_features
# - --verbose
- --load_data_vars
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Sweeps won’t receive parameters: missing ${args} in command.

Without ${args}, W&B won’t inject sweep parameters; runs will ignore the grid under parameters. Add it to the command.

   - --load_data_vars
+  - ${args}
 method: grid
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
command:
- ${env}
- ${program}
- train
- -d
- ../data/
- -t
# - BinaryClassification
- CassClassification
- --log-dir
- /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs
- --tune
- --wandb-sweep
- -tn
- SSI
- --hp-checkpoint
- "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/dataset_baseline_cohort_post-operative_static_2025-08-01T11:00:26/SSI/XGBClassifier/2025-08-01T11-08-08.547141/hyperparameter_tuning_logs.db"
- --modalities
- "all"
- --file-names
- '"DYNAMIC":"dyn.parquet","STATIC":"sta.parquet"'
- --explain_features
# - --verbose
- --load_data_vars
command:
- ${env}
- ${program}
- train
- -d
- ../data/
- -t
# - BinaryClassification
- CassClassification
- --log-dir
- /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs
- --tune
- --wandb-sweep
- -tn
- SSI
- --hp-checkpoint
- "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/dataset_baseline_cohort_post-operative_static_2025-08-01T11:00:26/SSI/XGBClassifier/2025-08-01T11-08-08.547141/hyperparameter_tuning_logs.db"
- --modalities
- "all"
- --file-names
- '"DYNAMIC":"dyn.parquet","STATIC":"sta.parquet"'
- --explain_features
# - --verbose
- --load_data_vars
- ${args}
method: grid
🤖 Prompt for AI Agents
In experiments/benchmark_cass_baselines.yml around lines 1-24, the command list
is missing the ${args} placeholder so W&B sweeps cannot inject parameter values;
add a new list entry "- ${args}" immediately after "- ${program}" in the command
sequence so sweep parameters are expanded into the CLI before the fixed
subcommands/flags.

Comment on lines +5 to +6
- -d
- ../data/
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove fixed flags duplicated by sweep parameters (they shadow/contradict the grid).

-d/--modalities/--file-names are already provided via parameters; keep only one source of truth.

-  - -d
-  - ../data/
@@
-  - --modalities
-  - "all"
-  - --file-names
-  - '"DYNAMIC":"dyn.parquet","STATIC":"sta.parquet"'

Also applies to: 18-21

🤖 Prompt for AI Agents
In experiments/benchmark_cass_baselines.yml around lines 5-6 (and similarly
lines 18-21), the fixed CLI flags (-d, --modalities, --file-names, etc.) are
duplicated while the sweep parameters already provide those values; remove the
redundant fixed flags so there is a single source of truth: delete the -d
../data/ entries (and the corresponding duplicated modality/file-name flags at
lines 18-21) from the YAML so the sweep parameters control those options
exclusively.

Comment on lines +5 to +6
- -d
- ../data/
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Grid parameters not wired into command.

Hardcoded data_dir, modalities, and file_names prevent the sweep grid from taking effect.

-  - -d
-  - ../data/
+  - -d
+  - ${data_dir}
@@
-  - --modalities
-  - "all"
+  - --modalities
+  - ${modalities}
@@
-  - --file-names
-  - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"'
+  - --file-names
+  - ${file_names}

Also applies to: 19-23

🤖 Prompt for AI Agents
In experiments/benchmark_cass_explain_modalities.yml around lines 5-6 (and
similarly lines 19-23) the YAML hardcodes -d ../data/ plus fixed modalities and
file_names which overrides the sweep grid; change those hardcoded values to
reference the grid parameters (remove the fixed -d entry and replace
modalities/file_names entries with the grid variable placeholders or anchors
used elsewhere in the file) so that the sweep can inject data_dir, modalities,
and file_names at runtime; ensure the command or args use the same variable
names as defined in the grid section and validate the YAML syntax after
updating.

Comment on lines +21 to +22
- --file-names
- '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"'
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

OUTCOME parquet mismatch (outc_6 vs outc_9).

Unify to one value; prefer using ${file_names} so the parameters block controls it.

Also applies to: 61-62


🏁 Script executed:

#!/bin/bash
set -e
file=$(fd -g 'benchmark_cass_explain_modalities.yml')
echo "=== File path: $file ==="
rg -n '"DYNAMIC"' "$file"
rg -n 'file_names' "$file"
echo "=== Top of file (1-80) ==="
sed -n '1,80p' "$file"
echo "=== Mid of file (60-100) ==="
sed -n '60,100p' "$file"

Length of output: 5497


Use ${file_names} in --file-names flag
Change in experiments/benchmark_cass_explain_modalities.yml lines 21–22:

- --file-names
- '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"'

to

- --file-names
- ${file_names}

This aligns the command with the parameters block’s value (currently outc_9.parquet at lines 59–62).

🤖 Prompt for AI Agents
In experiments/benchmark_cass_explain_modalities.yml around lines 21 to 22,
replace the hard-coded file-names string with the parameter reference: change
the second list item from the literal
'"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' to
${file_names} so the --file-names flag uses the pipeline parameter (which is set
later to outc_9.parquet) instead of a hardcoded value.

Comment on lines +1 to +26
command:
- ${env}
- ${program}
- train
- -d
- ../data/
- -t
# - BinaryClassification
- CassClassification
- --log-dir
- /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs
# - --tune
- --wandb-sweep
- -tn
- SSI
- --hp-checkpoint
- "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-20T18-29-31.571508/hyperparameter_tuning_logs.db"
# - --verbose
- --modalities
- "all"
- --file-names
- '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"'
# - --explain_features
# - --verbose
- --load_data_vars
method: grid
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure sweep parameters are injected: add ${args}.

   - --load_data_vars
+  - ${args}
 method: grid
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
command:
- ${env}
- ${program}
- train
- -d
- ../data/
- -t
# - BinaryClassification
- CassClassification
- --log-dir
- /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs
# - --tune
- --wandb-sweep
- -tn
- SSI
- --hp-checkpoint
- "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-20T18-29-31.571508/hyperparameter_tuning_logs.db"
# - --verbose
- --modalities
- "all"
- --file-names
- '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"'
# - --explain_features
# - --verbose
- --load_data_vars
method: grid
- --load_data_vars
- ${args}
method: grid
🤖 Prompt for AI Agents
In experiments/benchmark_cass_modalities.yml around lines 1 to 26, the command
list does not include the ${args} placeholder so sweep parameters from the
runner won't be injected; add a '- ${args}' entry into the command array
(typically immediately after '- ${program}') so wandb-sweep and other runtime
arguments are passed through to the train invocation.

Comment on lines +51 to +56
n_samples = min(1000, len(train_data))
indices = np.random.choice(len(train_data), size=n_samples, replace=False)
background_sample = train_data[indices]
self.explainer = shap.TreeExplainer(
self.model, background_sample, feature_perturbation="interventional", model_output="probability"
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Gate SHAP explainer creation behind explain_features to avoid overhead.

Currently always building SHAP explainer; compute only when requested.

-        n_samples = min(1000, len(train_data))
-        indices = np.random.choice(len(train_data), size=n_samples, replace=False)
-        background_sample = train_data[indices]
-        self.explainer = shap.TreeExplainer(
-            self.model, background_sample, feature_perturbation="interventional", model_output="probability"
-        )
+        if getattr(self, "explain_features", False):
+            n_samples = min(1000, len(train_data))
+            indices = np.random.choice(len(train_data), size=n_samples, replace=False)
+            background_sample = train_data[indices]
+            self.explainer = shap.TreeExplainer(
+                self.model, background_sample, feature_perturbation="interventional", model_output="probability"
+            )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
n_samples = min(1000, len(train_data))
indices = np.random.choice(len(train_data), size=n_samples, replace=False)
background_sample = train_data[indices]
self.explainer = shap.TreeExplainer(
self.model, background_sample, feature_perturbation="interventional", model_output="probability"
)
if getattr(self, "explain_features", False):
n_samples = min(1000, len(train_data))
indices = np.random.choice(len(train_data), size=n_samples, replace=False)
background_sample = train_data[indices]
self.explainer = shap.TreeExplainer(
self.model,
background_sample,
feature_perturbation="interventional",
model_output="probability"
)
🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/xgboost.py around lines 51 to 56, the SHAP
TreeExplainer is always constructed which adds unnecessary overhead; change this
so the explainer is only created when explain_features is requested — either
wrap the current explainer construction in an if self.explain_features: guard or
remove it from initialization and lazily instantiate and cache self.explainer
inside the explain_features method the first time it is called (use the same
background_sample selection logic), ensuring subsequent calls reuse the cached
explainer.

Comment on lines +92 to +105
def __init__(self, *args, **kwargs):
# self.model = self.set_model_args(
# xgb, *args, **kwargs, eval_metric="logloss", tree_method="hist", device="cuda", verbosity=0
# )
self.model = xgb
self.params = {
"eval_metric": "logloss",
"tree_method": "hist",
"device": "cuda",
"verbosity": 0,
**kwargs,
}
super().__init__(*args, **kwargs)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

GPU init sets model to the module — make intent explicit.

Keep params but initialize model as None; a Booster will be assigned in fit_model.

-        self.model = xgb
+        self.model = None
         self.params = {
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def __init__(self, *args, **kwargs):
# self.model = self.set_model_args(
# xgb, *args, **kwargs, eval_metric="logloss", tree_method="hist", device="cuda", verbosity=0
# )
self.model = xgb
self.params = {
"eval_metric": "logloss",
"tree_method": "hist",
"device": "cuda",
"verbosity": 0,
**kwargs,
}
super().__init__(*args, **kwargs)
def __init__(self, *args, **kwargs):
# self.model = self.set_model_args(
# xgb, *args, **kwargs, eval_metric="logloss", tree_method="hist", device="cuda", verbosity=0
# )
self.model = None
self.params = {
"eval_metric": "logloss",
"tree_method": "hist",
"device": "cuda",
"verbosity": 0,
**kwargs,
}
super().__init__(*args, **kwargs)
🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/xgboost.py around lines 92 to 105, the GPU
init currently assigns self.model = xgb (the module) which is ambiguous; change
it to initialize self.model = None and retain self.params as-is so that a proper
xgboost.Booster or estimator is created and assigned later in fit_model; update
any comments to state that a Booster will be created during fit_model and ensure
subsequent code checks self.model is None before using it.

Comment on lines +106 to +117
def predict(self, features):
"""
Predicts class probabilities for the given features.
Args:
features: Input features for prediction.
Returns:
numpy.ndarray: Predicted probabilities for each class.
"""
return self.model.predict(xgb.DMatrix(features))

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

GPU predict: handle untrained state and return class probabilities.

Return Nx2 probabilities for binary tasks and add basic error handling.

-    def predict(self, features):
+    def predict(self, features):
         """
         Predicts class probabilities for the given features.
@@
-        return self.model.predict(xgb.DMatrix(features))
+        if self.model is None:
+            raise RuntimeError("Model not trained. Call fit_model() first.")
+        try:
+            preds = self.model.predict(xgb.DMatrix(features))  # shape: (n,) for binary
+        except Exception as e:
+            logging.error(f"GPU prediction failed: {e}")
+            raise
+        # Convert to Nx2 proba for classification wrappers
+        if preds.ndim == 1:
+            return np.vstack([1.0 - preds, preds]).T
+        return preds
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def predict(self, features):
"""
Predicts class probabilities for the given features.
Args:
features: Input features for prediction.
Returns:
numpy.ndarray: Predicted probabilities for each class.
"""
return self.model.predict(xgb.DMatrix(features))
def predict(self, features):
"""
Predicts class probabilities for the given features.
Args:
features: Input features for prediction.
Returns:
numpy.ndarray: Predicted probabilities for each class.
"""
if self.model is None:
raise RuntimeError("Model not trained. Call fit_model() first.")
try:
preds = self.model.predict(xgb.DMatrix(features)) # shape: (n,) for binary
except Exception as e:
logging.error(f"GPU prediction failed: {e}")
raise
# Convert to Nx2 proba for classification wrappers
if preds.ndim == 1:
return np.vstack([1.0 - preds, preds]).T
return preds

Comment on lines 493 to 527
def setup_calibration(self, val_data, val_labels, method='isotonic'):
"""
Setup model calibration using validation data for better probability estimates.
Args:
val_data: Validation features for calibration
val_labels: Validation labels for calibration
method: 'isotonic' or 'sigmoid' calibration method
Returns:
float: Validation loss after calibration
"""
if self.run_mode != RunMode.classification:
logging.warning("Calibration only supported for classification tasks")
return None

logging.info(f"Applying {method} calibration using validation data")

# Create calibrated version using validation data as holdout set
self.calibrated_model = CalibratedClassifierCV(
self.model,
method=method,
cv='prefit' # Use prefit model with holdout validation set
)
self.calibrated_model.fit(val_data, val_labels)

# Calculate calibrated validation score
cal_val_pred = self.calibrated_model.predict_proba(val_data)
cal_val_loss = self.loss(val_labels, cal_val_pred)

logging.info(f"Calibration complete. Original val loss: {self.loss(val_labels, self.predict(val_data)):.4f}, "
f"Calibrated val loss: {cal_val_loss:.4f}")

return cal_val_loss

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Calibration: handle multiclass 'isotonic' case and log true pre-calibration loss.

Isotonic isn’t supported for multiclass; current code may throw. Also “Original val loss” is computed with calibrated preds.

-    def setup_calibration(self, val_data, val_labels, method='isotonic'):
+    def setup_calibration(self, val_data, val_labels, method='isotonic'):
@@
-        if self.run_mode != RunMode.classification:
+        if self.run_mode != RunMode.classification:
             logging.warning("Calibration only supported for classification tasks")
             return None
+        n_classes = np.unique(val_labels).size
+        if n_classes > 2 and method == 'isotonic':
+            logging.warning("Isotonic calibration is not supported for multiclass; falling back to 'sigmoid'.")
+            method = 'sigmoid'
@@
-        logging.info(f"Calibration complete. Original val loss: {self.loss(val_labels, self.predict(val_data)):.4f}, "
+        logging.info(f"Calibration complete. Original val loss: {self.loss(val_labels, self.predict(val_data, use_calibrated=False)):.4f}, "
                      f"Calibrated val loss: {cal_val_loss:.4f}")

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In icu_benchmarks/models/wrappers.py around lines 493 to 527, the calibration
code can break because isotonic calibration is not supported for multiclass
problems and the "Original val loss" log is computed after replacing the model
so it shows calibrated predictions; fix by (1) before fitting, compute and store
original_val_pred = self.predict(val_data) and original_val_loss =
self.loss(val_labels, original_val_pred) and use those in the log; (2) detect
multiclass (e.g. >2 unique labels) and if method == 'isotonic' either switch
method to 'sigmoid' or raise/log a clear warning and set method='sigmoid' so
CalibratedClassifierCV is not invoked with unsupported method.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (13)
experiments/benchmark_organsystems.yml (5)

59-61: Param name likely unused: use_pretrained_imputation ≠ pretrained_imputation.

run.py reads args.pretrained_imputation; this sweep defines use_pretrained_imputation and never passes a corresponding CLI flag. This will be ignored.

Consider:

-  use_pretrained_imputation:
+  pretrained_imputation:
     values:
-      - None
+      - None

Optionally pass explicitly in command when set:

+  - --pretrained-imputation
+  - None

5-6: Duplicate/conflicting data_dir sources.

You hardcode “-d ../data/” in command while also sweeping data_dir in parameters. Depending on the sweep runner/argparse resolution, the earlier value may be overridden or cause confusion.

Remove the stub and rely on the sweep parameter:

-  - -d
-  - ../data/

Also applies to: 29-35


17-17: hp-checkpoint path may mismatch dataset/task.

The checkpoint points to a different cohort (“lab_set_normal_ward…”) while the sweep targets organ_cohorts. If the checkpoint encodes dataset-specific search space or metrics, reuse could be invalid.

Parameterize this path so it can be swapped per organ system, or drop it if not intended.


11-11: Hard-coded absolute paths reduce portability.

Absolute cluster paths (/sc-projects/…) in log-dir and data_dir values make this sweep non-portable.

  • Use env interpolation (e.g., ${env:DATA_ROOT}) or anchor variables.
  • Expose log-dir as a parameter.

Example:

-  - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs
+  - ${env:YAIB_LOG_DIR}

and

-      - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/organ_cohorts_.../Pancreas"
+      - "${env:DATA_ROOT}/organ_cohorts_.../Pancreas"

Also applies to: 31-35


62-62: Add trailing newline.

YAMLlint flags missing newline at EOF.

-program: icu-benchmarks
+program: icu-benchmarks
+
experiments/benchmark_cass_outcome_time.yml (7)

5-6: Duplicate data_dir (CLI vs sweep params) can cause drift.

You set -d ../data/ in the command but also sweep over data_dir. Remove the CLI arg and let the sweep drive it.

-  - -d
-  - ../data/

Also applies to: 44-47


11-11: Parameterize cluster-specific absolute paths via env vars.

Hardcoded /sc-projects/... paths reduce portability. Use env-substitution so agents can inject paths.

   - --log-dir
-  - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs
+  - ${env:YAIB_LOG_DIR}
@@
   - --hp-checkpoint
-  - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-20T18-29-31.571508/hyperparameter_tuning_logs.db"
+  - ${env:YAIB_HP_CHECKPOINT}

Also applies to: 17-17


79-81: Use YAML null instead of Python None string in sweeps.

W&B/YAML interpret null, not Python’s None literal. This avoids passing the string "None".

   use_pretrained_imputation:
     values:
-      - None
+      - null

34-35: De-duplicate modalities (CLI vs sweep param).

Keep one source (recommend: sweep param) to avoid future drift.

-  - --modalities
-  - "all"

Also applies to: 61-64


13-15: Add a sweep metric to enable ranking/early stopping dashboards.

Grid doesn’t require a metric, but specifying one improves reporting. Use the exact key your trainer logs (e.g., val_auroc).

Proposed addition (adjust name to your logged key):

 method: grid
+metric:
+  name: val_auroc
+  goal: maximize
 name: yaib_classification_benchmark

Also applies to: 41-43


66-76: Consider native YAML mappings for file_names values.

Less quoting and no custom parsing needed. Only do this if the code accepts dicts from W&B config.

 file_names:
   values:
-      - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_1.parquet","STATIC":"sta.parquet"'
-      - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_2.parquet","STATIC":"sta.parquet"'
+      - {DYNAMIC: dyn.parquet, OUTCOME: outc_1.parquet, STATIC: sta.parquet}
+      - {DYNAMIC: dyn.parquet, OUTCOME: outc_2.parquet, STATIC: sta.parquet}

1-4: Minor sweep ergonomics + lint fix.

  • Consider adding ${interpreter} before ${program} if icu-benchmarks isn’t an installed console script on all agents.
  • Add a trailing newline to satisfy YAMLlint.

Also applies to: 82-82

icu_benchmarks/models/__init__.py (1)

34-36: Keep RUSBClassifier in the classifier union for type hints and tooling

If RUSBClassifier remains supported, include it in MLModelClassifier to keep annotations accurate.

 MLModelClassifier = Union[
     XGBClassifier,
     LGBMClassifier,
+    RUSBClassifier,
     XGBEnsembleClassifier,
     BRFClassifier,
     CBClassifier,
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 52895ac and ba742cd.

📒 Files selected for processing (4)
  • configs/prediction_models/RUSBClassifier.gin (1 hunks)
  • experiments/benchmark_cass_outcome_time.yml (1 hunks)
  • experiments/benchmark_organsystems.yml (1 hunks)
  • icu_benchmarks/models/__init__.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
icu_benchmarks/models/__init__.py (1)
icu_benchmarks/models/ml_models/imblearn.py (3)
  • BRFClassifier (12-17)
  • XGBEnsembleClassifier (21-57)
  • RUSBClassifier (17-22)
configs/prediction_models/RUSBClassifier.gin (1)
icu_benchmarks/models/ml_models/imblearn.py (2)
  • RUSBClassifier (17-22)
  • __init__ (20-22)
experiments/benchmark_organsystems.yml (1)
icu_benchmarks/run.py (1)
  • main (33-198)
🪛 YAMLlint (1.37.1)
experiments/benchmark_organsystems.yml

[error] 62-62: no new line character at the end of file

(new-line-at-end-of-file)

experiments/benchmark_cass_outcome_time.yml

[error] 82-82: no new line character at the end of file

(new-line-at-end-of-file)

🔇 Additional comments (10)
experiments/benchmark_organsystems.yml (3)

14-15: Check short flag -tn and task name.

  • Verify that -tn is a valid alias for task-name (run.py shows args.task_name but parser isn’t shown).
  • Ensure "CassClassification" exists under configs/tasks/.gin.

Would you like me to scan the repo to confirm the parser flags and task file presence?

Also applies to: 8-10


25-25: --load_data_vars has no value — confirm it’s a boolean flag.

If the parser expects an argument (path or list), this will error. If it’s store_true, fine.

Point me to the argparse definition if you want me to validate and fix.


13-13: Confirm wandb sweep integration mode.

You’re passing --wandb-sweep and also using a sweep YAML. Ensure this runner expects both; some setups rely solely on W&B agent to inject args without an internal --wandb-sweep switch.

If redundant, drop the flag to avoid mode drift.

experiments/benchmark_cass_outcome_time.yml (2)

40-40: Confirm --load_data_vars is a boolean flag.

If it expects a value, passing it bare will error. If it’s a store_true flag, all good.


36-37: Drop the --file-names CLI flag to avoid conflicts with the OUTCOME sweep param

The CLI pinning of "OUTCOME":"outc_6.parquet" conflicts with the sweep over OUTCOME={outc_1,outc_2}. Remove these lines at 36–37 (and the analogous block at 66–76) so the sweep parameter solely controls file names:

-  - --file-names
-  - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"'
icu_benchmarks/models/__init__.py (1)

34-35: XGBEnsembleClassifier.fit_model can raise UnboundLocalError on failure

The referenced implementation returns val_loss in the except block without initialization, causing UnboundLocalError if training fails. Also logs then swallows the error, which hides hard failures.

Patch in icu_benchmarks/models/ml_models/imblearn.py:

 def fit_model(self, train_data, train_labels, val_data, val_labels):
-    import xgboost as xgb
-    from sklearn.metrics import log_loss
+    from sklearn.metrics import log_loss
+    val_loss = float("inf")
 
     # Try XGBoost first with provided config
     try:
-        self.model.fit(train_data, train_labels,)
-
+        self.model.fit(train_data, train_labels)
         val_pred_proba = self.model.predict_proba(val_data)
         val_loss = log_loss(val_labels, val_pred_proba)
         logging.info(f"XGBoost model trained successfully. Validation loss: {val_loss:.4f}")
-        return val_loss
-
     except Exception as e:
-        logging.warning(f"XGBoost failed: {e}")
-
-    return val_loss
+        logging.warning(f"XGBoost failed during fit/eval: {e}")
+        # optionally: raise to fail fast
+        # raise
+    return val_loss

Likely an incorrect or invalid review comment.

configs/prediction_models/RUSBClassifier.gin (4)

18-20: Validate sampling_strategy values against current imbalanced-learn

"auto", "majority", "not minority" are valid historically, but versions differ. Given deps were bumped, confirm these map 1:1 in the installed imbalanced-learn.


21-25: Check param namespace for base estimator (base_estimator__*) after sklearn/imbalanced-learn renames

Recent versions migrated from base_estimator -> estimator in several APIs. Ensure RUSBoostClassifier still exposes base_estimator__* rather than estimator__*, or your grid won’t bind.


26-27: Ensure replacement is a valid top-level RUSBoost parameter

Depending on the version, replacement may belong to the internal sampler; some releases surface it, others don’t. Verify it is accepted by RUSBoostClassifier in your pinned version; otherwise expose via a sampler object or drop it.


12-17: Tuple “(min, max, 'log')” is explicitly handled in Optuna binding (it checks for len == 3 and sets log=True if value[2] == "log"). Lists fall back to categorical sampling (using trial.suggest_categorical). Your config patterns match those handlers, so (0.01, 2.0, "log") and tuples are supported; no change needed.

Comment on lines 7 to 11
train_common.model = @RUSBClassifier

# Hyperparameter tuning configuration
model/hyperparameter.class_to_tune = @RUSBClassifier
model/hyperparameter.n_estimators = (10, 50, 100, 200, 500)
model/hyperparameter.learning_rate = (0.005, 1, "log")
model/hyperparameter.sampling_strategy = "auto"

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Align model name with public API choice (RUSB vs XGBEnsemble)

You removed RUSBClassifier from the public API but this config still binds train_common.model and class_to_tune to RUSBClassifier. Choose one:

  • Keep RUSBClassifier supported (re-export in init.py as suggested), or
  • Migrate this config to XGBEnsembleClassifier and rename the file accordingly.

If migrating:

-train_common.model = @RUSBClassifier
-model/hyperparameter.class_to_tune = @RUSBClassifier
+train_common.model = @XGBEnsembleClassifier
+model/hyperparameter.class_to_tune = @XGBEnsembleClassifier

🏁 Script executed:

#!/usr/bin/env bash
set -e

echo "Finding RUSBClassifier class definitions..."
rg -nP "class\s+RUSBClassifier\b" .

echo "Searching for RUSBClassifier exports in __init__.py..."
rg -n "RUSBClassifier" -g "**/__init__.py" -n

echo "Finding XGBEnsembleClassifier class definitions..."
rg -nP "class\s+XGBEnsembleClassifier\b" .

echo "Searching for XGBEnsembleClassifier references..."
rg -n "XGBEnsembleClassifier" -g "**/*.py" -n

Length of output: 143


🏁 Script executed:

#!/usr/bin/env bash
set -e

echo "Searching for any occurrences of RUSBClassifier..."
rg -n "RUSBClassifier" .

echo "Searching for any occurrences of XGBEnsembleClassifier..."
rg -n "XGBEnsembleClassifier" .

echo "Searching for filenames containing RUSBClassifier or XGBEnsembleClassifier..."
fd --hidden --no-ignore RUSBClassifier
fd --hidden --no-ignore XGBEnsembleClassifier

Length of output: 2022


Align config with public API—migrate or re-export RUSBClassifier
configs/prediction_models/RUSBClassifier.gin still references the removed RUSBClassifier class; either re-export RUSBClassifier in icu_benchmarks/models/init.py or update this config (and filename) to use XGBEnsembleClassifier.

If migrating:

-train_common.model = @RUSBClassifier
-model/hyperparameter.class_to_tune = @RUSBClassifier
+train_common.model = @XGBEnsembleClassifier
+model/hyperparameter.class_to_tune = @XGBEnsembleClassifier
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
train_common.model = @RUSBClassifier
# Hyperparameter tuning configuration
model/hyperparameter.class_to_tune = @RUSBClassifier
model/hyperparameter.n_estimators = (10, 50, 100, 200, 500)
model/hyperparameter.learning_rate = (0.005, 1, "log")
model/hyperparameter.sampling_strategy = "auto"
train_common.model = @XGBEnsembleClassifier
# Hyperparameter tuning configuration
model/hyperparameter.class_to_tune = @XGBEnsembleClassifier
🤖 Prompt for AI Agents
In configs/prediction_models/RUSBClassifier.gin around lines 7 to 11, the config
references the removed RUSBClassifier; either re-export RUSBClassifier from
icu_benchmarks/models/__init__.py or update this config (and its filename) to
use the current public API XGBEnsembleClassifier. If you choose re-exporting,
add an import and export of RUSBClassifier in icu_benchmarks/models/__init__.py
so existing gin configs keep working. If you choose migration, rename the file
to match the new model (e.g., XGBEnsembleClassifier.gin), replace both
occurrences of RUSBClassifier with XGBEnsembleClassifier, and verify any
hyperparameter names match the XGBEnsembleClassifier API.

Comment on lines +21 to +22
- --file-names
- '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"'
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

file_names appears malformed (missing braces) — likely to break parse_dict.

The dictionary-like string is missing outer {} both in the command and in parameters. Most dict parsers (incl. typical CLI "parse_dict") expect a braced JSON-like payload.

Apply this diff:

-  - --file-names
-  - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"'
+  - --file-names
+  - '{"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"}'

And for parameters:

-      - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"'
+      - '{"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"}'

Also applies to: 53-55

🤖 Prompt for AI Agents
In experiments/benchmark_organsystems.yml around lines 21-22 (and similarly
lines 53-55) the file_names value is a dictionary-like string missing outer
braces which will break parse_dict; wrap the key/value list with curly braces so
it becomes a proper dict string (e.g. add leading "{" and trailing "}" around
the quoted pairs) in both the command and parameter sections, ensuring quoting
remains valid for the CLI parser.

from icu_benchmarks.models.dl_models.transformer import BaseTransformer, LocalTransformer, Transformer
from icu_benchmarks.models.ml_models.catboost import CBClassifier
from icu_benchmarks.models.ml_models.imblearn import BRFClassifier, RUSBClassifier
from icu_benchmarks.models.ml_models.imblearn import BRFClassifier, XGBEnsembleClassifier
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Re-export RUSBClassifier to avoid breaking configs referencing @RUSBClassifier

configs/prediction_models/RUSBClassifier.gin still points to RUSBClassifier; removing it from the public API will likely break resolution if the gin loader relies on from icu_benchmarks.models import *. Re-expose RUSBClassifier alongside XGBEnsembleClassifier, or update the configs to the new class. Keeping both exported is the lowest-risk path.

Apply:

-from icu_benchmarks.models.ml_models.imblearn import BRFClassifier, XGBEnsembleClassifier
+from icu_benchmarks.models.ml_models.imblearn import BRFClassifier, RUSBClassifier, XGBEnsembleClassifier

Run to find remaining references that rely on the top-level export:


🏁 Script executed:

#!/bin/bash
rg -nC2 -S '\bRUSBClassifier\b' configs/ icu_benchmarks/ | sed -n '1,200p'

Length of output: 640


Re-export RUSBClassifier in icu_benchmarks/models/init.py
Add RUSBClassifier alongside existing imports to preserve compatibility with configs/prediction_models/RUSBClassifier.gin:

-from icu_benchmarks.models.ml_models.imblearn import BRFClassifier, XGBEnsembleClassifier
+from icu_benchmarks.models.ml_models.imblearn import BRFClassifier, RUSBClassifier, XGBEnsembleClassifier
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from icu_benchmarks.models.ml_models.imblearn import BRFClassifier, XGBEnsembleClassifier
from icu_benchmarks.models.ml_models.imblearn import BRFClassifier, RUSBClassifier, XGBEnsembleClassifier
🤖 Prompt for AI Agents
In icu_benchmarks/models/__init__.py around line 7, the module re-exports only
BRFClassifier and XGBEnsembleClassifier but needs to also expose RUSBClassifier
for compatibility with configs/prediction_models/RUSBClassifier.gin; update the
import statement to include RUSBClassifier from
icu_benchmarks.models.ml_models.imblearn so it is re-exported by this package
(ensure the symbol name matches the implementation in ml_models/imblearn).

"LocalTransformer",
"CBClassifier",
"RUSBClassifier",
"XGBEnsembleClassifier",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add RUSBClassifier back to all for star-import/generation and docs

This restores backward-compatible public exposure used by configs and (likely) Sphinx autodoc.

 __all__ = [
   "GRUNet",
   "RNNet",
   "LSTMNet",
   "TemporalConvNet",
   "BaseTransformer",
   "Transformer",
   "LocalTransformer",
   "CBClassifier",
   "XGBEnsembleClassifier",
   "BRFClassifier",
+  "RUSBClassifier",
   "LGBMClassifier",
   "LGBMRegressor",
   "XGBClassifier",
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"XGBEnsembleClassifier",
__all__ = [
"GRUNet",
"RNNet",
"LSTMNet",
"TemporalConvNet",
"BaseTransformer",
"Transformer",
"LocalTransformer",
"CBClassifier",
"XGBEnsembleClassifier",
"BRFClassifier",
"RUSBClassifier",
"LGBMClassifier",
"LGBMRegressor",
"XGBClassifier",
# ... other entries ...
]
🤖 Prompt for AI Agents
In icu_benchmarks/models/__init__.py around line 60, the exported names list
__all__ no longer includes "RUSBClassifier"; add the string "RUSBClassifier"
back into the __all__ list (as an entry with proper quotes and a trailing comma)
so star-imports, configs and Sphinx autodoc continue to expose that public
symbol.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
icu_benchmarks/data/split_process_data.py (4)

169-176: Modality selection result is dropped (assigned to data, not sanitized_data).

You pass sanitized_data in but assign to data, so subsequent steps ignore the filtered frames.

-            data, vars = modality_selection(sanitized_data, modality_mapping, selected_modalities, vars)
+            sanitized_data, vars = modality_selection(sanitized_data, modality_mapping, selected_modalities, vars)

Also applies to: 173-173


203-205: Use sanitized data for full-train path.

The complete-train branch bypasses sanitation/modality filtering by using data.

-        sanitized_data = make_train_val_polars(data, vars, train_size=None, seed=seed, debug=debug, runmode=runmode)
+        sanitized_data = make_train_val_polars(sanitized_data, vars, train_size=None, seed=seed, debug=debug, runmode=runmode)

265-268: Cache the processed artifact, not the raw input.

caching(..., data, ...) writes the unprocessed dict, inconsistent with load_cache expectations.

-        caching(cache_dir, cache_file, data, load_cache)
+        caching(cache_dir, cache_file, sanitized_data, load_cache)

621-629: Missing random_state compromises reproducibility.

Add random_state=seed to all Shuffle/StratifiedShuffle splitters.

-            outer_cv = StratifiedShuffleSplit(cv_repetitions, train_size=train_size)
+            outer_cv = StratifiedShuffleSplit(cv_repetitions, train_size=train_size, random_state=seed)
@@
-            inner_cv = StratifiedShuffleSplit(cv_folds, train_size=0.75, random_state=seed)
+            inner_cv = StratifiedShuffleSplit(cv_folds, train_size=0.75, random_state=seed)
@@
-            outer_cv = ShuffleSplit(cv_repetitions, train_size=train_size)
+            outer_cv = ShuffleSplit(cv_repetitions, train_size=train_size, random_state=seed)

Also applies to: 636-640

♻️ Duplicate comments (12)
icu_benchmarks/data/split_process_data.py (4)

285-315: Local check_sanitize_data diverges from utils (drops infinite-removal and column pruning).

This regression removes infinite_removal(...) and HRV column filtering present in utils, changing pipeline semantics.

Either delete this override to use utils, or feature-parity it with utils.

-def check_sanitize_data(data: dict[str, pl.DataFrame], vars: dict[str, str | list[str]]) -> dict[str, pl.DataFrame]:  # noqa: F811
+def check_sanitize_data(data: dict[str, pl.DataFrame], vars: dict[str, str | list[str]]) -> dict[str, pl.DataFrame]:  # noqa: F811
@@
-    if DataSegment.dynamic in data.keys():
+    if DataSegment.dynamic in data.keys():
         old_len = len(data[DataSegment.dynamic])
-
-        data[DataSegment.dynamic] = data[DataSegment.dynamic].unique(
+        data[DataSegment.dynamic] = data[DataSegment.dynamic].unique(
             subset=flatten_column_names(group, sequence), keep=keep, maintain_order=True
         )
+        # Parity with utils.check_sanitize_data
+        from icu_benchmarks.data.utils import infinite_removal
+        data[DataSegment.dynamic] = infinite_removal(data[DataSegment.dynamic])
+        data[DataSegment.dynamic] = data[DataSegment.dynamic].select(~cs.starts_with("wearable_ppgfeature_HRV_SampEn"))
+        vars[DataSegment.dynamic] = [c for c in vars.get(DataSegment.dynamic, []) if c in data[DataSegment.dynamic].columns]

Or remove this function and rely on the imported one.


317-357: Redundant modality_selection override; prefer shared implementation.

Duplicating logic increases drift risk. If specialization is needed, rename and document delta; otherwise, drop this override.

-def modality_selection(  # noqa: F811
+def modality_selection(  # noqa: F811
@@
-    return data, vars
+    return data, vars

Recommendation: delete this function and import from utils, or rename to _modality_selection_local and only call it when needed.


23-23: Remove import or local redefinitions to avoid F811/behavior drift.

You import check_sanitize_data, modality_selection but later redefine both (see Lines 285–315 and 317–357). This shadows the imports and changes behavior from utils (e.g., missing infinite-removal), risking regressions.

Apply one of:

  • Prefer utils: delete the local function defs.
  • Or keep locals: remove this import and rename functions to avoid confusion.
-from .utils import check_sanitize_data, modality_selection
+from .utils import check_sanitize_data, modality_selection
+# If keeping local versions, remove this import and rename locals to e.g.
+# _check_sanitize_data_local, _modality_selection_local

91-91: Restore “SEQUENCE” to required_var_types (or guard all downstream uses).

Defaults exclude "SEQUENCE", but the code expects it (e.g., modality selection, splitting, reduction) and will raise at runtime when missing.

-    if required_var_types is None:
-        required_var_types = ["GROUP", "LABEL"]
+    if required_var_types is None:
+        required_var_types = ["GROUP", "LABEL", "SEQUENCE"]
icu_benchmarks/models/wrappers.py (8)

405-410: Remove duplicate variable assignment

The valid_idx variable is computed twice with identical logic on lines 406 and 409.

     # valid_idx is (batch, seq_len) boolean mask
     valid_idx = mask.detach().cpu().numpy().astype(bool)

     # indicators is (batch, seq_len, ...) or (batch, seq_len, 2)
-    indicators_np = indicators.detach().cpu().numpy()[valid_idx]
-    predictors = prediction.detach().cpu()
-    target_np = target.detach().cpu().numpy()
+    indicators_np = indicators.detach().cpu().numpy()[valid_idx]
+    predictors = prediction.detach().cpu()
+    target_np = target.detach().cpu().numpy()

50-52: Debug flag should default to False for production

The debug flag is set to True by default, which causes model outputs to be saved automatically. This is inconsistent with typical CLI defaults (which use False) and can lead to unexpected disk usage in production environments.

Apply this diff to fix the default value:

-    debug = True
+    debug = False
     # We do not want to explain features by default as it is expensive (and not needed during hp tuning)
     explain_features = False

Additionally, ensure that the debug flag is properly propagated from runtime arguments when models are instantiated.


155-155: Fix logging inconsistency

Same logging issue as above - the message refers to "row_indicators.csv" instead of "pred_indicators.csv".

-                logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / 'row_indicators.csv'}")
+                logging.debug(f"Saved pred indicators to {Path(self.logger.save_dir) / 'pred_indicators.csv'}")

451-451: Guard _save_model_outputs call with debug flag

The _save_model_outputs method is called unconditionally in DLPredictionWrapper.step_fn, which could cause unnecessary I/O operations.

         # Save predictions to file
-        self._save_model_outputs(indicators_np, transformed_predictors, target_np)
+        if self.debug:
+            self._save_model_outputs(indicators_np, transformed_predictors, target_np)

117-134: Fix logging message and format string mismatch

Line 134 logs "row_indicators.csv" but the actual file saved is "pred_indicators.csv". Additionally, the format string needs to handle both classification and regression cases correctly.

-                logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / 'row_indicators.csv'}")
+                logging.debug(f"Saved pred indicators to {Path(self.logger.save_dir) / 'pred_indicators.csv'}")

156-167: Fix logging and remove extraneous f-string prefix

Line 167 has an unnecessary nested f-string and incorrect file reference.

-            logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / f'row_indicators.csv'}")
+            logging.debug(f"Saved pred indicators to {Path(self.logger.save_dir) / 'pred_indicators.csv'}")

646-681: Fix logging issue for pre-calibration loss

The "Original val loss" is computed after the model is already replaced, showing calibrated predictions instead of the original uncalibrated ones.

Store the original predictions before calibration:

     def setup_calibration(self, val_data, val_labels, method='sigmoid'):
         """
         Setup model calibration using validation data for better probability estimates.
         
         Args:
             val_data: Validation features for calibration
             val_labels: Validation labels for calibration
             method: 'isotonic' or 'sigmoid' calibration method
         
         Returns:
             float: Validation loss after calibration
         """
         if self.run_mode != RunMode.classification:
             logging.warning("Calibration only supported for classification tasks")
             return None
+        
+        # Store original predictions before calibration
+        original_val_pred = self.model.predict_proba(val_data)
+        original_val_loss = self.loss(val_labels, original_val_pred)
         
         logging.info(f"Applying {method} calibration using validation data")
         
         # Create calibrated version using validation data as holdout set
         self.calibrated_model = CalibratedClassifierCV(
             self.model,
             method=method,
             cv='prefit'  # Use prefit model with holdout validation set
         )
         self.calibrated_model.fit(val_data, val_labels)
         
         # Calculate calibrated validation score
         cal_val_pred = self.calibrated_model.predict_proba(val_data)
         cal_val_loss = self.loss(val_labels, cal_val_pred)
         
         logging.info(f"Calibration complete. "
-                     f"Original val loss: {self.loss(val_labels, self.model.predict_proba(val_data)):.4f}, "
+                     f"Original val loss: {original_val_loss:.4f}, "
                      f"Calibrated val loss: {cal_val_loss:.4f}")

465-559: Handle multiclass limitation for isotonic calibration

Isotonic calibration is not supported for multiclass problems. The code should detect this and either fallback to sigmoid or raise a clear warning.

Add multiclass detection in the setup_calibration method:

     def setup_calibration(self, val_loader, method='isotonic'):
         """
         Setup model calibration using validation data for better probability estimates.
         
         Args:
             val_loader: Validation dataloader for calibration
             method: 'isotonic' or 'sigmoid' calibration method
         
         Returns:
             float: Validation loss after calibration
         """
         if self.run_mode != RunMode.classification:
             logging.warning("Calibration only supported for classification tasks")
             return None
+        
+        # Check if multiclass and isotonic
+        if hasattr(self, 'logit') and self.logit.out_features > 2 and method == 'isotonic':
+            logging.warning("Isotonic calibration not supported for multiclass; falling back to sigmoid")
+            method = 'sigmoid'
🧹 Nitpick comments (7)
icu_benchmarks/data/split_process_data.py (5)

120-135: exclude_preproc typing should be Optional.

Use Optional in the signature per PEP 484 and the Ruff hint; also coerce to set to deduplicate.

-    exclude_preproc: list[str] = None,
+    exclude_preproc: Optional[list[str]] = None,
@@
-    vars_to_exclude = []
+    vars_to_exclude: list[str] = []
@@
-            for modality in exclude_preproc:
+            for modality in exclude_preproc or []:
                 if modality in modality_mapping:
                     vars_to_exclude.extend(modality_mapping.get(modality))
@@
-            logging.info(
-                f"Excluding modalities in {exclude_preproc}. Total vars excluded from preprocessing: {len(vars_to_exclude)}"
-            )
+            vars_to_exclude = list(set(vars_to_exclude))
+            logging.info(f"Excluding modalities in {exclude_preproc}. Total vars excluded: {len(vars_to_exclude)}")

512-512: Inconsistent train_size defaults across helpers.

make_single_split_pandas defaults to 0.80, while make_single_split_polars and make_single_split use None. Align for predictable behavior and docs.

-def make_single_split_polars(
+def make_single_split_polars(
@@
-    train_size: Optional[float] = None,
+    train_size: Optional[float] = 0.80,

And/or set wrapper default to 0.80.

Also applies to: 591-591, 674-674


841-843: Minor: use iterable unpacking in concat (RUF005).

Slightly cleaner; no behavior change.

-            data[DataSegment.dynamic] = pl.concat([unchanged_dynamic] + reduced_dynamic_parts).sort(
+            data[DataSegment.dynamic] = pl.concat([unchanged_dynamic, *reduced_dynamic_parts]).sort(
                 [group_var, sequence_var])
@@
-                data[DataSegment.outcome] = pl.concat([unchanged_outcome] + reduced_outcome_parts).sort(
+                data[DataSegment.outcome] = pl.concat([unchanged_outcome, *reduced_outcome_parts]).sort(
                     [group_var, sequence_var])

Also applies to: 863-865


158-159: Style: long exception messages.

Ruff TRY003 flags long messages. Consider shorter messages or custom exception types, but this is non-blocking.

Also applies to: 767-768


211-221: Log message says “Dropping columns with nulls” but no drop occurs.

Adjust wording to avoid confusion.

-            logging.debug("Dropping columns with nulls")
+            logging.debug("Columns with remaining nulls (no drop performed):")
icu_benchmarks/models/wrappers.py (2)

694-700: Remove duplicate calibration flag assignment

Lines 694 and 697 both set calibrate = True, which is redundant.

-        calibrate = True
-        method = "sigmoid"
-        # Apply calibration if desired
         calibrate = True
+        method = "sigmoid"
         if calibrate and self.run_mode == RunMode.classification:

765-766: Clean up commented code

Line 765 contains incomplete commented code that should be removed.

         test_pred = self.predict(test_rep)
-        # test_pred_uncalibrated = self.predict(test_rep, use_calibrated=False)False
         self.compare_rankings(test_rep, test_label)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ba742cd and 235c02f.

📒 Files selected for processing (2)
  • icu_benchmarks/data/split_process_data.py (15 hunks)
  • icu_benchmarks/models/wrappers.py (16 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
icu_benchmarks/data/split_process_data.py (4)
icu_benchmarks/constants.py (1)
  • RunMode (4-7)
icu_benchmarks/data/preprocessor.py (6)
  • PolarsClassificationPreprocessor (90-268)
  • Preprocessor (36-87)
  • PolarsClassificationPreprocessor (71-246)
  • apply (563-591)
  • apply (102-176)
  • apply (355-402)
icu_benchmarks/data/constants.py (3)
  • DataSegment (7-11)
  • DataSplit (1-4)
  • VarType (14-17)
icu_benchmarks/data/utils.py (2)
  • check_sanitize_data (26-51)
  • modality_selection (54-81)
icu_benchmarks/models/wrappers.py (3)
icu_benchmarks/models/utils.py (2)
  • log_single_metric_to_file (290-315)
  • name (159-160)
icu_benchmarks/constants.py (1)
  • RunMode (4-7)
icu_benchmarks/models/ml_models/lgbm.py (1)
  • predict (36-46)
🪛 Ruff (0.12.2)
icu_benchmarks/data/split_process_data.py

35-35: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


158-158: Avoid specifying long messages outside the exception class

(TRY003)


767-767: Avoid specifying long messages outside the exception class

(TRY003)


841-841: Consider [unchanged_dynamic, *reduced_dynamic_parts] instead of concatenation

Replace with [unchanged_dynamic, *reduced_dynamic_parts]

(RUF005)


863-863: Consider [unchanged_outcome, *reduced_outcome_parts] instead of concatenation

Replace with [unchanged_outcome, *reduced_outcome_parts]

(RUF005)

icu_benchmarks/models/wrappers.py

167-167: f-string without any placeholders

Remove extraneous f prefix

(F541)


447-447: Avoid specifying long messages outside the exception class

(TRY003)


495-495: Create your own exception

(TRY002)


495-495: Avoid specifying long messages outside the exception class

(TRY003)


818-818: Unused method argument: pred_indicators

(ARG002)

🔇 Additional comments (3)
icu_benchmarks/models/wrappers.py (3)

114-116: Good addition of feature explanation placeholder

The _explain_model method provides a clear interface for subclasses to implement feature explanation functionality with an appropriate NotImplementedError message.


786-798: Well-implemented calibration fallback logic

The predict method properly handles both calibrated and uncalibrated prediction paths with appropriate fallbacks for both classification and regression modes.


818-829: Good implementation of metric logging to CSV

The log_curves method properly exports metrics to CSV files using the log_single_metric_to_file utility, providing valuable persistence for analysis.

Comment on lines +211 to +242
logging.debug(f"Checking for NaNs and nulls in {data.keys()}.")
for _dict in sanitized_data.values():
for key, val in _dict.items():
for key, dataframe in _dict.items():
logging.debug(f"Data type: {key}")
logging.debug("Is NaN:")
sel = _dict[key].select(pl.selectors.numeric().is_nan().max())
logging.debug(sel.select(col.name for col in sel if col.item(0)))
logging.debug("Has nulls:")
sel = _dict[key].select(pl.all().has_nulls())
logging.debug(sel.select(col.name for col in sel if col.item(0)))
_dict[key] = val.fill_null(strategy="zero")
_dict[key] = val.fill_nan(0)
_dict[key] = dataframe.fill_null(strategy="zero")
_dict[key] = dataframe.fill_nan(0)
logging.debug("Dropping columns with nulls")
sel = _dict[key].select(pl.all().has_nulls())
logging.debug(sel.select(col.name for col in sel if col.item(0)))
logging.debug("Checking for infinite values.")
for col in dataframe.select(cs.numeric()).columns:
if dataframe[col].is_infinite().any():
logging.warning(f"Column '{col}' contains infinite values. Datatype: {dataframe[col].dtype}")

max_float64 = 0
# Replace infinite values with the maximum value for float64
dataframe = dataframe.with_columns(
[
pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col)
for col in dataframe.columns
if dataframe[col].dtype == pl.Float64
]
)
_dict[key] = dataframe
logging.debug(f"Amount of columns: {len(dataframe.columns)}")
logging.info(f"{len(sanitized_data[DataSplit.train][DataSegment.features].columns)} columns in dynamic data.")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix NaN/null/Inf handling; current logic loses earlier fills and replaces Inf with 0.

  • Uses stale dataframe after assigning to _dict[key], discarding previous fills.
  • Replaces Inf with 0 and only in Float64 columns.
-    logging.debug(f"Checking for NaNs and nulls in {data.keys()}.")
+    logging.debug(f"Checking for NaNs and nulls.")
@@
-        for key, dataframe in _dict.items():
+        for key, dataframe in _dict.items():
@@
-            _dict[key] = dataframe.fill_null(strategy="zero")
-            _dict[key] = dataframe.fill_nan(0)
+            # Fill nulls/NaNs in numeric columns
+            dataframe = dataframe.fill_null(strategy="zero").fill_nan(0)
+            _dict[key] = dataframe
@@
-            logging.debug("Checking for infinite values.")
-            for col in dataframe.select(cs.numeric()).columns:
-                if dataframe[col].is_infinite().any():
-                    logging.warning(f"Column '{col}' contains infinite values. Datatype: {dataframe[col].dtype}")
+            logging.debug("Checking for infinite values.")
+            for col in dataframe.select(cs.numeric()).columns:
+                if dataframe[col].is_infinite().any():
+                    logging.warning(f"Column '{col}' contains infinite values. Datatype: {dataframe[col].dtype}")
@@
-            max_float64 = 0
-            # Replace infinite values with the maximum value for float64
-            dataframe = dataframe.with_columns(
-                [
-                    pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col)
-                    for col in dataframe.columns
-                    if dataframe[col].dtype == pl.Float64
-                ]
-            )
+            # Replace +/-inf in all numeric columns with max finite float64
+            import numpy as np
+            max_float64 = np.finfo(np.float64).max
+            num_cols = dataframe.select(cs.numeric()).columns
+            dataframe = dataframe.with_columns([
+                pl.when(pl.col(c).is_infinite()).then(pl.lit(max_float64)).otherwise(pl.col(c)).alias(c)
+                for c in num_cols
+            ])
             _dict[key] = dataframe
             logging.debug(f"Amount of columns: {len(dataframe.columns)}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
logging.debug(f"Checking for NaNs and nulls in {data.keys()}.")
for _dict in sanitized_data.values():
for key, val in _dict.items():
for key, dataframe in _dict.items():
logging.debug(f"Data type: {key}")
logging.debug("Is NaN:")
sel = _dict[key].select(pl.selectors.numeric().is_nan().max())
logging.debug(sel.select(col.name for col in sel if col.item(0)))
logging.debug("Has nulls:")
sel = _dict[key].select(pl.all().has_nulls())
logging.debug(sel.select(col.name for col in sel if col.item(0)))
_dict[key] = val.fill_null(strategy="zero")
_dict[key] = val.fill_nan(0)
_dict[key] = dataframe.fill_null(strategy="zero")
_dict[key] = dataframe.fill_nan(0)
logging.debug("Dropping columns with nulls")
sel = _dict[key].select(pl.all().has_nulls())
logging.debug(sel.select(col.name for col in sel if col.item(0)))
logging.debug("Checking for infinite values.")
for col in dataframe.select(cs.numeric()).columns:
if dataframe[col].is_infinite().any():
logging.warning(f"Column '{col}' contains infinite values. Datatype: {dataframe[col].dtype}")
max_float64 = 0
# Replace infinite values with the maximum value for float64
dataframe = dataframe.with_columns(
[
pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col)
for col in dataframe.columns
if dataframe[col].dtype == pl.Float64
]
)
_dict[key] = dataframe
logging.debug(f"Amount of columns: {len(dataframe.columns)}")
logging.info(f"{len(sanitized_data[DataSplit.train][DataSegment.features].columns)} columns in dynamic data.")
logging.debug(f"Checking for NaNs and nulls.")
for _dict in sanitized_data.values():
for key, dataframe in _dict.items():
logging.debug(f"Data type: {key}")
logging.debug("Is NaN:")
sel = _dict[key].select(pl.selectors.numeric().is_nan().max())
logging.debug(sel.select(col.name for col in sel if col.item(0)))
logging.debug("Has nulls:")
sel = _dict[key].select(pl.all().has_nulls())
logging.debug(sel.select(col.name for col in sel if col.item(0)))
# Fill nulls/NaNs in numeric columns
dataframe = dataframe.fill_null(strategy="zero").fill_nan(0)
_dict[key] = dataframe
logging.debug("Dropping columns with nulls")
sel = _dict[key].select(pl.all().has_nulls())
logging.debug(sel.select(col.name for col in sel if col.item(0)))
logging.debug("Checking for infinite values.")
for col in dataframe.select(cs.numeric()).columns:
if dataframe[col].is_infinite().any():
logging.warning(f"Column '{col}' contains infinite values. Datatype: {dataframe[col].dtype}")
# Replace +/-inf in all numeric columns with max finite float64
import numpy as np
max_float64 = np.finfo(np.float64).max
num_cols = dataframe.select(cs.numeric()).columns
dataframe = dataframe.with_columns([
pl.when(pl.col(c).is_infinite()).then(pl.lit(max_float64)).otherwise(pl.col(c)).alias(c)
for c in num_cols
])
_dict[key] = dataframe
logging.debug(f"Amount of columns: {len(dataframe.columns)}")
logging.info(f"{len(sanitized_data[DataSplit.train][DataSegment.features].columns)} columns in dynamic data.")
🤖 Prompt for AI Agents
In icu_benchmarks/data/split_process_data.py around lines 211-242, the code
fills nulls/NaNs into _dict[key] but then continues to use the older dataframe
variable (losing the fills) and replaces infinite values with 0 only for
Float64; fix by performing fills and Inf-replacement on the same dataframe
variable and assign back to _dict[key] once at the end. Specifically: use
dataframe = dataframe.fill_null(strategy="zero") and dataframe =
dataframe.fill_nan(0) (or equivalent) so subsequent checks operate on the
updated DataFrame; when replacing infinities compute a replacement per numeric
float column (either the column's finite max or the dtype max via numpy.finfo)
while ignoring inf/NaN, and apply the replacement to all float dtypes (Float32
and Float64) using dataframe.with_columns([...]); finally assign the fully
sanitized dataframe to _dict[key] and only then run final logging/column checks.

Comment on lines +246 to +248
train_incidence = sanitized_data[DataSplit.test][DataSegment.outcome][vars[VarType.label]]
val_incidence = sanitized_data[DataSplit.val][DataSegment.outcome][vars[VarType.label]]
test_incidence = sanitized_data[DataSplit.test][DataSegment.outcome][vars[VarType.label]]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Wrong splits used for incidence stats (train uses test split).

Use the corresponding split for each metric.

-    train_incidence = sanitized_data[DataSplit.test][DataSegment.outcome][vars[VarType.label]]
+    train_incidence = sanitized_data[DataSplit.train][DataSegment.outcome][vars[VarType.label]]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
train_incidence = sanitized_data[DataSplit.test][DataSegment.outcome][vars[VarType.label]]
val_incidence = sanitized_data[DataSplit.val][DataSegment.outcome][vars[VarType.label]]
test_incidence = sanitized_data[DataSplit.test][DataSegment.outcome][vars[VarType.label]]
train_incidence = sanitized_data[DataSplit.train][DataSegment.outcome][vars[VarType.label]]
val_incidence = sanitized_data[DataSplit.val][DataSegment.outcome][vars[VarType.label]]
test_incidence = sanitized_data[DataSplit.test][DataSegment.outcome][vars[VarType.label]]
🤖 Prompt for AI Agents
In icu_benchmarks/data/split_process_data.py around lines 246 to 248, the
incidence stats use the wrong data splits (train_incidence is using
DataSplit.test); change train_incidence to read from DataSplit.train, keep
val_incidence using DataSplit.val, and test_incidence using DataSplit.test so
each metric pulls from its corresponding split.

Comment on lines 446 to 448
else:
raise ValueError(f"Run mode {self.run_mode} not yet supported. Please implement it.")
transformed_output = self.output_transform((prediction, target))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove duplicate output transform assignment

Line 448 reassigns transformed_output immediately after it was already assigned based on calibration status (lines 440-442).

             else:
                 transformed_output = self.output_transform((prediction, target))
         elif self.run_mode == RunMode.regression:
             # Regression task
             loss = self.loss(prediction[:, 0], target.float()) + aux_loss
+            transformed_output = self.output_transform((prediction, target))
         else:
             raise ValueError(f"Run mode {self.run_mode} not yet supported. Please implement it.")
-        transformed_output = self.output_transform((prediction, target))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
else:
raise ValueError(f"Run mode {self.run_mode} not yet supported. Please implement it.")
transformed_output = self.output_transform((prediction, target))
if self.run_mode == RunMode.classification:
# Classification task (with optional calibration)
if self.calibrated:
prediction = self.calibrator.predict_proba(raw_prediction)
transformed_output = self.output_transform((prediction, target))
else:
transformed_output = self.output_transform((prediction, target))
elif self.run_mode == RunMode.regression:
# Regression task
loss = self.loss(prediction[:, 0], target.float()) + aux_loss
transformed_output = self.output_transform((prediction, target))
else:
raise ValueError(f"Run mode {self.run_mode} not yet supported. Please implement it.")
🧰 Tools
🪛 Ruff (0.12.2)

447-447: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In icu_benchmarks/models/wrappers.py around lines 446 to 448, there is a
duplicate assignment to transformed_output immediately after it was already set
based on calibration status; remove the second assignment (the one after the
run_mode branch) so transformed_output is only assigned once from the earlier
calibration-aware branch, or if the intent was to use a different input there,
replace the duplicate with the correct variable—ensure transformed_output is
computed exactly once and used thereafter.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
icu_benchmarks/data/split_process_data.py (2)

266-270: Cache the preprocessed data and honor generate_cache

Currently writes raw data and gated by load_cache, so caches aren’t created when generate_cache is True.

-    if generate_cache:
-        caching(cache_dir, cache_file, data, load_cache)
+    if generate_cache:
+        caching(cache_dir, cache_file, sanitized_data, use_cache=True)

287-317: Fix return type annotation to match actual return

Function returns (data, vars) but is annotated as a dict.

-def check_sanitize_data(data: dict[str, pl.DataFrame], vars: dict[str, str | list[str]]) -> dict[str, pl.DataFrame]:  # noqa: F811
+def check_sanitize_data(
+    data: dict[str, pl.DataFrame], vars: dict[str, str | list[str]]
+) -> tuple[dict[str, pl.DataFrame], dict[str, str | list[str]]]:  # noqa: F811
♻️ Duplicate comments (7)
icu_benchmarks/run_utils.py (3)

218-226: De-duplicate SHAP aggregation and log exceptions with traceback

The two concat+write blocks are duplicates; also log with logging.exception to retain stack trace.

-    if explain_features:
-        if explainer_values_test:
-            shap_values = pl.concat(explainer_values_test)
-            shap_values.write_parquet(log_dir / "aggregated_explainer_values.parquet")
-        try:
-            shap_values = pl.concat(explainer_values_test)
-            shap_values.write_parquet(log_dir / "aggregated_explainer_values.parquet")
-        except Exception as e:
-            logging.error(f"Error aggregating or writing SHAP values: {e}")
+    if explain_features and explainer_values_test:
+        try:
+            pl.concat(explainer_values_test).write_parquet(log_dir / "aggregated_explainer_values.parquet")
+        except Exception:
+            logging.exception("Error aggregating or writing explainer values")

233-235: Guard math.isnan for ints to avoid TypeError

math.isnan(int) raises TypeError.

-                if isinstance(score, (float, int)):
-                    if math.isnan(score):
+                if isinstance(score, float):
+                    if math.isnan(score):
                         logging.warning(f"Score for metric {metric} is NaN, adding 0 instead.")
                         score = 0

102-106: Fix collision handling (type error and weak uniqueness)

  • String–int concat raises TypeError.
  • Small range risks repeated collisions; add safer suffix and bounded retries.
-    log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f"))
-    while log_dir_run.exists():
-        log_dir_run = log_dir_run.with_name(log_dir_run.name + random.randint(1, 10))
+    log_dir_run = log_dir / datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f")
+    for _ in range(50):
+        if not log_dir_run.exists():
+            break
+        log_dir_run = log_dir_run.with_name(f"{log_dir_run.name}_{random.randint(1, 1000)}")
+    else:
+        # Last resort: timestamp + random salt
+        log_dir_run = log_dir / f"{datetime.now().strftime('%Y-%m-%dT%H-%M-%S-%f')}_{random.randint(1000, 9999)}"
icu_benchmarks/data/split_process_data.py (4)

248-250: Use the correct split for train incidence

train_incidence reads from test. Use train.

-    train_incidence = sanitized_data[DataSplit.test][DataSegment.outcome][vars[VarType.label]]
+    train_incidence = sanitized_data[DataSplit.train][DataSegment.outcome][vars[VarType.label]]

23-24: Avoid redefining imported names in the same module

This file imports check_sanitize_data and modality_selection but also defines them later (F811). Drop the imports and use the local definitions.

-from .utils import check_sanitize_data, modality_selection

91-93: Include SEQUENCE in required_var_types

Downstream code indexes vars[VarType.sequence]; omitting it will raise KeyError.

-    if required_var_types is None:
-        required_var_types = ["GROUP", "LABEL"]
+    if required_var_types is None:
+        required_var_types = ["GROUP", "LABEL", "SEQUENCE"]

213-244: Fix NaN/Inf sanitization: operate on the updated frame and replace Inf with max float for all numeric cols

  • Current code updates _dict[key] but keeps transforming the stale dataframe.
  • Replaces Inf with 0 and only for Float64.
-            _dict[key] = dataframe.fill_null(strategy="zero")
-            _dict[key] = dataframe.fill_nan(0)
+            dataframe = dataframe.fill_null(strategy="zero").fill_nan(0)
@@
-            logging.debug("Checking for infinite values.")
-            for col in dataframe.select(cs.numeric()).columns:
-                if dataframe[col].is_infinite().any():
-                    logging.warning(f"Column '{col}' contains infinite values. Datatype: {dataframe[col].dtype}")
+            logging.debug("Checking for infinite values.")
+            for col in dataframe.select(cs.numeric()).columns:
+                if dataframe[col].is_infinite().any():
+                    logging.warning(f"Column '{col}' contains infinite values. Datatype: {dataframe[col].dtype}")
@@
-            max_float64 = 0
-            # Replace infinite values with the maximum value for float64
-            dataframe = dataframe.with_columns(
-                [
-                    pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col)
-                    for col in dataframe.columns
-                    if dataframe[col].dtype == pl.Float64
-                ]
-            )
-            _dict[key] = dataframe
+            # Replace +/-inf in all numeric columns with max finite float64
+            import numpy as np
+            max_float64 = np.finfo(np.float64).max
+            num_cols = dataframe.select(cs.numeric()).columns
+            dataframe = dataframe.with_columns([
+                pl.when(pl.col(c).is_infinite()).then(pl.lit(max_float64)).otherwise(pl.col(c)).alias(c)
+                for c in num_cols
+            ])
+            _dict[key] = dataframe

Based on learnings (Polars selectors).

🧹 Nitpick comments (6)
icu_benchmarks/models/ml_models/lgbm.py (3)

40-49: Make background sampling configurable to cap SHAP cost

Current default uses the full train set as background, which can be very slow. Add a configurable cap (e.g., background_size=1000) and sample without replacement.

-        reduce_samples = False
-        if reduce_samples:
-            n_samples = min(1000, len(train_data))
-            indices = np.random.choice(len(train_data), size=n_samples, replace=False)
-        else:
-            indices = np.arange(len(train_data))
+        background_size = getattr(self, "background_size", 1000)
+        n_samples = min(background_size, len(train_data))
+        indices = np.random.choice(len(train_data), size=n_samples, replace=False)

59-61: Drop redundant/ineffective verbose kw and rely on verbosity

set_model_args filters by constructor signature; verbose isn’t a valid init kw for scikit LGBM and gets ignored. verbosity is already set to -1.

-        self.model = self.set_model_args(lgbm.LGBMClassifier, *args, **kwargs, verbose=-1)
+        self.model = self.set_model_args(lgbm.LGBMClassifier, *args, **kwargs)

7-9: Remove unused imports

os, wandb, and wandb_lgbm aren’t used.

-import os
-import wandb
-from wandb.integration.lightgbm import wandb_callback as wandb_lgbm
icu_benchmarks/run_utils.py (2)

163-179: Replace prints and NumPy conversion with Polars ops

Use logging and Polars min/max to avoid materializing full columns.

-                    print(f"Loading file: {file_path}")
+                    logging.info(f"Loading file: {file_path}")
@@
-                    original_ids = df.select(pl.col(id_column)).to_numpy().flatten()
-                    print(f"  Original ID range: {original_ids.min()} - {original_ids.max()}")
+                    min_id, max_id = df.select(
+                        pl.col(id_column).min().alias("min_id"),
+                        pl.col(id_column).max().alias("max_id")
+                    ).row(0)
+                    logging.info(f"  Original ID range: {min_id} - {max_id}")
@@
-                    modified_ids = df.select(pl.col(id_column)).to_numpy().flatten()
-                    print(f"  Modified ID range: {modified_ids.min()} - {modified_ids.max()}")
-                    print(f"  Counter: {counter}")
-                    print()
+                    new_min_id, new_max_id = df.select(
+                        pl.col(id_column).min().alias("min_id"),
+                        pl.col(id_column).max().alias("max_id")
+                    ).row(0)
+                    logging.info(f"  Modified ID range: {new_min_id} - {new_max_id}; counter={counter}")

240-249: Avoid shadowing built-in list in comprehensions

Rename to vals for readability and to avoid confusion.

-    averaged_scores = {metric: (mean(list)) for metric, list in list_scores.items()}
+    averaged_scores = {metric: mean(vals) for metric, vals in list_scores.items()}
@@
-    std_scores = {metric: (pstdev(list) / sqrt(len(list))) for metric, list in list_scores.items()}
+    std_scores = {metric: pstdev(vals) / sqrt(len(vals)) for metric, vals in list_scores.items()}
@@
-    confidence_interval = {
-        metric: (stats.t.interval(0.95, len(list) - 1, loc=mean(list), scale=stats.sem(list)))
-        for metric, list in list_scores.items()
-    }
+    confidence_interval = {
+        metric: stats.t.interval(0.95, len(vals) - 1, loc=mean(vals), scale=stats.sem(vals))
+        for metric, vals in list_scores.items()
+    }
icu_benchmarks/data/split_process_data.py (1)

820-867: Optional: Vectorize stay reduction for performance

The per-stay loop filters each stay individually (O(stays × rows)). Consider using groupby + cumcount to slice rows without Python loops.

I can provide a vectorized Polars variant if desired.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 235c02f and 88b63e7.

📒 Files selected for processing (3)
  • icu_benchmarks/data/split_process_data.py (15 hunks)
  • icu_benchmarks/models/ml_models/lgbm.py (4 hunks)
  • icu_benchmarks/run_utils.py (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
icu_benchmarks/run_utils.py (2)
icu_benchmarks/utils.py (1)
  • parse_dict (5-22)
icu_benchmarks/wandb_utils.py (1)
  • wandb_log (47-54)
icu_benchmarks/data/split_process_data.py (4)
icu_benchmarks/constants.py (1)
  • RunMode (4-7)
icu_benchmarks/data/preprocessor.py (6)
  • PolarsClassificationPreprocessor (90-268)
  • Preprocessor (36-87)
  • PolarsClassificationPreprocessor (71-246)
  • apply (563-591)
  • apply (102-176)
  • apply (355-402)
icu_benchmarks/data/constants.py (3)
  • DataSegment (7-11)
  • DataSplit (1-4)
  • VarType (14-17)
icu_benchmarks/data/utils.py (2)
  • check_sanitize_data (26-51)
  • modality_selection (54-81)
icu_benchmarks/models/ml_models/lgbm.py (2)
icu_benchmarks/models/wrappers.py (4)
  • fit (683-714)
  • fit (994-995)
  • set_model_args (892-902)
  • _explain_model (114-115)
icu_benchmarks/models/ml_models/xgboost.py (2)
  • set_model_args (67-77)
  • _explain_model (79-85)
🪛 Ruff (0.14.0)
icu_benchmarks/run_utils.py

104-104: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)


188-188: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


225-225: Do not catch blind exception: Exception

(BLE001)


226-226: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


229-229: Loop control variable repetition not used within loop body

(B007)


230-230: Loop control variable fold not used within loop body

Rename unused fold to _fold

(B007)

icu_benchmarks/data/split_process_data.py

35-35: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


158-158: Avoid specifying long messages outside the exception class

(TRY003)


769-769: Avoid specifying long messages outside the exception class

(TRY003)


843-843: Consider [unchanged_dynamic, *reduced_dynamic_parts] instead of concatenation

Replace with [unchanged_dynamic, *reduced_dynamic_parts]

(RUF005)


865-865: Consider [unchanged_outcome, *reduced_outcome_parts] instead of concatenation

Replace with [unchanged_outcome, *reduced_outcome_parts]

(RUF005)

icu_benchmarks/models/ml_models/lgbm.py

77-77: Avoid specifying long messages outside the exception class

(TRY003)

Comment on lines +24 to +40
# Redirect stderr and stdout to suppress warnings
stderr_backup = sys.stderr
stdout_backup = sys.stdout
sys.stderr = StringIO()
sys.stdout = StringIO()
try:
self.model = self.model.fit(
train_data,
train_labels,
eval_set=(val_data, val_labels),
callbacks=callbacks,
)
finally:
sys.stderr = stderr_backup
# to surpress: [Warning] No further splits with positive gain, best gain: -inf with the current config
sys.stdout = stdout_backup
reduce_samples = False
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix eval_set and avoid process-wide stdout/stderr hijack

  • Pass eval_set as a list of tuples; the current tuple form can skip validation logging/metrics and leave best_score_ unset.
  • Replace global sys redirection with scoped context managers to avoid hiding logs from other threads/process parts.

Apply:

+from contextlib import redirect_stdout, redirect_stderr
+from io import StringIO
@@
-        # Redirect stderr and stdout to suppress warnings
-        stderr_backup = sys.stderr
-        stdout_backup = sys.stdout
-        sys.stderr = StringIO()
-        sys.stdout = StringIO()
-        try:
-            self.model = self.model.fit(
-                train_data,
-                train_labels,
-                eval_set=(val_data, val_labels),
-                callbacks=callbacks,
-            )
-        finally:
-            sys.stderr = stderr_backup
-            # to surpress:  [Warning] No further splits with positive gain, best gain: -inf with the current config
-            sys.stdout = stdout_backup
+        # Suppress LightGBM chatter locally
+        with redirect_stderr(StringIO()), redirect_stdout(StringIO()):
+            self.model = self.model.fit(
+                train_data,
+                train_labels,
+                eval_set=[(val_data, val_labels)],
+                callbacks=callbacks,
+            )

Also applies to: 30-35

🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/lgbm.py around lines 24-40 (and also apply
same fix to lines 30-35), the code currently passes eval_set as a single tuple
and hijacks sys.stdout/sys.stderr by reassigning them; change eval_set to a list
of tuples (e.g. [(val_data, val_labels)]) so LightGBM records validation metrics
and best_score_, and replace the manual sys reassignment with scoped context
managers (use contextlib.redirect_stdout and contextlib.redirect_stderr or an
equivalent suppression context) so output suppression is limited to the fit call
and does not affect other threads or parts of the process; ensure stdout/stderr
are automatically restored by the context managers.

Comment on lines +75 to +81
def _explain_model(self, reps, labels):
if not hasattr(self.model, "feature_importances_"):
raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.")
# feature_importances = self.model.feature_importances_
shap_values = self.explainer.shap_values(reps, labels)
# feature_importances = np.abs(shap_values).mean(axis=1)
return shap_values
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Use a robust fitted-model check

Checking feature_importances_ is brittle. Prefer scikit’s check_is_fitted or LightGBM’s booster_.

+from sklearn.utils.validation import check_is_fitted
@@
-        if not hasattr(self.model, "feature_importances_"):
-                raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.")
+        check_is_fitted(self.model, attributes=["booster_"])
🧰 Tools
🪛 Ruff (0.14.0)

77-77: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/lgbm.py around lines 75 to 81, the current
fitted-model check uses the brittle attribute feature_importances_; replace this
with a robust check by using
sklearn.utils.validation.check_is_fitted(self.model) (import it at the top) or,
for LightGBM-specific safety, check for hasattr(self.model, "booster_") or
callable(getattr(self.model, "booster_", None)); wrap the check in a try/except
to raise the same ValueError message if not fitted, then proceed to compute
shap_values and return them.

Comment on lines +157 to +183
# Process fold by fold across all iterations
for fold_name in sorted(all_folds):
for iteration in iterations:
fold_path = iteration / fold_name
if fold_path.exists():
for file_path in sorted(fold_path.rglob(filename)):
print(f"Loading file: {file_path}")

# Load the CSV file
df = pl.read_csv(file_path)

# Check original ID range
original_ids = df.select(pl.col(id_column)).to_numpy().flatten()
print(f" Original ID range: {original_ids.min()} - {original_ids.max()}")

# Apply offset to prevent clashes
df = df.with_columns(pl.col(id_column) + counter * max_id)

# Check modified ID range
modified_ids = df.select(pl.col(id_column)).to_numpy().flatten()
print(f" Modified ID range: {modified_ids.min()} - {modified_ids.max()}")
print(f" Counter: {counter}")
print()
dataframes.append(df)
counter += 1

# Concatenate all DataFrames vertically
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

ID offset logic collides across iterations within the same fold

counter increments per fold, not per file; files in different iterations for the same fold receive the same offset.

-    counter = 0
+    counter = 0
@@
-    for fold_name in sorted(all_folds):
-        for iteration in iterations:
+    for fold_idx, fold_name in enumerate(sorted(all_folds)):
+        for iter_idx, iteration in enumerate(iterations):
             fold_path = iteration / fold_name
             if fold_path.exists():
                 for file_path in sorted(fold_path.rglob(filename)):
-                    print(f"Loading file: {file_path}")
+                    logging.info(f"Loading file: {file_path}")
@@
-                    # Apply offset to prevent clashes
-                    df = df.with_columns(pl.col(id_column) + counter * max_id)
+                    # Apply offset to prevent clashes
+                    df = df.with_columns(pl.col(id_column) + counter * max_id)
@@
-                    dataframes.append(df)
-        counter += 1
+                    dataframes.append(df)
+                    counter += 1
🤖 Prompt for AI Agents
In icu_benchmarks/run_utils.py around lines 157 to 183, the ID offset counter is
only incremented once per fold causing files from different iterations of the
same fold to get the same offset; move the counter increment to occur after each
file is processed (inside the file loop) or compute the offset from a unique
per-file index (e.g., iteration index + fold index) so each file receives a
distinct counter before applying df.with_columns(pl.col(id_column) + counter *
max_id); ensure the counter update happens immediately after appending the
modified df to dataframes so offsets do not collide.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants