-
Notifications
You must be signed in to change notification settings - Fork 291
Open
Description
Hi all,
I installed DIG and tried to explain the graph model that I have built using PyG as shown below:
`class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels):
super(GCN, self).__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.fc = torch.nn.Linear(hidden_channels, 1) # Output 1 probability for binary classification
def forward(self, x, edge_index, batch=None):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
if batch is not None:
# Global mean pooling to aggregate node features
x = global_mean_pool(x, batch)
else:
# If batch is not provided, we assume it's a single graph, so aggregate all nodes
x = x.mean(dim=0, keepdim=True)
x = self.fc(x) # Final classification layer
return x.squeeze() # Return logits for binary classification
model = GCN(in_channels=10, hidden_channels=16) # No need for out_channels, output is single probability
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.BCEWithLogitsLoss() # Use BCE loss with logits for binary classification`
When I call the SubgraphX explainer with the following commands:
`from dig.xgraph.method import SubgraphX
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
explainer = SubgraphX(model, num_classes=1, device=device, explain_graph=True,
reward_method='gnn_score')
explainer(val_data[0].x,val_data[0].edge_index)`
I get the error, (TypeError: GCN.forward() got an unexpected keyword argument 'data'), during the computation of the scores.
Do you have any idea how to resolve this error?
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels