Open
Description
Hi, thank you for this helpful library. I am trying to run the Molecular Graph Classification example with a few changes but I am unable to understand the concat op error that I get, below are some details for my specific requirements,
- Updates to graph spec: My requirement is for a regression output rather than a classification output, so I update the dtype of 'label' to tf.float32. Further, my node set features has size
(None, 9)
as opposed to(None, 7)
in the example. No other changes are made. I also follow the documentation to write the graph to TFRecord format. This allows me to use the example code for input without changes. Below is the full graph spec as I define it:
graph_tensor_spec = tfgnn.GraphTensorSpec.from_piece_specs(
context_spec=tfgnn.ContextSpec.from_field_specs(features_spec={
'label': tf.TensorSpec(shape=(1,), dtype=tf.float32)
}),
node_sets_spec={
'atoms':
tfgnn.NodeSetSpec.from_field_specs(
features_spec={
tfgnn.HIDDEN_STATE:
tf.TensorSpec((None, 9), tf.float32)
},
sizes_spec=tf.TensorSpec((1,), tf.int32))
},
edge_sets_spec={
'bonds':
tfgnn.EdgeSetSpec.from_field_specs(
features_spec={
tfgnn.HIDDEN_STATE:
tf.TensorSpec((None, 4), tf.float32)
},
sizes_spec=tf.TensorSpec((1,), tf.int32),
adjacency_spec=tfgnn.AdjacencySpec.from_incident_node_sets(
'atoms', 'atoms'))
})
- I use the same
_build_model
function as provided with the example from the notebook with no changes. - I update the loss and metrics to
tf.keras.losses.MeanSquaredError
for loss andtf.keras.metrics.RootMeanSquaredError
for metric, and pass the two tomodel.compile
appropriately. Since the output of the model is a linear output, I use the logit output as-is with no changes. - I do not have a validation set as this is only a POC to learn the TF-GNN API.
When I run the model.fit
function, this is the error I get (truncated for last few line):
Node: 'model_21/graph_update_43/node_set_update_43/simple_conv_44/concat'
ConcatOp : Dimension 0 in both shapes must be equal: shape[0] = [1218,8] vs. shape[1] = [2436,16]
[[{{node model_21/graph_update_43/node_set_update_43/simple_conv_44/concat}}]] [Op:__inference_train_function_125742]
After some trial-and-error, I noticed that commenting out sender_edge_feature=tfgnn.HIDDEN_STATE
in SimpleConv
for message passing allows the training to work without error, but as I understand the documentation, this means that edge features are not included in the convolutions (which is not desirable). Any pointers on what could be wrong and how I can fix this issue?
Thanks in advance.
Metadata
Metadata
Assignees
Labels
No labels