Skip to content

Commit f72768d

Browse files
Generated .md and .ipynb files for GAT
1 parent 05307c1 commit f72768d

File tree

3 files changed

+604
-365
lines changed

3 files changed

+604
-365
lines changed

examples/graph/gat_node_classification.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535

3636
import os
3737

38-
3938
os.environ["KERAS_BACKEND"] = "tensorflow"
4039

4140
import keras
@@ -220,7 +219,7 @@ def call(self, inputs):
220219

221220
# Broadcast sum back to edges to normalize
222221
attention_sum_per_edge = ops.take(attention_sum, target_indices, axis=0)
223-
attention_norm = attention_scores / (attention_sum_per_edge + 1e-8)
222+
attention_norm = attention_scores / ops.maximum(attention_sum_per_edge, 1e-8)
224223

225224
node_states_neighbors = ops.take(z, source_indices, axis=0)
226225
weighted_neighbors = node_states_neighbors * ops.expand_dims(
@@ -255,8 +254,8 @@ def call(self, inputs):
255254
### Implement the Graph Attention Network
256255
257256
The GAT model operates on the entire graph (both node_states and edges) during all phases.
258-
To maintain backend agnosticism and leverage Keras 3's built-in training optimizations,
259-
we store the graph data as internal tensors and design the call method to accept
257+
To maintain backend agnosticism and leverage Keras 3's built-in training optimizations,
258+
we store the graph data as internal tensors and design the call method to accept
260259
the target node indices as its primary input.
261260
"""
262261

0 commit comments

Comments
 (0)