Skip to content

CONCH encoder embedding #290

Open
Open
@niuhulu-rui

Description

I want to use the CONCH encoder, and I used the following script:

CUDA_VISIBLE_DEVICES=0 python main.py --drop_out 0.25 --early_stopping --lr 2e-4 --k 10 --exp_code task_1_tumor_vs_normal_CLAM_conch_sb --weighted_sample --bag_loss ce --inst_loss svm --task task_1_tumor_vs_normal --model_type clam_sb --log_data --data_root_dir features_conch/ --embed_dim 512

Why is the following error occurring?

Load Dataset
label column: label
label dictionary: {'normal': 0, 'tumor': 1}
number of classes: 2
slide-level counts:
label
1 44
0 34
Name: count, dtype: int64
Patient-LVL; Number of samples registered in class 0: 34
Slide-LVL; Number of samples registered in class 0: 34
Patient-LVL; Number of samples registered in class 1: 44
Slide-LVL; Number of samples registered in class 1: 44
split_dir: splits/task_1_tumor_vs_normal_100
################# Settings ###################
num_splits: 10
k_start: -1
k_end: -1
task: task_1_tumor_vs_normal
max_epochs: 200
results_dir: ./results
lr: 0.0002
experiment: task_1_tumor_vs_normal_CLAM_conch_sb
reg: 1e-05
label_frac: 1.0
bag_loss: ce
seed: 1
model_type: clam_sb
model_size: small
use_drop_out: 0.25
weighted_sample: True
opt: adam
bag_weight: 0.7
inst_loss: svm
B: 8
split_dir: splits/task_1_tumor_vs_normal_100

Training Fold 0!

Init train/val/test splits...
Done!
Training on 64 samples
Validating on 7 samples
Testing on 7 samples

Init loss function... Done!

Init Model... Setting tau to 1.0
Done!
CLAM_SB(
(attention_net): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): ReLU()
(2): Dropout(p=0.25, inplace=False)
(3): Attn_Net_Gated(
(attention_a): Sequential(
(0): Linear(in_features=512, out_features=256, bias=True)
(1): Tanh()
(2): Dropout(p=0.25, inplace=False)
)
(attention_b): Sequential(
(0): Linear(in_features=512, out_features=256, bias=True)
(1): Sigmoid()
(2): Dropout(p=0.25, inplace=False)
)
(attention_c): Linear(in_features=256, out_features=1, bias=True)
)
)
(classifiers): Linear(in_features=512, out_features=2, bias=True)
(instance_classifiers): ModuleList(
(0-1): 2 x Linear(in_features=512, out_features=2, bias=True)
)
(instance_loss_fn): SmoothTop1SVM()
)
Total number of parameters: 528647
Total number of trainable parameters: 528647

Init optimizer ... Done!

Init Loaders... Done!

Setup EarlyStopping... Done!

Traceback (most recent call last):
File "/data2/project/CLAM-master/main.py", line 213, in
results = main(args)
File "/data2/project/CLAM-master/main.py", line 52, in main
results, test_auc, val_auc, test_acc, val_acc = train(datasets, i, args)
File "/data2/project/CLAM-master/utils/core_utils.py", line 185, in train
train_loop_clam(epoch, model, train_loader, optimizer, args.n_classes, args.bag_weight, writer, loss_fn)
File "/data2/project/CLAM-master/utils/core_utils.py", line 237, in train_loop_clam
logits, Y_prob, Y_hat, _, instance_dict = model(data, label=label, instance_eval=True)
File "/data2/anaconda3/envs/clam_lateat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data2/anaconda3/envs/clam_lateat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/data2/project/CLAM-master/models/model_clam.py", line 139, in forward
A, h = self.attention_net(h) # NxK
File "/data2/anaconda3/envs/clam_lateat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data2/anaconda3/envs/clam_lateat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/data2/anaconda3/envs/clam_lateat/lib/python3.10/site-packages/torch/nn/modules/container.py", line 219, in forward
input = module(input)
File "/data2/anaconda3/envs/clam_lateat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data2/anaconda3/envs/clam_lateat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/data2/anaconda3/envs/clam_lateat/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 117, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1628x1024 and 512x512)

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions