diff --git a/dig/sslgraph/method/contrastive/views_fn/structure.py b/dig/sslgraph/method/contrastive/views_fn/structure.py index e9d0ff97..f24484f7 100644 --- a/dig/sslgraph/method/contrastive/views_fn/structure.py +++ b/dig/sslgraph/method/contrastive/views_fn/structure.py @@ -32,7 +32,9 @@ def do_trans(self, data): if self.drop: idx_remain = dropout_adj(data.edge_index, p=self.ratio)[0] - + else: + idx_remain = data.edge_index + if self.add: idx_add = torch.randint(node_num, (2, perturb_num), device=device)