-
Notifications
You must be signed in to change notification settings - Fork 28
Description
Hello, I am working with a small data set using the SHEPHERD checkpoint model files to do causal gene discovery. I ran the shortest paths script with no issues, and am at the predict.py step:
python predict.py \ --run_type causal_gene_discovery \ --patient_data my_data \ --edgelist KG_edgelist_mask.txt \ --node_map KG_node_map.txt \ --saved_node_embeddings_path checkpoints/pretrain.ckpt \ --best_ckpt checkpoints/causal_gene_discovery.ckpt
Here are my logs from the run which includes the error:
Global seed set to 33 Predict hparams: {'seed': 33, 'n_gpus': 0, 'num_workers': 4, 'profiler': 'simple', 'pin_memory': False, 'time': False, 'log_gpu_memory': False, 'debug': False, 'augment_genes': True, 'n_sim_genes': 3, 'aug_gene_w': 0.5, 'wandb_save_dir': PosixPath('/mnt/ds-nas/projects/shepherd/data_download/wandb'), 'saved_checkpoint_path': PosixPath('/mnt/ds-nas/projects/shepherd/data_download/checkpoints/pretrain.ckpt'), 'test_n_cand_diseases': -1, 'candidate_disease_type': 'all_kg_nodes', 'only_hard_distractors': False, 'patient_similarity_type': 'gene', 'n_similar_patients': 2, 'model_type': 'aligner', 'loss': 'gene_multisimilarity', 'use_diseases': False, 'add_cand_diseases': False, 'add_similar_patients': False, 'wandb_project_name': 'causal-gene-discovery', 'train_data': PosixPath('/mnt/ds-nas/projects/shepherd/XXXXX/disease_split_train_sim_patients_8.9.21_kg.txt'), 'validation_data': PosixPath('/mnt/ds-nas/projects/shepherd/XXXXX/disease_split_val_sim_patients_8.9.21_kg.txt'), 'test_data': PosixPath('/mnt/ds-nas/projects/shepherd/XXXXX/shep_export0_cases_YYYYY.jsonl'), 'spl': '/mnt/ds-nas/projects/shepherd/genedx_data/patient_set_0_YYYYY_aggmean_spl_matrix.npy', 'spl_index': '/mnt/ds-nas/projects/shepherd/genedx_data/patient_set_0_YYYYY_spl_index_dict.pkl'} Loading SPL... Loaded SPL information Dataset filepath: /mnt/ds-nas/projects/shepherd/genedx_data/shep_export0_YYYYY.jsonl Number of patients: 100 Finished initalizing dataset There are 100 patients in the test dataset batch size: 100 Traceback (most recent call last): File "predict.py", line 174, in <module> predict(args) File "/mnt/fb1/home/vustach/anaconda3/envs/shepherd/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, **kwargs) File "predict.py", line 116, in predict dataloader = PatientNeighborSampler('predict', all_data.edge_index, all_data.edge_index[:,all_data.test_mask], File "../shepherd/samplers.py", line 296, in __init__ if hparams["alpha"] < 1: self.gp_spl = gp_spl KeyError: 'alpha'
I don't see alpha in the hparams:
Lines 208 to 239 in e61281f
| def get_predict_hparams(args): | |
| hparams = { | |
| 'seed': 33, | |
| 'n_gpus': 0, # NOTE: currently predict scripts only work with CPU | |
| 'num_workers': 4, | |
| 'profiler': 'simple', | |
| 'pin_memory': False, | |
| 'time': False, | |
| 'log_gpu_memory': False, | |
| 'debug': False, | |
| 'augment_genes': True, | |
| 'n_sim_genes': 3, | |
| 'aug_gene_w': 0.5, | |
| 'wandb_save_dir' : project_config.PROJECT_DIR / 'wandb', | |
| 'saved_checkpoint_path': project_config.PROJECT_DIR / f'{args.saved_node_embeddings_path}', | |
| 'test_n_cand_diseases': -1, | |
| 'candidate_disease_type': 'all_kg_nodes', | |
| 'only_hard_distractors': False, # Flag when true only uses the curated hard distractors at train time | |
| 'patient_similarity_type': 'gene', # How we determine labels for similar patients in "Patients Like Me" | |
| 'n_similar_patients': 2, # (Patients Like Me only) Number of patients with the same gene/disease that we add to the batch | |
| } | |
| # Get hyperparameters based on run type arguments | |
| hparams = get_run_type_args(args, hparams) | |
| hparams.update({'add_similar_patients' : False}) | |
| hparams = get_patient_data_args(args, hparams) | |
| print('Predict hparams: ', hparams) | |
| return hparams |
Please let me know if you can reproduce my error. Is there a route to setting alpha, and if so, what value do you suggest?