Skip to content
12 changes: 12 additions & 0 deletions docs/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,15 @@ @book{higham2022accuracy
url = {https://epubs.siam.org/doi/abs/10.1137/1.9780898718027},
eprint = {https://epubs.siam.org/doi/pdf/10.1137/1.9780898718027}
}

@inproceedings{NIPS2007_d045c59a,
author = {Yu, Kai and Chu, Wei},
booktitle = {Advances in Neural Information Processing Systems},
editor = {J. Platt and D. Koller and Y. Singer and S. Roweis},
pages = {},
publisher = {Curran Associates, Inc.},
title = {Gaussian Process Models for Link Analysis and Transfer Learning},
url = {https://proceedings.neurips.cc/paper_files/paper/2007/file/d045c59a90d7587d8d671b5f5aec4e7c-Paper.pdf},
volume = {20},
year = {2007}
}
142 changes: 142 additions & 0 deletions examples/graph_edge_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# %% [markdown]
# # Graph Edge Kernels — medium random graph (~2000 edges)

# %%
import random

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import optax as ox
from sklearn.metrics import (
accuracy_score,
f1_score,
roc_auc_score,
roc_curve,
)
from sklearn.model_selection import train_test_split

import gpjax as gpx
from gpjax.kernels.non_euclidean.graph_edge import GraphEdgeKernel
from gpjax.parameters import Parameter

# %% [markdown]
# ## Configuration
# %%
SEED = 123
np.random.seed(SEED)
random.seed(SEED)
key = jr.key(42)

# %% [markdown]
# ## Construct medium-sized random graph (~2000 edges)
# %%
n_nodes = 150
target_edges = 3000
p_edge = target_edges / (n_nodes * (n_nodes - 1) / 2)
G = nx.erdos_renyi_graph(n_nodes, p_edge, seed=SEED)

while G.number_of_edges() > target_edges:
u, v = random.choice(list(G.edges()))
G.remove_edge(u, v)

print(f"Graph has {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")
# %%
pos = nx.spring_layout(G, seed=SEED)
plt.figure(figsize=(6, 5))
nx.draw(G, pos, node_size=40, edge_color="black", with_labels=False)
plt.title(
f"Random graph with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges"
)
plt.show()

# %% [markdown]
# ## Node features
# %%
node_feature_dim = 20
node_feature_matrix = np.random.uniform(
low=0.5, high=13.3, size=(n_nodes, node_feature_dim)
).astype(np.float64)

# %% [markdown]
# ## Prepare edges and labels
# %%
edge_list = jnp.array(G.edges).astype(jnp.int64)
num_edges = edge_list.shape[0]
pos_frac = 0.5
n_pos = int(pos_frac * num_edges)
labels = np.array([1] * n_pos + [0] * (num_edges - n_pos), dtype=np.float64)
np.random.shuffle(labels)
labels_jnp = jnp.array(labels).reshape(-1, 1).astype(jnp.float64)
# %% [markdown]
# ## Train / Test split
# %%
edge_idx = np.arange(num_edges)
train_idx, test_idx = train_test_split(
edge_idx, test_size=0.2, random_state=SEED, stratify=labels
)

edge_train = edge_list[train_idx]
edge_test = edge_list[test_idx]
y_train = labels_jnp[train_idx]
y_test = labels_jnp[test_idx]

print(f"Training edges: {len(train_idx)}, Test edges: {len(test_idx)}")

# %% [markdown]
# ## Model definition
# %%
base_kernel = gpx.kernels.RBF()
graph_kernel = GraphEdgeKernel(feature_mat=node_feature_matrix, base_kernel=base_kernel)
meanf = gpx.mean_functions.Constant()
prior = gpx.gps.Prior(mean_function=meanf, kernel=graph_kernel)
likelihood = gpx.likelihoods.Bernoulli(num_datapoints=len(train_idx))
posterior = prior * likelihood

# %% [markdown]
# ## Train model
# %%
D_train = gpx.Dataset(X=jnp.array(edge_train), y=y_train)
D_test = gpx.Dataset(X=jnp.array(edge_test), y=y_test)

optimiser = ox.adamw(learning_rate=0.1)
num_iters = 2000

opt_posterior, history = gpx.fit(
model=posterior,
objective=lambda p, d: -gpx.objectives.log_posterior_density(p, d),
train_data=D_train,
optim=optimiser,
num_iters=num_iters,
key=key,
trainable=Parameter,
)

# %% [markdown]
# ## Predictions on test edges
# %%
pred_dist = opt_posterior.likelihood(opt_posterior(edge_test, D_train))
pred_mean = pred_dist.mean

y_prob_np = np.array(pred_mean)
y_test_np = np.array(y_test)

pred_labels = jnp.where(y_prob_np > 0.5, 1, 0)
auc = roc_auc_score(y_test_np, y_prob_np)
acc = accuracy_score(y_test_np, pred_labels)
f1 = f1_score(y_test_np, pred_labels)
print(f"Test ROC-AUC: {auc:.4f}, Accuracy: {acc:.4f}, F1: {f1:.4f}")

# %% [markdown]
# ## Plot ROC curve
# %%
fpr, tpr, _ = roc_curve(y_test_np, y_prob_np)
plt.figure(figsize=(5, 4))
plt.plot(fpr, tpr)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve on Test Edges")
plt.grid(True)
plt.show()
6 changes: 5 additions & 1 deletion gpjax/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
DiagonalKernelComputation,
EigenKernelComputation,
)
from gpjax.kernels.non_euclidean import GraphKernel
from gpjax.kernels.non_euclidean import (
GraphEdgeKernel,
GraphKernel,
)
from gpjax.kernels.nonstationary import (
ArcCosine,
Linear,
Expand All @@ -53,6 +56,7 @@
"Constant",
"RBF",
"GraphKernel",
"GraphEdgeKernel",
"Matern12",
"Matern32",
"Matern52",
Expand Down
2 changes: 2 additions & 0 deletions gpjax/kernels/computations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from gpjax.kernels.computations.dense import DenseKernelComputation
from gpjax.kernels.computations.diagonal import DiagonalKernelComputation
from gpjax.kernels.computations.eigen import EigenKernelComputation
from gpjax.kernels.computations.graph_edge import GraphEdgeKernelComputation

__all__ = [
"AbstractKernelComputation",
Expand All @@ -29,4 +30,5 @@
"DenseKernelComputation",
"DiagonalKernelComputation",
"EigenKernelComputation",
"GraphEdgeKernelComputation",
]
38 changes: 38 additions & 0 deletions gpjax/kernels/computations/graph_edge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2025 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import beartype.typing as tp
from jaxtyping import (
Float,
Int,
)

import gpjax # noqa: F401
from gpjax.kernels.computations.base import AbstractKernelComputation
from gpjax.typing import Array

K = tp.TypeVar("K", bound="gpjax.kernels.base.AbstractKernel") # noqa: F821


class GraphEdgeKernelComputation(AbstractKernelComputation):
r"""Dense kernel computation class. Operations with the kernel assume
a dense gram matrix structure.
"""

def cross_covariance(
self, kernel: K, x: Int[Array, "N D"], y: Int[Array, "M D"]
) -> Float[Array, "N M"]:
cross_cov = kernel(x, y)
return cross_cov
3 changes: 2 additions & 1 deletion gpjax/kernels/non_euclidean/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@
# ==============================================================================

from gpjax.kernels.non_euclidean.graph import GraphKernel
from gpjax.kernels.non_euclidean.graph_edge import GraphEdgeKernel

__all__ = ["GraphKernel"]
__all__ = ["GraphKernel", "GraphEdgeKernel"]
88 changes: 88 additions & 0 deletions gpjax/kernels/non_euclidean/graph_edge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import beartype.typing as tp
import jax.numpy as jnp
from jaxtyping import (
Int,
Num,
)

from gpjax.kernels.base import AbstractKernel
from gpjax.kernels.computations import (
AbstractKernelComputation,
GraphEdgeKernelComputation,
)
from gpjax.typing import (
Array,
)


# Stationary kernels are a class of kernels that are invariant to translations in the input space.
class GraphEdgeKernel(AbstractKernel):
r"""The Edge graph kernel defined on the edge set of a graph.
The kernel is an implementation of Kai Yu et al 2008
https://papers.nips.cc/paper_files/paper/2007/hash/d045c59a90d7587d8d671b5f5aec4e7c-Abstract.html

Directed Graphs: K ((i, j), (i', j')) = 〈 xi ⊗ xj, xi' ⊗ xj' 〉
Undirected Graphs: K ((i, j), (i', j')) = 〈 xi ⊗ xj, xi' ⊗ xj' 〉 + 〈xi ⊗ xj, xj' ⊗ xi's 〉
Bipartite Graphs: K ((i, j), (i′, j′)) = 〈 xi ⊗ zj, xi′ ⊗ zj′ 〉

"""

name: str = "Graph Matérn"

def __init__(
self,
base_kernel: AbstractKernel,
feature_mat: Num[Array, "N M"],
directed=False,
active_dims: tp.Union[list[int], slice, None] = None,
n_dims: tp.Union[int, None] = None,
compute_engine: AbstractKernelComputation = GraphEdgeKernelComputation(),
):
"""Initializes the kernel.

Args:
base_kernel: the node feature matrix of size (number of nodes, number of features)
directed: True or false for directionality of graph edges
active_dims: The indices of the input dimensions that the kernel operates on.
compute_engine: The computation engine that the kernel uses to compute the
covariance matrix.
"""

self.base_kernel = base_kernel
self.dense_feature_mat = feature_mat
self.directed = directed

super().__init__(active_dims, n_dims, compute_engine)

def __call__( # TODO not consistent with general kernel interface
self,
X: Int[Array, "N 2"],
y: Int[Array, "M 2"],
*,
S=None,
**kwargs,
):
r"""
:param sender: Specifies the sending node indices for the edge in the batch. Shape
[B, 2].
:param reciever: Specifies the recieving node indices for each edge in the batch. Shape
[B', 2].
:return:
Kernel
"""

sender, reciever = X[:, 0], X[:, 1]
sender_test, reciever_test = y[:, 0], y[:, 1]

cov = self.base_kernel.gram(self.dense_feature_mat).to_dense()

cov_edges = jnp.take(
jnp.take(cov, sender, axis=0), sender_test, axis=1
) * jnp.take(jnp.take(cov, reciever, axis=0), reciever_test, axis=1)

if not self.directed:
cov_edges += jnp.take(
jnp.take(cov, sender, axis=0), reciever_test, axis=1
) * jnp.take(jnp.take(cov, reciever, axis=0), sender_test, axis=1)

return cov_edges.squeeze()
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ nav:
- Barycentres: _examples/barycentres.md
- Deep kernel learning: _examples/deep_kernels.md
- Graph kernels: _examples/graph_kernels.md
- Graph Edge kernels: _examples/graph_edge_kernels.md
- Sparse GPs: _examples/collapsed_vi.md
- Stochastic sparse GPs: _examples/uncollapsed_vi.md
- Multi-output GPs for Ocean Modelling: _examples/oceanmodelling.md
Expand Down
33 changes: 32 additions & 1 deletion tests/test_kernels/test_non_euclidean.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@
from jax import config
import jax.numpy as jnp
import networkx as nx
import numpy as np

from gpjax.kernels.non_euclidean import GraphKernel
from gpjax.kernels.non_euclidean import (
GraphEdgeKernel,
GraphKernel,
)
from gpjax.kernels.stationary import RBF
from gpjax.linalg.operators import Identity

# # Enable Float64 for more stable matrix inversions.
Expand Down Expand Up @@ -48,3 +53,29 @@ def test_graph_kernel():
Kxx += Identity(Kxx.shape[0]) * 1e-6
eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense())
assert all(eigen_values > 0)


def test_graph_edge_kernel():
# Create a random graph, G, and verice labels, x,
n_verticies = 20
n_edges = 40
G = nx.gnm_random_graph(n_verticies, n_edges, seed=123)
edge_indices = jnp.array(G.edges).astype(jnp.int64)

# Create graph kernel
kern = GraphEdgeKernel(
base_kernel=RBF(),
feature_mat=jnp.array(
np.random.uniform(low=0.5, high=13.3, size=(n_verticies, 5))
),
)
assert isinstance(kern, GraphEdgeKernel)

# Compute gram matrix
Kxx = kern.gram(edge_indices)
assert Kxx.shape == (n_edges, n_edges)

# Check positive definiteness
Kxx += Identity(Kxx.shape[0]) * 1e-6
eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense())
assert all(eigen_values > 0)
Loading