Skip to content
Open

Xai #123

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
149 commits
Select commit Hold shift + click to select a range
76c6097
had to comment out pytorch-cuda otherwise getting version error envi…
Jun 13, 2023
c7657c3
added how to enviroment variable PYTORCH_ENABLE_MPS_FALLBACK=1
Jun 13, 2023
cf8696e
type casted float64 to float32 to make it work with MPS
Jun 13, 2023
c08ed2c
added components of tft
Jun 20, 2023
0a5caff
added tft
Jun 20, 2023
52d0daa
setting variables for tft and tuning
Jun 22, 2023
8e812a0
modifications to fit in the tft and comments
Jun 22, 2023
92404d5
cleaned gin file
Jun 22, 2023
bed9b87
changed output to logits to match num of classes
Jun 22, 2023
9bda769
added jit import
Jun 22, 2023
960320b
added quantile loss
Jun 22, 2023
3cb823d
added new loader to work with tft but still needs work
Jun 22, 2023
f823e29
added dataloader for tft with correct order
Jun 27, 2023
742905f
tft
Jun 27, 2023
bb7ecc0
added quantile loss
Jun 27, 2023
c4085ef
added gin configurable for quantile loss
Jun 27, 2023
fa7e4b9
remove gin config for quantile loss
Jun 27, 2023
da9725b
added quantile loss
Jun 27, 2023
7c3d511
removed quantile loss
Jun 27, 2023
d410d53
Update ci.yml
youssefmecky96 Jun 27, 2023
38f9dee
Update ci.yml
youssefmecky96 Jun 27, 2023
bd0a9a5
fixing bugs to get tft to work
Jun 27, 2023
dda1a00
Merge branch 'XAI' of https://github.com/rvandewater/YAIB into XAI
Jun 27, 2023
f6d1272
fixing bugs to run tft for classification
Jun 27, 2023
0c44d2f
fixing bugs to run tft for classification
Jun 27, 2023
9963d7b
fixing bugs to run tft
Jun 27, 2023
4a462d2
fixed return shape of data
Jun 29, 2023
1114690
fixed return shape of data
Jun 29, 2023
2a89d21
bug fixes and checks
Jul 4, 2023
b757c66
fixed concat to be along the correct axis
Jul 4, 2023
273b9ae
regression is working now with quantile loss
Jul 4, 2023
5b97aac
changed gin file
Jul 4, 2023
f767499
fixing bugs to get tftpytorch to work
Jul 11, 2023
e72927a
new gin files
Jul 11, 2023
760a25d
created own set_weights for tft pytorch
Jul 11, 2023
a356df6
overrode get_features_names for tftpytorch dataloader
Jul 11, 2023
2e12592
fixed datatypes for old tft and working on new tft
Jul 18, 2023
c0dd05f
fixed datatypes for old tft and working on new tft
Jul 18, 2023
c7b76e6
Merge branch 'XAI' of https://github.com/rvandewater/YAIB into XAI
Aug 1, 2023
96f0400
Merge branch 'XAI' of https://github.com/rvandewater/YAIB into XAI
Aug 1, 2023
9e52574
latest changes
Aug 1, 2023
51f5bf5
Merge branch 'XAI' of https://github.com/rvandewater/YAIB into XAI
Aug 1, 2023
803d677
latest
Aug 8, 2023
c2fcc74
added range for layers and attention heads
Aug 14, 2023
5f18f72
formatted and added comments
Aug 14, 2023
9ef8252
latest
Sep 7, 2023
9b5a6e9
Merge branch 'development' into XAI
Sep 7, 2023
3c9b729
reproducible forced to be there for evaulate
Sep 12, 2023
03ad5c2
added v1 to load model
Sep 12, 2023
537bdc5
Merge branch 'development' into XAI
youssefmecky96 Sep 12, 2023
be61df8
merged
youssefmecky96 Sep 12, 2023
42815d7
match original
youssefmecky96 Sep 12, 2023
b071520
readded tft dataloader
youssefmecky96 Sep 12, 2023
778c619
readded ordered dict
youssefmecky96 Sep 12, 2023
3bcb969
merged with dev
youssefmecky96 Sep 12, 2023
9d4f4a7
changed if condition for loading models
Sep 13, 2023
a15fd40
reformatted
youssefmecky96 Sep 13, 2023
54563e4
added scheduler removed lr
Sep 18, 2023
7399202
removed lr as a required input
Sep 18, 2023
1d4e301
Merge branch 'XAI' of https://github.com/rvandewater/YAIB into XAI
Sep 18, 2023
bf330d4
removed extra test dataloader
Sep 18, 2023
bf47597
added name to tft loader
youssefmecky96 Sep 25, 2023
69c0287
reduced cv repetitions to 2
youssefmecky96 Sep 25, 2023
4196779
added explain flag
youssefmecky96 Sep 26, 2023
0e31937
added explain flag
youssefmecky96 Sep 26, 2023
008aa84
added explain flag
youssefmecky96 Sep 26, 2023
f9ac9ea
added explain flag
youssefmecky96 Sep 26, 2023
3bf60c7
cleaned conditons and added flag for pytorch forecasting
youssefmecky96 Sep 26, 2023
ec011c8
added pytorch tft built in methods
youssefmecky96 Sep 26, 2023
2f1b74a
added coma
youssefmecky96 Sep 26, 2023
86d719a
added attention explantation for tft
youssefmecky96 Sep 26, 2023
00d78dd
changed forward method for tft to allow captum explantations
youssefmecky96 Sep 27, 2023
e877c5f
cleaned files up
Sep 29, 2023
25fedaf
added predict_mode
youssefmecky96 Oct 2, 2023
e120465
added pytorch flag
youssefmecky96 Oct 2, 2023
b4d6b42
added LSTM from pytorch
youssefmecky96 Oct 2, 2023
ec8346d
added RNN implementation from pytorch forecasting and changed dataloa…
youssefmecky96 Oct 2, 2023
e3f6058
added RNN implementation from pytorch forecasting and changed dataloa…
youssefmecky96 Oct 2, 2023
6c05b22
changes to get RNN to work
youssefmecky96 Oct 3, 2023
5c3d6a3
changes to get RNN to work
youssefmecky96 Oct 3, 2023
a5d2a99
changes to get RNN to work
youssefmecky96 Oct 3, 2023
56b0deb
removed grad clipping
Oct 5, 2023
8880ade
small edits
Oct 9, 2023
869a611
changed 0/1 to categorical vars
Oct 9, 2023
ce0831c
changed input of rnn
Oct 9, 2023
2828a3a
latest changes
youssefmecky96 Oct 10, 2023
fec90cf
added quantus explantation and plot
youssefmecky96 Oct 11, 2023
391c7f5
added back gpu line
youssefmecky96 Oct 11, 2023
8746996
quantus looks like it is working
Oct 17, 2023
b7be9f6
added deeparmodel
Oct 24, 2023
8994345
deepar fixes
Oct 24, 2023
986e11f
minor changes
Oct 25, 2023
0ed2fbc
added XAI metric flag
youssefmecky96 Oct 25, 2023
d42bd73
added XAI metric part
youssefmecky96 Oct 25, 2023
46eea23
changes for quantus
youssefmecky96 Oct 25, 2023
16cc5c8
added faithfulness correlation manuelly
Oct 31, 2023
85216b7
fixed a few bugs to get the faithfulness correlation to work
Oct 31, 2023
119dbcf
added pytorch forecasting wrapper to make it easier
Oct 31, 2023
d03deed
added pytorch forecasting wrapper to make it easier
Oct 31, 2023
3c78359
added data randomization test and fixed few issues with wrapper
Oct 31, 2023
8972c12
removed check for pytorch forecasting from DLwrapper
Oct 31, 2023
bd8c2e5
similiary function from captum
Oct 31, 2023
d4a55ff
used similarty functions already defined
Oct 31, 2023
edd4147
latest changes
Nov 2, 2023
176120f
changed attribution to take into account all timesteps
Nov 2, 2023
edc4ad2
added attention method for randomization
Nov 2, 2023
2be5076
added baseline for faithfulness attribution
Nov 7, 2023
f5e8a3d
added option to load model trained with random labels
Nov 7, 2023
78146cc
added shapley explantations
Nov 7, 2023
48dc500
latest changes
Nov 8, 2023
446ceaa
latest
Nov 11, 2023
73d3909
removed label from input for tft
Nov 13, 2023
2544604
labels for los not shown to RNN and deepAR
Nov 14, 2023
f435717
added shap, sailency explantations and logged values
Nov 15, 2023
ce078a1
renamed dataloader , rescaled target, removed target scale, target c…
Nov 20, 2023
dd3bdac
changed gin to match DL
Nov 20, 2023
2defb37
changes to accomadte for los attribution, code more reusable now
Nov 20, 2023
4c8ee4a
changed to allow for timesteps and variable per timestep attribution
Nov 21, 2023
d3ea3d3
latest changes
Nov 21, 2023
54a6f18
latest
Nov 24, 2023
ab115fd
cleaned and fixed aggreagtion along batch
Nov 28, 2023
64fb77f
added data randomization test
Nov 29, 2023
42ef299
added mask for feature ablation
Nov 29, 2023
7c87164
plotting , added HP range , fixed distance range
Dec 4, 2023
08c3c31
added stability metric
Dec 4, 2023
4893f3e
attention ROS RIS
Dec 5, 2023
38ed6e3
added condition for ROS and RIS
Dec 13, 2023
a39b5c4
added condition for ROS and RIS for attention
Dec 13, 2023
daffaa2
added flags to add explantation methods and metrics
Dec 24, 2023
a9ec05e
changes to allow easier run
Dec 27, 2023
1534c13
fixes
Dec 27, 2023
1586fbb
:Revert "added condition for ROS and RIS for attention"
Dec 28, 2023
8b39c6e
reverted changes
Dec 28, 2023
21bf351
added choice of normalizer
Dec 28, 2023
31c8176
small fixes
Dec 28, 2023
07cbc3c
formatted
youssefmecky96 Mar 7, 2024
041b38b
formatting
youssefmecky96 Mar 7, 2024
2ab11a5
formatting
youssefmecky96 Mar 7, 2024
1f904dc
cleaning up and formatting
youssefmecky96 Mar 8, 2024
843c7e8
cleaning up and formatting
youssefmecky96 Mar 8, 2024
fe62ee2
formatting
youssefmecky96 Mar 8, 2024
6782d4f
restored tft gin file
Apr 29, 2024
3670bd6
changes to get tft to work with yaib dataloader
May 13, 2024
8e6db71
infered datatypes from dict supplied through gin
May 15, 2024
97d41f7
adjusted forward method in tft to handle input as model expects
May 17, 2024
afab7a4
changes to get tft to work
May 23, 2024
f639053
changed size check and sliced along the correct dimension
May 23, 2024
039377a
added changes to get pytorch forecasting to wrok and gradient clipping
Jun 3, 2024
2b509d2
config file for tft
Jun 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ jobs:
# - name: Setup package
# run: pip install -e .
# - name: Test command line tool
# run: python -m icu_benchmarks.run --help
# run: python -m icu_benchmarks.run --help
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ icu-benchmarks \
> For a list of available flags, run `icu-benchmarks train -h`.

> Run with `PYTORCH_ENABLE_MPS_FALLBACK=1` on Macs with Metal Performance Shaders.
> Can set conda enviroment variable by running `conda env config vars set PYTORCH_ENABLE_MPS_FALLBACK=1`

[//]: # (> Please note that, for Windows based systems, paths need to be formatted differently, e.g: ` r"\..\data\mortality_seq\hirid"`.)
> For Windows based systems, the next line character (\\) needs to be replaced by (^) (Command Prompt) or (`) (Powershell)
Expand Down
143 changes: 143 additions & 0 deletions configs/prediction_models/DeepARpytorch.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Settings for DeepAR model.

# Common settings for DL models
include "configs/prediction_models/common/DLCommon.gin"

# Optimizer params
train_common.model = @DeepARpytorch

optimizer/hyperparameter.class_to_tune = @Adam
optimizer/hyperparameter.weight_decay = 1e-6
optimizer/hyperparameter.lr = 0.00046
#(1e-5, 1e-3)
# Model params

model/hyperparameter.class_to_tune = @DeepARpytorch
model/hyperparameter.hidden = 116
#(4, 120, "log-uniform", 2)
model/hyperparameter.rnn_layers=1

model/hyperparameter.num_classes = %NUM_CLASSES
model/hyperparameter.cell_type='LSTM'
model/hyperparameter.dropout = (0.286968842146375, 0.29)
#model/hyperparameter.lr_scheduler = "exponential"

train_common/hyperparameter.class_to_tune = @train_common
train_common/hyperparameter.batch_size=256
train_common/hyperparameter.gradient_clip_val=0.01
#(0,0.01, 1.0, 100.0)
# Dataset params
PredictionDatasetpytorch.max_encoder_length = 24
PredictionDatasetpytorch.max_prediction_length = 1
PredictionDatasetpytorch.time_varying_known_reals=[]
PredictionDatasetpytorch.add_relative_time_idx=False
PredictionDatasetpytorch.target=[
"alb",
"alp",
"alt",
"ast",
"be",
"bicar",
"bili",
"bili_dir",
"bnd",
"bun",
"ca",
"cai",
"ck",
"ckmb",
"cl",
"crea",
"crp",
"dbp",
"fgn",
"fio2",
"glu",
"hgb",
"hr",
"inr_pt",
"k",
"lact",
"lymph",
"map",
"mch",
"mchc",
"mcv",
"methb",
"mg",
"na",
"neut",
"o2sat",
"pco2",
"ph",
"phos",
"plt",
"po2",
"ptt",
"resp",
"sbp",
"temp",
"tnt",
"urine",
"wbc",
"label",

]

PredictionDatasetpytorch.time_varying_unknown_reals=[
"alb",
"alp",
"alt",
"ast",
"be",
"bicar",
"bili",
"bili_dir",
"bnd",
"bun",
"ca",
"cai",
"ck",
"ckmb",
"cl",
"crea",
"crp",
"dbp",
"fgn",
"fio2",
"glu",
"hgb",
"hr",
"inr_pt",
"k",
"lact",
"lymph",
"map",
"mch",
"mchc",
"mcv",
"methb",
"mg",
"na",
"neut",
"o2sat",
"pco2",
"ph",
"phos",
"plt",
"po2",
"ptt",
"resp",
"sbp",
"temp",
"tnt",
"urine",
"wbc",
"label",

]
PredictionDatasetpytorch.time_varying_unknown_categoricals=[]
PredictionDatasetpytorch.lagged_variables=[]
PredictionDatasetpytorch.targetnormalizer='multi'


142 changes: 142 additions & 0 deletions configs/prediction_models/RNNpytorch.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Settings for RNN model.

# Common settings for DL models
include "configs/prediction_models/common/DLCommon.gin"

# Optimizer params
train_common.model = @RNNpytorch

optimizer/hyperparameter.class_to_tune = @Adam
optimizer/hyperparameter.weight_decay = 1e-6
optimizer/hyperparameter.lr = 0.00041

# Model params
model/hyperparameter.class_to_tune = @RNNpytorch
model/hyperparameter.hidden = 214
#(2, 64, "log-uniform", 2)
model/hyperparameter.rnn_layers=1
#(1,3)
model/hyperparameter.num_classes = %NUM_CLASSES
model/hyperparameter.cell_type='LSTM'
model/hyperparameter.dropout = (0.244, 0.2441)
#(0.0, 0.4)
#model/hyperparameter.lr_scheduler = "exponential"

train_common/hyperparameter.class_to_tune = @train_common
train_common/hyperparameter.batch_size=256
#(32,64,128,256,512)
train_common/hyperparameter.gradient_clip_val=100.0
#(0,0.01, 1.0, 100.0)
# Dataset params
PredictionDatasetpytorch.max_encoder_length = 24
PredictionDatasetpytorch.max_prediction_length = 1
PredictionDatasetpytorch.time_varying_known_reals=[ ]
PredictionDatasetpytorch.add_relative_time_idx=False
PredictionDatasetpytorch.target=[
"alb",
"alp",
"alt",
"ast",
"be",
"bicar",
"bili",
"bili_dir",
"bnd",
"bun",
"ca",
"cai",
"ck",
"ckmb",
"cl",
"crea",
"crp",
"dbp",
"fgn",
"fio2",
"glu",
"hgb",
"hr",
"inr_pt",
"k",
"lact",
"lymph",
"map",
"mch",
"mchc",
"mcv",
"methb",
"mg",
"na",
"neut",
"o2sat",
"pco2",
"ph",
"phos",
"plt",
"po2",
"ptt",
"resp",
"sbp",
"temp",
"tnt",
"urine",
"wbc",
"label",

]

PredictionDatasetpytorch.time_varying_unknown_reals=[
"alb",
"alp",
"alt",
"ast",
"be",
"bicar",
"bili",
"bili_dir",
"bnd",
"bun",
"ca",
"cai",
"ck",
"ckmb",
"cl",
"crea",
"crp",
"dbp",
"fgn",
"fio2",
"glu",
"hgb",
"hr",
"inr_pt",
"k",
"lact",
"lymph",
"map",
"mch",
"mchc",
"mcv",
"methb",
"mg",
"na",
"neut",
"o2sat",
"pco2",
"ph",
"phos",
"plt",
"po2",
"ptt",
"resp",
"sbp",
"temp",
"tnt",
"urine",
"wbc",
"label",

]
PredictionDatasetpytorch.time_varying_unknown_categoricals=[]
PredictionDatasetpytorch.lagged_variables=[]
PredictionDatasetpytorch.targetnormalizer='multi'
Loading