Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 86 additions & 123 deletions examples/graph/gat_node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
Title: Graph attention network (GAT) for node classification
Author: [akensert](https://github.com/akensert)
Date created: 2021/09/13
Last modified: 2021/12/26
Last modified: 2026/02/17
Description: An implementation of a Graph Attention Network (GAT) for node classification.
Accelerator: GPU
Converted to Keras 3 by: [LakshmiKalaKadali](https://github.com/LakshmiKalaKadali)
"""

"""
Expand Down Expand Up @@ -32,18 +33,23 @@
### Import packages
"""

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import os


os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
from keras import layers
from keras import ops
import numpy as np
import pandas as pd
import os
import warnings

warnings.filterwarnings("ignore")
pd.set_option("display.max_columns", 6)
pd.set_option("display.max_rows", 6)
np.random.seed(2)

keras.utils.set_random_seed(2)

"""
## Obtain the dataset
Expand All @@ -56,21 +62,20 @@
of seven labels (the *subject* of the paper).
"""


zip_file = keras.utils.get_file(
fname="cora.tgz",
origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",
extract=True,
)

data_dir = os.path.join(os.path.dirname(zip_file), "cora")
data_dir = os.path.join(zip_file, "cora")

citations = pd.read_csv(
os.path.join(data_dir, "cora.cites"),
sep="\t",
header=None,
names=["target", "source"],
)

papers = pd.read_csv(
os.path.join(data_dir, "cora.content"),
sep="\t",
Expand Down Expand Up @@ -116,10 +121,11 @@
test_labels = test_data["subject"].to_numpy()

# Define graph, namely an edge tensor and a node feature tensor
edges = tf.convert_to_tensor(citations[["target", "source"]])
node_states = tf.convert_to_tensor(papers.sort_values("paper_id").iloc[:, 1:-1])
edges = ops.convert_to_tensor(citations[["target", "source"]].to_numpy(), dtype="int32")
node_states = ops.convert_to_tensor(
papers.sort_values("paper_id").iloc[:, 1:-1].to_numpy(), dtype="float32"
)

# Print shapes of the graph
print("Edges shape:\t\t", edges.shape)
print("Node features shape:", node_states.shape)

Expand Down Expand Up @@ -193,39 +199,37 @@ def build(self, input_shape):
def call(self, inputs):
node_states, edges = inputs

# Linearly transform node states
node_states_transformed = tf.matmul(node_states, self.kernel)
z = ops.matmul(node_states, self.kernel)

# (1) Compute pair-wise attention scores
node_states_expanded = tf.gather(node_states_transformed, edges)
node_states_expanded = tf.reshape(
node_states_expanded, (tf.shape(edges)[0], -1)
)
attention_scores = tf.nn.leaky_relu(
tf.matmul(node_states_expanded, self.kernel_attention)
)
attention_scores = tf.squeeze(attention_scores, -1)

# (2) Normalize attention scores
attention_scores = tf.math.exp(tf.clip_by_value(attention_scores, -2, 2))
attention_scores_sum = tf.math.unsorted_segment_sum(
data=attention_scores,
segment_ids=edges[:, 0],
num_segments=tf.reduce_max(edges[:, 0]) + 1,
source_indices = edges[:, 1]
target_indices = edges[:, 0]

z_target = ops.take(z, target_indices, axis=0)
z_source = ops.take(z, source_indices, axis=0)

z_concat = ops.concatenate([z_target, z_source], axis=-1)
attention_scores = ops.leaky_relu(ops.matmul(z_concat, self.kernel_attention))
attention_scores = ops.squeeze(attention_scores, -1)

attention_scores = ops.exp(ops.clip(attention_scores, -2, 2))

num_nodes = ops.shape(node_states)[0]
attention_sum = ops.segment_sum(
attention_scores, target_indices, num_segments=num_nodes
)
attention_scores_sum = tf.repeat(
attention_scores_sum, tf.math.bincount(tf.cast(edges[:, 0], "int32"))

# Broadcast sum back to edges to normalize
attention_sum_per_edge = ops.take(attention_sum, target_indices, axis=0)
attention_norm = attention_scores / (attention_sum_per_edge + 1e-8)

node_states_neighbors = ops.take(z, source_indices, axis=0)
weighted_neighbors = node_states_neighbors * ops.expand_dims(
attention_norm, axis=-1
)
attention_scores_norm = attention_scores / attention_scores_sum

# (3) Gather node states of neighbors, apply attention scores and aggregate
node_states_neighbors = tf.gather(node_states_transformed, edges[:, 1])
out = tf.math.unsorted_segment_sum(
data=node_states_neighbors * attention_scores_norm[:, tf.newaxis],
segment_ids=edges[:, 0],
num_segments=tf.shape(node_states)[0],

return ops.segment_sum(
weighted_neighbors, target_indices, num_segments=num_nodes
)
return out


class MultiHeadGraphAttention(layers.Layer):
Expand All @@ -236,31 +240,24 @@ def __init__(self, units, num_heads=8, merge_type="concat", **kwargs):
self.attention_layers = [GraphAttention(units) for _ in range(num_heads)]

def call(self, inputs):
atom_features, pair_indices = inputs
node_states, edges = inputs
outputs = [layer([node_states, edges]) for layer in self.attention_layers]

# Obtain outputs from each attention head
outputs = [
attention_layer([atom_features, pair_indices])
for attention_layer in self.attention_layers
]
# Concatenate or average the node states from each head
if self.merge_type == "concat":
outputs = tf.concat(outputs, axis=-1)
outputs = ops.concatenate(outputs, axis=-1)
else:
outputs = tf.reduce_mean(tf.stack(outputs, axis=-1), axis=-1)
# Activate and return node states
return tf.nn.relu(outputs)
outputs = ops.mean(ops.stack(outputs, axis=0), axis=0)

return ops.relu(outputs)

"""
### Implement training logic with custom `train_step`, `test_step`, and `predict_step` methods

Notice, the GAT model operates on the entire graph (namely, `node_states` and
`edges`) in all phases (training, validation and testing). Hence, `node_states` and
`edges` are passed to the constructor of the `keras.Model` and used as attributes.
The difference between the phases are the indices (and labels), which gathers
certain outputs (`tf.gather(outputs, indices)`).
"""
### Implement the Graph Attention Network

The GAT model operates on the entire graph (both node_states and edges) during all phases.
To maintain backend agnosticism and leverage Keras 3's built-in training optimizations,
we store the graph data as internal tensors and design the call method to accept
the target node indices as its primary input.
"""


Expand All @@ -284,103 +281,69 @@ def __init__(
]
self.output_layer = layers.Dense(output_dim)

def call(self, inputs):
node_states, edges = inputs
x = self.preprocess(node_states)
def call(self, inputs, training=False):
# inputs here are the indices of nodes we want predictions for
indices = inputs

x = self.preprocess(self.node_states)
for attention_layer in self.attention_layers:
x = attention_layer([x, edges]) + x
x = attention_layer([x, self.edges]) + x

# Return only the requested node states
outputs = self.output_layer(x)
return outputs

def train_step(self, data):
indices, labels = data

with tf.GradientTape() as tape:
# Forward pass
outputs = self([self.node_states, self.edges])
# Compute loss
loss = self.compiled_loss(labels, tf.gather(outputs, indices))
# Compute gradients
grads = tape.gradient(loss, self.trainable_weights)
# Apply gradients (update weights)
optimizer.apply_gradients(zip(grads, self.trainable_weights))
# Update metric(s)
self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))

return {m.name: m.result() for m in self.metrics}

def predict_step(self, data):
indices = data
# Forward pass
outputs = self([self.node_states, self.edges])
# Compute probabilities
return tf.nn.softmax(tf.gather(outputs, indices))

def test_step(self, data):
indices, labels = data
# Forward pass
outputs = self([self.node_states, self.edges])
# Compute loss
loss = self.compiled_loss(labels, tf.gather(outputs, indices))
# Update metric(s)
self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))

return {m.name: m.result() for m in self.metrics}
return ops.take(outputs, indices, axis=0)


"""
### Train and evaluate
"""

# Define hyper-parameters
HIDDEN_UNITS = 100
NUM_HEADS = 8
NUM_LAYERS = 3
OUTPUT_DIM = len(class_values)

NUM_EPOCHS = 100
BATCH_SIZE = 256
VALIDATION_SPLIT = 0.1
LEARNING_RATE = 3e-1
MOMENTUM = 0.9

loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = keras.optimizers.SGD(LEARNING_RATE, momentum=MOMENTUM)
accuracy_fn = keras.metrics.SparseCategoricalAccuracy(name="acc")
early_stopping = keras.callbacks.EarlyStopping(
monitor="val_acc", min_delta=1e-5, patience=5, restore_best_weights=True
)

# Build model
# Build and compile model
gat_model = GraphAttentionNetwork(
node_states, edges, HIDDEN_UNITS, NUM_HEADS, NUM_LAYERS, OUTPUT_DIM
)

# Compile model
gat_model.compile(loss=loss_fn, optimizer=optimizer, metrics=[accuracy_fn])
gat_model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.SGD(learning_rate=0.003, momentum=0.9),
metrics=["accuracy"],
)

gat_model.fit(
x=train_indices,
y=train_labels,
validation_split=VALIDATION_SPLIT,
batch_size=BATCH_SIZE,
epochs=NUM_EPOCHS,
callbacks=[early_stopping],
validation_split=0.1,
batch_size=256,
epochs=100,
callbacks=[
keras.callbacks.EarlyStopping(
monitor="val_accuracy", patience=5, restore_best_weights=True
)
],
verbose=2,
)

_, test_accuracy = gat_model.evaluate(x=test_indices, y=test_labels, verbose=0)

print("--" * 38 + f"\nTest Accuracy {test_accuracy*100:.1f}%")


"""
### Predict (probabilities)
"""
test_probs = gat_model.predict(x=test_indices)
test_logits = gat_model.predict(x=test_indices)

test_probs = ops.softmax(test_logits)

test_probs_np = ops.convert_to_numpy(test_probs)

mapping = {v: k for (k, v) in class_idx.items()}

for i, (probs, label) in enumerate(zip(test_probs[:10], test_labels[:10])):
for i, (probs, label) in enumerate(zip(test_probs_np[:10], test_labels[:10])):
print(f"Example {i+1}: {mapping[label]}")
for j, c in zip(probs, class_idx.keys()):
print(f"\tProbability of {c: <24} = {j*100:7.3f}%")
Expand Down
Loading