-
Notifications
You must be signed in to change notification settings - Fork 25
Improvements from Cass project #164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: development
Are you sure you want to change the base?
Conversation
WalkthroughExpanded hyperparameter grids and GPU configs; added explainability export/persistence, alarm-metrics, modality selection and infinite-value sanitation, per-stay row-indicators, calibration & CSV exports, CLI/W&B/tuning enhancements, many experiment and SLURM scripts, Sphinx docs, dependency bumps, and package version update. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant CLI as User/CLI
participant Parser as build_parser
participant Run as run.py
participant CV as execute_repeated_cv
participant Trainer as train_common
participant Wrapper as model wrappers
participant FS as Filesystem
CLI->>Parser: parse args (--file_names, --explain_features, --load_data_vars)
CLI->>Run: start training
Run->>Run: optionally include data vars.gin (if --load_data_vars)
Run->>Run: bind preprocess.file_names (if provided)
Run->>CV: execute_repeated_cv(..., explain_features)
CV->>Trainer: train_common(..., explain_features)
Trainer->>Wrapper: set_explain_features(flag)
loop per fold/repetition
Wrapper->>Wrapper: train/test, compute metrics (calibrated/uncalibrated)
alt explain_features enabled
Wrapper->>FS: save explainer_values_train/test.parquet (averaged and full)
end
Wrapper->>FS: save predictions CSV and pred_indicators.csv
end
Run->>Parser: aggregate_results(explain_features)
Parser->>FS: read/write aggregated metrics and explainer values
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 17
🔭 Outside diff range comments (2)
icu_benchmarks/models/custom_metrics.py (1)
66-74:⚠️ Potential issueFix reference to moved function
mae_with_invert_compute_fn.The pipeline is failing because the function
mae_with_invert_compute_fnhas been moved outside the class, but it's still being referenced in the constructor on line 66. This function should either be moved back inside the class as a method, or it needs to be made accessible at the module level.def __init__( self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False, invert_transform: Callable = lambda x: x, ) -> None: super(MAE, self).__init__( - lambda x, y: mae_with_invert_compute_fn(x, y, invert_transform), + lambda x, y: mae_with_invert_compute_fn(x, y, invert_transform), output_transform=output_transform, check_compute_fn=check_compute_fn, ) -def mae_with_invert_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor, invert_fn=Callable) -> float: +def mae_with_invert_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor, invert_fn=Callable) -> float: y_true = invert_fn(y_targets.numpy().reshape(-1, 1))[:, 0] y_pred = invert_fn(y_preds.numpy().reshape(-1, 1))[:, 0] return mean_absolute_error(y_true, y_pred)Alternatively, move the function back inside the class:
def __init__( self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False, invert_transform: Callable = lambda x: x, ) -> None: super(MAE, self).__init__( - lambda x, y: mae_with_invert_compute_fn(x, y, invert_transform), + lambda x, y: self.mae_with_invert_compute_fn(x, y, invert_transform), output_transform=output_transform, check_compute_fn=check_compute_fn, ) + def mae_with_invert_compute_fn(self, y_preds: torch.Tensor, y_targets: torch.Tensor, invert_fn=Callable) -> float: + y_true = invert_fn(y_targets.numpy().reshape(-1, 1))[:, 0] + y_pred = invert_fn(y_preds.numpy().reshape(-1, 1))[:, 0] + return mean_absolute_error(y_true, y_pred) -def mae_with_invert_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor, invert_fn=Callable) -> float: - y_true = invert_fn(y_targets.numpy().reshape(-1, 1))[:, 0] - y_pred = invert_fn(y_preds.numpy().reshape(-1, 1))[:, 0] - return mean_absolute_error(y_true, y_pred)🧰 Tools
🪛 Ruff (0.8.2)
66-66: Undefined name
mae_with_invert_compute_fn(F821)
🪛 GitHub Actions: CI
[error] 66-66: Flake8: undefined name 'mae_with_invert_compute_fn'.
icu_benchmarks/data/split_process_data.py (1)
315-315:⚠️ Potential issueReconciling train_size parameter type.
train_sizeis assigned a float (e.g.,0.80) yet is annotated asintin the signature. This mismatch can lead to confusion or TypeErrors. Clarify whether it's intended to be a float ratio or an absolute count.Also applies to: 322-322
🧹 Nitpick comments (24)
icu_benchmarks/run.py (1)
187-187: Feature explanation enabled by defaultSetting
explain_features=Trueenables SHAP-based feature explanation during the cross-validation process, which enhances model interpretability but may increase computational overhead.Consider making this parameter configurable through command-line arguments rather than hardcoding it to
True. This would allow users to disable feature explanation when prioritizing performance or when explanations aren't needed.- explain_features=True, + explain_features=args.explain_features,Then add this argument to the parser in the
build_parserfunction:parser.add_argument("--explain-features", action="store_true", default=False, help="Enable feature explanation during model training")experiments/charhpc_wandb_sweep.sh (1)
1-14: Add error handling and usage documentation to the SLURM script.This script provides a good framework for running W&B sweeps on an HPC cluster with GPU resources. However, it lacks error handling and usage documentation that would make it more robust.
Consider adding:
- Parameter validation to check if the sweep ID is provided
- A usage comment at the top explaining the purpose and how to run the script
- Error handling for common failures
#!/bin/bash +# Usage: sbatch charhpc_wandb_sweep.sh <sweep_id> +# Purpose: Run a W&B sweep on the pgpu partition with GPU access + #SBATCH --job-name=yaib_experiment #SBATCH --partition=pgpu # -p #SBATCH --cpus-per-task=8 # -c #SBATCH --mem=200gb #SBATCH --output=logs/classification_%a_%j.log # %j is job id #SBATCH --gpus=1 #SBATCH --time=24:00:00 source /etc/profile.d/conda.sh eval "$(conda shell.bash hook)" conda activate yaib_req_pl + +if [ -z "$1" ]; then + echo "Error: No sweep ID provided. Usage: sbatch charhpc_wandb_sweep.sh <sweep_id>" + exit 1 +fi + wandb agent --count 1 cassandra_hpi/cassandra/"$1" +experiments/charhpc_wandb_sweep_cpu.sh (1)
1-16: Add error handling to CPU-based SLURM script and standardize with GPU version.This CPU-focused script is well-structured with appropriate resource allocations and includes a helpful debug command. However, it lacks error handling similar to the GPU version.
Consider:
- Adding parameter validation and error handling
- Including usage documentation
- Standardizing the structure with the GPU version for consistency
#!/bin/bash +# Usage: sbatch charhpc_wandb_sweep_cpu.sh <sweep_id> +# Purpose: Run a W&B sweep on the compute partition with CPU resources + #SBATCH --job-name=yaib_experiment #SBATCH --partition=compute # -p #SBATCH --cpus-per-task=16 # -c #SBATCH --mem=250gb #SBATCH --output=logs/classification_%a_%j.log # %j is job id #SBATCH --time=48:00:00 source /etc/profile.d/conda.sh eval "$(conda shell.bash hook)" conda activate yaib_req_pl + +if [ -z "$1" ]; then + echo "Error: No sweep ID provided. Usage: sbatch charhpc_wandb_sweep_cpu.sh <sweep_id>" + exit 1 +fi + wandb agent --count 1 cassandra_hpi/cassandra/"$1" # Debug instance: srun -p gpu --pty -t 5:00:00 --gres=gpu:1 --cpus-per-task=16 --mem=100GB bashicu_benchmarks/models/train.py (1)
216-216: Remove unnecessary f-string prefix.The string doesn't contain any placeholders, so the
fprefix is redundant.- logging.info(f"Saving SHAPS") + logging.info("Saving SHAPS")🧰 Tools
🪛 Ruff (0.8.2)
216-216: f-string without any placeholders
Remove extraneous
fprefix(F541)
icu_benchmarks/models/ml_models/xgboost.py (3)
6-6: Remove unused import.The
numpypackage is imported but not used in the code.-import numpy as np🧰 Tools
🪛 Ruff (0.8.2)
6-6:
numpyimported but unusedRemove unused import:
numpy(F401)
51-56: Consider makingexplain_featuresconfigurable.The feature explanation is currently hardcoded to
True. Consider making this configurable via a parameter to the constructor or through a gin configuration option.- self.explain_features = True + self.explain_features = kwargs.get('explain_features', True)
83-86: Clean up commented-out code.This commented-out method appears to be obsolete since it's been replaced by
_explain_model. Consider removing it if it's no longer needed.-# def explainer(self, reps): -# if not hasattr(self.model, "feature_importances_"): -# raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.") -# return self.model.feature_importances_icu_benchmarks/run_utils.py (1)
81-81: Explicitly convert random integer to string.Adding a random integer to the path name requires conversion to string, which is implicit in this operation. For clarity, consider making this conversion explicit.
- log_dir_run = log_dir_run.with_name(log_dir_run.name + random.randint(1, 10)) + log_dir_run = log_dir_run.with_name(log_dir_run.name + str(random.randint(1, 10)))icu_benchmarks/data/preprocessor.py (3)
185-185: Remove unused variableold_columns.The variable
old_columnsis assigned but never used in the code.- old_columns = data[Split.train][Segment.dynamic].columns🧰 Tools
🪛 Ruff (0.8.2)
185-185: Local variable
old_columnsis assigned to but never usedRemove assignment to unused variable
old_columns(F841)
193-195: Consider uncommenting logging statements.These commented-out logging statements could provide useful debug information. Consider either fully removing them or uncommenting them if they're helpful for debugging.
- # logging.info(f"Excluding {len(self.vars_to_exclude)} : {self.vars_to_exclude}") + logging.debug(f"Excluding {len(self.vars_to_exclude)} : {self.vars_to_exclude}")- # logging.info(f"Applying to {len(vars_to_apply)}") + logging.debug(f"Applying to {len(vars_to_apply)}")
206-206: Consider uncommenting or removing logging statement.This commented-out logging statement regarding column changes could be useful for debugging. Consider either fully removing it or uncommenting it at a debug level.
- # logging.info(f"Data columns: {len(data[Split.train][Segment.dynamic].columns)} -> old columns: {len(old_columns)}, added columns: {set(data[Split.train][Segment.dynamic].columns) - set(old_columns)}") + logging.debug(f"Data columns: {len(data[Split.train][Segment.dynamic].columns)}, added columns: {set(data[Split.train][Segment.dynamic].columns) - set(old_columns)}")icu_benchmarks/models/custom_metrics.py (2)
1-1: Remove unused import.The
numpypackage is imported but not used in the code.-import numpy🧰 Tools
🪛 Ruff (0.8.2)
1-1:
numpyimported but unusedRemove unused import:
numpy(F401)
191-209: Consider removing commented-out code.This large block of commented-out code appears to be an alternative implementation of the
Specificityclass. If it's no longer needed, consider removing it to improve code readability. If it's being kept for future reference, add a comment explaining why.icu_benchmarks/models/alarm_metrics.py (2)
13-29: Question the index-based slicing usage.Inside
convert_to_alarm, it's assumed thatsilence_positivesandfill_gapsoperate exclusively on integer-indexed arrays. If you intend to handle time-indexed data, consider a more robust approach (e.g., aligning predictions by timestamps).
55-72: Clarify the purpose of filling gaps.
- The function
fill_gapscurrently sets all subsequent values to 1 once it encounters a single 1. Ensure this behavior aligns with your domain needs.- The function modifies both
predictionsandground_truth, which can lead to confusion. Consider returning a separate updated copy or clarifying in docstrings.icu_benchmarks/data/split_process_data.py (2)
138-138: Confirm data sanitization flow.
check_sanitize_datapotentially mutates bothdataandvars. Ensure all necessary post-processing steps occur after this call so that your data transformations remain consistent throughout.
195-207: Ensure consistent handling of infinite values.
max_float64 = 0might not be representative of true maximum values. Consider retrievingnp.finfo(np.float64).maxor another large sentinel.- The loop only replaces infinite values in columns with
pl.Float64dtype; ensure columns of other numeric types are also handled if necessary.- max_float64 = 0 + import numpy as np + max_float64 = np.finfo(np.float64).maxexperiments/benchmark_cass.yml (1)
1-84: Add a newline at end of file.YAML best practices include ending files with a newline to avoid lint warnings and ensure consistent merging in version control.
Apply this diff:
No changes to the content +🧰 Tools
🪛 YAMLlint (1.35.1)
[error] 84-84: no new line character at the end of file
(new-line-at-end-of-file)
icu_benchmarks/data/utils.py (3)
9-20: Robust handling of infinite values.The
infinite_removalfunction effectively identifies and replaces infinite values in numeric columns with a safe maximum value. This prevents downstream issues that could occur with infinity values in models and statistical operations.Consider using a more descriptive name for the parameter (
val→df) and adding return type hints:-def infinite_removal(val): +def infinite_removal(df: pl.DataFrame) -> pl.DataFrame: for col in val.select(cs.numeric()).columns: if val[col].is_infinite().any(): logging.info(f"Column '{col}' contains infinite values. Datatype: {val[col].dtype}")
23-48: Comprehensive data sanitization function.The
check_sanitize_datafunction systematically removes duplicates from different data segments while preserving order, which is critical for time series data. It also handles infinite values and maintains the variable list integrity.The function uses
.keys()repeatedly, which static analysis flags as inefficient. Simplify these checks:- group = vars[Var.group] if Var.group in vars.keys() else None - sequence = vars[Var.sequence] if Var.sequence in vars.keys() else None + group = vars[Var.group] if Var.group in vars else None + sequence = vars[Var.sequence] if Var.sequence in vars else NoneAlso apply similar changes to other instances of
in dict.keys().🧰 Tools
🪛 Ruff (0.8.2)
25-25: Use
key in dictinstead ofkey in dict.keys()Remove
.keys()(SIM118)
26-26: Use
key in dictinstead ofkey in dict.keys()Remove
.keys()(SIM118)
29-29: Use
key in dictinstead ofkey in dict.keys()Remove
.keys()(SIM118)
33-33: Use
key in dictinstead ofkey in dict.keys()Remove
.keys()(SIM118)
40-40: Use
key in dictinstead ofkey in dict.keys()Remove
.keys()(SIM118)
51-77: Well-structured modality selection implementation.The
modality_selectionfunction provides a flexible way to filter data based on modalities, with proper validation, error handling, and detailed logging. It updates both the data dictionary and variable references consistently.Consider adding type hints for
varsparameter and return values, and use dictionary comprehensions for cleaner code:def modality_selection( - data: dict[pl.DataFrame], modality_mapping: dict[str], selected_modalities: list[str], vars + data: dict[pl.DataFrame], modality_mapping: dict[str], selected_modalities: list[str], vars: dict ) -> tuple[dict[pl.DataFrame], dict]:Also, replace the for loop in lines 73-76 with a dictionary comprehension:
data = { key: data[key].select([col for col in data[key].columns if col in selected_columns]) for key in data }🧰 Tools
🪛 Ruff (0.8.2)
55-55: Use
key in dictinstead ofkey in dict.keys()Remove
.keys()(SIM118)
56-56: Use
key in dictinstead ofkey in dict.keys()Remove
.keys()(SIM118)
73-73: Use
key in dictinstead ofkey in dict.keys()Remove
.keys()(SIM118)
icu_benchmarks/models/wrappers.py (3)
6-6: Remove unused import.The
ignite.metrics.EpochMetricis imported but never used in the code.-from ignite.metrics import EpochMetric🧰 Tools
🪛 Ruff (0.8.2)
6-6:
ignite.metrics.EpochMetricimported but unusedRemove unused import:
ignite.metrics.EpochMetric(F401)
482-519: Extended log_metrics to support row indicators.The
log_metricsmethod now accepts apred_indicatorsparameter and handles different cases based on its presence and shape. This enables more detailed analysis of predictions at specific time points.There is significant code duplication between the two branches (with and without pred_indicators). Consider refactoring:
def log_metrics(self, label, pred, metric_type, pred_indicators=None): """Log metrics to the PL logs.""" if "Confusion_Matrix" in self.metrics: self.log_dict(confusion_matrix(self.label_transform(label), self.output_transform(pred)), sync_dist=True) self.log_dict( { f"{metric_type}/{name}": (metric(self.label_transform(label), self.output_transform(pred))) # For every metric for name, metric in self.metrics.items() # Filter out metrics that return a tuple (e.g. precision_recall_curve) if not isinstance(metric(self.label_transform(label), self.output_transform(pred)), tuple) and name != "Confusion_Matrix" }, sync_dist=True, ) # Handle pred_indicators if present if pred_indicators is not None: if (len(pred_indicators.shape) > 1 and len(pred.shape) > 1 and pred_indicators.shape[1] == pred.shape[1] and pred_indicators.shape[0] == pred.shape[0]): pred_indicators = np.hstack((pred_indicators, label.reshape(-1, 1))) pred_indicators = np.hstack((pred_indicators, pred)) # TODO: Implement alarm metrics using row indicators
527-567: Comprehensive model output saving with improved formatting.The
_save_model_outputsmethod now handles different shapes of prediction indicators and includes proper headers in the saved CSV files. This significantly improves the usability of saved outputs for downstream analysis.Remove unnecessary f-string prefixes from string literals that don't contain placeholders:
- logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / f'row_indicators.csv'}") + logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / 'row_indicators.csv'}")Apply similar fixes to lines 554 and 563.
🧰 Tools
🪛 Ruff (0.8.2)
538-538: f-string without any placeholders
Remove extraneous
fprefix(F541)
554-554: f-string without any placeholders
Remove extraneous
fprefix(F541)
563-563: f-string without any placeholders
Remove extraneous
fprefix(F541)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
demo_data/mortality24/mimic_demo_static/attrition.csvis excluded by!**/*.csv
📒 Files selected for processing (23)
configs/prediction_models/XGBClassifier.gin(1 hunks)configs/prediction_models/common/MLTuning.gin(1 hunks)configs/tasks/common/Dataloader.gin(1 hunks)experiments/benchmark_cass.yml(1 hunks)experiments/charhpc_wandb_sweep.sh(1 hunks)experiments/charhpc_wandb_sweep_cpu.sh(1 hunks)experiments/slurm_base_char_sc.sh(1 hunks)icu_benchmarks/cross_validation.py(3 hunks)icu_benchmarks/data/loader.py(3 hunks)icu_benchmarks/data/preprocessor.py(4 hunks)icu_benchmarks/data/split_process_data.py(11 hunks)icu_benchmarks/data/utils.py(1 hunks)icu_benchmarks/models/alarm_metrics.py(1 hunks)icu_benchmarks/models/constants.py(2 hunks)icu_benchmarks/models/custom_metrics.py(4 hunks)icu_benchmarks/models/ml_models/xgboost.py(4 hunks)icu_benchmarks/models/train.py(4 hunks)icu_benchmarks/models/wrappers.py(8 hunks)icu_benchmarks/run.py(1 hunks)icu_benchmarks/run_utils.py(4 hunks)icu_benchmarks/tuning/hyperparameters.py(2 hunks)requirements.txt(2 hunks)setup.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
icu_benchmarks/data/preprocessor.py (1)
Learnt from: rvandewater
PR: rvandewater/YAIB#155
File: icu_benchmarks/data/preprocessor.py:77-143
Timestamp: 2024-11-12T07:39:38.994Z
Learning: In the `apply` method of `PolarsClassificationPreprocessor` in `icu_benchmarks/data/preprocessor.py`, the data is pre-split, necessitating multiple join operations; thus, optimizing by combining joins is not feasible.
🪛 Shellcheck (0.10.0)
experiments/slurm_base_char_sc.sh
[warning] 18-18: Remove space after = if trying to assign a value (for empty string, use var='' ... ).
(SC1007)
[warning] 28-28: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.
(SC2164)
🪛 Ruff (0.8.2)
icu_benchmarks/models/ml_models/xgboost.py
6-6: numpy imported but unused
Remove unused import: numpy
(F401)
78-78: Local variable feature_importances is assigned to but never used
Remove assignment to unused variable feature_importances
(F841)
icu_benchmarks/models/alarm_metrics.py
2-2: numpy.ndarray imported but unused
Remove unused import: numpy.ndarray
(F401)
3-3: torch imported but unused
Remove unused import: torch
(F401)
4-4: sklearn.metrics.precision_score imported but unused
Remove unused import: sklearn.metrics.precision_score
(F401)
icu_benchmarks/data/split_process_data.py
4-4: numpy imported but unused
Remove unused import: numpy
(F401)
45-45: Do not use mutable data structures for argument defaults
Replace with None; initialize within function
(B006)
icu_benchmarks/data/preprocessor.py
185-185: Local variable old_columns is assigned to but never used
Remove assignment to unused variable old_columns
(F841)
icu_benchmarks/models/custom_metrics.py
1-1: numpy imported but unused
Remove unused import: numpy
(F401)
165-165: Redefinition of unused Specificity from line 10
(F811)
icu_benchmarks/models/train.py
216-216: f-string without any placeholders
Remove extraneous f prefix
(F541)
icu_benchmarks/data/utils.py
25-25: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
26-26: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
29-29: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
33-33: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
40-40: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
55-55: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
56-56: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
73-73: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
icu_benchmarks/models/wrappers.py
6-6: ignite.metrics.EpochMetric imported but unused
Remove unused import: ignite.metrics.EpochMetric
(F401)
538-538: f-string without any placeholders
Remove extraneous f prefix
(F541)
554-554: f-string without any placeholders
Remove extraneous f prefix
(F541)
563-563: f-string without any placeholders
Remove extraneous f prefix
(F541)
🪛 GitHub Actions: CI
icu_benchmarks/models/custom_metrics.py
[error] 66-66: Flake8: undefined name 'mae_with_invert_compute_fn'.
🪛 YAMLlint (1.35.1)
experiments/benchmark_cass.yml
[error] 84-84: no new line character at the end of file
(new-line-at-end-of-file)
🔇 Additional comments (33)
setup.py (1)
79-79: Significant version bump from 0.3.1 to 1.1.0The version has been increased from 0.3.1 to 1.1.0, which indicates a major update according to semantic versioning principles. This aligns with the extensive modifications described in the PR summary, including new features like SHAP-based explanations, refined data processing, and extended XGBoost capabilities.
configs/prediction_models/common/MLTuning.gin (1)
4-4: Hyperparameter tuning has been disabledSetting
tune_hyperparameters.n_callsto 0 effectively disables the hyperparameter tuning process. While this may speed up the training process, it means models will use predefined hyperparameters instead of finding optimal ones.Is this intentional? Consider documenting the reason for disabling tuning, especially if you're relying on manually optimized parameters from previous experiments.
configs/tasks/common/Dataloader.gin (1)
6-6: New MPS parameter enabled for PredictionPolarsDatasetThe addition of
PredictionPolarsDataset.mps = Trueenables a new behavior in the data loading process, affecting how data is processed in theto_tensormethod.Please consider documenting what the
mpsparameter represents in this context and how it impacts data processing. Is this related to Metal Performance Shaders for GPU acceleration on Apple devices, or does it have a different meaning in this codebase?requirements.txt (1)
7-7:Details
✅ Verification successful
Verify compatibility of updated dependencies.
The updates to dependency versions and the addition of SHAP are meaningful improvements. The SHAP library (v0.46.0) enables model explainability, which is valuable for understanding feature importance.
Let's check if these versions are current and compatible:
Also applies to: 27-27, 29-29, 35-35
🏁 Script executed:
#!/bin/bash pip index versions pytorch-ignite | head -n 5 pip index versions optuna | head -n 5 pip index versions wandb | head -n 5 pip index versions shap | head -n 5 # Check for any known compatibility issues between these packages pip check pytorch-ignite==0.5.1 optuna==4.1.0 wandb==0.18.5 shap==0.46.0Length of output: 4446
Dependency compatibility verified – no issues found.
The dependency versions stated in the requirements file (including the updated
pytorch-ignite==0.5.1and the addition ofshap==0.46.0) have been checked. The output frompip checkconfirms that there are no compatibility issues with the specified versions. This applies to all indicated lines (7, 27, 29, and 35).icu_benchmarks/models/constants.py (1)
30-30: Great addition of clinically relevant metrics.Adding sensitivity, specificity, and positive predictive value (PPV) metrics is valuable for healthcare applications. These metrics are essential for evaluating models where false positives and false negatives have different clinical implications.
The additions properly complement the existing metrics and follow the established structure. Well done!
Also applies to: 42-46
icu_benchmarks/cross_validation.py (3)
41-41: Good addition of feature explanation capability.The new
explain_featuresparameter provides a clean way to control feature explanation during model training, following the same style as other parameters in this function.
76-78: Excellent defensive programming check.This validation prevents a common configuration error by catching cases where
cv_folds_to_trainexceedscv_folds, which would lead to unexpected behavior or errors later in the execution.
126-126: LGTM - Correctly passes the explain_features parameter.The parameter is correctly passed to the
train_commonfunction, ensuring the feature explanation capability is properly propagated to the underlying training process.icu_benchmarks/models/train.py (3)
56-56: LGTM - Parameter addition is consistent with cross_validation.py changes.The
explain_featuresparameter addition maintains a consistent interface between the cross-validation and training modules.
177-178: Good implementation of the explainer configuration.The explainer is appropriately set before testing begins, ensuring that any test-time feature explanations can be generated if requested.
196-197: LGTM - Conditional persistence of feature explanations.The function correctly persists SHAP data only when feature explanations are requested, avoiding unnecessary computation and storage.
icu_benchmarks/tuning/hyperparameters.py (2)
86-86: Good improvement in file existence check.Changing from
exists()toisfile()makes the check more precise, ensuring that the checkpoint path specifically refers to a file rather than any filesystem object.
225-227: Excellent safeguard for skipping unnecessary tuning.This early return when
n_calls <= 0prevents unnecessary execution of hyperparameter tuning, which is particularly useful when tuning is intentionally disabled through configuration.icu_benchmarks/models/ml_models/xgboost.py (1)
28-28: LGTM: Added missing parameter settings to XGBoost model.The addition of
eval_metric=log_loss, device="cpu", missing='inf'parameters ensures more consistent model behavior.icu_benchmarks/run_utils.py (3)
79-79: LGTM: Added microseconds to timestamp format.Adding microseconds to the timestamp format makes the directory names more unique.
109-109: LGTM: Renamed variable to match new terminology.Renaming
shap_values_testtoexplainer_values_testensures consistency with the naming changes in the rest of the codebase.
128-133: LGTM: Updated file paths for explainer values.Consistently renamed file paths from
test_shap_values.parquettoexplainer_values_test.parquetto match the variable name changes.icu_benchmarks/data/preprocessor.py (2)
131-141: LGTM: Improved conditional logic for sequence validation.Restructuring the sequence check to first verify if "SEQUENCE" exists in vars makes the code more explicit and easier to understand.
188-188: LGTM: Added backend parameter to StepScale.Specifying the backend for the StepScale operation ensures proper integration with the polars data structure.
icu_benchmarks/models/alarm_metrics.py (2)
13-29: Validate parameter types and thresholds.
- The function
convert_to_alarmexpectsground_truthandpredictionsto be arrays. Consider validating that numeric values are provided to avoid runtime errors.- For the threshold of 0.5, be prepared to parameterize it or make it easily configurable if your use case evolves.
31-53: Check boundary conditions when silencing positives.
positive_indices = positive_indices[positive_indices < len(ground_truth) - grace_horizon]: This logic might skip positives near file end. Confirm this is desired.- Repeatedly slicing the array in a while loop (
silence_array[positive_index+1: ... ]) could be susceptible to edge off-by-one errors.icu_benchmarks/data/split_process_data.py (2)
99-107: Confirm correct variable exclusions.The newly introduced logic for
exclude_preprocmerges modality-based variables. Ensure you verify that only the intended variables are excluded. Edge cases:
- Overlapping modalities.
- Missing keys in
modality_mapping.
383-386: Validate multi-fold strategy logic.Switching between
StratifiedKFoldandStratifiedShuffleSplitbased oncv_foldscan produce different distributions. Confirm that randomization is acceptable and that test coverage is thorough in both code paths.configs/prediction_models/XGBClassifier.gin (3)
14-14: Enhanced min_child_weight granularity and range for better regularization control.The values for min_child_weight now include both finer granularity and a wider range (0.1-10 instead of just 0.5-1), allowing for more precise control over overfitting.
18-21: Added regularization parameters for improved model tuning.Four important hyperparameters have been added:
- gamma - Controls minimum loss reduction for creating a new tree split
- early_stopping_rounds - Prevents overfitting by stopping training when validation metric stops improving
- reg_lambda/reg_alpha - L2/L1 regularization parameters to control model complexity
These additions provide better control over the model's complexity and training behavior.
17-17:Details
❓ Verification inconclusive
Verify impact of removing eval_metric parameter.
The eval_metric parameter has been commented out, which means the model might use a default evaluation metric instead of the previously specified "aucpr".
Run the following to check what evaluation metric is being used now:
🏁 Script executed:
#!/bin/bash # Find where eval_metric is defined or used elsewhere rg "eval_metric" --type=py icu_benchmarks/models/Length of output: 220
Action: Confirm intended evaluation metric change
After verification, it appears that the XGBoost model’s evaluation metric is now explicitly set to
log_lossinicu_benchmarks/models/ml_models/xgboost.py:
- The gin config’s
eval_metricparameter (previously set to"aucpr") is commented out.- The model instantiation explicitly passes
eval_metric=log_loss, meaning it will uselog_lossinstead of"aucpr".Please confirm if switching to
log_lossis an intentional change. If this update is desired, the removal of theaucprparameter is acceptable; otherwise, consider reverting or updating the evaluation metric accordingly.icu_benchmarks/data/utils.py (1)
37-38:Details
✅ Verification successful
Verify the impact of removing the wearable_ppgfeature_HRV_SampEn columns.
Line 37 removes specific wearable features from dynamic data, and line 38 updates the variable list accordingly. This could affect downstream analysis if these features were previously used.
Let's check where these features were used:
🏁 Script executed:
#!/bin/bash # Search for references to the removed feature rg "wearable_ppgfeature_HRV_SampEn" --type=pyLength of output: 186
Impact Verification Complete – Removal Appears Safe
After executing the grep search, the only reference to "wearable_ppgfeature_HRV_SampEn" is within its removal in
icu_benchmarks/data/utils.py. No downstream Python references were found, suggesting that the change won’t inadvertently affect further analysis. Please double-check any external or non-Python integrations if applicable, but within the codebase, the removal appears to be safe.icu_benchmarks/data/loader.py (2)
45-46: Added explicit sorting for row indicators to ensure consistent order.Sorting the row indicators ensures that data is processed in a consistent order across different components, which is especially important for time series data where sequence matters.
Also applies to: 53-53
176-179: Group column removal from representation matrix.Line 179 removes the first column (group column) from the data representation. This prevents the group identifier from being used as a feature in the model, which is typically correct behavior as IDs shouldn't be treated as features.
icu_benchmarks/models/wrappers.py (4)
50-50: Debug mode enabled by default.Debug mode has been changed from False to True, which will result in more verbose logging and output saving. This may impact performance in production environments.
Check if this is intentional or if it should only be enabled in development environments.
109-110: Added support for feature explanation.The new
set_explain_featuresmethod allows enabling/disabling feature explanation, which can help with model interpretability.
461-467: Enhanced model output saving and feature explanation.The error handling for saving model outputs has been commented out, and feature explanations are now stored in
explainer_values_test. This change enables more robust feature importance analysis.Consider whether removing the error handling is intentional:
- # try: self._save_model_outputs(pred_indicators, test_pred, test_label) - # except Exception as e: - # logging.warning(f"Could not save model outputs: {e}")
436-437:Details
❓ Verification inconclusive
Removed default validation loss initialization.
Previously, val_loss was initialized to 0.0 as a fallback, but now it directly uses the return value from model.fit(). This assumes that model.fit() always returns a valid validation loss.
Check if all model implementations properly return a validation loss:
🏁 Script executed:
#!/bin/bash # Find model.fit implementations rg "def fit\(" --type=py icu_benchmarks/models/ml_models/Length of output: 59
Attention: Need Manual Verification of
model.fit()Return ValuesThe removal of the default
val_loss = 0.0fallback now assumes that everymodel.fit()invocation reliably returns a valid validation loss. However, our automated search in theicu_benchmarks/models/ml_models/directory (usingrg "def fit\(" --type=py) did not yield any matching implementations. This could mean one of the following:
- The
fit()methods may be implemented elsewhere (or inherited from base classes), or- They might not explicitly return a validation loss as expected.
Action Items:
- Manually verify that all model implementations (including those possibly inherited or defined in other directories) indeed return a meaningful validation loss.
- If any model does not meet this assumption, consider either restoring a fallback mechanism or updating the implementation accordingly.
|
|
||
| echo "This is a SLURM job named" $SLURM_JOB_NAME "with array id" $SLURM_ARRAY_TASK_ID "and job id" $SLURM_JOB_ID | ||
| echo "Resources allocated: " $SLURM_CPUS_PER_TASK "CPUs, " $SLURM_MEM_PER_NODE "GB RAM, " $SLURM_GPUS_PER_NODE "GPUs" | ||
| echi "Task type:" ${TASK} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix typo in echo command.
There's a typo in the command: echi should be echo.
-echi "Task type:" ${TASK}
+echo "Task type:" ${TASK}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| echi "Task type:" ${TASK} | |
| echo "Task type:" ${TASK} |
| TASK=BinaryClassification # BinaryClassification | ||
| YAIB_PATH=/home/vandewrp/projects/YAIB #/dhc/home/robin.vandewater/projects/yaib | ||
| EXPERIMENT_PATH=../yaib_logs/${TASK_NAME}_experiment | ||
| DATASET_ROOT_PATH= /sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format #data/YAIB_Datasets/data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix Bash variable assignment.
Bash assignments should not have spaces around the = sign, as it affects variable assignment.
-DATASET_ROOT_PATH= /sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format #data/YAIB_Datasets/data
+DATASET_ROOT_PATH=/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format #data/YAIB_Datasets/data📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| DATASET_ROOT_PATH= /sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format #data/YAIB_Datasets/data | |
| DATASET_ROOT_PATH=/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format #data/YAIB_Datasets/data |
🧰 Tools
🪛 Shellcheck (0.10.0)
[warning] 18-18: Remove space after = if trying to assign a value (for empty string, use var='' ... ).
(SC1007)
| echo "Dataset: "${DATASETS[$SLURM_ARRAY_TASK_ID]} | ||
| echo "Experiment path: "${EXPERIMENT_PATH} | ||
|
|
||
| cd ${YAIB_PATH} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add error handling for directory change.
The script should exit if the directory change fails, which can happen if the path is incorrect or inaccessible.
-cd ${YAIB_PATH}
+cd ${YAIB_PATH} || { echo "Failed to change to directory ${YAIB_PATH}"; exit 1; }📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| cd ${YAIB_PATH} | |
| cd ${YAIB_PATH} || { echo "Failed to change to directory ${YAIB_PATH}"; exit 1; } |
🧰 Tools
🪛 Shellcheck (0.10.0)
[warning] 28-28: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.
(SC2164)
|
|
||
| icu-benchmarks train \ | ||
| -d ${DATASET_ROOT_PATH} \ | ||
| -n ${DATASETS} \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable expansion may be incorrect.
The script passes ${DATASETS} to the -n flag, but this appears to be an array from the usage on line 25. You likely need to reference a specific element of the array or convert it appropriately.
- -n ${DATASETS} \
+ -n ${DATASETS[$SLURM_ARRAY_TASK_ID]} \📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| -n ${DATASETS} \ | |
| -n ${DATASETS[$SLURM_ARRAY_TASK_ID]} \ |
| model/hyperparameter.class_to_tune = @XGBClassifier | ||
| model/hyperparameter.learning_rate = (0.01, 0.1, "log") | ||
| model/hyperparameter.n_estimators = [50, 100, 250, 500, 750, 1000,1500,2000] | ||
| model/hyperparameter.n_estimators = [50, 100, 250, 500, 750, 1000,1500,2000, 2500, 3000, 3500, 4000, 4500, 5000] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Expanded n_estimators range for potentially deeper ensemble models.
The range of n_estimators has been significantly increased from a previous maximum of 2000 to 5000, allowing for much deeper ensemble models. This can lead to better performance but may also increase training time substantially.
Make sure this expanded range is compatible with your computational resources. Run the following to check if the tuning is actually enabled:
🏁 Script executed:
#!/bin/bash
# Check if hyperparameter tuning is enabled in MLTuning.gin
rg "tune_hyperparameters.n_calls" configs/prediction_models/common/MLTuning.ginLength of output: 112
Action Required: Verify Hyperparameter Tuning Configuration
- The expanded
n_estimatorsrange inconfigs/prediction_models/XGBClassifier.ginnow goes up to 5000 (previous max was 2000). This change could improve model performance but also significantly increase training time. - However, the output from
configs/prediction_models/common/MLTuning.ginshowstune_hyperparameters.n_calls = 0, indicating that hyperparameter tuning is currently disabled. - Please confirm whether hyperparameter tuning is intended to be enabled. If not, consider whether expanding the estimator range is still necessary. If tuning should be enabled to take advantage of the broader range, update the tuning configuration accordingly and verify that your computational resources can handle the increased training time.
icu_benchmarks/data/loader.py
Outdated
|
|
||
| def get_feature_names(self) -> List[str]: | ||
| return self.features_df.columns | ||
| return [col for col in self.features_df.columns ]#if col != self.vars["GROUP"] and col != self.vars["SEQUENCE"]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Commented code may affect feature name filtering.
The commented portion would filter out GROUP and SEQUENCE columns from feature names. Currently, these columns will be included, which might affect downstream models that expect only feature columns.
Let's check if these columns are used as features elsewhere:
🏁 Script executed:
#!/bin/bash
# Find where get_feature_names() is used
rg "get_feature_names\(\)" --type=py icu_benchmarks/models/Length of output: 157
Attention: Inconsistent Feature Filtering – Investigate GROUP and SEQUENCE Exclusion
The current implementation in icu_benchmarks/data/loader.py (line 75) returns all columns from features_df, including GROUP and SEQUENCE, because the filtering logic is commented out. Notably, in icu_benchmarks/models/train.py, the return value of get_feature_names() is used to set the model’s trained columns. This means that the extra columns are now treated as features, which may lead to downstream issues if these columns are not intended for model training.
-
Location Affected:
icu_benchmarks/data/loader.py(line 75): The filtering condition forGROUPandSEQUENCEis commented out.icu_benchmarks/models/train.py: The output ofget_feature_names()is passed tomodel.set_trained_columns().
-
Action Required:
Verify whether excluding theGROUPandSEQUENCEcolumns is necessary. If these columns are not meant to be part of the feature set, please re-enable the filtering logic. Otherwise, update the documentation to clarify that these columns are intentionally included during training.
icu_benchmarks/data/loader.py
Outdated
| data, labels, row_indicators = self.get_data_and_labels() | ||
| if self.mps: | ||
| return from_numpy(data).to(float32), from_numpy(labels).to(float32) | ||
| return from_numpy(data).to(float32), from_numpy(labels).to(float32), row_indicators |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Return type extended to include row indicators.
The to_tensor method now returns row_indicators along with data and labels when mps is true. Ensure all calling code handles this additional return value properly.
Check if all callers of this method handle the additional return value:
🏁 Script executed:
#!/bin/bash
# Find where to_tensor is called
rg "\bto_tensor\(\)" --type=pyLength of output: 129
Attention: Update DataLoader Usage to Properly Handle Triple Return Value
The recent change to extend the return type of to_tensor now includes row_indicators along with data and labels. However, the call in icu_benchmarks/models/train.py:
else DataLoader([test_dataset.to_tensor()], batch_size=1)currently does not unpack or handle the additional row_indicators. Please update the caller to either unpack the tuple (e.g., (data, labels, row_indicators) = test_dataset.to_tensor()) or adjust downstream code to accommodate the extra value accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
icu_benchmarks/models/train.py (1)
209-235:⚠️ Potential issueInconsistency in explainer values handling between test and train.
There's an inconsistency in how explainer values are processed:
- For test values (line 222):
explainer_valuesis used directly- For train values (line 230):
explainer_values.valuesis usedThis suggests different object types are being handled, which could cause errors.
- shaps_train = pl.DataFrame(schema=trained_columns, data=explainer_values.values) + shaps_train = pl.DataFrame(schema=trained_columns, data=explainer_values)Alternatively, if the objects truly have different structures:
- shaps_train = pl.DataFrame(schema=trained_columns, data=explainer_values.values) + # Use consistent structure based on the explainer_values_train object type + if hasattr(explainer_values, 'values'): + shaps_train = pl.DataFrame(schema=trained_columns, data=explainer_values.values) + else: + shaps_train = pl.DataFrame(schema=trained_columns, data=explainer_values)🧰 Tools
🪛 Ruff (0.8.2)
216-216: f-string without any placeholders
Remove extraneous
fprefix(F541)
🧹 Nitpick comments (20)
experiments/benchmark_cass.yml (9)
14-17: Hyperparameter Tuning Naming and Checkpoint Options
The parameters-tnandSSIare configured for naming conventions within hyperparameter tuning. Additionally, several hyperparameter checkpoint options remain commented out. If these are no longer required, consider removing them to improve clarity.
18-23: Historical Configuration Comments
The commented-out lines for alternate hyperparameter tuning log paths and verbose flags provide useful historical context. Once the configuration is finalized, consider cleaning up unnecessary commented lines to reduce clutter.
24-27: Modalities Flag in Command
The--modalitiesflag is introduced correctly with the active value"all". The commented-out verbose flag is acceptable during experimentation—if it’s no longer needed, its removal might simplify the configuration.
28-30: Optional Log Configuration Comments
Additional commented-out options (for example, alternate transformer log paths and flags like-gcor-lc) remain present. Keeping these can be useful during development, but you may want to remove them once the configuration stabilizes.
33-49: Data Directory Configuration
Thedata_dirparameter is configured with one active data path and several alternative paths commented out. This flexible setup allows for quick switching between data sources during development. Ensure that the selected active path is valid and consider cleaning up the commented entries once you've finalized the configuration.
68-81: Model Parameter Selection
Themodelparameter explicitly selectsXGBClassifieras the active model, while alternative model options are provided as comments. This configuration aligns with the enhanced hyperparameter settings. If these alternatives will no longer be used, removing them might improve readability.
82-89: Modalities Configuration Consistency
Themodalitiesparameter lists two active values: a string"all"and an array of modality identifiers. Please verify that the system consuming this config can handle mixed type values. If a consistent format is needed, consider normalizing these entries.
93-95: Pretrained Imputation Setting
Theuse_pretrained_imputationparameter is set with the valueNone, indicating that no pretrained imputation is applied. Confirm that downstream processes correctly interpret this setting (e.g., as a null value) or adjust if a different behavior is required.
101-101: Newline at End-of-File
The file is missing a newline character at the end. Please add a newline to adhere to YAML style guidelines and avoid potential issues with certain tools or environments.🧰 Tools
🪛 YAMLlint (1.35.1)
[error] 101-101: no new line character at the end of file
(new-line-at-end-of-file)
icu_benchmarks/models/train.py (1)
216-216: Remove extraneous f-string prefix.The string doesn't contain any placeholders, so the
fprefix is unnecessary.- logging.info(f"Saving SHAPS") + logging.info("Saving SHAPS")🧰 Tools
🪛 Ruff (0.8.2)
216-216: f-string without any placeholders
Remove extraneous
fprefix(F541)
icu_benchmarks/models/ml_models/xgboost.py (2)
6-6: Remove unused import.The
numpymodule is imported but never used in this file.-import numpy as np🧰 Tools
🪛 Ruff (0.8.2)
6-6:
numpyimported but unusedRemove unused import:
numpy(F401)
84-87: Consider removing commented-out code.The commented-out
explainermethod seems to be replaced by the new SHAP-based implementation. Unless this is intentionally being kept for reference, consider removing it to keep the codebase clean.- # def explainer(self, reps): - # if not hasattr(self.model, "feature_importances_"): - # raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.") - # return self.model.feature_importances_icu_benchmarks/models/wrappers.py (8)
6-6: Remove unused import.The import
EpochMetricis never referenced in the code and can safely be removed to reduce clutter and improve maintainability.-from ignite.metrics import EpochMetric🧰 Tools
🪛 Ruff (0.8.2)
6-6:
ignite.metrics.EpochMetricimported but unusedRemove unused import:
ignite.metrics.EpochMetric(F401)
109-110: Add docstring for new method.The
set_explain_featuresmethod would benefit from a short docstring describing how and when to call it, ensuring maintainers and users understand its purpose.def set_explain_features(self, set_explain_features: bool): + """ + Enable or disable feature explanations for downstream tasks. + """ self.explain_features = set_explain_features
436-437: Clarify return value frommodel.fit.The call to
self.model.fit(train_data, train_labels)is stored inval_loss, but it's unclear if the model’sfitmethod actually returns a loss value. Consider renaming the variable or verifying the returned object to avoid confusion.
461-462: Restore error handling if necessary.Removing the
tryblock around_save_model_outputsmeans the method could raise exceptions that go unhandled during testing. Ensure that this is intentional or reintroduce suitable error handling.
463-464: Reevaluate commented-out exception handler.The exception handling has been fully commented out, so any file I/O issues will now propagate. If the intention is to log such failures, reintroduce the logging statement in an
exceptblock.
473-473: Using raw label/pred in log metrics.Unlike line 470, here you leave the data as-is. Consider standardizing your approach (either always cast or ensure consistent types) for consistent metric logging.
482-520: Unify metric logging logic.Within
log_metrics, both theif pred_indicators is Noneandelseblocks contain largely similar metric logging code (confusion matrix, dictionary creation). Consider unifying these sections to reduce duplication and complexity.
528-568: Clean up duplicate or commented-out lines & fix extraneous f-strings.
- Several commented-out lines in
_save_model_outputsadd noise. Remove or refactor them for clarity.- The static analysis flags f-strings without placeholders (e.g.,
f'row_indicators.csv'). Remove unusedfprefixes to avoid confusion.-logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / f'row_indicators.csv'}") +logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / 'row_indicators.csv'}")🧰 Tools
🪛 Ruff (0.8.2)
538-538: f-string without any placeholders
Remove extraneous
fprefix(F541)
554-554: f-string without any placeholders
Remove extraneous
fprefix(F541)
563-563: f-string without any placeholders
Remove extraneous
fprefix(F541)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
configs/prediction_models/common/MLTuning.gin(1 hunks)configs/tasks/BinaryClassification.gin(1 hunks)experiments/benchmark_cass.yml(1 hunks)icu_benchmarks/data/split_process_data.py(12 hunks)icu_benchmarks/models/ml_models/xgboost.py(4 hunks)icu_benchmarks/models/train.py(4 hunks)icu_benchmarks/models/wrappers.py(8 hunks)icu_benchmarks/run_utils.py(6 hunks)icu_benchmarks/tuning/hyperparameters.py(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- configs/prediction_models/common/MLTuning.gin
- icu_benchmarks/run_utils.py
- icu_benchmarks/tuning/hyperparameters.py
🧰 Additional context used
🧬 Code Definitions (2)
icu_benchmarks/models/wrappers.py (2)
icu_benchmarks/models/utils.py (2)
log_metrics(162-183)name(158-159)icu_benchmarks/models/custom_metrics.py (1)
confusion_matrix(137-146)
icu_benchmarks/models/train.py (1)
icu_benchmarks/models/wrappers.py (1)
set_explain_features(109-110)
🪛 Ruff (0.8.2)
icu_benchmarks/data/split_process_data.py
4-4: numpy imported but unused
Remove unused import: numpy
(F401)
45-45: Do not use mutable data structures for argument defaults
Replace with None; initialize within function
(B006)
icu_benchmarks/models/ml_models/xgboost.py
6-6: numpy imported but unused
Remove unused import: numpy
(F401)
icu_benchmarks/models/wrappers.py
6-6: ignite.metrics.EpochMetric imported but unused
Remove unused import: ignite.metrics.EpochMetric
(F401)
538-538: f-string without any placeholders
Remove extraneous f prefix
(F541)
554-554: f-string without any placeholders
Remove extraneous f prefix
(F541)
563-563: f-string without any placeholders
Remove extraneous f prefix
(F541)
icu_benchmarks/models/train.py
216-216: f-string without any placeholders
Remove extraneous f prefix
(F541)
🪛 YAMLlint (1.35.1)
experiments/benchmark_cass.yml
[error] 101-101: no new line character at the end of file
(new-line-at-end-of-file)
🔇 Additional comments (23)
experiments/benchmark_cass.yml (5)
1-9: Command Configuration Validation
The command list is structured properly with the appropriate use of environment variables (${env}) and program parameters (${program}). The change from the commented-out "BinaryClassification" to the active "CassClassification" appears intentional and aligns with the updated benchmarking focus.
10-13: Logging and Tuning Parameters
The logging flag (--log-dir) and tuning flags (--tune,--wandb-sweep) are clearly specified. Please ensure that the log directory (/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs) is accessible and that appropriate permissions are in place in the target environment.
31-32: Benchmark Method and Name
Setting the benchmarking method togridand naming the benchmarkyaib_classification_benchmarkclearly defines the experiment’s search strategy and identification.
90-92: Seed for Reproducibility
The random seed is set to1111, which helps in ensuring reproducible results across experiment runs. This is a standard and effective configuration.
96-96: Program Specification
The program is set toicu-benchmarks, which should correspond with the intended execution context of the benchmark. Verify that all associated scripts and modules recognize this program identifier appropriately.configs/tasks/BinaryClassification.gin (1)
25-30: Configuration enhancement for required data segments and filenames.The addition of explicit required segments and file name mappings improves configuration clarity and makes dependencies more explicit. This change aligns well with the overall improvements in data processing from the PR.
icu_benchmarks/models/train.py (3)
56-56: Well-designed addition of feature explanation flag.Adding the
explain_featuresparameter with a default ofFalseprovides a clean way to optionally enable feature explanations during testing without breaking existing functionality.
177-178: Good implementation of feature explanation configuration.The model is properly configured for feature explanation before testing, which aligns with the PR's goal of enabling feature explanation throughout training.
196-197: Conditionally persisting explainer data is a good approach.The conditional call to
persist_shap_dataonly whenexplain_featuresisTrueensures that explainer values are only computed and saved when needed.icu_benchmarks/models/ml_models/xgboost.py (3)
28-28: Enhanced model initialization with improved parameters.Adding
eval_metric=log_loss,missing='inf', andverbosity=0parameters provides better default behavior for the XGBoost classifier.
51-57: Good implementation of SHAP-based feature explanation.The implementation of TreeExplainer with
feature_perturbation="interventional"andmodel_output="probability"parameters provides an effective way to explain feature importance. The conditional computation of SHAP values only when feature explanation is enabled is also a good approach.
76-83: Well-designed feature explanation method.The
_explain_modelmethod properly returns SHAP values from the explainer, which is a more interpretable way to understand feature importance than the built-in feature importances.icu_benchmarks/data/split_process_data.py (5)
4-4: Remove unused import.The
numpymodule is imported but never used in this file.-import numpy as np🧰 Tools
🪛 Ruff (0.8.2)
4-4:
numpyimported but unusedRemove unused import:
numpy(F401)
45-45: Avoid mutable default arguments.Using
["GROUP", "LABEL"]as a default forrequired_var_typescan introduce subtle bugs if the list is mutated. Replace withNoneor a tuple and initialize inside the function as needed.-required_var_types=["GROUP", "LABEL"], +required_var_types=None,And inside the function:
if required_var_types is None: required_var_types = ["GROUP", "LABEL"]🧰 Tools
🪛 Ruff (0.8.2)
45-45: Do not use mutable data structures for argument defaults
Replace with
None; initialize within function(B006)
99-111: Good implementation of modality exclusion logic.The added logic for excluding variables from preprocessing based on modality mapping is well-structured and includes proper error checking and informative logging messages.
322-322: Improved default train_size parameter.Setting a default value of 0.80 for
train_sizeprovides a reasonable split ratio and makes the function more user-friendly by not requiring explicit specification of this parameter.
383-386: Enhanced cross-validation strategy for small fold counts.Using
StratifiedShuffleSplitwhen fold count is ≤ 2 is a good approach for handling smaller cross-validation scenarios, as it ensures proper class distribution in the splits.icu_benchmarks/models/wrappers.py (6)
50-50: Confirm the impact of always enabling debug mode.Setting
debug = Truemight produce verbose logs or additional overhead. Ensure that this setting is intentional for production environments, as debug logs can expose sensitive details and affect performance.
66-66: Logging column count is useful.This logging addition provides helpful visibility into how many columns are being set for training. No issues here.
228-229: No functional changes.These lines appear to be blank space insertions or formatting tweaks. No action needed.
432-432: Confirm non-null usage ofrow_indicators.You are now passing
row_indicatorstolog_metricsfor the training phase. Ifrow_indicatorscan beNonefor the training set, ensure thatlog_metricscan handle it gracefully.
466-467: Use_explain_modelconsistently.You replaced
self.explain_modelwithself.explainer_values_test = self._explain_model(...). If_explain_modelprovides all needed functionality, confirm thatexplain_modelis no longer required or reference it for clarity.
470-470: Casting labels and predictions to float32.Confirm that explicitly casting
test_labelandtest_predtofloat32does not introduce precision issues for downstream calculations.
| logging.info("Checking for infinite values.") | ||
| for col in val.select(cs.numeric()).columns: | ||
| if val[col].is_infinite().any(): | ||
| logging.info(f"Column '{col}' contains infinite values. Datatype: {val[col].dtype}") | ||
|
|
||
| max_float64 = 0 | ||
| # Replace infinite values with the maximum value for float64 | ||
| val = val.with_columns([ | ||
| pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col) | ||
| for col in val.columns if val[col].dtype == pl.Float64 | ||
| ]) | ||
| dict[key] = val | ||
| logging.info(f"Amount of columns: {len(val.columns)}") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The max_float64 value is incorrectly set to 0.
Setting max_float64 = 0 before replacing infinite values doesn't seem correct. To properly handle infinite values, you should use the maximum finite value for float64 from numpy.
- max_float64 = 0
+ # Use numpy's maximum finite value for float64
+ import numpy as np
+ max_float64 = np.finfo(np.float64).maxThis ensures that infinite values are replaced with the largest valid float64 value instead of 0, which could significantly affect model training.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| logging.info("Checking for infinite values.") | |
| for col in val.select(cs.numeric()).columns: | |
| if val[col].is_infinite().any(): | |
| logging.info(f"Column '{col}' contains infinite values. Datatype: {val[col].dtype}") | |
| max_float64 = 0 | |
| # Replace infinite values with the maximum value for float64 | |
| val = val.with_columns([ | |
| pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col) | |
| for col in val.columns if val[col].dtype == pl.Float64 | |
| ]) | |
| dict[key] = val | |
| logging.info(f"Amount of columns: {len(val.columns)}") | |
| logging.info("Checking for infinite values.") | |
| for col in val.select(cs.numeric()).columns: | |
| if val[col].is_infinite().any(): | |
| logging.info(f"Column '{col}' contains infinite values. Datatype: {val[col].dtype}") | |
| # Use numpy's maximum finite value for float64 | |
| import numpy as np | |
| max_float64 = np.finfo(np.float64).max | |
| # Replace infinite values with the maximum value for float64 | |
| val = val.with_columns([ | |
| pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col) | |
| for col in val.columns if val[col].dtype == pl.Float64 | |
| ]) | |
| dict[key] = val | |
| logging.info(f"Amount of columns: {len(val.columns)}") |
# Conflicts: # icu_benchmarks/data/loader.py # icu_benchmarks/data/split_process_data.py # icu_benchmarks/models/ml_models/xgboost.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
icu_benchmarks/data/split_process_data.py (1)
220-244: Update function documentation to match new return signature.The
check_sanitize_datafunction now returns bothdataandvars, but the function docstring doesn't reflect this change. Update the documentation to match the actual implementation.def check_sanitize_data(data, vars): - """Check for duplicates in the loaded data and remove them.""" + """Check for duplicates in the loaded data and remove them. + + Args: + data: Dictionary containing data divided into segments + vars: Dictionary containing variable metadata + + Returns: + Tuple of (sanitized_data, updated_vars) + """🧰 Tools
🪛 Ruff (0.8.2)
220-220: Redefinition of unused
check_sanitize_datafrom line 19(F811)
222-222: Use
key in dictinstead ofkey in dict.keys()Remove
.keys()(SIM118)
223-223: Use
key in dictinstead ofkey in dict.keys()Remove
.keys()(SIM118)
225-225: Use
key in dictinstead ofkey in dict.keys()Remove
.keys()(SIM118)
230-230: Use
key in dictinstead ofkey in dict.keys()Remove
.keys()(SIM118)
235-235: Use
key in dictinstead ofkey in dict.keys()Remove
.keys()(SIM118)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
icu_benchmarks/data/split_process_data.py(12 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
icu_benchmarks/data/split_process_data.py
4-4: numpy imported but unused
Remove unused import: numpy
(F401)
45-45: Do not use mutable data structures for argument defaults
Replace with None; initialize within function
(B006)
🔇 Additional comments (7)
icu_benchmarks/data/split_process_data.py (7)
4-4: Thenumpyimport is now being used in the code.Previously flagged as unused, now it appears to be used for handling infinite values (in concept at least). However, the numpy import is still not directly referenced in the implementation. It could be better utilized for setting the maximum float64 value.
🧰 Tools
🪛 Ruff (0.8.2)
4-4:
numpyimported but unusedRemove unused import:
numpy(F401)
45-45: Avoid mutable default arguments.Using
["GROUP", "LABEL"]as a default forrequired_var_typescan introduce subtle bugs if the list is mutated. Replace withNoneor a tuple and initialize inside the function as needed.- required_var_types=["GROUP", "LABEL"], + required_var_types=None,Then add inside the function:
if required_var_types is None: required_var_types = ["GROUP", "LABEL"]🧰 Tools
🪛 Ruff (0.8.2)
45-45: Do not use mutable data structures for argument defaults
Replace with
None; initialize within function(B006)
99-111: LGTM! Good implementation of modality exclusion.The implementation of modality-based variable exclusion is well-structured with appropriate error handling and logging. This provides a flexible way to exclude variables that have already undergone preprocessing.
138-138: Function call updated to match new return signature.The code correctly handles the updated return signature of
check_sanitize_data, which now returns both the sanitized data and updated vars.
199-207: The max_float64 value is incorrectly set to 0.Setting
max_float64 = 0before replacing infinite values doesn't seem correct. To properly handle infinite values, you should use the maximum finite value for float64 from numpy.- max_float64 = 0 + # Use numpy's maximum finite value for float64 + max_float64 = np.finfo(np.float64).maxThis ensures that infinite values are replaced with the largest valid float64 value instead of 0, which could significantly affect model training.
439-442: Good enhancement for flexibility in cross-validation.The conditional logic for choosing between StratifiedKFold and StratifiedShuffleSplit based on the number of folds is a good enhancement that improves flexibility, especially for smaller datasets.
378-378: Good default value for train_size parameter.Setting a default value of 0.80 for train_size is a sensible choice and aligns with common machine learning practices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (1)
icu_benchmarks/data/utils.py (1)
26-51: Apply static analysis suggestions and improve code clarityThe function works correctly but has some style issues flagged by static analysis tools.
Apply this diff to address the static analysis hints:
def check_sanitize_data(data, vars): """Check for duplicates in the loaded data and remove them.""" - group = vars[Var.group] if Var.group in vars.keys() else None - sequence = vars[Var.sequence] if Var.sequence in vars.keys() else None + group = vars[Var.group] if Var.group in vars else None + sequence = vars[Var.sequence] if Var.sequence in vars else None keep = "last" logging.info(data.keys()) - if Segment.static in data.keys(): + if Segment.static in data: old_len = len(data[Segment.static]) data[Segment.static] = data[Segment.static].unique(subset=group, keep=keep, maintain_order=True) logging.warning(f"Removed {old_len - len(data[Segment.static])} duplicates from static data.") - if Segment.dynamic in data.keys(): + if Segment.dynamic in data: old_len = len(data[Segment.dynamic]) data[Segment.dynamic] = data[Segment.dynamic].unique(subset=[group, sequence], keep=keep, maintain_order=True) data[Segment.dynamic] = infinite_removal(data[Segment.dynamic]) data[Segment.dynamic] = data[Segment.dynamic].select(~cs.starts_with("wearable_ppgfeature_HRV_SampEn")) vars[Segment.dynamic] = [col for col in vars[Segment.dynamic] if col in data[Segment.dynamic].columns] logging.warning(f"Removed {old_len - len(data[Segment.dynamic])} duplicates from dynamic data.") - if Segment.outcome in data.keys(): + if Segment.outcome in data: old_len = len(data[Segment.outcome]) if sequence in data[Segment.outcome].columns: # We have a dynamic outcome with group and sequence data[Segment.outcome] = data[Segment.outcome].unique(subset=[group, sequence], keep=keep, maintain_order=True) else: data[Segment.outcome] = data[Segment.outcome].unique(subset=[group], keep=keep, maintain_order=True) logging.warning(f"Removed {old_len - len(data[Segment.outcome])} duplicates from outcome data.") return data, vars
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
icu_benchmarks/data/loader.py(3 hunks)icu_benchmarks/data/preprocessor.py(6 hunks)icu_benchmarks/data/utils.py(1 hunks)icu_benchmarks/tuning/hyperparameters.py(7 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- icu_benchmarks/data/loader.py
- icu_benchmarks/tuning/hyperparameters.py
- icu_benchmarks/data/preprocessor.py
🧰 Additional context used
🪛 Ruff (0.12.2)
icu_benchmarks/data/utils.py
28-28: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
29-29: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
32-32: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
36-36: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
43-43: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
58-58: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
59-59: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
77-77: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
🔇 Additional comments (1)
icu_benchmarks/data/utils.py (1)
1-6: LGTM: Clean imports and dependenciesThe imports are well-organized and appropriate for the data processing utilities implemented in this module.
| logging.info("No columns selected. Using all columns.") | ||
| return data, vars | ||
| selected_columns = sum(selected_columns, []) | ||
| selected_columns.extend([vars[Var.group], vars[Var.label], vars[Var.sequence]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add defensive programming for missing standard columns
The function assumes that Var.group, Var.label, and Var.sequence always exist in the vars dictionary. Add validation to prevent KeyError exceptions.
Apply this diff to add defensive checks:
- selected_columns.extend([vars[Var.group], vars[Var.label], vars[Var.sequence]])
+ standard_cols = []
+ for var_type in [Var.group, Var.label, Var.sequence]:
+ if var_type in vars:
+ standard_cols.append(vars[var_type])
+ else:
+ logging.warning(f"Standard column {var_type} not found in vars")
+ selected_columns.extend(standard_cols)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| selected_columns.extend([vars[Var.group], vars[Var.label], vars[Var.sequence]]) | |
| standard_cols = [] | |
| for var_type in [Var.group, Var.label, Var.sequence]: | |
| if var_type in vars: | |
| standard_cols.append(vars[var_type]) | |
| else: | |
| logging.warning(f"Standard column {var_type} not found in vars") | |
| selected_columns.extend(standard_cols) |
🤖 Prompt for AI Agents
In icu_benchmarks/data/utils.py at line 65, the code assumes that keys
Var.group, Var.label, and Var.sequence exist in the vars dictionary, which can
cause KeyError if any are missing. Modify the code to check if each of these
keys exists in vars before extending selected_columns. Only add the columns that
are present to avoid exceptions and ensure robustness.
| def infinite_removal(val): | ||
| for col in val.select(cs.numeric()).columns: | ||
| if val[col].is_infinite().any(): | ||
| logging.info(f"Column '{col}' contains infinite values. Datatype: {val[col].dtype}") | ||
|
|
||
| max_float64 = np.finfo(np.float64).max / 100 | ||
| # Replace infinite values with the maximum value for float64 | ||
| val = val.with_columns( | ||
| [ | ||
| pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col) | ||
| for col in val.columns | ||
| if val[col].dtype == pl.Float64 | ||
| ] | ||
| ) | ||
| return val |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix inconsistent infinite value handling across numeric types
The function checks for infinite values in all numeric columns but only replaces them in Float64 columns. This could leave infinite values in other numeric types (Float32, Int64, etc.) unhandled.
Apply this diff to handle all numeric types consistently:
def infinite_removal(val):
- for col in val.select(cs.numeric()).columns:
- if val[col].is_infinite().any():
- logging.info(f"Column '{col}' contains infinite values. Datatype: {val[col].dtype}")
-
- max_float64 = np.finfo(np.float64).max / 100
- # Replace infinite values with the maximum value for float64
- val = val.with_columns(
- [
- pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col)
- for col in val.columns
- if val[col].dtype == pl.Float64
- ]
- )
+ numeric_cols = val.select(cs.numeric()).columns
+ cols_with_infinite = []
+
+ for col in numeric_cols:
+ if val[col].is_infinite().any():
+ logging.info(f"Column '{col}' contains infinite values. Datatype: {val[col].dtype}")
+ cols_with_infinite.append(col)
+
+ if cols_with_infinite:
+ max_float64 = np.finfo(np.float64).max / 100
+ # Replace infinite values with the maximum value for numeric columns
+ val = val.with_columns(
+ [
+ pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col)
+ for col in cols_with_infinite
+ ]
+ )
+
return val📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def infinite_removal(val): | |
| for col in val.select(cs.numeric()).columns: | |
| if val[col].is_infinite().any(): | |
| logging.info(f"Column '{col}' contains infinite values. Datatype: {val[col].dtype}") | |
| max_float64 = np.finfo(np.float64).max / 100 | |
| # Replace infinite values with the maximum value for float64 | |
| val = val.with_columns( | |
| [ | |
| pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col) | |
| for col in val.columns | |
| if val[col].dtype == pl.Float64 | |
| ] | |
| ) | |
| return val | |
| def infinite_removal(val): | |
| numeric_cols = val.select(cs.numeric()).columns | |
| cols_with_infinite = [] | |
| for col in numeric_cols: | |
| if val[col].is_infinite().any(): | |
| logging.info(f"Column '{col}' contains infinite values. Datatype: {val[col].dtype}") | |
| cols_with_infinite.append(col) | |
| if cols_with_infinite: | |
| max_float64 = np.finfo(np.float64).max / 100 | |
| # Replace infinite values with the maximum value for numeric columns | |
| val = val.with_columns( | |
| [ | |
| pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col) | |
| for col in cols_with_infinite | |
| ] | |
| ) | |
| return val |
🤖 Prompt for AI Agents
In icu_benchmarks/data/utils.py around lines 9 to 23, the function
infinite_removal only replaces infinite values in Float64 columns but checks all
numeric columns, causing inconsistent handling. Modify the replacement logic to
include all numeric columns with floating point types (e.g., Float32 and
Float64) by adjusting the condition to cover these types, and ensure the
replacement value is appropriate for each type. This will consistently handle
infinite values across all relevant numeric columns.
| def modality_selection( | ||
| data: dict[pl.DataFrame], modality_mapping: dict[str], selected_modalities: list[str], vars | ||
| ) -> dict[pl.DataFrame]: | ||
| logging.info(f"Selected modalities: {selected_modalities}") | ||
| selected_columns = [modality_mapping[cols] for cols in selected_modalities if cols in modality_mapping.keys()] | ||
| if not any(col in modality_mapping.keys() for col in selected_modalities): | ||
| raise ValueError("None of the selected modalities found in modality mapping.") | ||
| if selected_columns == []: | ||
| logging.info("No columns selected. Using all columns.") | ||
| return data, vars | ||
| selected_columns = sum(selected_columns, []) | ||
| selected_columns.extend([vars[Var.group], vars[Var.label], vars[Var.sequence]]) | ||
| old_columns = [] | ||
| # Update vars dict | ||
| for key, value in vars.items(): | ||
| if key not in [Var.group, Var.label, Var.sequence]: | ||
| old_columns.extend(value) | ||
| vars[key] = [col for col in value if col in selected_columns] | ||
| # -3 because of standard columns | ||
| logging.info(f"Selected columns: {len(selected_columns) - 3}, original columns: {len(old_columns)}, " | ||
| f"not using: {len(set(old_columns) - set(selected_columns))} columns") | ||
| logging.debug(f"Not using columns: {set(old_columns) - set(selected_columns)}") | ||
| # Update data dict | ||
| for key in data.keys(): | ||
| sel_col = [col for col in data[key].columns if col in selected_columns] | ||
| data[key] = data[key].select(sel_col) | ||
| logging.debug(f"Selected columns in {key}: {len(data[key].columns)}") | ||
| return data, vars |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix return type annotation and apply static analysis suggestions
The function has an incorrect return type annotation and contains static analysis issues.
Apply this diff to fix the return type and style issues:
def modality_selection(
- data: dict[pl.DataFrame], modality_mapping: dict[str], selected_modalities: list[str], vars
-) -> dict[pl.DataFrame]:
+ data: dict[pl.DataFrame], modality_mapping: dict[str], selected_modalities: list[str], vars
+) -> tuple[dict[pl.DataFrame], dict]:
logging.info(f"Selected modalities: {selected_modalities}")
- selected_columns = [modality_mapping[cols] for cols in selected_modalities if cols in modality_mapping.keys()]
- if not any(col in modality_mapping.keys() for col in selected_modalities):
+ selected_columns = [modality_mapping[cols] for cols in selected_modalities if cols in modality_mapping]
+ if not any(col in modality_mapping for col in selected_modalities):
raise ValueError("None of the selected modalities found in modality mapping.")
if selected_columns == []:
logging.info("No columns selected. Using all columns.")
return data, vars
selected_columns = sum(selected_columns, [])
selected_columns.extend([vars[Var.group], vars[Var.label], vars[Var.sequence]])
old_columns = []
# Update vars dict
for key, value in vars.items():
if key not in [Var.group, Var.label, Var.sequence]:
old_columns.extend(value)
vars[key] = [col for col in value if col in selected_columns]
# -3 because of standard columns
logging.info(f"Selected columns: {len(selected_columns) - 3}, original columns: {len(old_columns)}, "
f"not using: {len(set(old_columns) - set(selected_columns))} columns")
logging.debug(f"Not using columns: {set(old_columns) - set(selected_columns)}")
# Update data dict
- for key in data.keys():
+ for key in data:
sel_col = [col for col in data[key].columns if col in selected_columns]
data[key] = data[key].select(sel_col)
logging.debug(f"Selected columns in {key}: {len(data[key].columns)}")
return data, vars📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def modality_selection( | |
| data: dict[pl.DataFrame], modality_mapping: dict[str], selected_modalities: list[str], vars | |
| ) -> dict[pl.DataFrame]: | |
| logging.info(f"Selected modalities: {selected_modalities}") | |
| selected_columns = [modality_mapping[cols] for cols in selected_modalities if cols in modality_mapping.keys()] | |
| if not any(col in modality_mapping.keys() for col in selected_modalities): | |
| raise ValueError("None of the selected modalities found in modality mapping.") | |
| if selected_columns == []: | |
| logging.info("No columns selected. Using all columns.") | |
| return data, vars | |
| selected_columns = sum(selected_columns, []) | |
| selected_columns.extend([vars[Var.group], vars[Var.label], vars[Var.sequence]]) | |
| old_columns = [] | |
| # Update vars dict | |
| for key, value in vars.items(): | |
| if key not in [Var.group, Var.label, Var.sequence]: | |
| old_columns.extend(value) | |
| vars[key] = [col for col in value if col in selected_columns] | |
| # -3 because of standard columns | |
| logging.info(f"Selected columns: {len(selected_columns) - 3}, original columns: {len(old_columns)}, " | |
| f"not using: {len(set(old_columns) - set(selected_columns))} columns") | |
| logging.debug(f"Not using columns: {set(old_columns) - set(selected_columns)}") | |
| # Update data dict | |
| for key in data.keys(): | |
| sel_col = [col for col in data[key].columns if col in selected_columns] | |
| data[key] = data[key].select(sel_col) | |
| logging.debug(f"Selected columns in {key}: {len(data[key].columns)}") | |
| return data, vars | |
| def modality_selection( | |
| data: dict[pl.DataFrame], modality_mapping: dict[str], selected_modalities: list[str], vars | |
| ) -> tuple[dict[pl.DataFrame], dict]: | |
| logging.info(f"Selected modalities: {selected_modalities}") | |
| selected_columns = [ | |
| modality_mapping[cols] | |
| for cols in selected_modalities | |
| if cols in modality_mapping | |
| ] | |
| if not any(col in modality_mapping for col in selected_modalities): | |
| raise ValueError("None of the selected modalities found in modality mapping.") | |
| if selected_columns == []: | |
| logging.info("No columns selected. Using all columns.") | |
| return data, vars | |
| selected_columns = sum(selected_columns, []) | |
| selected_columns.extend([vars[Var.group], vars[Var.label], vars[Var.sequence]]) | |
| old_columns = [] | |
| # Update vars dict | |
| for key, value in vars.items(): | |
| if key not in [Var.group, Var.label, Var.sequence]: | |
| old_columns.extend(value) | |
| vars[key] = [col for col in value if col in selected_columns] | |
| # -3 because of standard columns | |
| logging.info( | |
| f"Selected columns: {len(selected_columns) - 3}, " | |
| f"original columns: {len(old_columns)}, " | |
| f"not using: {len(set(old_columns) - set(selected_columns))} columns" | |
| ) | |
| logging.debug(f"Not using columns: {set(old_columns) - set(selected_columns)}") | |
| # Update data dict | |
| for key in data: | |
| sel_col = [col for col in data[key].columns if col in selected_columns] | |
| data[key] = data[key].select(sel_col) | |
| logging.debug(f"Selected columns in {key}: {len(data[key].columns)}") | |
| return data, vars |
🧰 Tools
🪛 Ruff (0.12.2)
58-58: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
59-59: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
77-77: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
🤖 Prompt for AI Agents
In icu_benchmarks/data/utils.py lines 54 to 81, the function modality_selection
has an incorrect return type annotation and static analysis issues. Change the
return type annotation to a tuple of two dictionaries with keys as strings and
values as pl.DataFrame for the first and dict for the second. Also, ensure the
function consistently returns a tuple (data, vars) in all cases, including when
selected_columns is empty. Adjust the code style to fix any static analysis
warnings, such as using consistent key checks and avoiding redundant calls to
keys().
| if selected_columns == []: | ||
| logging.info("No columns selected. Using all columns.") | ||
| return data, vars |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Verify empty list check logic
The condition selected_columns == [] might not work as intended since selected_columns is a list of lists at this point. Consider checking if the flattened list is empty.
Please verify this logic handles the edge case correctly. The check should occur after flattening the list or use a different approach:
🏁 Script executed:
#!/bin/bash
# Description: Search for similar patterns in the codebase to understand the intended behavior
# Expected: Find usage patterns of modality_selection function
rg -A 10 -B 5 "modality_selection" --type pyLength of output: 5378
🏁 Script executed:
#!/bin/bash
# Extract the modality_selection implementation from utils.py to inspect selected_columns usage
rg -A 20 -B 10 "def modality_selection" --type py icu_benchmarks/data/utils.pyLength of output: 1849
Update empty‐check to use flattened columns
The current check if selected_columns == [] only detects when no modalities matched, but doesn’t catch cases where the mapping exists but yields empty lists (e.g. [[]]). Flatten first, then guard on the resulting column list:
-
File:
icu_benchmarks/data/utils.py, around lines 61–63 -
Replace:
-
if selected_columns == []:
-
logging.info("No columns selected. Using all columns.") -
return data, vars -
selected_columns = sum(selected_columns, [])
with:
```diff
+ # Flatten and catch “no actual columns” after mapping
+ flattened = sum(selected_columns, [])
+ if not flattened:
+ logging.info("No columns selected. Using all columns.")
+ return data, vars
+ selected_columns = flattened
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if selected_columns == []: | |
| logging.info("No columns selected. Using all columns.") | |
| return data, vars | |
| # Flatten and catch “no actual columns” after mapping | |
| flattened = sum(selected_columns, []) | |
| if not flattened: | |
| logging.info("No columns selected. Using all columns.") | |
| return data, vars | |
| selected_columns = flattened |
🤖 Prompt for AI Agents
In icu_benchmarks/data/utils.py around lines 61 to 63, the check for empty
selected_columns only detects an empty list but misses cases where
selected_columns contains empty lists like [[]]. To fix this, flatten the
selected_columns list first and then check if the flattened list is empty. If it
is empty, log the info message and return data and vars as before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (4)
icu_benchmarks/run_utils.py (4)
98-98: Fix collision handling implementation.This is the same issue identified in previous reviews -
random.randint(1, 10)returns an integer butwith_name()expects a string, and the limited range could cause issues.
190-190: Variable naming is consistent.This addresses the previous review comment about variable naming consistency - the variable is now properly named
explainer_values_testthroughout.
209-220: Remove duplicate code block.Lines 216-220 duplicate the functionality of lines 213-215. This redundant try-except block should be removed as it performs the exact same operation.
227-229: Fix TypeError risk with math.isnan().The code uses
math.isnan()on a variable that can be an integer, which will raise a TypeError. This needs to be fixed by either restricting the type check to floats only or casting to float first.
🧹 Nitpick comments (1)
icu_benchmarks/run_utils.py (1)
118-180: Well-structured function for prediction aggregation.The
append_predictions_foldwisefunction is well-documented and implements a systematic approach to aggregate prediction files across folds and iterations. The ID offsetting mechanism prevents clashes, and the fold-by-fold processing ensures consistent ordering.Consider using
enumerate()for the counter as suggested by static analysis:- counter = 0 # Process fold by fold across all iterations - for fold_name in sorted(all_folds): + for counter, fold_name in enumerate(sorted(all_folds)): for iteration in iterations: # ... existing code ... - counter += 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
icu_benchmarks/run.py(5 hunks)icu_benchmarks/run_utils.py(6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- icu_benchmarks/run.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
icu_benchmarks/run_utils.py (4)
icu_benchmarks/utils.py (1)
parse_dict(5-22)icu_benchmarks/wandb_utils.py (1)
wandb_log(46-53)icu_benchmarks/models/utils.py (2)
default(95-109)name(159-160)scripts/evaluate_results/aggregrate_experiment.py (1)
aggregate_results(7-85)
🪛 Ruff (0.12.2)
icu_benchmarks/run_utils.py
175-175: Use enumerate() for index variable counter in for loop
(SIM113)
223-223: Loop control variable repetition not used within loop body
(B007)
224-224: Loop control variable fold not used within loop body
Rename unused fold to _fold
(B007)
🔇 Additional comments (5)
icu_benchmarks/run_utils.py (5)
2-2: New imports look good.The addition of
mathandrandomimports are appropriately used within the file for NaN checking and collision handling respectively.Also applies to: 18-18
20-21: Imported utilities are properly used.The
parse_dictfunction from utils andwandb_logfrom wandb_utils are correctly imported and used in their respective contexts.
66-79: New command line arguments enhance flexibility.The addition of
--file_names,--explain_features, and--load_data_varsarguments provides good configurability for dataset handling and feature explanation functionality.
96-96: Improved timestamp precision.Adding microsecond precision (
%f) to the timestamp format reduces the likelihood of directory name collisions, which is a good improvement.
182-182: Good addition of explain_features parameter.Adding the
explain_featuresparameter to control SHAP value aggregation provides appropriate conditional behavior for feature explanation functionality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🔭 Outside diff range comments (2)
icu_benchmarks/run.py (2)
150-154: Include vars.gin in gin_config_files (first), instead of editing task file.To ensure variables are available when parsing the task, include vars.gin (if requested and present) as the first file in gin_config_files. This keeps all parsing consolidated in one place and avoids double-parsing.
Apply this diff:
- gin_config_files = ( - [Path(f"configs/experiments/{args.experiment}.gin")] - if args.experiment - else [model_path, Path(f"configs/tasks/{task}.gin")] - ) + # Build config file list. Ensure vars.gin (if present) is parsed first. + if args.experiment: + gin_config_files = [Path(f"configs/experiments/{args.experiment}.gin")] + else: + gin_config_files = [model_path, Path(f"configs/tasks/{task}.gin")] + if load_data_vars and (data_dir / "vars.gin").exists(): + gin_config_files.insert(0, data_dir / "vars.gin")
137-143: Bug: Directory name uses fine_tune when branching on samples.This branch checks args.samples but names the directory with args.fine_tune. This can mislabel runs and cause directory collisions.
Apply this diff:
- log_dir /= f"samples_{args.fine_tune}" + log_dir /= f"samples_{args.samples}"
♻️ Duplicate comments (2)
icu_benchmarks/run.py (2)
160-175: Use lazy logging and correct the logged variable; raising is fixed.Good: you now raise ValueError. Still pending: switch logging to lazy formatting and log the parsed file_names when applicable to avoid duplicate parsing and to improve performance.
Apply this diff:
- logging.info(f"Attempting to bind dataset files: {args.file_names}, type: {type(args.file_names)}") + logging.info("Attempting to bind dataset files: %s, type: %s", args.file_names, type(args.file_names)) if isinstance(args.file_names, dict): - logging.info(f"Will load data from {args.file_names}") + logging.info("Will load data from %s", args.file_names) gin.bind_parameter("preprocess.file_names", args.file_names) elif isinstance(args.file_names, str): file_names = parse_dict(args.file_names) - logging.info(f"Will load data from {args.file_names}") + logging.info("Will load data from %s", file_names) gin.bind_parameter("preprocess.file_names", file_names)
147-149: Model path continuation indentation looks corrected.The previous E122 issue appears resolved with consistent visual indentation in the multi-line expression.
🧹 Nitpick comments (2)
icu_benchmarks/run.py (2)
176-180: Cast Path objects to str before updating WandB config.W&B config values should be JSON-serializable; Path objects may not serialize cleanly.
Apply this diff:
- update_wandb_config({ - "data_dir": data_dir.resolve(), - "task": task, - "run_dir": run_dir.resolve(), - }) + update_wandb_config({ + "data_dir": str(data_dir.resolve()), + "task": task, + "run_dir": str(run_dir.resolve()), + })
232-233: Persist fold-wise predictions returned by append_predictions_foldwise.The function returns a Polars DataFrame; currently it’s discarded. Save it to disk to make the aggregation usable downstream.
Apply this diff:
- append_predictions_foldwise(run_dir, "pred_indicators.csv") + pred_df = append_predictions_foldwise(run_dir, "pred_indicators.csv") + pred_df.write_parquet(run_dir / "pred_indicators_aggregated.parquet")Add this import at the top of the file if Polars isn’t already imported elsewhere:
import polars as pl # at top-level
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
icu_benchmarks/run.py(5 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
icu_benchmarks/run.py (5)
icu_benchmarks/run_utils.py (6)
get_config_files(343-367)append_predictions_foldwise(118-180)aggregate_results(182-260)aggregate_results(101-176)build_parser(21-63)name_datasets(179-181)icu_benchmarks/utils.py (1)
parse_dict(5-22)icu_benchmarks/wandb_utils.py (1)
update_wandb_config(15-23)scripts/evaluate_results/aggregrate_experiment.py (1)
aggregate_results(7-85)icu_benchmarks/models/train.py (1)
train_common(34-204)
🔇 Additional comments (1)
icu_benchmarks/run.py (1)
23-23: CLI parser flags are correctly definedI’ve confirmed in icu_benchmarks/run_utils.py’s build_parser that all three flags are present with the intended types and defaults:
--file_names(line 67) usestype=parse_dict--explain_features(line 73) hasdefault=False,action=BOA--load_data_vars(line 75) hasdefault=False,action=BOANo changes required.
| if load_data_vars: | ||
| logging.info(f"Loading variables from {task} from {data_dir} configuration") | ||
| if (data_dir / "vars.gin").exists(): | ||
| # Open the task config file in append mode and add a line | ||
| with open(f"configs/tasks/{task}.gin", "a") as config_file: | ||
| config_file.write("\n# Added automatically by run.py\n") | ||
| config_file.write(f"# Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") | ||
| config_file.write(f'include "{data_dir}/vars.gin"\n') | ||
| else: | ||
| logging.warning(f"No vars.gin file found in {data_dir}. Please ensure the file exists.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Do not mutate repo configs at runtime; parse vars.gin instead (race condition and drift risk).
Appending an include to configs/tasks/{task}.gin is unsafe:
- Concurrent runs can corrupt the file.
- Repeated runs will keep appending, causing duplicates.
- It pollutes the repo/installed package state.
Parse vars.gin directly instead of editing the task config file.
Apply this diff:
- if load_data_vars:
- logging.info(f"Loading variables from {task} from {data_dir} configuration")
- if (data_dir / "vars.gin").exists():
- # Open the task config file in append mode and add a line
- with open(f"configs/tasks/{task}.gin", "a") as config_file:
- config_file.write("\n# Added automatically by run.py\n")
- config_file.write(f"# Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
- config_file.write(f'include "{data_dir}/vars.gin"\n')
- else:
- logging.warning(f"No vars.gin file found in {data_dir}. Please ensure the file exists.")
+ if load_data_vars:
+ vars_path = data_dir / "vars.gin"
+ logging.info("Loading variables for task %s from %s configuration", task, data_dir)
+ if vars_path.exists():
+ gin.parse_config_file(vars_path)
+ logging.info("Loaded %s", vars_path)
+ else:
+ logging.warning("No vars.gin file found in %s. Please ensure the file exists.", data_dir)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if load_data_vars: | |
| logging.info(f"Loading variables from {task} from {data_dir} configuration") | |
| if (data_dir / "vars.gin").exists(): | |
| # Open the task config file in append mode and add a line | |
| with open(f"configs/tasks/{task}.gin", "a") as config_file: | |
| config_file.write("\n# Added automatically by run.py\n") | |
| config_file.write(f"# Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") | |
| config_file.write(f'include "{data_dir}/vars.gin"\n') | |
| else: | |
| logging.warning(f"No vars.gin file found in {data_dir}. Please ensure the file exists.") | |
| if load_data_vars: | |
| vars_path = data_dir / "vars.gin" | |
| logging.info("Loading variables for task %s from %s configuration", task, data_dir) | |
| if vars_path.exists(): | |
| gin.parse_config_file(vars_path) | |
| logging.info("Loaded %s", vars_path) | |
| else: | |
| logging.warning("No vars.gin file found in %s. Please ensure the file exists.", data_dir) |
| config_file.write(f'include "{data_dir}/vars.gin"\n') | ||
| else: | ||
| logging.warning(f"No vars.gin file found in {data_dir}. Please ensure the file exists.") | ||
| gin.parse_config_file(f"configs/tasks/{task}.gin") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid parsing task config before branching; this double-parses and interferes with load_weights.
You parse configs/tasks/{task}.gin here, and then again later in the normal training path. When evaluating/fine-tuning (load_weights=True), this early parse can conflict with the config loaded from source_dir/train_config.gin.
Remove the unconditional parse here and handle task/var configs only inside the normal training branch.
Apply this diff:
- gin.parse_config_file(f"configs/tasks/{task}.gin")
+ # NOTE: Task config is parsed in the normal training branch below.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| gin.parse_config_file(f"configs/tasks/{task}.gin") | |
| # NOTE: Task config is parsed in the normal training branch below. |
🤖 Prompt for AI Agents
In icu_benchmarks/run.py around line 80, the unconditional
gin.parse_config_file(f"configs/tasks/{task}.gin") is parsed before branching
which causes the task config to be double-parsed and interferes with
load_weights when evaluating/fine-tuning; remove this unconditional parse and
instead parse task and variable configs only inside the normal training branch
(where train_config.gin and task-specific configs are meant to be applied),
ensuring the evaluation/fine-tune path does not re-parse the task config so
load_weights can load source_dir/train_config.gin without conflict.
| execution_time = datetime.now() - start_time | ||
| log_full_line(f"DURATION: {execution_time}", level=logging.INFO, char="") | ||
| try: | ||
| aggregate_results(run_dir, execution_time) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pass explain_features into aggregate_results to write SHAP aggregates.
aggregate_results has an explain_features flag; without passing it through, SHAP values won’t be aggregated even if generated.
Apply this diff:
- aggregate_results(run_dir, execution_time)
+ aggregate_results(run_dir, execution_time, explain_features=args.explain_features)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| aggregate_results(run_dir, execution_time) | |
| - aggregate_results(run_dir, execution_time) | |
| + aggregate_results(run_dir, execution_time, explain_features=args.explain_features) |
🤖 Prompt for AI Agents
In icu_benchmarks/run.py around line 231, the call to aggregate_results(run_dir,
execution_time) omits the explain_features argument so SHAP aggregates are never
written; update the call to pass the explain_features boolean (e.g.
aggregate_results(run_dir, execution_time, explain_features)) ensuring the
explain_features variable is in scope and matches the aggregate_results
parameter order.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (6)
icu_benchmarks/data/loader.py (4)
53-54: Duplicate comment: Static dataset mis-alignment risk exists.The sorting of
row_indicatorswithout corresponding sorting offeatures_dfcreates the same alignment risk mentioned in previous reviews. This issue persists in the static dataset path.
75-75: Duplicate comment: Commented code affects feature filtering.The feature filtering for GROUP and SEQUENCE columns remains commented out, which means these columns will be included in the feature names returned to downstream components. This issue was flagged in previous reviews.
182-199: Duplicate comment: Avoid mutating self.row_indicators in accessor method.The method still mutates
self.row_indicatorsdirectly on lines 198, which creates the same thread-safety and reproducibility concerns flagged in previous reviews. The suggested fix to use a local copy should be applied.
46-47: Fix data alignment issue between features_df and row_indicators.The current implementation sorts
row_indicatorsafter already sortingfeatures_dfby the same columns. However, this creates a potential misalignment sincefeatures_dfhas the SEQUENCE column dropped on Line 45, which could affect the sorting order.Apply this diff to ensure both DataFrames remain aligned:
self.features_df = data[split][DataSegment.features] self.features_df = self.features_df.sort([self.vars["GROUP"], self.vars["SEQUENCE"]]) -self.features_df = self.features_df.drop(self.vars["SEQUENCE"]) -self.row_indicators = self.row_indicators.sort([self.vars["GROUP"], self.vars["SEQUENCE"]]) +self.row_indicators = self.row_indicators.sort([self.vars["GROUP"], self.vars["SEQUENCE"]]) +self.features_df = self.features_df.drop(self.vars["SEQUENCE"])icu_benchmarks/models/wrappers.py (2)
540-544: Error handling removal needs justification.The try/catch block for
_save_model_outputshas been commented out, which means errors will now propagate and potentially crash the test step. This was flagged in previous reviews.
574-615: Complex conditional logic in updatedlog_metricsmethod.This was flagged in previous reviews for lacking comprehensive unit tests. The method now has two distinct branches based on whether
pred_indicatorsis None, with complex shape matching logic for the non-None case.
🧹 Nitpick comments (2)
icu_benchmarks/data/loader.py (1)
97-104: Improve conditional logic for SEQUENCE handling.The nested if-else structure could be simplified and made more readable. The logic correctly handles both temporal and non-temporal SEQUENCE columns, but the structure is unnecessarily complex.
Apply this diff to simplify the conditional logic:
-if "SEQUENCE" in self.vars: - if self.row_indicators[self.vars["SEQUENCE"]].dtype.is_temporal(): - self.row_indicators = self.row_indicators.with_columns(pl.col(self.vars["SEQUENCE"]).dt.total_hours()) - else: - self.row_indicators = self.row_indicators.with_columns(pl.col(self.vars["SEQUENCE"])) -else: - logging.info("Using static dataset") - self.row_indicators = self.grouping_df[self.vars["GROUP"]].to_frame(self.vars["GROUP"]) +if "SEQUENCE" in self.vars: + sequence_col = pl.col(self.vars["SEQUENCE"]) + if self.row_indicators[self.vars["SEQUENCE"]].dtype.is_temporal(): + sequence_col = sequence_col.dt.total_hours() + self.row_indicators = self.row_indicators.with_columns(sequence_col) +else: + logging.info("Using static dataset") + self.row_indicators = self.grouping_df[self.vars["GROUP"]].to_frame(self.vars["GROUP"])icu_benchmarks/models/wrappers.py (1)
169-169: Remove unnecessary f-string prefix.The f-string has no placeholders and should use a regular string.
- logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / f'row_indicators.csv'}") + logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / 'row_indicators.csv'}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
icu_benchmarks/data/loader.py(7 hunks)icu_benchmarks/models/wrappers.py(11 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (2)
icu_benchmarks/data/loader.py (1)
icu_benchmarks/data/constants.py (1)
DataSegment(7-11)
icu_benchmarks/models/wrappers.py (3)
icu_benchmarks/models/utils.py (2)
log_single_metric_to_file(290-315)name(159-160)icu_benchmarks/models/ml_models/xgboost.py (5)
_explain_model(76-82)fit_model(41-62)fit_model(121-154)predict(29-39)predict(109-119)icu_benchmarks/models/custom_metrics.py (1)
confusion_matrix(141-150)
🪛 Ruff (0.12.2)
icu_benchmarks/models/wrappers.py
169-169: f-string without any placeholders
Remove extraneous f prefix
(F541)
🔇 Additional comments (18)
icu_benchmarks/data/loader.py (5)
36-39: SEQUENCE validation properly gated by feature availability.The conditional validation ensures that SEQUENCE type checking only occurs when SEQUENCE is both defined in vars and present in the data columns. This prevents unnecessary errors when SEQUENCE is not needed.
132-132: Extract row indicators properly for current stay.The extraction logic correctly filters row indicators for the current stay_id, ensuring proper alignment with the window data.
143-145: Row indicator padding implementation is correct.The padding logic properly handles row indicators by extending them with pad_value and converting to float32 for consistency with other tensors.
194-195: Data processing logic is correct.The conversion to NumPy and removal of the first column (group column) is appropriately handled. The logic correctly prepares the feature data for ML models.
154-155: PredictionPolarsDataset now returns (data, labels, pad_mask, row_inds) — callers verifiedShort summary: I inspected the repo and confirmed internal callers support the 4-tuple return.
Relevant locations:
- icu_benchmarks/data/loader.py — PredictionPolarsDataset.getitem: returns 4-tuple at line 155:
return from_numpy(data), from_numpy(labels), from_numpy(pad_mask), from_numpy(row_inds)- icu_benchmarks/models/wrappers.py — step_fn explicitly handles both shapes:
- len(element) == 3 branch (line ~376) -> data, labels, indicators
- len(element) == 4 branch (line ~385) -> data, labels, mask, indicators
- icu_benchmarks/data/loader.py — PredictionPandasDataset.getitem still returns 3-tuple (line ~327); wrappers' 3-tuple branch covers this.
- Imputation-related dataset getitem implementations already return 4-tuples where applicable.
Optional suggestion: update PredictionPolarsDataset.getitem type annotation/docstring to reflect the 4-tuple return for clarity (not required for correctness).
icu_benchmarks/models/wrappers.py (13)
20-20: Verify import addition aligns with actual usage.The
log_single_metric_to_fileimport has been added, which aligns with its usage in the newlog_curvesmethods in both ML and DL wrappers.
66-67: LGTM: Useful logging enhancement.The addition of column count logging in
set_trained_columnsprovides helpful debugging information.
109-111: Simple and effective feature toggle method.The
set_explain_featuresmethod provides a clean interface to toggle explainability features.
375-393: Enhanced step function correctly handles indicators parameter.The extension to support 3 or 4 element batches (with indicators and optional mask) is well implemented and maintains backward compatibility.
404-415: Good extraction of prediction data for output saving.The code properly extracts indicators, predictions, and targets using the mask for valid indices, and correctly applies softmax transformation.
427-429: Calling_save_model_outputsin the step function.This call correctly integrates the new output saving functionality into the training loop.
499-501: Conditional explainer computation added.The explainer is now properly gated behind the
explain_featuresflag, which helps avoid expensive computations during hyperparameter tuning.
510-510: Method signature updated to include pred_indicators.The
log_metricscall now includes therow_indicatorsparameter to support the enhanced logging capabilities.
545-547: Explainer logic updated correctly.The explainer computation is now properly conditional on
explain_featuresflag and stores results inexplainer_values_test.
547-547: Newlog_curvesmethod call added.This integrates the curve logging functionality into the test workflow.
550-554: Updatedlog_metricscalls with pred_indicators parameter.Both MPS and standard paths now correctly pass the
pred_indicatorsparameter to enable enhanced metric logging.
562-573: Newlog_curvesmethod implementation.The method correctly handles metrics that return tuples (like precision-recall curves) and saves them to CSV files using the imported
log_single_metric_to_fileutility.
616-622: Commented-out_explain_modelmethod.This method definition is commented out but this doesn't cause the runtime error mentioned in previous reviews since the actual
_explain_modelimplementations are in the specific model classes (as seen in the relevant code snippets).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 10
🧹 Nitpick comments (14)
experiments/charhpc_wandb_sweep_gpu.sh (4)
3-3: Verify GPU request flags for your Slurm version/cluster policy.Some clusters require
--gres=gpu:1(and matchingsrunflags) instead of--gpus=1. Using the wrong flag can yield zero GPUs at runtime.Suggested alternative if your site uses GRES:
-#SBATCH --partition=gpu # -p -#SBATCH --gpus=1 +#SBATCH --partition=gpu # -p +#SBATCH --gres=gpu:1If you keep
--gpus=1, prefer matching it in thesruninvocation (see later comment) to propagate GPU binding to the job step.Also applies to: 7-7
1-1: Minor: Prefer env-based shebang for portability.This avoids hard-coding /bin path differences across clusters.
-#!/bin/bash +#!/usr/bin/env bash
9-9: Enable strict bash mode for safer failures.Prevents silent errors and unbound variable usage from derailing long-running jobs.
#SBATCH --time=48:00:00 + +set -euo pipefail +IFS=$'\n\t'
10-13: Conda init is redundant; use one method consistently.You’re both sourcing conda.sh and using the shell hook. One is enough; keeping both can be brittle across sites.
Option A — rely on conda.sh (common on HPC images):
-source /etc/profile.d/conda.sh - -eval "$(conda shell.bash hook)" -conda activate yaib_req_pl +source /etc/profile.d/conda.sh +conda activate yaib_req_plOption B — rely on the hook (when conda is already on PATH):
-source /etc/profile.d/conda.sh - -eval "$(conda shell.bash hook)" -conda activate yaib_req_pl +eval "$(conda shell.bash hook)" +conda activate yaib_req_plOptional: fail fast if the env is missing.
+conda env list | grep -q '^yaib_req_pl\s' || { echo "Conda env 'yaib_req_pl' not found"; exit 3; }icu_benchmarks/models/ml_models/imblearn.py (8)
1-2: Prefer a module-scoped logger over the root logger.Using the root logger (
logging.info,logging.warning) makes it harder to control log levels per module. Define a module logger and use it.Apply:
import logging +logger = logging.getLogger(__name__)And update usages later in this file:
-logging.info(f"XGBoost model trained successfully. Validation loss: {val_loss:.4f}") +logger.info(f"XGBoost model trained successfully. Validation loss: {val_loss:.4f}") -logging.warning(f"XGBoost failed: {e}") +logger.warning(f"XGBoost failed: {e}")
7-7: Remove unused import (Ruff F401).
DecisionTreeClassifierisn’t used in this file.-from sklearn.tree import DecisionTreeClassifier
8-8: Clean up duplicate/unused imports (Ruff F401/F811).
RUSBoostClassifierandBalancedRandomForestClassifierare already imported on Line 3 and are unused here. OnlyEasyEnsembleClassifieris needed.Within this hunk:
-from imblearn.ensemble import RUSBoostClassifier, BalancedRandomForestClassifier, EasyEnsembleClassifier +from imblearn.ensemble import EasyEnsembleClassifierAlso update the earlier import to drop the unused
RUSBoostClassifier:-from imblearn.ensemble import BalancedRandomForestClassifier, RUSBoostClassifier +from imblearn.ensemble import BalancedRandomForestClassifier
21-36: Base estimator defined as a class attribute may surprise; consider constructing per-instance and make it gin-configurable.
- Defining
individual_modelas a class attribute creates a single estimator object shared by all instances. While scikit-learn meta-estimators clone base estimators, sharing a mutable instance at class scope is a common footgun and reduces configurability.- The current hyperparameters are very heavy for an ensemble wrapper (
n_estimators=5000) and combined withEasyEnsembleClassifiercan explode training time/memory. Also, with EasyEnsemble balancing,scale_pos_weight=30likely double-weights the minority class.Suggested refactor: build the XGB estimator in
__init__and allow passing hyperparameters via gin.-@gin.configurable -class XGBEnsembleClassifier(MLWrapper): +@gin.configurable +class XGBEnsembleClassifier(MLWrapper): _supported_run_modes = [RunMode.classification] - individual_model = xgb.XGBClassifier( - learning_rate=0.1, - n_estimators=5000, - max_depth=10, - scale_pos_weight=30, - min_child_weight=1, - max_delta_step=3, - colsample_bytree=0.25, - gamma=0.9, - reg_lambda=0.1, - reg_alpha=100, - random_state=42, - eval_metric='logloss' - ) - def __init__(self, *args, **kwargs): - self.model = self.set_model_args(EasyEnsembleClassifier, *args, **kwargs, estimator=self.individual_model) + def __init__(self, *args, xgb_params=None, **kwargs): + xgb_params = xgb_params or { + "learning_rate": 0.1, + "n_estimators": 1000, # Lower by default; tune via gin + "max_depth": 10, + "min_child_weight": 1, + "max_delta_step": 3, + "colsample_bytree": 0.25, + "gamma": 0.9, + "reg_lambda": 0.1, + "reg_alpha": 100, + "random_state": 42, + "eval_metric": "logloss", + "verbosity": 0, + # Consider dropping scale_pos_weight when using EasyEnsemble + # "scale_pos_weight": 30, + } + individual_model = xgb.XGBClassifier(**xgb_params) + self.model = self.set_model_args(EasyEnsembleClassifier, *args, **kwargs, estimator=individual_model) super().__init__(*args, **kwargs)Verification checklist:
- Confirm we don’t rely on class-level access to
individual_modelelsewhere.- Validate whether
scale_pos_weightis still desired when usingEasyEnsembleClassifier(it already balances the classes). If you keep it, document the intended effect.
38-39: Consider wiring through EasyEnsembleClassifier’s own random_state and n_jobs for reproducibility/throughput.
EasyEnsembleClassifierhasrandom_stateandn_jobs. Using them helps ensure deterministic behavior and better CPU utilization.- self.model = self.set_model_args(EasyEnsembleClassifier, *args, **kwargs, estimator=self.individual_model) + self.model = self.set_model_args( + EasyEnsembleClassifier, *args, **kwargs, + estimator=self.individual_model, + random_state=42, + n_jobs=-1, + )
23-36: Training-time/performance caveat with current hyperparameters.
n_estimators=5000inside an EasyEnsemble can be prohibitively slow and memory-intensive.- With balancing done by EasyEnsemble,
scale_pos_weight=30likely over-penalizes negatives. Consider removing or tuning it based on post-resampling class ratios.- Consider adding
tree_method="hist"(CPU) andn_jobs=-1to speed up XGBoost training.If helpful, I can propose a gin-tunable grid with sensible defaults for
n_estimatorsandscale_pos_weight. Would you like me to open a follow-up PR for that?
7-8: Remove unused imports flagged by Ruff (F401).
- Line 7:
DecisionTreeClassifieris unused.- Line 42: Inner
import xgboost as xgbis unused.Apply:
- import xgboost as xgb(Line 7 fix shown in another comment.)
Also applies to: 41-42
21-22: Naming alignment and docstring suggestion.The class name
XGBEnsembleClassifieris clear, but a short class docstring would help future readers understand that this is an EasyEnsemble over XGBoost base learners and why this is preferred over RUSBoost in this project.Example:
class XGBEnsembleClassifier(MLWrapper): """Imbalanced-class classifier using imblearn's EasyEnsemble with XGBClassifier base learners. Rationale: undersample majority per ensemble member; leverage XGBoost capacity; tune XGB via gin; prefer over RUSBoost due to <project-specific reasons>. """configs/prediction_models/LGBMClassifier.gin (2)
34-37: DART ranges are quite aggressive; consider narrowing to reduce instability.High
drop_rate/skip_dropcan make convergence noisy. Typical starting ranges are tighter.Apply this diff if you observe unstable validation curves with DART:
-model/hyperparameter.drop_rate = (0.1, 0.5) +model/hyperparameter.drop_rate = (0.05, 0.2) @@ -model/hyperparameter.skip_drop = (0.1, 0.9) +model/hyperparameter.skip_drop = (0.0, 0.5)
16-17: Sanity-check upper bounds for min_child_samples and min_child_weight.Including very large
min_child_samples(e.g., 1000) can eliminate splits on small/medium datasets. Keep if your datasets are large; otherwise, narrow to something like(5, 200)and(1e-3, 10.0, "log")respectively.Apply if your typical training datasets are <100k rows:
-model/hyperparameter.min_child_samples = (10, 20, 50, 100, 500, 1000) +model/hyperparameter.min_child_samples = (5, 200)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (7)
configs/prediction_models/LGBMClassifier.gin(1 hunks)docs/source/conf.py(1 hunks)experiments/charhpc_wandb_sweep_cpu.sh(1 hunks)experiments/charhpc_wandb_sweep_gpu.sh(1 hunks)icu_benchmarks/models/ml_models/imblearn.py(2 hunks)icu_benchmarks/models/ml_models/lgbm.py(0 hunks)icu_benchmarks/wandb_utils.py(3 hunks)
💤 Files with no reviewable changes (1)
- icu_benchmarks/models/ml_models/lgbm.py
🚧 Files skipped from review as they are similar to previous changes (3)
- docs/source/conf.py
- experiments/charhpc_wandb_sweep_cpu.sh
- icu_benchmarks/wandb_utils.py
🧰 Additional context used
🧬 Code graph analysis (1)
icu_benchmarks/models/ml_models/imblearn.py (3)
icu_benchmarks/constants.py (1)
RunMode(4-7)icu_benchmarks/models/wrappers.py (5)
MLWrapper(444-652)set_model_args(642-652)fit_model(512-517)fit(487-510)fit(744-745)icu_benchmarks/models/ml_models/xgboost.py (6)
XGBClassifier(21-87)set_model_args(64-74)fit_model(41-62)fit_model(121-154)XGBClassifier(21-74)__init__(25-27)
🪛 Ruff (0.12.2)
icu_benchmarks/models/ml_models/imblearn.py
7-7: sklearn.tree.DecisionTreeClassifier imported but unused
Remove unused import: sklearn.tree.DecisionTreeClassifier
(F401)
8-8: Redefinition of unused RUSBoostClassifier from line 3
Remove definition: RUSBoostClassifier
(F811)
8-8: imblearn.ensemble.RUSBoostClassifier imported but unused
Remove unused import: imblearn.ensemble.RUSBoostClassifier
(F401)
8-8: Redefinition of unused BalancedRandomForestClassifier from line 3
Remove definition: BalancedRandomForestClassifier
(F811)
42-42: xgboost imported but unused
Remove unused import: xgboost
(F401)
🔇 Additional comments (3)
icu_benchmarks/models/ml_models/imblearn.py (1)
9-9: LGTM: xgboost import is used below.The module-level
xgboostimport is required for defining the base estimator.configs/prediction_models/LGBMClassifier.gin (2)
9-11: LGTM: Explicitly scoping the tuner target.Declaring
model/hyperparameter.class_to_tune = @LGBMClassifieris clear and aligns withtrain_common.model = @LGBMClassifier.
13-17: Guard cross-parameter validity: num_leaves vs max_depth.Some sampled combos (e.g., shallow
max_depthwith very largenum_leaves) are wasteful or incoherent. If your tuner supports constraints, enforcenum_leaves <= 2**max(1, max_depth)whenmax_depth > 0. Otherwise, consider cappingnum_leavesto 2^(max_depth) in post-processing before estimator instantiation.If your tuner doesn’t support relational constraints, I can draft a small adapter to clamp
num_leavesbased on the sampledmax_depth. Want me to open an issue and provide a PR?
| model/hyperparameter.n_estimators = (500, 2000, 5000, 10000, 100000) | ||
| model/hyperparameter.max_depth = (3, 5, 7, 10, 15) | ||
| model/hyperparameter.num_leaves = (8, 16, 31, 64, 128, 256, "log", 2) | ||
| model/hyperparameter.min_child_samples = (10, 20, 50, 100, 500, 1000) | ||
| model/hyperparameter.min_child_weight = (1e-3, 10.0, "log") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Fix malformed num_leaves spec and consider simplifying n_estimators.
- The
num_leavestuple mixes discrete values with"log", 2. This is likely to break your parser (or yield invalid values). Prefer either a discrete list or a proper range tuple. - Optional: Given early stopping is enabled below, a very large discrete set for
n_estimatorsadds search noise/cost. A log-range or a single high ceiling with early stopping is more efficient.
Apply this diff to make both entries unambiguous:
-model/hyperparameter.n_estimators = (500, 2000, 5000, 10000, 100000)
+model/hyperparameter.n_estimators = (500, 20000, "log")
@@
-model/hyperparameter.num_leaves = (8, 16, 31, 64, 128, 256, "log", 2)
+model/hyperparameter.num_leaves = (8, 256, "log")Notes:
- If you prefer discrete leaves, use e.g.
[15, 31, 63, 127, 255]instead. - Consider adding
max_depth = [-1, 3, 5, 7, 10, 15]so the unconstrained depth case is explored.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| model/hyperparameter.n_estimators = (500, 2000, 5000, 10000, 100000) | |
| model/hyperparameter.max_depth = (3, 5, 7, 10, 15) | |
| model/hyperparameter.num_leaves = (8, 16, 31, 64, 128, 256, "log", 2) | |
| model/hyperparameter.min_child_samples = (10, 20, 50, 100, 500, 1000) | |
| model/hyperparameter.min_child_weight = (1e-3, 10.0, "log") | |
| # configs/prediction_models/LGBMClassifier.gin | |
| model/hyperparameter.n_estimators = (500, 20000, "log") | |
| model/hyperparameter.max_depth = (3, 5, 7, 10, 15) | |
| model/hyperparameter.num_leaves = (8, 256, "log") | |
| model/hyperparameter.min_child_samples = (10, 20, 50, 100, 500, 1000) | |
| model/hyperparameter.min_child_weight = (1e-3, 10.0, "log") |
🤖 Prompt for AI Agents
In configs/prediction_models/LGBMClassifier.gin around lines 13 to 18, the
num_leaves hyperparameter is malformed (mixes discrete values with `"log", 2`)
and n_estimators is an overly large discrete set; fix num_leaves to be a clear
discrete list (e.g. 15, 31, 63, 127, 255) or a valid range spec so the parser
won’t break, and simplify n_estimators to either a log-range (e.g. 100, 500,
1000, 5000) or a single high ceiling relying on early stopping to control
training; optionally add max_depth = (-1, 3, 5, 7, 10, 15) so the unconstrained
depth case is included.
| # Sampling parameters | ||
| model/hyperparameter.subsample = (0.4, 1.0) | ||
| model/hyperparameter.subsample_freq = (0, 1, 5, 10) | ||
| model/hyperparameter.colsample_bytree = (0.4, 1.0) | ||
| model/hyperparameter.colsample_bynode = (0.4, 1.0) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Remove LightGBM/XGBoost alias duplication and an invalid parameter name.
subsample/subsample_freqare LightGBM aliases ofbagging_fraction/bagging_freq. Keeping both creates conflicting grids and harder-to-debug overrides.colsample_bytreeis an alias offeature_fractionin LightGBM; again, duplicate knobs.colsample_bynodeis not a LightGBM parameter (XGBoost-ism). If you want per-node feature sampling, usefeature_fraction_bynode.
Apply this diff to remove the aliases here and manage sampling via the canonical keys in the Performance block (Lines 42-47):
-# Sampling parameters
-model/hyperparameter.subsample = (0.4, 1.0)
-model/hyperparameter.subsample_freq = (0, 1, 5, 10)
-model/hyperparameter.colsample_bytree = (0.4, 1.0)
-model/hyperparameter.colsample_bynode = (0.4, 1.0)Then, add the per-node variant where the canonical feature/bagging params are declared (see the suggested change on Lines 42-47).
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Sampling parameters | |
| model/hyperparameter.subsample = (0.4, 1.0) | |
| model/hyperparameter.subsample_freq = (0, 1, 5, 10) | |
| model/hyperparameter.colsample_bytree = (0.4, 1.0) | |
| model/hyperparameter.colsample_bynode = (0.4, 1.0) |
🤖 Prompt for AI Agents
In configs/prediction_models/LGBMClassifier.gin around lines 24 to 29, remove
the LightGBM/XGBoost alias and invalid parameters — delete
model/hyperparameter.subsample, model/hyperparameter.subsample_freq,
model/hyperparameter.colsample_bytree and model/hyperparameter.colsample_bynode
— and instead manage sampling via the canonical LightGBM keys in the Performance
block (lines 42-47); then add the per-node variant
model/hyperparameter.feature_fraction_bynode alongside feature_fraction and
bagging_fraction in that Performance block so per-node sampling is available
under the canonical LightGBM name.
| # Boosting parameters | ||
| model/hyperparameter.boosting_type = ["gbdt", "dart"] | ||
|
|
||
| # Advanced DART parameters (active when boosting_type="dart") | ||
| model/hyperparameter.drop_rate = (0.1, 0.5) | ||
| model/hyperparameter.max_drop = (10, 50) | ||
| model/hyperparameter.skip_drop = (0.1, 0.9) | ||
|
|
||
| # GOSS parameters (active when boosting_type="goss") | ||
| model/hyperparameter.top_rate = (0.1, 0.5) | ||
| model/hyperparameter.other_rate = (0.05, 0.2) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
GOSS parameters present without enabling GOSS; make the search space consistent.
You’re tuning top_rate/other_rate for GOSS but boosting_type doesn’t include "goss". Either add "goss" to boosting_type or drop those knobs to avoid dead/ignored params.
Apply one of the following:
Option A — enable GOSS:
-model/hyperparameter.boosting_type = ["gbdt", "dart"]
+model/hyperparameter.boosting_type = ["gbdt", "dart", "goss"]Option B — keep only GBDT/DART (remove GOSS-only params):
-# GOSS parameters (active when boosting_type="goss")
-model/hyperparameter.top_rate = (0.1, 0.5)
-model/hyperparameter.other_rate = (0.05, 0.2)Note: When using GOSS, ensure top_rate + other_rate < 1. If your tuner can’t enforce relational constraints, narrow ranges to, e.g., top_rate=(0.1,0.3) and other_rate=(0.05,0.15).
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Boosting parameters | |
| model/hyperparameter.boosting_type = ["gbdt", "dart"] | |
| # Advanced DART parameters (active when boosting_type="dart") | |
| model/hyperparameter.drop_rate = (0.1, 0.5) | |
| model/hyperparameter.max_drop = (10, 50) | |
| model/hyperparameter.skip_drop = (0.1, 0.9) | |
| # GOSS parameters (active when boosting_type="goss") | |
| model/hyperparameter.top_rate = (0.1, 0.5) | |
| model/hyperparameter.other_rate = (0.05, 0.2) | |
| # Boosting parameters | |
| model/hyperparameter.boosting_type = ["gbdt", "dart", "goss"] | |
| # Advanced DART parameters (active when boosting_type="dart") | |
| model/hyperparameter.drop_rate = (0.1, 0.5) | |
| model/hyperparameter.max_drop = (10, 50) | |
| model/hyperparameter.skip_drop = (0.1, 0.9) | |
| # GOSS parameters (active when boosting_type="goss") | |
| model/hyperparameter.top_rate = (0.1, 0.5) | |
| model/hyperparameter.other_rate = (0.05, 0.2) |
| # Boosting parameters | |
| model/hyperparameter.boosting_type = ["gbdt", "dart"] | |
| # Advanced DART parameters (active when boosting_type="dart") | |
| model/hyperparameter.drop_rate = (0.1, 0.5) | |
| model/hyperparameter.max_drop = (10, 50) | |
| model/hyperparameter.skip_drop = (0.1, 0.9) | |
| # GOSS parameters (active when boosting_type="goss") | |
| model/hyperparameter.top_rate = (0.1, 0.5) | |
| model/hyperparameter.other_rate = (0.05, 0.2) | |
| # Boosting parameters | |
| model/hyperparameter.boosting_type = ["gbdt", "dart"] | |
| # Advanced DART parameters (active when boosting_type="dart") | |
| model/hyperparameter.drop_rate = (0.1, 0.5) | |
| model/hyperparameter.max_drop = (10, 50) | |
| model/hyperparameter.skip_drop = (0.1, 0.9) |
🤖 Prompt for AI Agents
In configs/prediction_models/LGBMClassifier.gin around lines 30 to 41, the file
defines GOSS-specific hyperparameters (top_rate, other_rate) but the
boosting_type list does not include "goss", so those knobs will be ignored; fix
by either (A) add "goss" to model/hyperparameter.boosting_type so the GOSS
params become active and, if enabling GOSS, ensure top_rate + other_rate < 1 (or
tighten ranges, e.g., top_rate=(0.1,0.3) and other_rate=(0.05,0.15) if the tuner
can’t enforce relational constraints), or (B) remove the GOSS-only params
(top_rate and other_rate) to keep the search space consistent with only "gbdt"
and "dart".
| # Performance and stability | ||
| model/hyperparameter.feature_fraction = (0.4, 1.0) | ||
| model/hyperparameter.bagging_fraction = (0.4, 1.0) | ||
| model/hyperparameter.bagging_freq = (0, 1, 5, 10) | ||
| model/hyperparameter.min_split_gain = (1e-6, 1.0, "log") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consolidate sampling on canonical LightGBM params and add by-node feature sampling.
Given aliases were removed above, keep tuning here via canonical names. If you want per-node feature sampling, add feature_fraction_bynode.
Apply this diff to extend the canonical block:
model/hyperparameter.feature_fraction = (0.4, 1.0)
model/hyperparameter.bagging_fraction = (0.4, 1.0)
model/hyperparameter.bagging_freq = (0, 1, 5, 10)
model/hyperparameter.min_split_gain = (1e-6, 1.0, "log")
+model/hyperparameter.feature_fraction_bynode = (0.4, 1.0)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Performance and stability | |
| model/hyperparameter.feature_fraction = (0.4, 1.0) | |
| model/hyperparameter.bagging_fraction = (0.4, 1.0) | |
| model/hyperparameter.bagging_freq = (0, 1, 5, 10) | |
| model/hyperparameter.min_split_gain = (1e-6, 1.0, "log") | |
| # Performance and stability | |
| model/hyperparameter.feature_fraction = (0.4, 1.0) | |
| model/hyperparameter.bagging_fraction = (0.4, 1.0) | |
| model/hyperparameter.bagging_freq = (0, 1, 5, 10) | |
| model/hyperparameter.min_split_gain = (1e-6, 1.0, "log") | |
| model/hyperparameter.feature_fraction_bynode = (0.4, 1.0) |
🤖 Prompt for AI Agents
In configs/prediction_models/LGBMClassifier.gin around lines 42 to 47, the
tuning block uses LightGBM aliases but should use canonical parameter names and
include per-node feature sampling; replace/extend the current entries so they
use canonical names (feature_fraction, bagging_fraction, bagging_freq,
min_split_gain) and add a new feature_fraction_bynode hyperparameter with an
appropriate search range (e.g., (0.4, 1.0)) to enable by-node feature sampling;
ensure the value formats match the file's existing hyperparameter tuple
conventions.
| # Early stopping and validation | ||
| model/hyperparameter.early_stopping_rounds = 100 | ||
| model/hyperparameter.eval_metric = ["binary_logloss", "auc", "binary_error"] | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Confirm fit-time plumbing for early stopping and evaluation metric.
early_stopping_rounds and eval_metric are fit-time arguments in the LGBMClassifier sklearn API. If your trainer only instantiates the estimator and never passes these to fit(), these settings won’t take effect. Please confirm they’re forwarded at training time.
Run the following script to verify that your training pipeline passes these through to LightGBM’s fit() (or wraps them in callbacks):
Optional: For stability, consider dropping binary_error from the metric grid for early-stopping-driven tuning and rely on auc or binary_logloss.
Apply if desired:
-model/hyperparameter.eval_metric = ["binary_logloss", "auc", "binary_error"]
+model/hyperparameter.eval_metric = ["binary_logloss", "auc"]🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Inspect how tuning params are materialized and passed to LGBM
rg -n -C3 -e 'early_stopping_rounds|eval_metric|class_to_tune|LGBMClassifier|fit\('
# Common places:
# - Tuning utilities that map gin "model/hyperparameter.*" into estimator kwargs
# - Trainer code that calls estimator.fit(..., eval_set=..., eval_metric=..., early_stopping_rounds=...)Length of output: 45667
🏁 Script executed:
#!/bin/bash
# Display the beginning of the LGBM wrapper to inspect how fit() is implemented
sed -n '1,200p' icu_benchmarks/models/ml_models/lgbm.pyLength of output: 1753
🏁 Script executed:
#!/bin/bash
# Locate the MLWrapper class to inspect how hyperparameters are mapped
rg -C5 -n 'class MLWrapper' icu_benchmarks/models/wrappers.pyLength of output: 485
🏁 Script executed:
#!/bin/bash
# Inspect how model hyperparameters (including eval_metric, early_stopping_rounds) are applied during model instantiation
rg -C5 -n 'def set_model_args' icu_benchmarks/models/wrappers.pyLength of output: 654
🏁 Script executed:
#!/bin/bash
python3 - << 'EOF'
import lightgbm as lgb
import inspect
print(inspect.signature(lgb.LGBMClassifier.__init__))
EOFLength of output: 132
🏁 Script executed:
#!/bin/bash
sed -n '630,730p' icu_benchmarks/models/wrappers.pyLength of output: 4091
Update LGBMWrapper.fit_model to wire through Gin-configured early stopping and eval metrics
Your current implementation in icu_benchmarks/models/ml_models/lgbm.py hard-codes:
- an early-stopping callback using
self.hparams.patience(not your Ginmodel/hyperparameter.early_stopping_rounds) - a generic log_evaluation callback (ignoring your Gin
model/hyperparameter.eval_metric)
As a result, neither early_stopping_rounds nor eval_metric from your Gin config ever reach the LightGBM fit() call. You’ll need to refactor the fit_model method to read from self.hparams and pass those values into the model, for example:
--- a/icu_benchmarks/models/ml_models/lgbm.py
+++ b/icu_benchmarks/models/ml_models/lgbm.py
@@ class LGBMWrapper(MLWrapper):
- # Early stopping and logging (hard-coded)
- callbacks = [
- lgbm.early_stopping(self.hparams.patience, verbose=True),
- lgbm.log_evaluation(period=-1)
- ]
-
- self.model = self.model.fit(
- train_data,
- train_labels,
- eval_set=(val_data, val_labels),
- callbacks=callbacks,
- )
+ # Wire through Gin hyperparameters for early stopping and evaluation metrics
+ callbacks = []
+ # Use Gin’s early_stopping_rounds if set
+ if getattr(self.hparams, "early_stopping_rounds", None):
+ callbacks.append(lgbm.early_stopping(self.hparams.early_stopping_rounds, verbose=True))
+
+ # Override the eval_metric on the model if configured
+ if getattr(self.hparams, "eval_metric", None):
+ # scikit-learn wrapper will pass this metric to the Booster
+ self.model.set_params(eval_metric=self.hparams.eval_metric)
+
+ self.model = self.model.fit(
+ train_data,
+ train_labels,
+ eval_set=(val_data, val_labels),
+ callbacks=callbacks,
+ )This change ensures that:
model/hyperparameter.early_stopping_roundsdrives your early-stopping callbackmodel/hyperparameter.eval_metricis passed to LightGBM and used for evaluation logging and stopping
References:
LightGBM scikit-learn API expects callbacks for early stopping and log evaluation; early_stopping_rounds as a fit arg is deprecated in favor of early_stopping() callback (blog.csdn.net).
The wrapper’s fit(...) signature supports eval_metric and custom callbacks for metric tracking (lightgbm.readthedocs.io).
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Early stopping and validation | |
| model/hyperparameter.early_stopping_rounds = 100 | |
| model/hyperparameter.eval_metric = ["binary_logloss", "auc", "binary_error"] | |
| # Wire through Gin hyperparameters for early stopping and evaluation metrics | |
| callbacks = [] | |
| # Use Gin’s early_stopping_rounds if set | |
| if getattr(self.hparams, "early_stopping_rounds", None): | |
| callbacks.append(lgbm.early_stopping(self.hparams.early_stopping_rounds, verbose=True)) | |
| # Override the eval_metric on the model if configured | |
| if getattr(self.hparams, "eval_metric", None): | |
| # scikit-learn wrapper will pass this metric to the Booster | |
| self.model.set_params(eval_metric=self.hparams.eval_metric) | |
| self.model = self.model.fit( | |
| train_data, | |
| train_labels, | |
| eval_set=(val_data, val_labels), | |
| callbacks=callbacks, | |
| ) |
🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/lgbm.py around the fit_model method (where
callbacks and eval metrics are assembled), the code currently hardcodes an
early-stopping callback using self.hparams.patience and always uses a generic
log_evaluation callback, so Gin-configured
model/hyperparameter.early_stopping_rounds and model/hyperparameter.eval_metric
never reach LightGBM; update fit_model to read early_stopping_rounds and
eval_metric from self.hparams, construct callbacks using
lightgbm.callback.early_stopping(self.hparams.early_stopping_rounds) and a
log_evaluation configured appropriately (or omit if not desired), and pass
eval_metric=self.hparams.eval_metric into the LGBMClassifier.fit(...) call along
with the validation set (eval_set/eval_names) and the assembled callbacks so
LightGBM uses the Gin-provided values for evaluation and stopping.
| # Class imbalance handling | ||
| model/hyperparameter.is_unbalance = [True, False] | ||
| model/hyperparameter.scale_pos_weight = (0.1, 10.0, "log") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Avoid combining is_unbalance with scale_pos_weight; prefer one strategy.
LightGBM treats these as alternative approaches to class imbalance. Searching both simultaneously can produce ambiguous or counteracting settings. Prefer tuning scale_pos_weight and fixing is_unbalance=False, or vice versa.
Apply this diff to prefer scale_pos_weight only:
-# Class imbalance handling
-model/hyperparameter.is_unbalance = [True, False]
-model/hyperparameter.scale_pos_weight = (0.1, 10.0, "log")
+# Class imbalance handling (prefer explicit weighting)
+model/hyperparameter.is_unbalance = [False]
+model/hyperparameter.scale_pos_weight = (0.1, 10.0, "log")If you want to compare strategies, introduce a higher-level toggle (e.g., imbalance_strategy in {"none","is_unbalance","scale_pos_weight"}) and conditionally set params accordingly to avoid invalid combos.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Class imbalance handling | |
| model/hyperparameter.is_unbalance = [True, False] | |
| model/hyperparameter.scale_pos_weight = (0.1, 10.0, "log") | |
| # Class imbalance handling (prefer explicit weighting) | |
| model/hyperparameter.is_unbalance = [False] | |
| model/hyperparameter.scale_pos_weight = (0.1, 10.0, "log") |
🤖 Prompt for AI Agents
In configs/prediction_models/LGBMClassifier.gin around lines 56 to 59, the
current config exposes both model/hyperparameter.is_unbalance and
model/hyperparameter.scale_pos_weight which are mutually exclusive; update the
file to prefer scale_pos_weight only by fixing model/hyperparameter.is_unbalance
to False and keeping model/hyperparameter.scale_pos_weight as the tunable range
(0.1, 10.0, "log"); alternatively, if you need to compare strategies, introduce
a single high-level toggle (e.g., imbalance_strategy in
{"none","is_unbalance","scale_pos_weight"}) and conditionally set either
is_unbalance or scale_pos_weight based on that toggle to avoid invalid
combinations.
| #SBATCH --partition=gpu # -p | ||
| #SBATCH --cpus-per-task=16 # -c | ||
| #SBATCH --mem=200gb | ||
| #SBATCH --output=logs/classification_%a_%j.log # %j is job id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix SBATCH output path: logs/ directory may not exist and %a is unused.
- If the logs/ directory does not pre-exist at submission time, Slurm will fail to open the output file and your job won’t start. Also,
%aresolves to 0 for non-array jobs, which is likely unintended noise.
Apply one of the following:
Option A — drop the subdir and unused array token:
-#SBATCH --output=logs/classification_%a_%j.log # %j is job id
+#SBATCH --output=classification_%j.log # %j is job idOption B — keep the subdir but remove %a, and ensure the directory exists before submission:
-#SBATCH --output=logs/classification_%a_%j.log # %j is job id
+#SBATCH --output=logs/classification_%j.log # %j is job idAnd document/ensure: mkdir -p logs before running sbatch.
🤖 Prompt for AI Agents
In experiments/charhpc_wandb_sweep_gpu.sh around line 6, the SBATCH output path
uses a non-existent logs/ directory and the unused %a array token; change the
output directive to remove %a (e.g. --output=logs/classification_%j.log or
--output=classification_%j.log) and ensure the logs directory exists before
submission by running mkdir -p logs (or drop the subdir and write to the current
dir if you don’t want to create directories).
|
|
||
| eval "$(conda shell.bash hook)" | ||
| conda activate yaib_req_pl | ||
| wandb agent --count 1 cassandra_hpi/cassandra/"$1" No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Validate required arg and run via srun to inherit CPU/GPU binding.
- Without an argument, the W&B agent path becomes invalid and the job will fail.
- Launching the agent with
srunensures CUDA-visible device assignment and CPU binding honor the job allocation across Slurm versions.
-wandb agent --count 1 cassandra_hpi/cassandra/"$1"
+[[ -n "${1:-}" ]] || { echo "Usage: $0 <SWEEP_ID>"; exit 2; }
+# If cluster uses --gpus=1, mirror that here; if it uses GRES, use --gres=gpu:1 instead.
+# CPU binding can be adjusted per site (e.g., --cpu-bind=cores).
+srun --cpus-per-task="${SLURM_CPUS_PER_TASK:-16}" --gpus="${SLURM_GPUS:-1}" \
+ wandb agent --count 1 "cassandra_hpi/cassandra/${1}"Optional, for better provenance in logs:
+echo "Job: ${SLURM_JOB_ID:-na} | Node: ${SLURM_NODELIST:-na} | GPUs: ${SLURM_GPUS:-na} | CPUs/Task: ${SLURM_CPUS_PER_TASK:-na}"
+command -v nvidia-smi >/dev/null 2>&1 && nvidia-smi || true
+conda info && conda list | grep -E '^(pytorch|torch|cuda|cudnn|wandb)\s' || true📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| wandb agent --count 1 cassandra_hpi/cassandra/"$1" | |
| [[ -n "${1:-}" ]] || { echo "Usage: $0 <SWEEP_ID>"; exit 2; } | |
| # If cluster uses --gpus=1, mirror that here; if it uses GRES, use --gres=gpu:1 instead. | |
| # CPU binding can be adjusted per site (e.g., --cpu-bind=cores). | |
| srun --cpus-per-task="${SLURM_CPUS_PER_TASK:-16}" --gpus="${SLURM_GPUS:-1}" \ | |
| wandb agent --count 1 "cassandra_hpi/cassandra/${1}" | |
| # Optional provenance/debug logging | |
| echo "Job: ${SLURM_JOB_ID:-na} | Node: ${SLURM_NODELIST:-na} | GPUs: ${SLURM_GPUS:-na} | CPUs/Task: ${SLURM_CPUS_PER_TASK:-na}" | |
| command -v nvidia-smi >/dev/null 2>&1 && nvidia-smi || true | |
| conda info && conda list | grep -E '^(pytorch|torch|cuda|cudnn|wandb)\s' || true |
| class XGBEnsembleClassifier(MLWrapper): | ||
| _supported_run_modes = [RunMode.classification] | ||
|
|
||
| individual_model = xgb.XGBClassifier( | ||
| learning_rate=0.1, | ||
| n_estimators=5000, | ||
| max_depth=10, | ||
| scale_pos_weight=30, | ||
| min_child_weight=1, | ||
| max_delta_step=3, | ||
| colsample_bytree=0.25, | ||
| gamma=0.9, | ||
| reg_lambda=0.1, | ||
| reg_alpha=100, | ||
| random_state=42, | ||
| eval_metric='logloss' | ||
| ) | ||
| def __init__(self, *args, **kwargs): | ||
| self.model = self.set_model_args(RUSBoostClassifier, *args, **kwargs) | ||
| self.model = self.set_model_args(EasyEnsembleClassifier, *args, **kwargs, estimator=self.individual_model) | ||
| super().__init__(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Check for references to removed/renamed classifiers.
If RUSBClassifier was used in configs/scripts, renaming to XGBEnsembleClassifier may break gin configs or imports.
Run the following to find lingering references:
🏁 Script executed:
#!/bin/bash
# Find references to the old class and to RUSBoost usage in the repo
rg -nP -C2 '\bRUSBClassifier\b|\bRUSBoostClassifier\b'Length of output: 3543
Update Remaining RUSBClassifier References
The renaming of RUSBClassifier to XGBEnsembleClassifier has left several broken imports and config entries. Please update or remove all occurrences to ensure the new class is correctly wired up.
• icu_benchmarks/models/init.py
– Replace the RUSBClassifier import and its entry in __all__ (around line 7 and the list at lines 34–62) with XGBEnsembleClassifier.
• configs/prediction_models/RUSBClassifier.gin
– Rename this file (and its references) to match XGBEnsembleClassifier.gin.
– Update train_common.model = @RUSBClassifier and model/hyperparameter.class_to_tune = @RUSBClassifier to use @XGBEnsembleClassifier.
• experiments/benchmark_*.yml
– Although currently commented out, there are lines referencing RUSBClassifier in several experiment YAMLs. Either remove or update these for consistency.
Addressing these spots is required to prevent import errors and broken gin configurations.
| def fit_model(self, train_data, train_labels, val_data, val_labels): | ||
| import xgboost as xgb | ||
| from sklearn.metrics import log_loss | ||
|
|
||
| # Try XGBoost first with provided config | ||
| try: | ||
| self.model.fit(train_data, train_labels,) | ||
|
|
||
| val_pred_proba = self.model.predict_proba(val_data) | ||
| val_loss = log_loss(val_labels, val_pred_proba) | ||
| logging.info(f"XGBoost model trained successfully. Validation loss: {val_loss:.4f}") | ||
| return val_loss | ||
|
|
||
| except Exception as e: | ||
| logging.warning(f"XGBoost failed: {e}") | ||
|
|
||
| return val_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: val_loss may be referenced before assignment on failure; also prefer self.loss/self.predict and drop unused imports.
If self.model.fit raises, the function falls through to return val_loss without it being defined, causing UnboundLocalError. In-function imports are unnecessary, and using self.loss/self.predict preserves consistency with the wrapper’s configured loss.
Apply:
- def fit_model(self, train_data, train_labels, val_data, val_labels):
- import xgboost as xgb
- from sklearn.metrics import log_loss
-
- # Try XGBoost first with provided config
- try:
- self.model.fit(train_data, train_labels,)
-
- val_pred_proba = self.model.predict_proba(val_data)
- val_loss = log_loss(val_labels, val_pred_proba)
- logging.info(f"XGBoost model trained successfully. Validation loss: {val_loss:.4f}")
- return val_loss
-
- except Exception as e:
- logging.warning(f"XGBoost failed: {e}")
-
- return val_loss
+ def fit_model(self, train_data, train_labels, val_data, val_labels):
+ # Ensure a defined loss even if training fails
+ val_loss = float("inf")
+ try:
+ self.model.fit(train_data, train_labels)
+ # Use the wrapper's predict() and configured loss for consistency
+ val_pred = self.predict(val_data)
+ val_loss = self.loss(val_labels, val_pred)
+ logging.info(f"XGBoost model trained successfully. Validation loss: {val_loss:.4f}")
+ except Exception:
+ logging.exception("XGBoost training failed")
+ return val_loss📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def fit_model(self, train_data, train_labels, val_data, val_labels): | |
| import xgboost as xgb | |
| from sklearn.metrics import log_loss | |
| # Try XGBoost first with provided config | |
| try: | |
| self.model.fit(train_data, train_labels,) | |
| val_pred_proba = self.model.predict_proba(val_data) | |
| val_loss = log_loss(val_labels, val_pred_proba) | |
| logging.info(f"XGBoost model trained successfully. Validation loss: {val_loss:.4f}") | |
| return val_loss | |
| except Exception as e: | |
| logging.warning(f"XGBoost failed: {e}") | |
| return val_loss | |
| def fit_model(self, train_data, train_labels, val_data, val_labels): | |
| # Ensure a defined loss even if training fails | |
| val_loss = float("inf") | |
| try: | |
| self.model.fit(train_data, train_labels) | |
| # Use the wrapper's predict() and configured loss for consistency | |
| val_pred = self.predict(val_data) | |
| val_loss = self.loss(val_labels, val_pred) | |
| logging.info(f"XGBoost model trained successfully. Validation loss: {val_loss:.4f}") | |
| except Exception: | |
| logging.exception("XGBoost training failed") | |
| return val_loss |
🧰 Tools
🪛 Ruff (0.12.2)
42-42: xgboost imported but unused
Remove unused import: xgboost
(F401)
🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/imblearn.py around lines 41 to 57, the
current fit_model may return val_loss when an exception occurs before it is
defined and contains unnecessary in-function imports; update the method to
remove the in-function imports, call self.model.fit inside try, compute
validation predictions using self.predict(val_data) and compute loss with
self.loss(val_labels, val_pred_proba) (or assign to self.loss if a callable
exists), ensure val_loss is always assigned before any return by returning a
default failure value (e.g., float("inf")) or re-raising after logging, and log
the caught exception details; keep the success path returning val_loss and the
failure path returning the safe default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
♻️ Duplicate comments (5)
icu_benchmarks/models/train.py (1)
253-254: Ensure consistent data source for train explainer values.There's an inconsistency that was previously flagged - for test values (line 230),
explainer_valuesis used directly, but for train values (line 255), the variable name suggests the same structure. Verify both have the same data structure.icu_benchmarks/models/wrappers.py (3)
49-51: Debug flag should default to False in production.Setting
debug = Trueby default means all models will save outputs to disk by default, which may not be intended for production use. This is inconsistent with the CLI default (args.debugdefaults to False).
116-133: Fix inconsistent logging message and improve code maintainability.
- Line 133 logs "row_indicators.csv" but saves to "pred_indicators.csv"
- Remove large blocks of commented code (lines 134-140)
147-153: Fix format string mismatch in static dataset case.The format string
"%d,%d,%0.3f,%0.3f"has 4 format specifiers but the header "id,ground_truth,prediction_0,prediction_1" suggests 4 columns with different types. The second%dshould likely be a float format.header="id,ground_truth,prediction_0,prediction_1", - fmt="%d,%d,%0.3f,%0.3f", + fmt="%d,%.3f,%.3f,%.3f",icu_benchmarks/models/ml_models/xgboost.py (1)
76-82: Remove unusedlabelsparameter from_explain_model.The
labelsparameter is passed to the method but not used. The parent class signature might require it, butshap_valuesmethod doesn't need labels.
🧹 Nitpick comments (9)
icu_benchmarks/models/train.py (3)
204-205: Consider renamingpersist_shap_datato align with the new explainer terminology.The function now persists "explainer values" rather than "shap values" as indicated by the logging message and variable names. Consider renaming to
persist_explainer_datafor consistency.- if explain_features: - persist_shap_data(trainer, log_dir) + if explain_features: + persist_explainer_data(trainer, log_dir)And update the function definition:
-def persist_shap_data(trainer: Trainer, log_dir: Path, save_full_valuesets=True) -> None: +def persist_explainer_data(trainer: Trainer, log_dir: Path, save_full_valuesets=True) -> None:
221-225: Simplify the column filtering logic for better readability.The current implementation could be more concise and easier to understand.
- if ( - any(x in ["stay_id", "id"] for x in trained_columns) - and len(trained_columns) != trainer.lightning_module.explainer_values_test.shape[1] - ): - trained_columns.remove("stay_id" if "stay_id" in trained_columns else "id") + # Remove ID columns if shape mismatch indicates they're not in the explainer values + if len(trained_columns) != trainer.lightning_module.explainer_values_test.shape[1]: + id_cols = [col for col in ["stay_id", "id"] if col in trained_columns] + if id_cols: + trained_columns.remove(id_cols[0])
237-243: Clarify and adjustrep_test/label_testsaving conditionsCurrently, within
persist_shap_data(trainer, log_dir, save_full_valuesets=True)(icu_benchmarks/models/train.py:232–242), bothrep_testandlabel_testare written to disk only whensave_full_valuesetsis true (which it always is, because no call ever overrides the default) and when those attributes exist ontrainer.lightning_module.
- Because
persist_shap_datais only ever invoked aspersist_shap_data(trainer, log_dir)(with no explicitsave_full_valuesets=False), these files will always be saved wheneverexplainer_values_testandrep_testexist.- If the intention is to control representation persistence via the module’s own
persist_repsflag (set inwrappers.py), consider replacingif save_full_valuesets:withif trainer.lightning_module.persist_reps:.- If instead you want to save
rep_testandlabel_testindependently of “full” SHAP values, move their write logic out of thesave_full_valuesetsblock so they run whenever those attributes exist.Please review and adjust the conditional to match your desired behavior.
icu_benchmarks/models/wrappers.py (4)
113-114: Consider making_explain_modelabstract or providing a default implementation.Currently, the base implementation raises
NotImplementedError. Consider making this an abstract method using@abstractmethoddecorator, or provide a default implementation that returns None with a warning, to make the interface clearer.+ from abc import abstractmethod + + @abstractmethod def _explain_model(self, reps, labels): - raise NotImplementedError(f"Model {self.__class__.__name__} does not currently support feature explanation.") + """Abstract method for model explanation. Must be implemented by subclasses that support feature explanation.""" + passOr alternatively, provide a default implementation:
def _explain_model(self, reps, labels): - raise NotImplementedError(f"Model {self.__class__.__name__} does not currently support feature explanation.") + logging.warning(f"Model {self.__class__.__name__} does not currently support feature explanation.") + return None
174-174: Remove unnecessary f-string prefix.Static analysis correctly identified an f-string without any placeholders.
- logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / f'row_indicators.csv'}") + logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / 'row_indicators.csv'}")
545-560: Consider consolidating the log_curves and log_metrics calls.The test_step method now calls both
log_curvesandlog_metricswith the same parameters. Consider whether these could be unified or if the separation is intentional for clarity.If the separation is not required, consider creating a single method that handles both:
- self.log_curves(test_label, test_pred, "test", pred_indicators) if self.debug: self._save_model_outputs(pred_indicators, test_pred, test_label) if self.explain_features: - # self.explain_model(test_rep, test_label) self.explainer_values_test = self._explain_model(test_rep, test_label) if self.persist_reps: self.rep_test = test_rep self.label_test = test_label + + # Log both curves and metrics + self.log_curves(test_label, test_pred, "test", pred_indicators) if self.mps: self.log("test/loss", np.float32(self.loss(test_label, test_pred)), sync_dist=True) self.log_metrics(np.float32(test_label), np.float32(test_pred), "test", pred_indicators)
597-621: Remove TODO comment or create an issue to track it.Line 607 contains a TODO comment about implementing alarm metrics. Either implement this functionality or create a GitHub issue to track this enhancement.
Would you like me to help implement the alarm metrics functionality using the row indicators, or should I create a GitHub issue to track this enhancement?
icu_benchmarks/models/ml_models/xgboost.py (2)
47-49: Consider using parameterized logging level.The logging level changed from debug to info for shape information. Consider making this configurable or consistent with the verbosity setting.
- logging.info(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}") - logging.info(train_labels) + logging.debug(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}") + if self.model.verbosity > 0: + logging.info(f"Training labels distribution: {train_labels}")
50-56: Clean up commented code or document why it's retained.There's commented code for conditional SHAP value computation. Either remove it or add a comment explaining why it's kept for reference.
- # if self.explain_features: self.explainer = shap.TreeExplainer( self.model, train_data, feature_perturbation="interventional", model_output="probability" ) - # if self.explain_features: - # logging.info("Explaining features") - # self.train_shap_values = self.explainer.shap_values(train_data)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (3)
icu_benchmarks/models/ml_models/xgboost.py(4 hunks)icu_benchmarks/models/train.py(3 hunks)icu_benchmarks/models/wrappers.py(11 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
icu_benchmarks/models/train.py (2)
icu_benchmarks/models/wrappers.py (1)
set_explain_features(109-111)icu_benchmarks/models/utils.py (1)
save_config_file(22-25)
icu_benchmarks/models/ml_models/xgboost.py (3)
icu_benchmarks/models/wrappers.py (8)
set_model_args(642-652)fit(492-516)fit(744-745)_explain_model(113-114)MLWrapper(449-652)predict(562-566)predict(750-756)fit_model(518-523)icu_benchmarks/constants.py (1)
RunMode(4-7)icu_benchmarks/models/ml_models/imblearn.py (1)
fit_model(41-57)
icu_benchmarks/models/wrappers.py (3)
icu_benchmarks/models/utils.py (2)
log_single_metric_to_file(290-315)name(159-160)icu_benchmarks/models/ml_models/xgboost.py (3)
_explain_model(76-82)fit_model(41-62)fit_model(115-148)icu_benchmarks/models/custom_metrics.py (1)
confusion_matrix(141-150)
🪛 Ruff (0.12.2)
icu_benchmarks/models/wrappers.py
174-174: f-string without any placeholders
Remove extraneous f prefix
(F541)
🔇 Additional comments (6)
icu_benchmarks/models/train.py (1)
185-186: LGTM! Good separation of concerns for explainer setup.Setting the explainer feature flag just before testing cleanly separates training from explanation, which is good for performance.
icu_benchmarks/models/wrappers.py (3)
109-114: LGTM! Clean implementation of feature explanation toggle.The method properly sets both
explain_featuresandpersist_repsflags together, ensuring consistent behavior.
504-506: LGTM! Clean integration of feature explanation during training.The conditional execution of
_explain_modelwhenexplain_featuresis enabled is well-placed and properly stores the results.
568-579: LGTM! Clean implementation of curve logging.The
log_curvesmethod properly identifies metrics that return tuples and saves them to CSV files for later analysis.icu_benchmarks/models/ml_models/xgboost.py (2)
26-26: LGTM! Good addition of explicit XGBoost parameters.Setting
eval_metric=log_loss,device="cpu", andverbosity=0explicitly ensures consistent behavior across different environments.
66-72: LGTM! Better approach to get valid XGBoost parameters.Using
model().get_params().keys()is more reliable than inspecting the constructor signature for XGBoost models.
| @gin.configurable | ||
| class XGBClassifierGPU(MLWrapper): | ||
| _supported_run_modes = [RunMode.classification] | ||
| _explain_values = False | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| # self.model = self.set_model_args( | ||
| # xgb, *args, **kwargs, eval_metric="logloss", tree_method="hist", device="cuda", verbosity=0 | ||
| # ) | ||
| self.model = xgb | ||
| self.params = { | ||
| "eval_metric": "logloss", | ||
| "tree_method": "hist", | ||
| "device": "cuda", | ||
| "verbosity": 0, | ||
| **kwargs, | ||
| } | ||
| super().__init__(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Document the GPU implementation status and clean up initialization.
The GPU class has commented-out model initialization code. Either complete the implementation or document its current status.
@gin.configurable
class XGBClassifierGPU(MLWrapper):
+ """GPU-accelerated XGBoost classifier using native XGBoost API.
+
+ Note: This implementation uses the native XGBoost train API rather than
+ the scikit-learn compatible API for better GPU performance control.
+ """
_supported_run_modes = [RunMode.classification]
_explain_values = False
def __init__(self, *args, **kwargs):
- # self.model = self.set_model_args(
- # xgb, *args, **kwargs, eval_metric="logloss", tree_method="hist", device="cuda", verbosity=0
- # )
self.model = xgb
self.params = {
"eval_metric": "logloss",📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @gin.configurable | |
| class XGBClassifierGPU(MLWrapper): | |
| _supported_run_modes = [RunMode.classification] | |
| _explain_values = False | |
| def __init__(self, *args, **kwargs): | |
| # self.model = self.set_model_args( | |
| # xgb, *args, **kwargs, eval_metric="logloss", tree_method="hist", device="cuda", verbosity=0 | |
| # ) | |
| self.model = xgb | |
| self.params = { | |
| "eval_metric": "logloss", | |
| "tree_method": "hist", | |
| "device": "cuda", | |
| "verbosity": 0, | |
| **kwargs, | |
| } | |
| super().__init__(*args, **kwargs) | |
| @gin.configurable | |
| class XGBClassifierGPU(MLWrapper): | |
| """GPU-accelerated XGBoost classifier using native XGBoost API. | |
| Note: This implementation uses the native XGBoost train API rather than | |
| the scikit-learn compatible API for better GPU performance control. | |
| """ | |
| _supported_run_modes = [RunMode.classification] | |
| _explain_values = False | |
| def __init__(self, *args, **kwargs): | |
| self.model = xgb | |
| self.params = { | |
| "eval_metric": "logloss", | |
| "tree_method": "hist", | |
| "device": "cuda", | |
| "verbosity": 0, | |
| **kwargs, | |
| } | |
| super().__init__(*args, **kwargs) |
🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/xgboost.py around lines 84 to 101, the GPU
classifier contains commented-out initialization and leaves self.model set to
the xgb module while params are stored separately; either restore a proper
GPU-ready model construction or explicitly mark the GPU implementation as
not-initialized. Fix by replacing the commented block with a concrete
initialization that constructs an XGBoost estimator (e.g., call
xgb.XGBClassifier or xgb.XGBModel with the params merged in) and assign that
instance to self.model, or if GPU support is not yet available, remove the
commented code, add a clear docstring or comment indicating GPU mode is
unimplemented, and raise NotImplementedError in __init__ (or fall back to CPU)
so behavior is explicit and the codebase has no stale commented-out
initialization.
| Returns: | ||
| numpy.ndarray: Predicted probabilities for each class. | ||
| """ | ||
| return self.model.predict(xgb.DMatrix(features)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add error handling for GPU-specific operations.
The predict method should handle potential GPU-related errors gracefully.
def predict(self, features):
"""
Predicts class probabilities for the given features.
Args:
features: Input features for prediction.
Returns:
numpy.ndarray: Predicted probabilities for each class.
"""
- return self.model.predict(xgb.DMatrix(features))
+ try:
+ dmatrix = xgb.DMatrix(features)
+ return self.model.predict(dmatrix)
+ except Exception as e:
+ logging.error(f"GPU prediction failed: {e}")
+ raise📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| return self.model.predict(xgb.DMatrix(features)) | |
| def predict(self, features): | |
| """ | |
| Predicts class probabilities for the given features. | |
| Args: | |
| features: Input features for prediction. | |
| Returns: | |
| numpy.ndarray: Predicted probabilities for each class. | |
| """ | |
| try: | |
| dmatrix = xgb.DMatrix(features) | |
| return self.model.predict(dmatrix) | |
| except Exception as e: | |
| logging.error(f"GPU prediction failed: {e}") | |
| raise |
🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/xgboost.py around line 113, the return of
self.model.predict(xgb.DMatrix(features)) lacks GPU-specific error handling;
wrap the predict call in a try/except that catches XGBoost GPU-related errors
(e.g., xgboost.core.XGBoostError) and general Exception, log a clear warning
about falling back to CPU, then retry prediction using a CPU fallback (for
example recreate the DMatrix for CPU and call model.predict with
predictor='cpu_predictor' or reset model params to use CPU) and if that also
fails re-raise the exception so callers can handle it.
|
|
||
| if wandb.run is not None: | ||
| callbacks.append(wandb_xgb()) | ||
| self.model.train(self.params, train_data=dtrain, evals=evals, callbacks=callbacks) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix incorrect API usage in GPU fit_model.
The self.model.train() call is incorrect - xgb.train() is a module-level function, not a method. The model should be assigned from the train function's return value.
- self.model.train(self.params, train_data=dtrain, evals=evals, callbacks=callbacks)
+ self.model = xgb.train(self.params, dtrain=dtrain, evals=evals, callbacks=callbacks)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| self.model.train(self.params, train_data=dtrain, evals=evals, callbacks=callbacks) | |
| self.model = xgb.train(self.params, dtrain=dtrain, evals=evals, callbacks=callbacks) |
🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/xgboost.py around line 136, the code calls
self.model.train(...), but xgboost's train is a module-level function and
returns the trained Booster; replace the method call with an assignment using
xgb.train(...) and set self.model = xgb.train(self.params, train_data=dtrain,
evals=evals, callbacks=callbacks) (preserve existing args/kwargs), removing the
erroneous method invocation so self.model holds the returned trained model.
| # logging.info("Explaining features") | ||
| # self.train_shap_values = self.explainer.shap_values(dtrain) | ||
|
|
||
| eval_score = mean(next(iter(self.model.evals_result()["validation"].values()))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix method call for getting evaluation results.
The evals_result() should be called on the trained booster, not as shown.
- eval_score = mean(next(iter(self.model.evals_result()["validation"].values())))
+ # Get the evaluation history from the last callback or the model's evals_result
+ evals_result = self.model.evals_result() if hasattr(self.model, 'evals_result') else {}
+ if "validation" in evals_result and evals_result["validation"]:
+ eval_score = mean(next(iter(evals_result["validation"].values())))
+ else:
+ logging.warning("No validation metrics found")
+ eval_score = float('inf')📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| eval_score = mean(next(iter(self.model.evals_result()["validation"].values()))) | |
| # Get the evaluation history from the last callback or the model's evals_result | |
| evals_result = self.model.evals_result() if hasattr(self.model, 'evals_result') else {} | |
| if "validation" in evals_result and evals_result["validation"]: | |
| eval_score = mean(next(iter(evals_result["validation"].values()))) | |
| else: | |
| logging.warning("No validation metrics found") | |
| eval_score = float('inf') |
🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/xgboost.py around line 147, the code calls
evals_result() on the estimator object; instead call it on the trained booster.
Replace self.model.evals_result() with self.model.get_booster().evals_result(),
and keep the existing extraction of the validation metric (e.g.,
mean(next(iter(...))) ), ensuring get_booster() is used to access
evals_result().
| def persist_shap_data(trainer: Trainer, log_dir: Path, save_full_valuesets=True) -> None: | ||
| """ | ||
| Persist shap values to disk. | ||
| Args: | ||
| trainer: Pytorch lightning trainer object | ||
| log_dir: Log directory | ||
| """ | ||
| logging.info("Persisting explainer values to disk.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Update docstring to reflect the new function behavior.
The docstring still mentions "shap values" but the function now handles "explainer values" as shown in the logging and variable names.
def persist_shap_data(trainer: Trainer, log_dir: Path, save_full_valuesets=True) -> None:
"""
- Persist shap values to disk.
+ Persist explainer values to disk.
Args:
trainer: Pytorch lightning trainer object
log_dir: Log directory
+ save_full_valuesets: Whether to save full explainer values (default: True)
"""📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def persist_shap_data(trainer: Trainer, log_dir: Path, save_full_valuesets=True) -> None: | |
| """ | |
| Persist shap values to disk. | |
| Args: | |
| trainer: Pytorch lightning trainer object | |
| log_dir: Log directory | |
| """ | |
| logging.info("Persisting explainer values to disk.") | |
| def persist_shap_data(trainer: Trainer, log_dir: Path, save_full_valuesets=True) -> None: | |
| """ | |
| Persist explainer values to disk. | |
| Args: | |
| trainer: Pytorch lightning trainer object | |
| log_dir: Log directory | |
| save_full_valuesets: Whether to save full explainer values (default: True) | |
| """ | |
| logging.info("Persisting explainer values to disk.") |
🤖 Prompt for AI Agents
In icu_benchmarks/models/train.py around lines 210 to 217, the function
docstring is outdated — it refers to "shap values" while the function now
persists "explainer values" and adds the parameter save_full_valuesets; update
the docstring to accurately describe that the function persists explainer (e.g.,
SHAP or other explainer) values to disk, list and describe the parameters
trainer: Trainer, log_dir: Path, save_full_valuesets: bool (default True) and
their meanings, and mention the function returns None; keep the brief summary
and Args/Returns sections consistent with the new logging and variable names.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (4)
icu_benchmarks/models/ml_models/xgboost.py (4)
86-104: GPU classifier not constructing a model; add docstring and initialize params; set objective.
self.model = xgbis a module reference, not a model. Default an objective; Booster gets created in fit.@gin.configurable class XGBClassifierGPU(MLWrapper): + """GPU-accelerated XGBoost classifier using native xgboost.train API.""" @@ - self.model = xgb - self.params = { + self.model = None # Will hold the trained Booster + self.params = { + "objective": "binary:logistic", "eval_metric": "logloss", "tree_method": "hist", "device": "cuda", "verbosity": 0, **kwargs, }
105-115: Predict uses module instead of trained Booster; add error handling and return NxK probs.Return shape should align with
predict_proba; add GPU error handling.- return self.model.predict(xgb.DMatrix(features)) + if self.model is None: + raise ValueError("Model not trained. Call fit() first.") + try: + dmatrix = xgb.DMatrix(features) + preds = self.model.predict(dmatrix) + except xgb.core.XGBoostError as e: + logging.error(f"GPU prediction failed: {e}") + raise + # Ensure 2D probabilities for binary classification + if preds.ndim == 1: + preds = np.column_stack([1.0 - preds, preds]) + return preds
134-138: Train viaxgb.train, set objective/num_boost_round, and assign returned Booster.Current call is wrong and ignores rounds; set objective based on class count.
- self.model.train(self.params, train_data=dtrain, evals=evals, callbacks=callbacks) + # Configure objective / classes + num_classes = int(len(np.unique(train_labels))) + params = dict(self.params) + if num_classes > 2: + params.update({"objective": "multi:softprob", "num_class": num_classes}) + else: + params.update({"objective": "binary:logistic"}) + num_boost_round = int(params.pop("num_boost_round", params.pop("n_estimators", 100))) + self.model = xgb.train(params, dtrain, num_boost_round=num_boost_round, evals=evals, callbacks=callbacks)
149-150: Fix evals_result access; take the final validation metric.Call on Booster and read the last value.
- eval_score = mean(next(iter(self.model.evals_result()["validation"].values()))) - return eval_score + evals_result = self.model.evals_result() + if "validation" in evals_result and evals_result["validation"]: + metric = next(iter(evals_result["validation"])) + eval_score = evals_result["validation"][metric][-1] + else: + logging.warning("No validation metrics found") + eval_score = float("inf") + return eval_score
🧹 Nitpick comments (3)
icu_benchmarks/models/ml_models/xgboost.py (3)
50-55: Gate SHAP explainer creation behindexplain_featuresand make sampling reproducible.Avoid unnecessary overhead and non-determinism.
- n_samples = min(1000, len(train_data)) - indices = np.random.choice(len(train_data), size=n_samples, replace=False) - background_sample = train_data[indices] - self.explainer = shap.TreeExplainer( - self.model, background_sample, feature_perturbation="interventional", model_output="probability" - ) - # if self.explain_features: - # logging.info("Explaining features") - # self.train_shap_values = self.explainer.shap_values(train_data) + if self.explain_features: + n_samples = min(1000, len(train_data)) + rng = np.random.default_rng(getattr(self.hparams, "random_state", None)) + indices = rng.choice(len(train_data), size=n_samples, replace=False) + background_sample = train_data[indices] + self.explainer = shap.TreeExplainer( + self.model, background_sample, feature_perturbation="interventional", model_output="probability" + )Also applies to: 56-58
63-64: Report the final/best validation metric, not the mean across rounds.Mean hides the true end-of-training performance.
- eval_score = mean(next(iter(self.model.evals_result_["validation_0"].values()))) - return eval_score # , callbacks=callbacks) + evals = self.model.evals_result_.get("validation_0", {}) + metric = next(iter(evals)) if evals else None + eval_score = evals[metric][-1] if metric else float("inf") + return eval_score
73-75: Clarify warning about dropped kwargs.Minor wording tweak to make logs clearer.
- if len(valid_kwargs) == 0: - logging.warning("No valid arguments passed to XGBoost") + if len(valid_kwargs) == 0: + logging.warning("No valid XGBoost hyperparameters found in kwargs")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
icu_benchmarks/models/ml_models/xgboost.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
icu_benchmarks/models/ml_models/xgboost.py (1)
icu_benchmarks/models/wrappers.py (9)
set_model_args(642-652)fit(492-516)fit(744-745)_explain_model(113-114)MLWrapper(449-652)predict(562-566)predict(750-756)fit_model(518-523)_explain_model(488-492)
🔇 Additional comments (2)
icu_benchmarks/models/ml_models/xgboost.py (2)
68-70: LGTM: deriving valid kwargs fromget_params()is the right call.
117-131: DMatrix construction LGTM.Inputs are correctly wrapped for native training.
| logging.info(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}") | ||
| logging.info(train_labels) | ||
| self.model.fit(train_data, train_labels, eval_set=[(val_data, val_labels)], verbose=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Enable early stopping (callbacks currently unused) and avoid logging raw labels.
Wire callbacks or early_stopping_rounds, and don’t log full label arrays.
- logging.info(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}")
- logging.info(train_labels)
- self.model.fit(train_data, train_labels, eval_set=[(val_data, val_labels)], verbose=0)
+ logging.debug(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}")
+ self.model.fit(
+ train_data,
+ train_labels,
+ eval_set=[(val_data, val_labels)],
+ early_stopping_rounds=self.hparams.patience,
+ callbacks=callbacks,
+ verbose=False,
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| logging.info(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}") | |
| logging.info(train_labels) | |
| self.model.fit(train_data, train_labels, eval_set=[(val_data, val_labels)], verbose=0) | |
| logging.debug(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}") | |
| self.model.fit( | |
| train_data, | |
| train_labels, | |
| eval_set=[(val_data, val_labels)], | |
| early_stopping_rounds=self.hparams.patience, | |
| callbacks=callbacks, | |
| verbose=False, | |
| ) |
| def _explain_model(self, reps, labels): | ||
| if not hasattr(self.model, "feature_importances_"): | ||
| raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.") | ||
| return self.model.feature_importances_ | ||
| # feature_importances = self.model.feature_importances_ | ||
| shap_values = self.explainer.shap_values(reps, labels) | ||
| # feature_importances = np.abs(shap_values).mean(axis=1) | ||
| return shap_values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Fix explainer precondition and SHAP API usage.
Check for self.explainer instead of feature_importances_; avoid passing labels to shap_values (API-compat issues).
def _explain_model(self, reps, labels):
- if not hasattr(self.model, "feature_importances_"):
- raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.")
- # feature_importances = self.model.feature_importances_
- shap_values = self.explainer.shap_values(reps, labels)
- # feature_importances = np.abs(shap_values).mean(axis=1)
- return shap_values
+ if not hasattr(self, "explainer"):
+ raise ValueError("Explainer not initialized. Call fit_model() with explain_features=True.")
+ return self.explainer.shap_values(reps)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _explain_model(self, reps, labels): | |
| if not hasattr(self.model, "feature_importances_"): | |
| raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.") | |
| return self.model.feature_importances_ | |
| # feature_importances = self.model.feature_importances_ | |
| shap_values = self.explainer.shap_values(reps, labels) | |
| # feature_importances = np.abs(shap_values).mean(axis=1) | |
| return shap_values | |
| def _explain_model(self, reps, labels): | |
| if not hasattr(self, "explainer"): | |
| raise ValueError("Explainer not initialized. Call fit_model() with explain_features=True.") | |
| return self.explainer.shap_values(reps) |
🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/xgboost.py around lines 78 to 84, the
precondition and SHAP call are wrong: check for the explainer (e.g., if not
hasattr(self, "explainer") or self.explainer is None) and raise a clear
ValueError instructing to call the explainer creation step before explaining;
then call self.explainer.shap_values(reps) without passing labels (remove the
labels argument) to match the SHAP API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 24
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
icu_benchmarks/models/wrappers.py (1)
569-579: Bug: validation_step miscalls log_metrics and omits pred_indicators.This will raise a TypeError and skip proper metric logging.
- self.log_metrics("val/loss", self.loss(val_label, val_pred), sync_dist=True) + self.log("val/loss", self.loss(val_label, val_pred), sync_dist=True) logging.info(f"Val loss: {self.loss(val_label, val_pred)}") - self.log_metrics(val_label, val_pred, "val") + self.log_metrics(val_label, val_pred, "val", row_indicators)
♻️ Duplicate comments (6)
icu_benchmarks/models/ml_models/xgboost.py (3)
79-85: Fix explainer precondition and SHAP API usage.Check for explainer, not feature_importances_; don’t pass labels to shap_values.
- def _explain_model(self, reps, labels): - if not hasattr(self.model, "feature_importances_"): - raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.") - # feature_importances = self.model.feature_importances_ - shap_values = self.explainer.shap_values(reps, labels) - # feature_importances = np.abs(shap_values).mean(axis=1) - return shap_values + def _explain_model(self, reps, labels): + if not hasattr(self, "explainer") or self.explainer is None: + raise ValueError("Explainer not initialized. Set explain_features=True when fitting.") + return self.explainer.shap_values(reps)
45-50: Don’t log full label arrays; wire early stopping callbacks.Avoid printing labels (PII/verbosity) and pass callbacks to fit.
- logging.info(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}") - logging.info(train_labels) - self.model.fit(train_data, train_labels, eval_set=[(val_data, val_labels)], verbose=0) + logging.debug(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}") + self.model.fit( + train_data, + train_labels, + eval_set=[(val_data, val_labels)], + callbacks=callbacks, + verbose=False, + )
118-151: Fix xgb.train usage and eval extraction; assign the Booster.Use module-level xgb.train, assign to self.model, and safely read evals_result.
- callbacks = [EarlyStopping(self.hparams.patience)] + callbacks = [EarlyStopping(self.hparams.patience)] @@ - if wandb.run is not None: - callbacks.append(wandb_xgb()) - self.model.train(self.params, train_data=dtrain, evals=evals, callbacks=callbacks) + if wandb.run is not None: + callbacks.append(wandb_xgb()) + self.model = xgb.train(self.params, dtrain=dtrain, evals=evals, callbacks=callbacks) @@ - eval_score = mean(next(iter(self.model.evals_result()["validation"].values()))) + evals_result = self.model.evals_result() if hasattr(self.model, "evals_result") else {} + if "validation" in evals_result and evals_result["validation"]: + eval_score = mean(next(iter(evals_result["validation"].values()))) + else: + logging.warning("No validation metrics found for GPU model.") + eval_score = float("inf")icu_benchmarks/models/wrappers.py (3)
410-414: Remove duplicate valid_idx computation.Identical assignment appears twice back-to-back.
- valid_idx = mask.detach().cpu().numpy().astype(bool) - - # valid_idx is (batch, seq_len) boolean mask - valid_idx = mask.detach().cpu().numpy().astype(bool) + # valid_idx is (batch, seq_len) boolean mask + valid_idx = mask.detach().cpu().numpy().astype(bool)
632-673: Tests needed for new pred_indicators branches in log_metrics.Edge cases (temporal vs. flat, shape mismatches) aren’t covered.
I can add unit tests covering:
- pred_indicators=None
- 2-D indicators matching pred (temporal)
- Non-matching shapes path
- Classification vs regression
Say the word and I’ll open a PR-ready test module.
50-52: Default debug=True causes unintended disk I/O; gate saves behind debug and revert default.Persisting outputs by default is risky (large files, slower runs) and inconsistent with CLI defaults. Also,
_save_model_outputsis called unconditionally in DL, bypassing debug.Apply:
- debug = True + debug = False- # Save predictions to file - self._save_model_outputs(indicators_np, transformed_predictors, target_np) + # Save predictions to file + if self.debug: + self._save_model_outputs(indicators_np, transformed_predictors, target_np)Also applies to: 433-435
🧹 Nitpick comments (18)
experiments/benchmark_cass_baselines.yml (2)
50-53: Normalize file_names to a proper dict string.Safer to pass a JSON-like dict (what parse_dict typically expects).
- - '"OUTCOME":"outc.parquet","STATIC":"sta.parquet"' + - '{"OUTCOME":"outc.parquet","STATIC":"sta.parquet"}'If the parser intentionally accepts the current format, ignore this change; otherwise adopt braces consistently across all experiment files.
64-64: Add trailing newline.Fixes YAML lint warning.
experiments/benchmark_sex.yml (2)
50-53: Use braces for file_names dict string.- - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"' + - '{"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"}'
60-60: Add trailing newline.experiments/benchmark_cass_modalities.yml (2)
60-61: Wrap file_names mapping in braces for consistency.- - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"' + - '{"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"}'
68-68: Add trailing newline.experiments/benchmark_cass_wearable_modalities.yml (2)
66-68: Consistent dict format for file_names.- - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"' + - '{"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"}'
73-73: Add trailing newline.experiments/benchmark_cass_segment_length.yml (2)
66-69: Use braces around file_names mapping.- - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' + - '{"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"}'
79-79: Add trailing newline.experiments/benchmark_cass_model_architecture.yml (2)
59-59: Fix YAML indentation (yamllint).Indentation under values should be 6 spaces, not 8.
- - [ top_500_features ] + - [ top_500_features ]
74-74: Add trailing newline (yamllint).-program: icu-benchmarks +program: icu-benchmarks +experiments/benchmark_cass_explain_modalities.yml (1)
69-69: Add trailing newline (yamllint).-program: icu-benchmarks +program: icu-benchmarks +experiments/top_features_benchmark_cass.yml (1)
88-89: Add trailing newline.-program: icu-benchmarks +program: icu-benchmarks +icu_benchmarks/models/ml_models/xgboost.py (1)
90-91: Annotate class var per Ruff (RUF012).Mark _explain_values as ClassVar[bool] or remove if unused.
-from icu_benchmarks.models.wrappers import MLWrapper +from icu_benchmarks.models.wrappers import MLWrapper +from typing import ClassVar @@ - _explain_values = False + _explain_values: ClassVar[bool] = Falseicu_benchmarks/models/wrappers.py (3)
539-547: Clean up calibration toggles; use the method variable and remove duplicate assignment.Avoid unused local (method) and duplicate
calibrateassignment.- val_loss = self.fit_model(train_rep, train_label, val_rep, val_label) - calibrate = True - method = "isotonic" - # Apply calibration if desired - calibrate = True - if calibrate and self.run_mode == RunMode.classification: - cal_val_loss = self.setup_calibration(val_rep, val_label, method='isotonic') + val_loss = self.fit_model(train_rep, train_label, val_rep, val_label) + calibrate = True + method = "isotonic" + # Apply calibration if desired + if calibrate and self.run_mode == RunMode.classification: + cal_val_loss = self.setup_calibration(val_rep, val_label, method=method)
620-631: Silence ARG002: pred_indicators unused in log_curves.Rename the arg to underscore to satisfy linters without changing behavior.
- def log_curves(self, label, pred, metric_type, pred_indicators): + def log_curves(self, label, pred, metric_type, _pred_indicators=None):
381-399: Clarify/align DataLoader contract in DL step_fn.Currently enforces 3 or 4 elements, but the exception message mentions older 2-tuple signatures. Either re-support the 2-tuple path (with pred_indicators=None) or update the message.
If you only intend to support indicators now, change:
- raise Exception("Loader should return either (data, label) or (data, label, mask)") + raise Exception("Loader should return (data, label, indicators) or (data, label, mask, indicators)")Otherwise, I can send a patch to handle 2-tuples by setting indicators=None and skipping saves/indicator-based logic.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (11)
experiments/benchmark_cass.yml(1 hunks)experiments/benchmark_cass_baselines.yml(1 hunks)experiments/benchmark_cass_explain_modalities.yml(1 hunks)experiments/benchmark_cass_modalities.yml(1 hunks)experiments/benchmark_cass_model_architecture.yml(1 hunks)experiments/benchmark_cass_segment_length.yml(1 hunks)experiments/benchmark_cass_wearable_modalities.yml(1 hunks)experiments/benchmark_sex.yml(1 hunks)experiments/top_features_benchmark_cass.yml(1 hunks)icu_benchmarks/models/ml_models/xgboost.py(3 hunks)icu_benchmarks/models/wrappers.py(12 hunks)
🧰 Additional context used
🧬 Code graph analysis (10)
experiments/benchmark_cass_wearable_modalities.yml (1)
icu_benchmarks/run.py (1)
main(33-198)
experiments/benchmark_sex.yml (1)
icu_benchmarks/run.py (1)
main(33-198)
experiments/top_features_benchmark_cass.yml (1)
icu_benchmarks/run.py (1)
main(33-198)
experiments/benchmark_cass_baselines.yml (1)
icu_benchmarks/run.py (1)
main(33-198)
experiments/benchmark_cass_model_architecture.yml (1)
icu_benchmarks/run.py (1)
main(33-198)
experiments/benchmark_cass_segment_length.yml (2)
icu_benchmarks/run.py (1)
main(33-198)icu_benchmarks/wandb_utils.py (1)
set_wandb_experiment_name(54-78)
experiments/benchmark_cass_explain_modalities.yml (1)
icu_benchmarks/run.py (1)
main(33-198)
experiments/benchmark_cass.yml (1)
icu_benchmarks/run.py (1)
main(33-198)
icu_benchmarks/models/ml_models/xgboost.py (1)
icu_benchmarks/models/wrappers.py (7)
set_model_args(694-704)fit_model(562-567)fit(529-560)fit(796-797)_explain_model(114-115)predict(608-619)predict(802-808)
icu_benchmarks/models/wrappers.py (6)
icu_benchmarks/models/utils.py (2)
log_single_metric_to_file(290-315)name(159-160)icu_benchmarks/models/ml_models/xgboost.py (1)
predict(106-116)icu_benchmarks/constants.py (1)
RunMode(4-7)icu_benchmarks/models/ml_models/lgbm.py (1)
predict(36-46)icu_benchmarks/models/ml_models/catboost.py (1)
predict(16-26)icu_benchmarks/models/custom_metrics.py (1)
confusion_matrix(141-150)
🪛 YAMLlint (1.37.1)
experiments/benchmark_cass_wearable_modalities.yml
[error] 73-73: no new line character at the end of file
(new-line-at-end-of-file)
experiments/benchmark_sex.yml
[error] 60-60: no new line character at the end of file
(new-line-at-end-of-file)
experiments/benchmark_cass_baselines.yml
[error] 64-64: no new line character at the end of file
(new-line-at-end-of-file)
experiments/benchmark_cass_model_architecture.yml
[warning] 59-59: wrong indentation: expected 6 but found 8
(indentation)
[error] 74-74: no new line character at the end of file
(new-line-at-end-of-file)
experiments/benchmark_cass_segment_length.yml
[error] 79-79: no new line character at the end of file
(new-line-at-end-of-file)
experiments/benchmark_cass_explain_modalities.yml
[error] 69-69: no new line character at the end of file
(new-line-at-end-of-file)
experiments/benchmark_cass_modalities.yml
[error] 68-68: no new line character at the end of file
(new-line-at-end-of-file)
🪛 Ruff (0.12.2)
icu_benchmarks/models/ml_models/xgboost.py
81-81: Avoid specifying long messages outside the exception class
(TRY003)
89-89: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
icu_benchmarks/models/wrappers.py
175-175: f-string without any placeholders
Remove extraneous f prefix
(F541)
541-541: Local variable method is assigned to but never used
Remove assignment to unused variable method
(F841)
620-620: Unused method argument: pred_indicators
(ARG002)
| command: | ||
| - ${env} | ||
| - ${program} | ||
| - train | ||
| - -d | ||
| - ../data/ | ||
| - -t | ||
| # - BinaryClassification | ||
| - CassClassification | ||
| - --log-dir | ||
| - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs | ||
| - --tune | ||
| - --wandb-sweep | ||
| - -tn | ||
| - SSI | ||
| - --hp-checkpoint | ||
| - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/dataset_baseline_cohort_post-operative_static_2025-08-01T11:00:26/SSI/XGBClassifier/2025-08-01T11-08-08.547141/hyperparameter_tuning_logs.db" | ||
| - --modalities | ||
| - "all" | ||
| - --file-names | ||
| - '"DYNAMIC":"dyn.parquet","STATIC":"sta.parquet"' | ||
| - --explain_features | ||
| # - --verbose | ||
| - --load_data_vars |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sweeps won’t receive parameters: missing ${args} in command.
Without ${args}, W&B won’t inject sweep parameters; runs will ignore the grid under parameters. Add it to the command.
- --load_data_vars
+ - ${args}
method: grid📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| command: | |
| - ${env} | |
| - ${program} | |
| - train | |
| - -d | |
| - ../data/ | |
| - -t | |
| # - BinaryClassification | |
| - CassClassification | |
| - --log-dir | |
| - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs | |
| - --tune | |
| - --wandb-sweep | |
| - -tn | |
| - SSI | |
| - --hp-checkpoint | |
| - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/dataset_baseline_cohort_post-operative_static_2025-08-01T11:00:26/SSI/XGBClassifier/2025-08-01T11-08-08.547141/hyperparameter_tuning_logs.db" | |
| - --modalities | |
| - "all" | |
| - --file-names | |
| - '"DYNAMIC":"dyn.parquet","STATIC":"sta.parquet"' | |
| - --explain_features | |
| # - --verbose | |
| - --load_data_vars | |
| command: | |
| - ${env} | |
| - ${program} | |
| - train | |
| - -d | |
| - ../data/ | |
| - -t | |
| # - BinaryClassification | |
| - CassClassification | |
| - --log-dir | |
| - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs | |
| - --tune | |
| - --wandb-sweep | |
| - -tn | |
| - SSI | |
| - --hp-checkpoint | |
| - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/dataset_baseline_cohort_post-operative_static_2025-08-01T11:00:26/SSI/XGBClassifier/2025-08-01T11-08-08.547141/hyperparameter_tuning_logs.db" | |
| - --modalities | |
| - "all" | |
| - --file-names | |
| - '"DYNAMIC":"dyn.parquet","STATIC":"sta.parquet"' | |
| - --explain_features | |
| # - --verbose | |
| - --load_data_vars | |
| - ${args} | |
| method: grid |
🤖 Prompt for AI Agents
In experiments/benchmark_cass_baselines.yml around lines 1-24, the command list
is missing the ${args} placeholder so W&B sweeps cannot inject parameter values;
add a new list entry "- ${args}" immediately after "- ${program}" in the command
sequence so sweep parameters are expanded into the CLI before the fixed
subcommands/flags.
| - -d | ||
| - ../data/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Remove fixed flags duplicated by sweep parameters (they shadow/contradict the grid).
-d/--modalities/--file-names are already provided via parameters; keep only one source of truth.
- - -d
- - ../data/
@@
- - --modalities
- - "all"
- - --file-names
- - '"DYNAMIC":"dyn.parquet","STATIC":"sta.parquet"'Also applies to: 18-21
🤖 Prompt for AI Agents
In experiments/benchmark_cass_baselines.yml around lines 5-6 (and similarly
lines 18-21), the fixed CLI flags (-d, --modalities, --file-names, etc.) are
duplicated while the sweep parameters already provide those values; remove the
redundant fixed flags so there is a single source of truth: delete the -d
../data/ entries (and the corresponding duplicated modality/file-name flags at
lines 18-21) from the YAML so the sweep parameters control those options
exclusively.
| - -d | ||
| - ../data/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Grid parameters not wired into command.
Hardcoded data_dir, modalities, and file_names prevent the sweep grid from taking effect.
- - -d
- - ../data/
+ - -d
+ - ${data_dir}
@@
- - --modalities
- - "all"
+ - --modalities
+ - ${modalities}
@@
- - --file-names
- - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"'
+ - --file-names
+ - ${file_names}Also applies to: 19-23
🤖 Prompt for AI Agents
In experiments/benchmark_cass_explain_modalities.yml around lines 5-6 (and
similarly lines 19-23) the YAML hardcodes -d ../data/ plus fixed modalities and
file_names which overrides the sweep grid; change those hardcoded values to
reference the grid parameters (remove the fixed -d entry and replace
modalities/file_names entries with the grid variable placeholders or anchors
used elsewhere in the file) so that the sweep can inject data_dir, modalities,
and file_names at runtime; ensure the command or args use the same variable
names as defined in the grid section and validate the YAML syntax after
updating.
| - --file-names | ||
| - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
OUTCOME parquet mismatch (outc_6 vs outc_9).
Unify to one value; prefer using ${file_names} so the parameters block controls it.
Also applies to: 61-62
🏁 Script executed:
#!/bin/bash
set -e
file=$(fd -g 'benchmark_cass_explain_modalities.yml')
echo "=== File path: $file ==="
rg -n '"DYNAMIC"' "$file"
rg -n 'file_names' "$file"
echo "=== Top of file (1-80) ==="
sed -n '1,80p' "$file"
echo "=== Mid of file (60-100) ==="
sed -n '60,100p' "$file"Length of output: 5497
Use ${file_names} in --file-names flag
Change in experiments/benchmark_cass_explain_modalities.yml lines 21–22:
- --file-names
- '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"'
to
- --file-names
- ${file_names}
This aligns the command with the parameters block’s value (currently outc_9.parquet at lines 59–62).
🤖 Prompt for AI Agents
In experiments/benchmark_cass_explain_modalities.yml around lines 21 to 22,
replace the hard-coded file-names string with the parameter reference: change
the second list item from the literal
'"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' to
${file_names} so the --file-names flag uses the pipeline parameter (which is set
later to outc_9.parquet) instead of a hardcoded value.
| command: | ||
| - ${env} | ||
| - ${program} | ||
| - train | ||
| - -d | ||
| - ../data/ | ||
| - -t | ||
| # - BinaryClassification | ||
| - CassClassification | ||
| - --log-dir | ||
| - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs | ||
| # - --tune | ||
| - --wandb-sweep | ||
| - -tn | ||
| - SSI | ||
| - --hp-checkpoint | ||
| - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-20T18-29-31.571508/hyperparameter_tuning_logs.db" | ||
| # - --verbose | ||
| - --modalities | ||
| - "all" | ||
| - --file-names | ||
| - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' | ||
| # - --explain_features | ||
| # - --verbose | ||
| - --load_data_vars | ||
| method: grid |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure sweep parameters are injected: add ${args}.
- --load_data_vars
+ - ${args}
method: grid📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| command: | |
| - ${env} | |
| - ${program} | |
| - train | |
| - -d | |
| - ../data/ | |
| - -t | |
| # - BinaryClassification | |
| - CassClassification | |
| - --log-dir | |
| - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs | |
| # - --tune | |
| - --wandb-sweep | |
| - -tn | |
| - SSI | |
| - --hp-checkpoint | |
| - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-20T18-29-31.571508/hyperparameter_tuning_logs.db" | |
| # - --verbose | |
| - --modalities | |
| - "all" | |
| - --file-names | |
| - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' | |
| # - --explain_features | |
| # - --verbose | |
| - --load_data_vars | |
| method: grid | |
| - --load_data_vars | |
| - ${args} | |
| method: grid |
🤖 Prompt for AI Agents
In experiments/benchmark_cass_modalities.yml around lines 1 to 26, the command
list does not include the ${args} placeholder so sweep parameters from the
runner won't be injected; add a '- ${args}' entry into the command array
(typically immediately after '- ${program}') so wandb-sweep and other runtime
arguments are passed through to the train invocation.
| n_samples = min(1000, len(train_data)) | ||
| indices = np.random.choice(len(train_data), size=n_samples, replace=False) | ||
| background_sample = train_data[indices] | ||
| self.explainer = shap.TreeExplainer( | ||
| self.model, background_sample, feature_perturbation="interventional", model_output="probability" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Gate SHAP explainer creation behind explain_features to avoid overhead.
Currently always building SHAP explainer; compute only when requested.
- n_samples = min(1000, len(train_data))
- indices = np.random.choice(len(train_data), size=n_samples, replace=False)
- background_sample = train_data[indices]
- self.explainer = shap.TreeExplainer(
- self.model, background_sample, feature_perturbation="interventional", model_output="probability"
- )
+ if getattr(self, "explain_features", False):
+ n_samples = min(1000, len(train_data))
+ indices = np.random.choice(len(train_data), size=n_samples, replace=False)
+ background_sample = train_data[indices]
+ self.explainer = shap.TreeExplainer(
+ self.model, background_sample, feature_perturbation="interventional", model_output="probability"
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| n_samples = min(1000, len(train_data)) | |
| indices = np.random.choice(len(train_data), size=n_samples, replace=False) | |
| background_sample = train_data[indices] | |
| self.explainer = shap.TreeExplainer( | |
| self.model, background_sample, feature_perturbation="interventional", model_output="probability" | |
| ) | |
| if getattr(self, "explain_features", False): | |
| n_samples = min(1000, len(train_data)) | |
| indices = np.random.choice(len(train_data), size=n_samples, replace=False) | |
| background_sample = train_data[indices] | |
| self.explainer = shap.TreeExplainer( | |
| self.model, | |
| background_sample, | |
| feature_perturbation="interventional", | |
| model_output="probability" | |
| ) |
🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/xgboost.py around lines 51 to 56, the SHAP
TreeExplainer is always constructed which adds unnecessary overhead; change this
so the explainer is only created when explain_features is requested — either
wrap the current explainer construction in an if self.explain_features: guard or
remove it from initialization and lazily instantiate and cache self.explainer
inside the explain_features method the first time it is called (use the same
background_sample selection logic), ensuring subsequent calls reuse the cached
explainer.
| def __init__(self, *args, **kwargs): | ||
| # self.model = self.set_model_args( | ||
| # xgb, *args, **kwargs, eval_metric="logloss", tree_method="hist", device="cuda", verbosity=0 | ||
| # ) | ||
| self.model = xgb | ||
| self.params = { | ||
| "eval_metric": "logloss", | ||
| "tree_method": "hist", | ||
| "device": "cuda", | ||
| "verbosity": 0, | ||
| **kwargs, | ||
| } | ||
| super().__init__(*args, **kwargs) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
GPU init sets model to the module — make intent explicit.
Keep params but initialize model as None; a Booster will be assigned in fit_model.
- self.model = xgb
+ self.model = None
self.params = {📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def __init__(self, *args, **kwargs): | |
| # self.model = self.set_model_args( | |
| # xgb, *args, **kwargs, eval_metric="logloss", tree_method="hist", device="cuda", verbosity=0 | |
| # ) | |
| self.model = xgb | |
| self.params = { | |
| "eval_metric": "logloss", | |
| "tree_method": "hist", | |
| "device": "cuda", | |
| "verbosity": 0, | |
| **kwargs, | |
| } | |
| super().__init__(*args, **kwargs) | |
| def __init__(self, *args, **kwargs): | |
| # self.model = self.set_model_args( | |
| # xgb, *args, **kwargs, eval_metric="logloss", tree_method="hist", device="cuda", verbosity=0 | |
| # ) | |
| self.model = None | |
| self.params = { | |
| "eval_metric": "logloss", | |
| "tree_method": "hist", | |
| "device": "cuda", | |
| "verbosity": 0, | |
| **kwargs, | |
| } | |
| super().__init__(*args, **kwargs) |
🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/xgboost.py around lines 92 to 105, the GPU
init currently assigns self.model = xgb (the module) which is ambiguous; change
it to initialize self.model = None and retain self.params as-is so that a proper
xgboost.Booster or estimator is created and assigned later in fit_model; update
any comments to state that a Booster will be created during fit_model and ensure
subsequent code checks self.model is None before using it.
| def predict(self, features): | ||
| """ | ||
| Predicts class probabilities for the given features. | ||
| Args: | ||
| features: Input features for prediction. | ||
| Returns: | ||
| numpy.ndarray: Predicted probabilities for each class. | ||
| """ | ||
| return self.model.predict(xgb.DMatrix(features)) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
GPU predict: handle untrained state and return class probabilities.
Return Nx2 probabilities for binary tasks and add basic error handling.
- def predict(self, features):
+ def predict(self, features):
"""
Predicts class probabilities for the given features.
@@
- return self.model.predict(xgb.DMatrix(features))
+ if self.model is None:
+ raise RuntimeError("Model not trained. Call fit_model() first.")
+ try:
+ preds = self.model.predict(xgb.DMatrix(features)) # shape: (n,) for binary
+ except Exception as e:
+ logging.error(f"GPU prediction failed: {e}")
+ raise
+ # Convert to Nx2 proba for classification wrappers
+ if preds.ndim == 1:
+ return np.vstack([1.0 - preds, preds]).T
+ return preds📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def predict(self, features): | |
| """ | |
| Predicts class probabilities for the given features. | |
| Args: | |
| features: Input features for prediction. | |
| Returns: | |
| numpy.ndarray: Predicted probabilities for each class. | |
| """ | |
| return self.model.predict(xgb.DMatrix(features)) | |
| def predict(self, features): | |
| """ | |
| Predicts class probabilities for the given features. | |
| Args: | |
| features: Input features for prediction. | |
| Returns: | |
| numpy.ndarray: Predicted probabilities for each class. | |
| """ | |
| if self.model is None: | |
| raise RuntimeError("Model not trained. Call fit_model() first.") | |
| try: | |
| preds = self.model.predict(xgb.DMatrix(features)) # shape: (n,) for binary | |
| except Exception as e: | |
| logging.error(f"GPU prediction failed: {e}") | |
| raise | |
| # Convert to Nx2 proba for classification wrappers | |
| if preds.ndim == 1: | |
| return np.vstack([1.0 - preds, preds]).T | |
| return preds |
icu_benchmarks/models/wrappers.py
Outdated
| def setup_calibration(self, val_data, val_labels, method='isotonic'): | ||
| """ | ||
| Setup model calibration using validation data for better probability estimates. | ||
| Args: | ||
| val_data: Validation features for calibration | ||
| val_labels: Validation labels for calibration | ||
| method: 'isotonic' or 'sigmoid' calibration method | ||
| Returns: | ||
| float: Validation loss after calibration | ||
| """ | ||
| if self.run_mode != RunMode.classification: | ||
| logging.warning("Calibration only supported for classification tasks") | ||
| return None | ||
|
|
||
| logging.info(f"Applying {method} calibration using validation data") | ||
|
|
||
| # Create calibrated version using validation data as holdout set | ||
| self.calibrated_model = CalibratedClassifierCV( | ||
| self.model, | ||
| method=method, | ||
| cv='prefit' # Use prefit model with holdout validation set | ||
| ) | ||
| self.calibrated_model.fit(val_data, val_labels) | ||
|
|
||
| # Calculate calibrated validation score | ||
| cal_val_pred = self.calibrated_model.predict_proba(val_data) | ||
| cal_val_loss = self.loss(val_labels, cal_val_pred) | ||
|
|
||
| logging.info(f"Calibration complete. Original val loss: {self.loss(val_labels, self.predict(val_data)):.4f}, " | ||
| f"Calibrated val loss: {cal_val_loss:.4f}") | ||
|
|
||
| return cal_val_loss | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Calibration: handle multiclass 'isotonic' case and log true pre-calibration loss.
Isotonic isn’t supported for multiclass; current code may throw. Also “Original val loss” is computed with calibrated preds.
- def setup_calibration(self, val_data, val_labels, method='isotonic'):
+ def setup_calibration(self, val_data, val_labels, method='isotonic'):
@@
- if self.run_mode != RunMode.classification:
+ if self.run_mode != RunMode.classification:
logging.warning("Calibration only supported for classification tasks")
return None
+ n_classes = np.unique(val_labels).size
+ if n_classes > 2 and method == 'isotonic':
+ logging.warning("Isotonic calibration is not supported for multiclass; falling back to 'sigmoid'.")
+ method = 'sigmoid'
@@
- logging.info(f"Calibration complete. Original val loss: {self.loss(val_labels, self.predict(val_data)):.4f}, "
+ logging.info(f"Calibration complete. Original val loss: {self.loss(val_labels, self.predict(val_data, use_calibrated=False)):.4f}, "
f"Calibrated val loss: {cal_val_loss:.4f}")Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In icu_benchmarks/models/wrappers.py around lines 493 to 527, the calibration
code can break because isotonic calibration is not supported for multiclass
problems and the "Original val loss" log is computed after replacing the model
so it shows calibrated predictions; fix by (1) before fitting, compute and store
original_val_pred = self.predict(val_data) and original_val_loss =
self.loss(val_labels, original_val_pred) and use those in the log; (2) detect
multiclass (e.g. >2 unique labels) and if method == 'isotonic' either switch
method to 'sigmoid' or raise/log a clear warning and set method='sigmoid' so
CalibratedClassifierCV is not invoked with unsupported method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (13)
experiments/benchmark_organsystems.yml (5)
59-61: Param name likely unused: use_pretrained_imputation ≠ pretrained_imputation.run.py reads args.pretrained_imputation; this sweep defines use_pretrained_imputation and never passes a corresponding CLI flag. This will be ignored.
Consider:
- use_pretrained_imputation: + pretrained_imputation: values: - - None + - NoneOptionally pass explicitly in command when set:
+ - --pretrained-imputation + - None
5-6: Duplicate/conflicting data_dir sources.You hardcode “-d ../data/” in command while also sweeping data_dir in parameters. Depending on the sweep runner/argparse resolution, the earlier value may be overridden or cause confusion.
Remove the stub and rely on the sweep parameter:
- - -d - - ../data/Also applies to: 29-35
17-17: hp-checkpoint path may mismatch dataset/task.The checkpoint points to a different cohort (“lab_set_normal_ward…”) while the sweep targets organ_cohorts. If the checkpoint encodes dataset-specific search space or metrics, reuse could be invalid.
Parameterize this path so it can be swapped per organ system, or drop it if not intended.
11-11: Hard-coded absolute paths reduce portability.Absolute cluster paths (/sc-projects/…) in log-dir and data_dir values make this sweep non-portable.
- Use env interpolation (e.g., ${env:DATA_ROOT}) or anchor variables.
- Expose log-dir as a parameter.
Example:
- - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs + - ${env:YAIB_LOG_DIR}and
- - "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/organ_cohorts_.../Pancreas" + - "${env:DATA_ROOT}/organ_cohorts_.../Pancreas"Also applies to: 31-35
62-62: Add trailing newline.YAMLlint flags missing newline at EOF.
-program: icu-benchmarks +program: icu-benchmarks +experiments/benchmark_cass_outcome_time.yml (7)
5-6: Duplicate data_dir (CLI vs sweep params) can cause drift.You set -d ../data/ in the command but also sweep over data_dir. Remove the CLI arg and let the sweep drive it.
- - -d - - ../data/Also applies to: 44-47
11-11: Parameterize cluster-specific absolute paths via env vars.Hardcoded /sc-projects/... paths reduce portability. Use env-substitution so agents can inject paths.
- --log-dir - - /sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs + - ${env:YAIB_LOG_DIR} @@ - --hp-checkpoint - - "/sc-projects/sc-proj-cc08-cassandra/RW_Prospective/yaib_logs/lab_set_normal_ward_SSI-3_POPF_Galleleck_BDA_segment_1.0_horizon_6:00:00_transfer_full_2025-07-14T13:20:08/SSI/XGBClassifier/2025-07-20T18-29-31.571508/hyperparameter_tuning_logs.db" + - ${env:YAIB_HP_CHECKPOINT}Also applies to: 17-17
79-81: Use YAML null instead of Python None string in sweeps.W&B/YAML interpret null, not Python’s None literal. This avoids passing the string "None".
use_pretrained_imputation: values: - - None + - null
34-35: De-duplicate modalities (CLI vs sweep param).Keep one source (recommend: sweep param) to avoid future drift.
- - --modalities - - "all"Also applies to: 61-64
13-15: Add a sweep metric to enable ranking/early stopping dashboards.Grid doesn’t require a metric, but specifying one improves reporting. Use the exact key your trainer logs (e.g., val_auroc).
Proposed addition (adjust name to your logged key):
method: grid +metric: + name: val_auroc + goal: maximize name: yaib_classification_benchmarkAlso applies to: 41-43
66-76: Consider native YAML mappings for file_names values.Less quoting and no custom parsing needed. Only do this if the code accepts dicts from W&B config.
file_names: values: - - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_1.parquet","STATIC":"sta.parquet"' - - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_2.parquet","STATIC":"sta.parquet"' + - {DYNAMIC: dyn.parquet, OUTCOME: outc_1.parquet, STATIC: sta.parquet} + - {DYNAMIC: dyn.parquet, OUTCOME: outc_2.parquet, STATIC: sta.parquet}
1-4: Minor sweep ergonomics + lint fix.
- Consider adding ${interpreter} before ${program} if icu-benchmarks isn’t an installed console script on all agents.
- Add a trailing newline to satisfy YAMLlint.
Also applies to: 82-82
icu_benchmarks/models/__init__.py (1)
34-36: Keep RUSBClassifier in the classifier union for type hints and toolingIf RUSBClassifier remains supported, include it in MLModelClassifier to keep annotations accurate.
MLModelClassifier = Union[ XGBClassifier, LGBMClassifier, + RUSBClassifier, XGBEnsembleClassifier, BRFClassifier, CBClassifier,
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (4)
configs/prediction_models/RUSBClassifier.gin(1 hunks)experiments/benchmark_cass_outcome_time.yml(1 hunks)experiments/benchmark_organsystems.yml(1 hunks)icu_benchmarks/models/__init__.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
icu_benchmarks/models/__init__.py (1)
icu_benchmarks/models/ml_models/imblearn.py (3)
BRFClassifier(12-17)XGBEnsembleClassifier(21-57)RUSBClassifier(17-22)
configs/prediction_models/RUSBClassifier.gin (1)
icu_benchmarks/models/ml_models/imblearn.py (2)
RUSBClassifier(17-22)__init__(20-22)
experiments/benchmark_organsystems.yml (1)
icu_benchmarks/run.py (1)
main(33-198)
🪛 YAMLlint (1.37.1)
experiments/benchmark_organsystems.yml
[error] 62-62: no new line character at the end of file
(new-line-at-end-of-file)
experiments/benchmark_cass_outcome_time.yml
[error] 82-82: no new line character at the end of file
(new-line-at-end-of-file)
🔇 Additional comments (10)
experiments/benchmark_organsystems.yml (3)
14-15: Check short flag -tn and task name.
- Verify that -tn is a valid alias for task-name (run.py shows args.task_name but parser isn’t shown).
- Ensure "CassClassification" exists under configs/tasks/.gin.
Would you like me to scan the repo to confirm the parser flags and task file presence?
Also applies to: 8-10
25-25: --load_data_vars has no value — confirm it’s a boolean flag.If the parser expects an argument (path or list), this will error. If it’s store_true, fine.
Point me to the argparse definition if you want me to validate and fix.
13-13: Confirm wandb sweep integration mode.You’re passing --wandb-sweep and also using a sweep YAML. Ensure this runner expects both; some setups rely solely on W&B agent to inject args without an internal --wandb-sweep switch.
If redundant, drop the flag to avoid mode drift.
experiments/benchmark_cass_outcome_time.yml (2)
40-40: Confirm --load_data_vars is a boolean flag.If it expects a value, passing it bare will error. If it’s a store_true flag, all good.
36-37: Drop the--file-namesCLI flag to avoid conflicts with the OUTCOME sweep paramThe CLI pinning of
"OUTCOME":"outc_6.parquet"conflicts with the sweep overOUTCOME={outc_1,outc_2}. Remove these lines at 36–37 (and the analogous block at 66–76) so the sweep parameter solely controls file names:- - --file-names - - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"'icu_benchmarks/models/__init__.py (1)
34-35: XGBEnsembleClassifier.fit_model can raise UnboundLocalError on failureThe referenced implementation returns
val_lossin theexceptblock without initialization, causing UnboundLocalError if training fails. Also logs then swallows the error, which hides hard failures.Patch in icu_benchmarks/models/ml_models/imblearn.py:
def fit_model(self, train_data, train_labels, val_data, val_labels): - import xgboost as xgb - from sklearn.metrics import log_loss + from sklearn.metrics import log_loss + val_loss = float("inf") # Try XGBoost first with provided config try: - self.model.fit(train_data, train_labels,) - + self.model.fit(train_data, train_labels) val_pred_proba = self.model.predict_proba(val_data) val_loss = log_loss(val_labels, val_pred_proba) logging.info(f"XGBoost model trained successfully. Validation loss: {val_loss:.4f}") - return val_loss - except Exception as e: - logging.warning(f"XGBoost failed: {e}") - - return val_loss + logging.warning(f"XGBoost failed during fit/eval: {e}") + # optionally: raise to fail fast + # raise + return val_lossLikely an incorrect or invalid review comment.
configs/prediction_models/RUSBClassifier.gin (4)
18-20: Validate sampling_strategy values against current imbalanced-learn"auto", "majority", "not minority" are valid historically, but versions differ. Given deps were bumped, confirm these map 1:1 in the installed imbalanced-learn.
21-25: Check param namespace for base estimator (base_estimator__*) after sklearn/imbalanced-learn renamesRecent versions migrated from base_estimator -> estimator in several APIs. Ensure RUSBoostClassifier still exposes
base_estimator__*rather thanestimator__*, or your grid won’t bind.
26-27: Ensurereplacementis a valid top-level RUSBoost parameterDepending on the version,
replacementmay belong to the internal sampler; some releases surface it, others don’t. Verify it is accepted by RUSBoostClassifier in your pinned version; otherwise expose via a sampler object or drop it.
12-17: Tuple “(min, max, 'log')” is explicitly handled in Optuna binding (it checks for len == 3 and sets log=True if value[2] == "log"). Lists fall back to categorical sampling (using trial.suggest_categorical). Your config patterns match those handlers, so(0.01, 2.0, "log")and tuples are supported; no change needed.
| train_common.model = @RUSBClassifier | ||
|
|
||
| # Hyperparameter tuning configuration | ||
| model/hyperparameter.class_to_tune = @RUSBClassifier | ||
| model/hyperparameter.n_estimators = (10, 50, 100, 200, 500) | ||
| model/hyperparameter.learning_rate = (0.005, 1, "log") | ||
| model/hyperparameter.sampling_strategy = "auto" | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Align model name with public API choice (RUSB vs XGBEnsemble)
You removed RUSBClassifier from the public API but this config still binds train_common.model and class_to_tune to RUSBClassifier. Choose one:
- Keep RUSBClassifier supported (re-export in init.py as suggested), or
- Migrate this config to XGBEnsembleClassifier and rename the file accordingly.
If migrating:
-train_common.model = @RUSBClassifier
-model/hyperparameter.class_to_tune = @RUSBClassifier
+train_common.model = @XGBEnsembleClassifier
+model/hyperparameter.class_to_tune = @XGBEnsembleClassifier🏁 Script executed:
#!/usr/bin/env bash
set -e
echo "Finding RUSBClassifier class definitions..."
rg -nP "class\s+RUSBClassifier\b" .
echo "Searching for RUSBClassifier exports in __init__.py..."
rg -n "RUSBClassifier" -g "**/__init__.py" -n
echo "Finding XGBEnsembleClassifier class definitions..."
rg -nP "class\s+XGBEnsembleClassifier\b" .
echo "Searching for XGBEnsembleClassifier references..."
rg -n "XGBEnsembleClassifier" -g "**/*.py" -nLength of output: 143
🏁 Script executed:
#!/usr/bin/env bash
set -e
echo "Searching for any occurrences of RUSBClassifier..."
rg -n "RUSBClassifier" .
echo "Searching for any occurrences of XGBEnsembleClassifier..."
rg -n "XGBEnsembleClassifier" .
echo "Searching for filenames containing RUSBClassifier or XGBEnsembleClassifier..."
fd --hidden --no-ignore RUSBClassifier
fd --hidden --no-ignore XGBEnsembleClassifierLength of output: 2022
Align config with public API—migrate or re-export RUSBClassifier
configs/prediction_models/RUSBClassifier.gin still references the removed RUSBClassifier class; either re-export RUSBClassifier in icu_benchmarks/models/init.py or update this config (and filename) to use XGBEnsembleClassifier.
If migrating:
-train_common.model = @RUSBClassifier
-model/hyperparameter.class_to_tune = @RUSBClassifier
+train_common.model = @XGBEnsembleClassifier
+model/hyperparameter.class_to_tune = @XGBEnsembleClassifier📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| train_common.model = @RUSBClassifier | |
| # Hyperparameter tuning configuration | |
| model/hyperparameter.class_to_tune = @RUSBClassifier | |
| model/hyperparameter.n_estimators = (10, 50, 100, 200, 500) | |
| model/hyperparameter.learning_rate = (0.005, 1, "log") | |
| model/hyperparameter.sampling_strategy = "auto" | |
| train_common.model = @XGBEnsembleClassifier | |
| # Hyperparameter tuning configuration | |
| model/hyperparameter.class_to_tune = @XGBEnsembleClassifier |
🤖 Prompt for AI Agents
In configs/prediction_models/RUSBClassifier.gin around lines 7 to 11, the config
references the removed RUSBClassifier; either re-export RUSBClassifier from
icu_benchmarks/models/__init__.py or update this config (and its filename) to
use the current public API XGBEnsembleClassifier. If you choose re-exporting,
add an import and export of RUSBClassifier in icu_benchmarks/models/__init__.py
so existing gin configs keep working. If you choose migration, rename the file
to match the new model (e.g., XGBEnsembleClassifier.gin), replace both
occurrences of RUSBClassifier with XGBEnsembleClassifier, and verify any
hyperparameter names match the XGBEnsembleClassifier API.
| - --file-names | ||
| - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
file_names appears malformed (missing braces) — likely to break parse_dict.
The dictionary-like string is missing outer {} both in the command and in parameters. Most dict parsers (incl. typical CLI "parse_dict") expect a braced JSON-like payload.
Apply this diff:
- - --file-names
- - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"'
+ - --file-names
+ - '{"DYNAMIC":"dyn.parquet","OUTCOME":"outc_6.parquet","STATIC":"sta.parquet"}'And for parameters:
- - '"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"'
+ - '{"DYNAMIC":"dyn.parquet","OUTCOME":"outc_9.parquet","STATIC":"sta.parquet"}'Also applies to: 53-55
🤖 Prompt for AI Agents
In experiments/benchmark_organsystems.yml around lines 21-22 (and similarly
lines 53-55) the file_names value is a dictionary-like string missing outer
braces which will break parse_dict; wrap the key/value list with curly braces so
it becomes a proper dict string (e.g. add leading "{" and trailing "}" around
the quoted pairs) in both the command and parameter sections, ensuring quoting
remains valid for the CLI parser.
| from icu_benchmarks.models.dl_models.transformer import BaseTransformer, LocalTransformer, Transformer | ||
| from icu_benchmarks.models.ml_models.catboost import CBClassifier | ||
| from icu_benchmarks.models.ml_models.imblearn import BRFClassifier, RUSBClassifier | ||
| from icu_benchmarks.models.ml_models.imblearn import BRFClassifier, XGBEnsembleClassifier |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Re-export RUSBClassifier to avoid breaking configs referencing @RUSBClassifier
configs/prediction_models/RUSBClassifier.gin still points to RUSBClassifier; removing it from the public API will likely break resolution if the gin loader relies on from icu_benchmarks.models import *. Re-expose RUSBClassifier alongside XGBEnsembleClassifier, or update the configs to the new class. Keeping both exported is the lowest-risk path.
Apply:
-from icu_benchmarks.models.ml_models.imblearn import BRFClassifier, XGBEnsembleClassifier
+from icu_benchmarks.models.ml_models.imblearn import BRFClassifier, RUSBClassifier, XGBEnsembleClassifierRun to find remaining references that rely on the top-level export:
🏁 Script executed:
#!/bin/bash
rg -nC2 -S '\bRUSBClassifier\b' configs/ icu_benchmarks/ | sed -n '1,200p'Length of output: 640
Re-export RUSBClassifier in icu_benchmarks/models/init.py
Add RUSBClassifier alongside existing imports to preserve compatibility with configs/prediction_models/RUSBClassifier.gin:
-from icu_benchmarks.models.ml_models.imblearn import BRFClassifier, XGBEnsembleClassifier
+from icu_benchmarks.models.ml_models.imblearn import BRFClassifier, RUSBClassifier, XGBEnsembleClassifier📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from icu_benchmarks.models.ml_models.imblearn import BRFClassifier, XGBEnsembleClassifier | |
| from icu_benchmarks.models.ml_models.imblearn import BRFClassifier, RUSBClassifier, XGBEnsembleClassifier |
🤖 Prompt for AI Agents
In icu_benchmarks/models/__init__.py around line 7, the module re-exports only
BRFClassifier and XGBEnsembleClassifier but needs to also expose RUSBClassifier
for compatibility with configs/prediction_models/RUSBClassifier.gin; update the
import statement to include RUSBClassifier from
icu_benchmarks.models.ml_models.imblearn so it is re-exported by this package
(ensure the symbol name matches the implementation in ml_models/imblearn).
| "LocalTransformer", | ||
| "CBClassifier", | ||
| "RUSBClassifier", | ||
| "XGBEnsembleClassifier", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add RUSBClassifier back to all for star-import/generation and docs
This restores backward-compatible public exposure used by configs and (likely) Sphinx autodoc.
__all__ = [
"GRUNet",
"RNNet",
"LSTMNet",
"TemporalConvNet",
"BaseTransformer",
"Transformer",
"LocalTransformer",
"CBClassifier",
"XGBEnsembleClassifier",
"BRFClassifier",
+ "RUSBClassifier",
"LGBMClassifier",
"LGBMRegressor",
"XGBClassifier",📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| "XGBEnsembleClassifier", | |
| __all__ = [ | |
| "GRUNet", | |
| "RNNet", | |
| "LSTMNet", | |
| "TemporalConvNet", | |
| "BaseTransformer", | |
| "Transformer", | |
| "LocalTransformer", | |
| "CBClassifier", | |
| "XGBEnsembleClassifier", | |
| "BRFClassifier", | |
| "RUSBClassifier", | |
| "LGBMClassifier", | |
| "LGBMRegressor", | |
| "XGBClassifier", | |
| # ... other entries ... | |
| ] |
🤖 Prompt for AI Agents
In icu_benchmarks/models/__init__.py around line 60, the exported names list
__all__ no longer includes "RUSBClassifier"; add the string "RUSBClassifier"
back into the __all__ list (as an entry with proper quotes and a trailing comma)
so star-imports, configs and Sphinx autodoc continue to expose that public
symbol.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
icu_benchmarks/data/split_process_data.py (4)
169-176: Modality selection result is dropped (assigned todata, notsanitized_data).You pass
sanitized_datain but assign todata, so subsequent steps ignore the filtered frames.- data, vars = modality_selection(sanitized_data, modality_mapping, selected_modalities, vars) + sanitized_data, vars = modality_selection(sanitized_data, modality_mapping, selected_modalities, vars)Also applies to: 173-173
203-205: Use sanitized data for full-train path.The complete-train branch bypasses sanitation/modality filtering by using
data.- sanitized_data = make_train_val_polars(data, vars, train_size=None, seed=seed, debug=debug, runmode=runmode) + sanitized_data = make_train_val_polars(sanitized_data, vars, train_size=None, seed=seed, debug=debug, runmode=runmode)
265-268: Cache the processed artifact, not the raw input.
caching(..., data, ...)writes the unprocessed dict, inconsistent withload_cacheexpectations.- caching(cache_dir, cache_file, data, load_cache) + caching(cache_dir, cache_file, sanitized_data, load_cache)
621-629: Missing random_state compromises reproducibility.Add
random_state=seedto all Shuffle/StratifiedShuffle splitters.- outer_cv = StratifiedShuffleSplit(cv_repetitions, train_size=train_size) + outer_cv = StratifiedShuffleSplit(cv_repetitions, train_size=train_size, random_state=seed) @@ - inner_cv = StratifiedShuffleSplit(cv_folds, train_size=0.75, random_state=seed) + inner_cv = StratifiedShuffleSplit(cv_folds, train_size=0.75, random_state=seed) @@ - outer_cv = ShuffleSplit(cv_repetitions, train_size=train_size) + outer_cv = ShuffleSplit(cv_repetitions, train_size=train_size, random_state=seed)Also applies to: 636-640
♻️ Duplicate comments (12)
icu_benchmarks/data/split_process_data.py (4)
285-315: Local check_sanitize_data diverges from utils (drops infinite-removal and column pruning).This regression removes
infinite_removal(...)and HRV column filtering present in utils, changing pipeline semantics.Either delete this override to use utils, or feature-parity it with utils.
-def check_sanitize_data(data: dict[str, pl.DataFrame], vars: dict[str, str | list[str]]) -> dict[str, pl.DataFrame]: # noqa: F811 +def check_sanitize_data(data: dict[str, pl.DataFrame], vars: dict[str, str | list[str]]) -> dict[str, pl.DataFrame]: # noqa: F811 @@ - if DataSegment.dynamic in data.keys(): + if DataSegment.dynamic in data.keys(): old_len = len(data[DataSegment.dynamic]) - - data[DataSegment.dynamic] = data[DataSegment.dynamic].unique( + data[DataSegment.dynamic] = data[DataSegment.dynamic].unique( subset=flatten_column_names(group, sequence), keep=keep, maintain_order=True ) + # Parity with utils.check_sanitize_data + from icu_benchmarks.data.utils import infinite_removal + data[DataSegment.dynamic] = infinite_removal(data[DataSegment.dynamic]) + data[DataSegment.dynamic] = data[DataSegment.dynamic].select(~cs.starts_with("wearable_ppgfeature_HRV_SampEn")) + vars[DataSegment.dynamic] = [c for c in vars.get(DataSegment.dynamic, []) if c in data[DataSegment.dynamic].columns]Or remove this function and rely on the imported one.
317-357: Redundant modality_selection override; prefer shared implementation.Duplicating logic increases drift risk. If specialization is needed, rename and document delta; otherwise, drop this override.
-def modality_selection( # noqa: F811 +def modality_selection( # noqa: F811 @@ - return data, vars + return data, varsRecommendation: delete this function and import from utils, or rename to
_modality_selection_localand only call it when needed.
23-23: Remove import or local redefinitions to avoid F811/behavior drift.You import
check_sanitize_data, modality_selectionbut later redefine both (see Lines 285–315 and 317–357). This shadows the imports and changes behavior from utils (e.g., missing infinite-removal), risking regressions.Apply one of:
- Prefer utils: delete the local function defs.
- Or keep locals: remove this import and rename functions to avoid confusion.
-from .utils import check_sanitize_data, modality_selection +from .utils import check_sanitize_data, modality_selection +# If keeping local versions, remove this import and rename locals to e.g. +# _check_sanitize_data_local, _modality_selection_local
91-91: Restore “SEQUENCE” to required_var_types (or guard all downstream uses).Defaults exclude
"SEQUENCE", but the code expects it (e.g., modality selection, splitting, reduction) and will raise at runtime when missing.- if required_var_types is None: - required_var_types = ["GROUP", "LABEL"] + if required_var_types is None: + required_var_types = ["GROUP", "LABEL", "SEQUENCE"]icu_benchmarks/models/wrappers.py (8)
405-410: Remove duplicate variable assignmentThe
valid_idxvariable is computed twice with identical logic on lines 406 and 409.# valid_idx is (batch, seq_len) boolean mask valid_idx = mask.detach().cpu().numpy().astype(bool) # indicators is (batch, seq_len, ...) or (batch, seq_len, 2) - indicators_np = indicators.detach().cpu().numpy()[valid_idx] - predictors = prediction.detach().cpu() - target_np = target.detach().cpu().numpy() + indicators_np = indicators.detach().cpu().numpy()[valid_idx] + predictors = prediction.detach().cpu() + target_np = target.detach().cpu().numpy()
50-52: Debug flag should default toFalsefor productionThe debug flag is set to
Trueby default, which causes model outputs to be saved automatically. This is inconsistent with typical CLI defaults (which useFalse) and can lead to unexpected disk usage in production environments.Apply this diff to fix the default value:
- debug = True + debug = False # We do not want to explain features by default as it is expensive (and not needed during hp tuning) explain_features = FalseAdditionally, ensure that the debug flag is properly propagated from runtime arguments when models are instantiated.
155-155: Fix logging inconsistencySame logging issue as above - the message refers to "row_indicators.csv" instead of "pred_indicators.csv".
- logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / 'row_indicators.csv'}") + logging.debug(f"Saved pred indicators to {Path(self.logger.save_dir) / 'pred_indicators.csv'}")
451-451: Guard_save_model_outputscall with debug flagThe
_save_model_outputsmethod is called unconditionally inDLPredictionWrapper.step_fn, which could cause unnecessary I/O operations.# Save predictions to file - self._save_model_outputs(indicators_np, transformed_predictors, target_np) + if self.debug: + self._save_model_outputs(indicators_np, transformed_predictors, target_np)
117-134: Fix logging message and format string mismatchLine 134 logs "row_indicators.csv" but the actual file saved is "pred_indicators.csv". Additionally, the format string needs to handle both classification and regression cases correctly.
- logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / 'row_indicators.csv'}") + logging.debug(f"Saved pred indicators to {Path(self.logger.save_dir) / 'pred_indicators.csv'}")
156-167: Fix logging and remove extraneous f-string prefixLine 167 has an unnecessary nested f-string and incorrect file reference.
- logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / f'row_indicators.csv'}") + logging.debug(f"Saved pred indicators to {Path(self.logger.save_dir) / 'pred_indicators.csv'}")
646-681: Fix logging issue for pre-calibration lossThe "Original val loss" is computed after the model is already replaced, showing calibrated predictions instead of the original uncalibrated ones.
Store the original predictions before calibration:
def setup_calibration(self, val_data, val_labels, method='sigmoid'): """ Setup model calibration using validation data for better probability estimates. Args: val_data: Validation features for calibration val_labels: Validation labels for calibration method: 'isotonic' or 'sigmoid' calibration method Returns: float: Validation loss after calibration """ if self.run_mode != RunMode.classification: logging.warning("Calibration only supported for classification tasks") return None + + # Store original predictions before calibration + original_val_pred = self.model.predict_proba(val_data) + original_val_loss = self.loss(val_labels, original_val_pred) logging.info(f"Applying {method} calibration using validation data") # Create calibrated version using validation data as holdout set self.calibrated_model = CalibratedClassifierCV( self.model, method=method, cv='prefit' # Use prefit model with holdout validation set ) self.calibrated_model.fit(val_data, val_labels) # Calculate calibrated validation score cal_val_pred = self.calibrated_model.predict_proba(val_data) cal_val_loss = self.loss(val_labels, cal_val_pred) logging.info(f"Calibration complete. " - f"Original val loss: {self.loss(val_labels, self.model.predict_proba(val_data)):.4f}, " + f"Original val loss: {original_val_loss:.4f}, " f"Calibrated val loss: {cal_val_loss:.4f}")
465-559: Handle multiclass limitation for isotonic calibrationIsotonic calibration is not supported for multiclass problems. The code should detect this and either fallback to sigmoid or raise a clear warning.
Add multiclass detection in the
setup_calibrationmethod:def setup_calibration(self, val_loader, method='isotonic'): """ Setup model calibration using validation data for better probability estimates. Args: val_loader: Validation dataloader for calibration method: 'isotonic' or 'sigmoid' calibration method Returns: float: Validation loss after calibration """ if self.run_mode != RunMode.classification: logging.warning("Calibration only supported for classification tasks") return None + + # Check if multiclass and isotonic + if hasattr(self, 'logit') and self.logit.out_features > 2 and method == 'isotonic': + logging.warning("Isotonic calibration not supported for multiclass; falling back to sigmoid") + method = 'sigmoid'
🧹 Nitpick comments (7)
icu_benchmarks/data/split_process_data.py (5)
120-135: exclude_preproc typing should be Optional.Use Optional in the signature per PEP 484 and the Ruff hint; also coerce to set to deduplicate.
- exclude_preproc: list[str] = None, + exclude_preproc: Optional[list[str]] = None, @@ - vars_to_exclude = [] + vars_to_exclude: list[str] = [] @@ - for modality in exclude_preproc: + for modality in exclude_preproc or []: if modality in modality_mapping: vars_to_exclude.extend(modality_mapping.get(modality)) @@ - logging.info( - f"Excluding modalities in {exclude_preproc}. Total vars excluded from preprocessing: {len(vars_to_exclude)}" - ) + vars_to_exclude = list(set(vars_to_exclude)) + logging.info(f"Excluding modalities in {exclude_preproc}. Total vars excluded: {len(vars_to_exclude)}")
512-512: Inconsistent train_size defaults across helpers.
make_single_split_pandasdefaults to 0.80, whilemake_single_split_polarsandmake_single_splituseNone. Align for predictable behavior and docs.-def make_single_split_polars( +def make_single_split_polars( @@ - train_size: Optional[float] = None, + train_size: Optional[float] = 0.80,And/or set wrapper default to 0.80.
Also applies to: 591-591, 674-674
841-843: Minor: use iterable unpacking in concat (RUF005).Slightly cleaner; no behavior change.
- data[DataSegment.dynamic] = pl.concat([unchanged_dynamic] + reduced_dynamic_parts).sort( + data[DataSegment.dynamic] = pl.concat([unchanged_dynamic, *reduced_dynamic_parts]).sort( [group_var, sequence_var]) @@ - data[DataSegment.outcome] = pl.concat([unchanged_outcome] + reduced_outcome_parts).sort( + data[DataSegment.outcome] = pl.concat([unchanged_outcome, *reduced_outcome_parts]).sort( [group_var, sequence_var])Also applies to: 863-865
158-159: Style: long exception messages.Ruff TRY003 flags long messages. Consider shorter messages or custom exception types, but this is non-blocking.
Also applies to: 767-768
211-221: Log message says “Dropping columns with nulls” but no drop occurs.Adjust wording to avoid confusion.
- logging.debug("Dropping columns with nulls") + logging.debug("Columns with remaining nulls (no drop performed):")icu_benchmarks/models/wrappers.py (2)
694-700: Remove duplicate calibration flag assignmentLines 694 and 697 both set
calibrate = True, which is redundant.- calibrate = True - method = "sigmoid" - # Apply calibration if desired calibrate = True + method = "sigmoid" if calibrate and self.run_mode == RunMode.classification:
765-766: Clean up commented codeLine 765 contains incomplete commented code that should be removed.
test_pred = self.predict(test_rep) - # test_pred_uncalibrated = self.predict(test_rep, use_calibrated=False)False self.compare_rankings(test_rep, test_label)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
icu_benchmarks/data/split_process_data.py(15 hunks)icu_benchmarks/models/wrappers.py(16 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
icu_benchmarks/data/split_process_data.py (4)
icu_benchmarks/constants.py (1)
RunMode(4-7)icu_benchmarks/data/preprocessor.py (6)
PolarsClassificationPreprocessor(90-268)Preprocessor(36-87)PolarsClassificationPreprocessor(71-246)apply(563-591)apply(102-176)apply(355-402)icu_benchmarks/data/constants.py (3)
DataSegment(7-11)DataSplit(1-4)VarType(14-17)icu_benchmarks/data/utils.py (2)
check_sanitize_data(26-51)modality_selection(54-81)
icu_benchmarks/models/wrappers.py (3)
icu_benchmarks/models/utils.py (2)
log_single_metric_to_file(290-315)name(159-160)icu_benchmarks/constants.py (1)
RunMode(4-7)icu_benchmarks/models/ml_models/lgbm.py (1)
predict(36-46)
🪛 Ruff (0.12.2)
icu_benchmarks/data/split_process_data.py
35-35: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
158-158: Avoid specifying long messages outside the exception class
(TRY003)
767-767: Avoid specifying long messages outside the exception class
(TRY003)
841-841: Consider [unchanged_dynamic, *reduced_dynamic_parts] instead of concatenation
Replace with [unchanged_dynamic, *reduced_dynamic_parts]
(RUF005)
863-863: Consider [unchanged_outcome, *reduced_outcome_parts] instead of concatenation
Replace with [unchanged_outcome, *reduced_outcome_parts]
(RUF005)
icu_benchmarks/models/wrappers.py
167-167: f-string without any placeholders
Remove extraneous f prefix
(F541)
447-447: Avoid specifying long messages outside the exception class
(TRY003)
495-495: Create your own exception
(TRY002)
495-495: Avoid specifying long messages outside the exception class
(TRY003)
818-818: Unused method argument: pred_indicators
(ARG002)
🔇 Additional comments (3)
icu_benchmarks/models/wrappers.py (3)
114-116: Good addition of feature explanation placeholderThe
_explain_modelmethod provides a clear interface for subclasses to implement feature explanation functionality with an appropriate NotImplementedError message.
786-798: Well-implemented calibration fallback logicThe
predictmethod properly handles both calibrated and uncalibrated prediction paths with appropriate fallbacks for both classification and regression modes.
818-829: Good implementation of metric logging to CSVThe
log_curvesmethod properly exports metrics to CSV files using thelog_single_metric_to_fileutility, providing valuable persistence for analysis.
| logging.debug(f"Checking for NaNs and nulls in {data.keys()}.") | ||
| for _dict in sanitized_data.values(): | ||
| for key, val in _dict.items(): | ||
| for key, dataframe in _dict.items(): | ||
| logging.debug(f"Data type: {key}") | ||
| logging.debug("Is NaN:") | ||
| sel = _dict[key].select(pl.selectors.numeric().is_nan().max()) | ||
| logging.debug(sel.select(col.name for col in sel if col.item(0))) | ||
| logging.debug("Has nulls:") | ||
| sel = _dict[key].select(pl.all().has_nulls()) | ||
| logging.debug(sel.select(col.name for col in sel if col.item(0))) | ||
| _dict[key] = val.fill_null(strategy="zero") | ||
| _dict[key] = val.fill_nan(0) | ||
| _dict[key] = dataframe.fill_null(strategy="zero") | ||
| _dict[key] = dataframe.fill_nan(0) | ||
| logging.debug("Dropping columns with nulls") | ||
| sel = _dict[key].select(pl.all().has_nulls()) | ||
| logging.debug(sel.select(col.name for col in sel if col.item(0))) | ||
| logging.debug("Checking for infinite values.") | ||
| for col in dataframe.select(cs.numeric()).columns: | ||
| if dataframe[col].is_infinite().any(): | ||
| logging.warning(f"Column '{col}' contains infinite values. Datatype: {dataframe[col].dtype}") | ||
|
|
||
| max_float64 = 0 | ||
| # Replace infinite values with the maximum value for float64 | ||
| dataframe = dataframe.with_columns( | ||
| [ | ||
| pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col) | ||
| for col in dataframe.columns | ||
| if dataframe[col].dtype == pl.Float64 | ||
| ] | ||
| ) | ||
| _dict[key] = dataframe | ||
| logging.debug(f"Amount of columns: {len(dataframe.columns)}") | ||
| logging.info(f"{len(sanitized_data[DataSplit.train][DataSegment.features].columns)} columns in dynamic data.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix NaN/null/Inf handling; current logic loses earlier fills and replaces Inf with 0.
- Uses stale
dataframeafter assigning to_dict[key], discarding previous fills. - Replaces Inf with
0and only in Float64 columns.
- logging.debug(f"Checking for NaNs and nulls in {data.keys()}.")
+ logging.debug(f"Checking for NaNs and nulls.")
@@
- for key, dataframe in _dict.items():
+ for key, dataframe in _dict.items():
@@
- _dict[key] = dataframe.fill_null(strategy="zero")
- _dict[key] = dataframe.fill_nan(0)
+ # Fill nulls/NaNs in numeric columns
+ dataframe = dataframe.fill_null(strategy="zero").fill_nan(0)
+ _dict[key] = dataframe
@@
- logging.debug("Checking for infinite values.")
- for col in dataframe.select(cs.numeric()).columns:
- if dataframe[col].is_infinite().any():
- logging.warning(f"Column '{col}' contains infinite values. Datatype: {dataframe[col].dtype}")
+ logging.debug("Checking for infinite values.")
+ for col in dataframe.select(cs.numeric()).columns:
+ if dataframe[col].is_infinite().any():
+ logging.warning(f"Column '{col}' contains infinite values. Datatype: {dataframe[col].dtype}")
@@
- max_float64 = 0
- # Replace infinite values with the maximum value for float64
- dataframe = dataframe.with_columns(
- [
- pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col)
- for col in dataframe.columns
- if dataframe[col].dtype == pl.Float64
- ]
- )
+ # Replace +/-inf in all numeric columns with max finite float64
+ import numpy as np
+ max_float64 = np.finfo(np.float64).max
+ num_cols = dataframe.select(cs.numeric()).columns
+ dataframe = dataframe.with_columns([
+ pl.when(pl.col(c).is_infinite()).then(pl.lit(max_float64)).otherwise(pl.col(c)).alias(c)
+ for c in num_cols
+ ])
_dict[key] = dataframe
logging.debug(f"Amount of columns: {len(dataframe.columns)}")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| logging.debug(f"Checking for NaNs and nulls in {data.keys()}.") | |
| for _dict in sanitized_data.values(): | |
| for key, val in _dict.items(): | |
| for key, dataframe in _dict.items(): | |
| logging.debug(f"Data type: {key}") | |
| logging.debug("Is NaN:") | |
| sel = _dict[key].select(pl.selectors.numeric().is_nan().max()) | |
| logging.debug(sel.select(col.name for col in sel if col.item(0))) | |
| logging.debug("Has nulls:") | |
| sel = _dict[key].select(pl.all().has_nulls()) | |
| logging.debug(sel.select(col.name for col in sel if col.item(0))) | |
| _dict[key] = val.fill_null(strategy="zero") | |
| _dict[key] = val.fill_nan(0) | |
| _dict[key] = dataframe.fill_null(strategy="zero") | |
| _dict[key] = dataframe.fill_nan(0) | |
| logging.debug("Dropping columns with nulls") | |
| sel = _dict[key].select(pl.all().has_nulls()) | |
| logging.debug(sel.select(col.name for col in sel if col.item(0))) | |
| logging.debug("Checking for infinite values.") | |
| for col in dataframe.select(cs.numeric()).columns: | |
| if dataframe[col].is_infinite().any(): | |
| logging.warning(f"Column '{col}' contains infinite values. Datatype: {dataframe[col].dtype}") | |
| max_float64 = 0 | |
| # Replace infinite values with the maximum value for float64 | |
| dataframe = dataframe.with_columns( | |
| [ | |
| pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col) | |
| for col in dataframe.columns | |
| if dataframe[col].dtype == pl.Float64 | |
| ] | |
| ) | |
| _dict[key] = dataframe | |
| logging.debug(f"Amount of columns: {len(dataframe.columns)}") | |
| logging.info(f"{len(sanitized_data[DataSplit.train][DataSegment.features].columns)} columns in dynamic data.") | |
| logging.debug(f"Checking for NaNs and nulls.") | |
| for _dict in sanitized_data.values(): | |
| for key, dataframe in _dict.items(): | |
| logging.debug(f"Data type: {key}") | |
| logging.debug("Is NaN:") | |
| sel = _dict[key].select(pl.selectors.numeric().is_nan().max()) | |
| logging.debug(sel.select(col.name for col in sel if col.item(0))) | |
| logging.debug("Has nulls:") | |
| sel = _dict[key].select(pl.all().has_nulls()) | |
| logging.debug(sel.select(col.name for col in sel if col.item(0))) | |
| # Fill nulls/NaNs in numeric columns | |
| dataframe = dataframe.fill_null(strategy="zero").fill_nan(0) | |
| _dict[key] = dataframe | |
| logging.debug("Dropping columns with nulls") | |
| sel = _dict[key].select(pl.all().has_nulls()) | |
| logging.debug(sel.select(col.name for col in sel if col.item(0))) | |
| logging.debug("Checking for infinite values.") | |
| for col in dataframe.select(cs.numeric()).columns: | |
| if dataframe[col].is_infinite().any(): | |
| logging.warning(f"Column '{col}' contains infinite values. Datatype: {dataframe[col].dtype}") | |
| # Replace +/-inf in all numeric columns with max finite float64 | |
| import numpy as np | |
| max_float64 = np.finfo(np.float64).max | |
| num_cols = dataframe.select(cs.numeric()).columns | |
| dataframe = dataframe.with_columns([ | |
| pl.when(pl.col(c).is_infinite()).then(pl.lit(max_float64)).otherwise(pl.col(c)).alias(c) | |
| for c in num_cols | |
| ]) | |
| _dict[key] = dataframe | |
| logging.debug(f"Amount of columns: {len(dataframe.columns)}") | |
| logging.info(f"{len(sanitized_data[DataSplit.train][DataSegment.features].columns)} columns in dynamic data.") |
🤖 Prompt for AI Agents
In icu_benchmarks/data/split_process_data.py around lines 211-242, the code
fills nulls/NaNs into _dict[key] but then continues to use the older dataframe
variable (losing the fills) and replaces infinite values with 0 only for
Float64; fix by performing fills and Inf-replacement on the same dataframe
variable and assign back to _dict[key] once at the end. Specifically: use
dataframe = dataframe.fill_null(strategy="zero") and dataframe =
dataframe.fill_nan(0) (or equivalent) so subsequent checks operate on the
updated DataFrame; when replacing infinities compute a replacement per numeric
float column (either the column's finite max or the dtype max via numpy.finfo)
while ignoring inf/NaN, and apply the replacement to all float dtypes (Float32
and Float64) using dataframe.with_columns([...]); finally assign the fully
sanitized dataframe to _dict[key] and only then run final logging/column checks.
| train_incidence = sanitized_data[DataSplit.test][DataSegment.outcome][vars[VarType.label]] | ||
| val_incidence = sanitized_data[DataSplit.val][DataSegment.outcome][vars[VarType.label]] | ||
| test_incidence = sanitized_data[DataSplit.test][DataSegment.outcome][vars[VarType.label]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrong splits used for incidence stats (train uses test split).
Use the corresponding split for each metric.
- train_incidence = sanitized_data[DataSplit.test][DataSegment.outcome][vars[VarType.label]]
+ train_incidence = sanitized_data[DataSplit.train][DataSegment.outcome][vars[VarType.label]]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| train_incidence = sanitized_data[DataSplit.test][DataSegment.outcome][vars[VarType.label]] | |
| val_incidence = sanitized_data[DataSplit.val][DataSegment.outcome][vars[VarType.label]] | |
| test_incidence = sanitized_data[DataSplit.test][DataSegment.outcome][vars[VarType.label]] | |
| train_incidence = sanitized_data[DataSplit.train][DataSegment.outcome][vars[VarType.label]] | |
| val_incidence = sanitized_data[DataSplit.val][DataSegment.outcome][vars[VarType.label]] | |
| test_incidence = sanitized_data[DataSplit.test][DataSegment.outcome][vars[VarType.label]] |
🤖 Prompt for AI Agents
In icu_benchmarks/data/split_process_data.py around lines 246 to 248, the
incidence stats use the wrong data splits (train_incidence is using
DataSplit.test); change train_incidence to read from DataSplit.train, keep
val_incidence using DataSplit.val, and test_incidence using DataSplit.test so
each metric pulls from its corresponding split.
| else: | ||
| raise ValueError(f"Run mode {self.run_mode} not yet supported. Please implement it.") | ||
| transformed_output = self.output_transform((prediction, target)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove duplicate output transform assignment
Line 448 reassigns transformed_output immediately after it was already assigned based on calibration status (lines 440-442).
else:
transformed_output = self.output_transform((prediction, target))
elif self.run_mode == RunMode.regression:
# Regression task
loss = self.loss(prediction[:, 0], target.float()) + aux_loss
+ transformed_output = self.output_transform((prediction, target))
else:
raise ValueError(f"Run mode {self.run_mode} not yet supported. Please implement it.")
- transformed_output = self.output_transform((prediction, target))📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| else: | |
| raise ValueError(f"Run mode {self.run_mode} not yet supported. Please implement it.") | |
| transformed_output = self.output_transform((prediction, target)) | |
| if self.run_mode == RunMode.classification: | |
| # Classification task (with optional calibration) | |
| if self.calibrated: | |
| prediction = self.calibrator.predict_proba(raw_prediction) | |
| transformed_output = self.output_transform((prediction, target)) | |
| else: | |
| transformed_output = self.output_transform((prediction, target)) | |
| elif self.run_mode == RunMode.regression: | |
| # Regression task | |
| loss = self.loss(prediction[:, 0], target.float()) + aux_loss | |
| transformed_output = self.output_transform((prediction, target)) | |
| else: | |
| raise ValueError(f"Run mode {self.run_mode} not yet supported. Please implement it.") |
🧰 Tools
🪛 Ruff (0.12.2)
447-447: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In icu_benchmarks/models/wrappers.py around lines 446 to 448, there is a
duplicate assignment to transformed_output immediately after it was already set
based on calibration status; remove the second assignment (the one after the
run_mode branch) so transformed_output is only assigned once from the earlier
calibration-aware branch, or if the intent was to use a different input there,
replace the duplicate with the correct variable—ensure transformed_output is
computed exactly once and used thereafter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
icu_benchmarks/data/split_process_data.py (2)
266-270: Cache the preprocessed data and honorgenerate_cacheCurrently writes raw
dataand gated byload_cache, so caches aren’t created whengenerate_cacheis True.- if generate_cache: - caching(cache_dir, cache_file, data, load_cache) + if generate_cache: + caching(cache_dir, cache_file, sanitized_data, use_cache=True)
287-317: Fix return type annotation to match actual returnFunction returns
(data, vars)but is annotated as a dict.-def check_sanitize_data(data: dict[str, pl.DataFrame], vars: dict[str, str | list[str]]) -> dict[str, pl.DataFrame]: # noqa: F811 +def check_sanitize_data( + data: dict[str, pl.DataFrame], vars: dict[str, str | list[str]] +) -> tuple[dict[str, pl.DataFrame], dict[str, str | list[str]]]: # noqa: F811
♻️ Duplicate comments (7)
icu_benchmarks/run_utils.py (3)
218-226: De-duplicate SHAP aggregation and log exceptions with tracebackThe two concat+write blocks are duplicates; also log with logging.exception to retain stack trace.
- if explain_features: - if explainer_values_test: - shap_values = pl.concat(explainer_values_test) - shap_values.write_parquet(log_dir / "aggregated_explainer_values.parquet") - try: - shap_values = pl.concat(explainer_values_test) - shap_values.write_parquet(log_dir / "aggregated_explainer_values.parquet") - except Exception as e: - logging.error(f"Error aggregating or writing SHAP values: {e}") + if explain_features and explainer_values_test: + try: + pl.concat(explainer_values_test).write_parquet(log_dir / "aggregated_explainer_values.parquet") + except Exception: + logging.exception("Error aggregating or writing explainer values")
233-235: Guard math.isnan for ints to avoid TypeError
math.isnan(int)raises TypeError.- if isinstance(score, (float, int)): - if math.isnan(score): + if isinstance(score, float): + if math.isnan(score): logging.warning(f"Score for metric {metric} is NaN, adding 0 instead.") score = 0
102-106: Fix collision handling (type error and weak uniqueness)
- String–int concat raises TypeError.
- Small range risks repeated collisions; add safer suffix and bounded retries.
- log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f")) - while log_dir_run.exists(): - log_dir_run = log_dir_run.with_name(log_dir_run.name + random.randint(1, 10)) + log_dir_run = log_dir / datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f") + for _ in range(50): + if not log_dir_run.exists(): + break + log_dir_run = log_dir_run.with_name(f"{log_dir_run.name}_{random.randint(1, 1000)}") + else: + # Last resort: timestamp + random salt + log_dir_run = log_dir / f"{datetime.now().strftime('%Y-%m-%dT%H-%M-%S-%f')}_{random.randint(1000, 9999)}"icu_benchmarks/data/split_process_data.py (4)
248-250: Use the correct split for train incidence
train_incidencereads fromtest. Usetrain.- train_incidence = sanitized_data[DataSplit.test][DataSegment.outcome][vars[VarType.label]] + train_incidence = sanitized_data[DataSplit.train][DataSegment.outcome][vars[VarType.label]]
23-24: Avoid redefining imported names in the same moduleThis file imports
check_sanitize_dataandmodality_selectionbut also defines them later (F811). Drop the imports and use the local definitions.-from .utils import check_sanitize_data, modality_selection
91-93: Include SEQUENCE in required_var_typesDownstream code indexes
vars[VarType.sequence]; omitting it will raise KeyError.- if required_var_types is None: - required_var_types = ["GROUP", "LABEL"] + if required_var_types is None: + required_var_types = ["GROUP", "LABEL", "SEQUENCE"]
213-244: Fix NaN/Inf sanitization: operate on the updated frame and replace Inf with max float for all numeric cols
- Current code updates
_dict[key]but keeps transforming the staledataframe.- Replaces Inf with
0and only for Float64.- _dict[key] = dataframe.fill_null(strategy="zero") - _dict[key] = dataframe.fill_nan(0) + dataframe = dataframe.fill_null(strategy="zero").fill_nan(0) @@ - logging.debug("Checking for infinite values.") - for col in dataframe.select(cs.numeric()).columns: - if dataframe[col].is_infinite().any(): - logging.warning(f"Column '{col}' contains infinite values. Datatype: {dataframe[col].dtype}") + logging.debug("Checking for infinite values.") + for col in dataframe.select(cs.numeric()).columns: + if dataframe[col].is_infinite().any(): + logging.warning(f"Column '{col}' contains infinite values. Datatype: {dataframe[col].dtype}") @@ - max_float64 = 0 - # Replace infinite values with the maximum value for float64 - dataframe = dataframe.with_columns( - [ - pl.when(pl.col(col).is_infinite()).then(max_float64).otherwise(pl.col(col)).alias(col) - for col in dataframe.columns - if dataframe[col].dtype == pl.Float64 - ] - ) - _dict[key] = dataframe + # Replace +/-inf in all numeric columns with max finite float64 + import numpy as np + max_float64 = np.finfo(np.float64).max + num_cols = dataframe.select(cs.numeric()).columns + dataframe = dataframe.with_columns([ + pl.when(pl.col(c).is_infinite()).then(pl.lit(max_float64)).otherwise(pl.col(c)).alias(c) + for c in num_cols + ]) + _dict[key] = dataframeBased on learnings (Polars selectors).
🧹 Nitpick comments (6)
icu_benchmarks/models/ml_models/lgbm.py (3)
40-49: Make background sampling configurable to cap SHAP costCurrent default uses the full train set as background, which can be very slow. Add a configurable cap (e.g., background_size=1000) and sample without replacement.
- reduce_samples = False - if reduce_samples: - n_samples = min(1000, len(train_data)) - indices = np.random.choice(len(train_data), size=n_samples, replace=False) - else: - indices = np.arange(len(train_data)) + background_size = getattr(self, "background_size", 1000) + n_samples = min(background_size, len(train_data)) + indices = np.random.choice(len(train_data), size=n_samples, replace=False)
59-61: Drop redundant/ineffectiveverbosekw and rely onverbosity
set_model_argsfilters by constructor signature;verboseisn’t a valid init kw for scikit LGBM and gets ignored.verbosityis already set to -1.- self.model = self.set_model_args(lgbm.LGBMClassifier, *args, **kwargs, verbose=-1) + self.model = self.set_model_args(lgbm.LGBMClassifier, *args, **kwargs)
7-9: Remove unused imports
os,wandb, andwandb_lgbmaren’t used.-import os -import wandb -from wandb.integration.lightgbm import wandb_callback as wandb_lgbmicu_benchmarks/run_utils.py (2)
163-179: Replace prints and NumPy conversion with Polars opsUse logging and Polars min/max to avoid materializing full columns.
- print(f"Loading file: {file_path}") + logging.info(f"Loading file: {file_path}") @@ - original_ids = df.select(pl.col(id_column)).to_numpy().flatten() - print(f" Original ID range: {original_ids.min()} - {original_ids.max()}") + min_id, max_id = df.select( + pl.col(id_column).min().alias("min_id"), + pl.col(id_column).max().alias("max_id") + ).row(0) + logging.info(f" Original ID range: {min_id} - {max_id}") @@ - modified_ids = df.select(pl.col(id_column)).to_numpy().flatten() - print(f" Modified ID range: {modified_ids.min()} - {modified_ids.max()}") - print(f" Counter: {counter}") - print() + new_min_id, new_max_id = df.select( + pl.col(id_column).min().alias("min_id"), + pl.col(id_column).max().alias("max_id") + ).row(0) + logging.info(f" Modified ID range: {new_min_id} - {new_max_id}; counter={counter}")
240-249: Avoid shadowing built-inlistin comprehensionsRename to
valsfor readability and to avoid confusion.- averaged_scores = {metric: (mean(list)) for metric, list in list_scores.items()} + averaged_scores = {metric: mean(vals) for metric, vals in list_scores.items()} @@ - std_scores = {metric: (pstdev(list) / sqrt(len(list))) for metric, list in list_scores.items()} + std_scores = {metric: pstdev(vals) / sqrt(len(vals)) for metric, vals in list_scores.items()} @@ - confidence_interval = { - metric: (stats.t.interval(0.95, len(list) - 1, loc=mean(list), scale=stats.sem(list))) - for metric, list in list_scores.items() - } + confidence_interval = { + metric: stats.t.interval(0.95, len(vals) - 1, loc=mean(vals), scale=stats.sem(vals)) + for metric, vals in list_scores.items() + }icu_benchmarks/data/split_process_data.py (1)
820-867: Optional: Vectorize stay reduction for performanceThe per-stay loop filters each stay individually (O(stays × rows)). Consider using groupby + cumcount to slice rows without Python loops.
I can provide a vectorized Polars variant if desired.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
icu_benchmarks/data/split_process_data.py(15 hunks)icu_benchmarks/models/ml_models/lgbm.py(4 hunks)icu_benchmarks/run_utils.py(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
icu_benchmarks/run_utils.py (2)
icu_benchmarks/utils.py (1)
parse_dict(5-22)icu_benchmarks/wandb_utils.py (1)
wandb_log(47-54)
icu_benchmarks/data/split_process_data.py (4)
icu_benchmarks/constants.py (1)
RunMode(4-7)icu_benchmarks/data/preprocessor.py (6)
PolarsClassificationPreprocessor(90-268)Preprocessor(36-87)PolarsClassificationPreprocessor(71-246)apply(563-591)apply(102-176)apply(355-402)icu_benchmarks/data/constants.py (3)
DataSegment(7-11)DataSplit(1-4)VarType(14-17)icu_benchmarks/data/utils.py (2)
check_sanitize_data(26-51)modality_selection(54-81)
icu_benchmarks/models/ml_models/lgbm.py (2)
icu_benchmarks/models/wrappers.py (4)
fit(683-714)fit(994-995)set_model_args(892-902)_explain_model(114-115)icu_benchmarks/models/ml_models/xgboost.py (2)
set_model_args(67-77)_explain_model(79-85)
🪛 Ruff (0.14.0)
icu_benchmarks/run_utils.py
104-104: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
188-188: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
225-225: Do not catch blind exception: Exception
(BLE001)
226-226: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
229-229: Loop control variable repetition not used within loop body
(B007)
230-230: Loop control variable fold not used within loop body
Rename unused fold to _fold
(B007)
icu_benchmarks/data/split_process_data.py
35-35: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
158-158: Avoid specifying long messages outside the exception class
(TRY003)
769-769: Avoid specifying long messages outside the exception class
(TRY003)
843-843: Consider [unchanged_dynamic, *reduced_dynamic_parts] instead of concatenation
Replace with [unchanged_dynamic, *reduced_dynamic_parts]
(RUF005)
865-865: Consider [unchanged_outcome, *reduced_outcome_parts] instead of concatenation
Replace with [unchanged_outcome, *reduced_outcome_parts]
(RUF005)
icu_benchmarks/models/ml_models/lgbm.py
77-77: Avoid specifying long messages outside the exception class
(TRY003)
| # Redirect stderr and stdout to suppress warnings | ||
| stderr_backup = sys.stderr | ||
| stdout_backup = sys.stdout | ||
| sys.stderr = StringIO() | ||
| sys.stdout = StringIO() | ||
| try: | ||
| self.model = self.model.fit( | ||
| train_data, | ||
| train_labels, | ||
| eval_set=(val_data, val_labels), | ||
| callbacks=callbacks, | ||
| ) | ||
| finally: | ||
| sys.stderr = stderr_backup | ||
| # to surpress: [Warning] No further splits with positive gain, best gain: -inf with the current config | ||
| sys.stdout = stdout_backup | ||
| reduce_samples = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix eval_set and avoid process-wide stdout/stderr hijack
- Pass eval_set as a list of tuples; the current tuple form can skip validation logging/metrics and leave best_score_ unset.
- Replace global sys redirection with scoped context managers to avoid hiding logs from other threads/process parts.
Apply:
+from contextlib import redirect_stdout, redirect_stderr
+from io import StringIO
@@
- # Redirect stderr and stdout to suppress warnings
- stderr_backup = sys.stderr
- stdout_backup = sys.stdout
- sys.stderr = StringIO()
- sys.stdout = StringIO()
- try:
- self.model = self.model.fit(
- train_data,
- train_labels,
- eval_set=(val_data, val_labels),
- callbacks=callbacks,
- )
- finally:
- sys.stderr = stderr_backup
- # to surpress: [Warning] No further splits with positive gain, best gain: -inf with the current config
- sys.stdout = stdout_backup
+ # Suppress LightGBM chatter locally
+ with redirect_stderr(StringIO()), redirect_stdout(StringIO()):
+ self.model = self.model.fit(
+ train_data,
+ train_labels,
+ eval_set=[(val_data, val_labels)],
+ callbacks=callbacks,
+ )Also applies to: 30-35
🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/lgbm.py around lines 24-40 (and also apply
same fix to lines 30-35), the code currently passes eval_set as a single tuple
and hijacks sys.stdout/sys.stderr by reassigning them; change eval_set to a list
of tuples (e.g. [(val_data, val_labels)]) so LightGBM records validation metrics
and best_score_, and replace the manual sys reassignment with scoped context
managers (use contextlib.redirect_stdout and contextlib.redirect_stderr or an
equivalent suppression context) so output suppression is limited to the fit call
and does not affect other threads or parts of the process; ensure stdout/stderr
are automatically restored by the context managers.
| def _explain_model(self, reps, labels): | ||
| if not hasattr(self.model, "feature_importances_"): | ||
| raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.") | ||
| # feature_importances = self.model.feature_importances_ | ||
| shap_values = self.explainer.shap_values(reps, labels) | ||
| # feature_importances = np.abs(shap_values).mean(axis=1) | ||
| return shap_values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Use a robust fitted-model check
Checking feature_importances_ is brittle. Prefer scikit’s check_is_fitted or LightGBM’s booster_.
+from sklearn.utils.validation import check_is_fitted
@@
- if not hasattr(self.model, "feature_importances_"):
- raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.")
+ check_is_fitted(self.model, attributes=["booster_"])🧰 Tools
🪛 Ruff (0.14.0)
77-77: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In icu_benchmarks/models/ml_models/lgbm.py around lines 75 to 81, the current
fitted-model check uses the brittle attribute feature_importances_; replace this
with a robust check by using
sklearn.utils.validation.check_is_fitted(self.model) (import it at the top) or,
for LightGBM-specific safety, check for hasattr(self.model, "booster_") or
callable(getattr(self.model, "booster_", None)); wrap the check in a try/except
to raise the same ValueError message if not fitted, then proceed to compute
shap_values and return them.
| # Process fold by fold across all iterations | ||
| for fold_name in sorted(all_folds): | ||
| for iteration in iterations: | ||
| fold_path = iteration / fold_name | ||
| if fold_path.exists(): | ||
| for file_path in sorted(fold_path.rglob(filename)): | ||
| print(f"Loading file: {file_path}") | ||
|
|
||
| # Load the CSV file | ||
| df = pl.read_csv(file_path) | ||
|
|
||
| # Check original ID range | ||
| original_ids = df.select(pl.col(id_column)).to_numpy().flatten() | ||
| print(f" Original ID range: {original_ids.min()} - {original_ids.max()}") | ||
|
|
||
| # Apply offset to prevent clashes | ||
| df = df.with_columns(pl.col(id_column) + counter * max_id) | ||
|
|
||
| # Check modified ID range | ||
| modified_ids = df.select(pl.col(id_column)).to_numpy().flatten() | ||
| print(f" Modified ID range: {modified_ids.min()} - {modified_ids.max()}") | ||
| print(f" Counter: {counter}") | ||
| print() | ||
| dataframes.append(df) | ||
| counter += 1 | ||
|
|
||
| # Concatenate all DataFrames vertically |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ID offset logic collides across iterations within the same fold
counter increments per fold, not per file; files in different iterations for the same fold receive the same offset.
- counter = 0
+ counter = 0
@@
- for fold_name in sorted(all_folds):
- for iteration in iterations:
+ for fold_idx, fold_name in enumerate(sorted(all_folds)):
+ for iter_idx, iteration in enumerate(iterations):
fold_path = iteration / fold_name
if fold_path.exists():
for file_path in sorted(fold_path.rglob(filename)):
- print(f"Loading file: {file_path}")
+ logging.info(f"Loading file: {file_path}")
@@
- # Apply offset to prevent clashes
- df = df.with_columns(pl.col(id_column) + counter * max_id)
+ # Apply offset to prevent clashes
+ df = df.with_columns(pl.col(id_column) + counter * max_id)
@@
- dataframes.append(df)
- counter += 1
+ dataframes.append(df)
+ counter += 1🤖 Prompt for AI Agents
In icu_benchmarks/run_utils.py around lines 157 to 183, the ID offset counter is
only incremented once per fold causing files from different iterations of the
same fold to get the same offset; move the counter increment to occur after each
file is processed (inside the file loop) or compute the offset from a unique
per-file index (e.g., iteration index + fold index) so each file receives a
distinct counter before applying df.with_columns(pl.col(id_column) + counter *
max_id); ensure the counter update happens immediately after appending the
modified df to dataframes so offsets do not collide.
Summary by CodeRabbit
New Features
Enhancements
Documentation
Chores