-
Notifications
You must be signed in to change notification settings - Fork 291
Description
I recently received the error below when trying to train using GraphCL on an in house dataset. I believe this issue is because when EdgePerturbation is added to views_fn, the default constructor assigns drop=False. In the do_trans function, since self.drop is False by default, idx_remain is never assigned but it is called a few lines below that.
test_acc_mean, test_acc_std = evaluator.evaluate(learning_model=graphcl, encoder=encoder, pred_head=pred_head)
File "/home/.local/lib/python3.9/site-packages/dig/sslgraph/evaluation/eval_graph.py", line 395, in evaluate
encoder = next(learning_model.train(encoder, pretrain_loader, p_optimizer, self.p_epoch))
File "/home/.local/lib/python3.9/site-packages/dig/sslgraph/method/contrastive/model/graphcl.py", line 70, in train
for enc, proj in super(GraphCL, self).train(encoders, data_loader,
File "/home/.local/lib/python3.9/site-packages/dig/sslgraph/method/contrastive/model/contrastive.py", line 139, in train
for enc in train_fn(encoder, data_loader, optimizer, epochs):
File "/home/.local/lib/python3.9/site-packages/dig/sslgraph/method/contrastive/model/contrastive.py", line 173, in train_encoder_graph
views = [v_fn(data) for v_fn in self.views_fn]
File "/home/.local/lib/python3.9/site-packages/dig/sslgraph/method/contrastive/model/contrastive.py", line 173, in
views = [v_fn(data) for v_fn in self.views_fn]
File "/home/.local/lib/python3.9/site-packages/dig/sslgraph/method/contrastive/views_fn/structure.py", line 24, in call
return self.views_fn(data)
File "/home/.local/lib/python3.9/site-packages/dig/sslgraph/method/contrastive/views_fn/structure.py", line 53, in views_fn
dlist = [self.do_trans(d) for d in data.to_data_list()]
File "/home/.local/lib/python3.9/site-packages/dig/sslgraph/method/contrastive/views_fn/structure.py", line 53, in
dlist = [self.do_trans(d) for d in data.to_data_list()]
File "/home/.local/lib/python3.9/site-packages/dig/sslgraph/method/contrastive/views_fn/structure.py", line 39, in do_trans
new_edge_index = torch.cat((idx_remain, idx_add), dim=1)
UnboundLocalError: local variable 'idx_remain' referenced before assignment