Description
I want to use the CONCH encoder, and I used the following script:
CUDA_VISIBLE_DEVICES=0 python main.py --drop_out 0.25 --early_stopping --lr 2e-4 --k 10 --exp_code task_1_tumor_vs_normal_CLAM_conch_sb --weighted_sample --bag_loss ce --inst_loss svm --task task_1_tumor_vs_normal --model_type clam_sb --log_data --data_root_dir features_conch/ --embed_dim 512
Why is the following error occurring?
Load Dataset
label column: label
label dictionary: {'normal': 0, 'tumor': 1}
number of classes: 2
slide-level counts:
label
1 44
0 34
Name: count, dtype: int64
Patient-LVL; Number of samples registered in class 0: 34
Slide-LVL; Number of samples registered in class 0: 34
Patient-LVL; Number of samples registered in class 1: 44
Slide-LVL; Number of samples registered in class 1: 44
split_dir: splits/task_1_tumor_vs_normal_100
################# Settings ###################
num_splits: 10
k_start: -1
k_end: -1
task: task_1_tumor_vs_normal
max_epochs: 200
results_dir: ./results
lr: 0.0002
experiment: task_1_tumor_vs_normal_CLAM_conch_sb
reg: 1e-05
label_frac: 1.0
bag_loss: ce
seed: 1
model_type: clam_sb
model_size: small
use_drop_out: 0.25
weighted_sample: True
opt: adam
bag_weight: 0.7
inst_loss: svm
B: 8
split_dir: splits/task_1_tumor_vs_normal_100
Training Fold 0!
Init train/val/test splits...
Done!
Training on 64 samples
Validating on 7 samples
Testing on 7 samples
Init loss function... Done!
Init Model... Setting tau to 1.0
Done!
CLAM_SB(
(attention_net): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): ReLU()
(2): Dropout(p=0.25, inplace=False)
(3): Attn_Net_Gated(
(attention_a): Sequential(
(0): Linear(in_features=512, out_features=256, bias=True)
(1): Tanh()
(2): Dropout(p=0.25, inplace=False)
)
(attention_b): Sequential(
(0): Linear(in_features=512, out_features=256, bias=True)
(1): Sigmoid()
(2): Dropout(p=0.25, inplace=False)
)
(attention_c): Linear(in_features=256, out_features=1, bias=True)
)
)
(classifiers): Linear(in_features=512, out_features=2, bias=True)
(instance_classifiers): ModuleList(
(0-1): 2 x Linear(in_features=512, out_features=2, bias=True)
)
(instance_loss_fn): SmoothTop1SVM()
)
Total number of parameters: 528647
Total number of trainable parameters: 528647
Init optimizer ... Done!
Init Loaders... Done!
Setup EarlyStopping... Done!
Traceback (most recent call last):
File "/data2/project/CLAM-master/main.py", line 213, in
results = main(args)
File "/data2/project/CLAM-master/main.py", line 52, in main
results, test_auc, val_auc, test_acc, val_acc = train(datasets, i, args)
File "/data2/project/CLAM-master/utils/core_utils.py", line 185, in train
train_loop_clam(epoch, model, train_loader, optimizer, args.n_classes, args.bag_weight, writer, loss_fn)
File "/data2/project/CLAM-master/utils/core_utils.py", line 237, in train_loop_clam
logits, Y_prob, Y_hat, _, instance_dict = model(data, label=label, instance_eval=True)
File "/data2/anaconda3/envs/clam_lateat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data2/anaconda3/envs/clam_lateat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/data2/project/CLAM-master/models/model_clam.py", line 139, in forward
A, h = self.attention_net(h) # NxK
File "/data2/anaconda3/envs/clam_lateat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data2/anaconda3/envs/clam_lateat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/data2/anaconda3/envs/clam_lateat/lib/python3.10/site-packages/torch/nn/modules/container.py", line 219, in forward
input = module(input)
File "/data2/anaconda3/envs/clam_lateat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data2/anaconda3/envs/clam_lateat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/data2/anaconda3/envs/clam_lateat/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 117, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1628x1024 and 512x512)