diff --git a/docs/refs.bib b/docs/refs.bib index fe7fb834e..fcfff2906 100644 --- a/docs/refs.bib +++ b/docs/refs.bib @@ -124,3 +124,15 @@ @book{higham2022accuracy url = {https://epubs.siam.org/doi/abs/10.1137/1.9780898718027}, eprint = {https://epubs.siam.org/doi/pdf/10.1137/1.9780898718027} } + +@inproceedings{NIPS2007_d045c59a, + author = {Yu, Kai and Chu, Wei}, + booktitle = {Advances in Neural Information Processing Systems}, + editor = {J. Platt and D. Koller and Y. Singer and S. Roweis}, + pages = {}, + publisher = {Curran Associates, Inc.}, + title = {Gaussian Process Models for Link Analysis and Transfer Learning}, + url = {https://proceedings.neurips.cc/paper_files/paper/2007/file/d045c59a90d7587d8d671b5f5aec4e7c-Paper.pdf}, + volume = {20}, + year = {2007} +} diff --git a/examples/graph_edge_kernels.py b/examples/graph_edge_kernels.py new file mode 100644 index 000000000..e90134abb --- /dev/null +++ b/examples/graph_edge_kernels.py @@ -0,0 +1,142 @@ +# %% [markdown] +# # Graph Edge Kernels — medium random graph (~2000 edges) + +# %% +import random + +import jax.numpy as jnp +import jax.random as jr +import matplotlib.pyplot as plt +import networkx as nx +import numpy as np +import optax as ox +from sklearn.metrics import ( + accuracy_score, + f1_score, + roc_auc_score, + roc_curve, +) +from sklearn.model_selection import train_test_split + +import gpjax as gpx +from gpjax.kernels.non_euclidean.graph_edge import GraphEdgeKernel +from gpjax.parameters import Parameter + +# %% [markdown] +# ## Configuration +# %% +SEED = 123 +np.random.seed(SEED) +random.seed(SEED) +key = jr.key(42) + +# %% [markdown] +# ## Construct medium-sized random graph (~2000 edges) +# %% +n_nodes = 150 +target_edges = 3000 +p_edge = target_edges / (n_nodes * (n_nodes - 1) / 2) +G = nx.erdos_renyi_graph(n_nodes, p_edge, seed=SEED) + +while G.number_of_edges() > target_edges: + u, v = random.choice(list(G.edges())) + G.remove_edge(u, v) + +print(f"Graph has {G.number_of_nodes()} nodes and {G.number_of_edges()} edges") +# %% +pos = nx.spring_layout(G, seed=SEED) +plt.figure(figsize=(6, 5)) +nx.draw(G, pos, node_size=40, edge_color="black", with_labels=False) +plt.title( + f"Random graph with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges" +) +plt.show() + +# %% [markdown] +# ## Node features +# %% +node_feature_dim = 20 +node_feature_matrix = np.random.uniform( + low=0.5, high=13.3, size=(n_nodes, node_feature_dim) +).astype(np.float64) + +# %% [markdown] +# ## Prepare edges and labels +# %% +edge_list = jnp.array(G.edges).astype(jnp.int64) +num_edges = edge_list.shape[0] +pos_frac = 0.5 +n_pos = int(pos_frac * num_edges) +labels = np.array([1] * n_pos + [0] * (num_edges - n_pos), dtype=np.float64) +np.random.shuffle(labels) +labels_jnp = jnp.array(labels).reshape(-1, 1).astype(jnp.float64) +# %% [markdown] +# ## Train / Test split +# %% +edge_idx = np.arange(num_edges) +train_idx, test_idx = train_test_split( + edge_idx, test_size=0.2, random_state=SEED, stratify=labels +) + +edge_train = edge_list[train_idx] +edge_test = edge_list[test_idx] +y_train = labels_jnp[train_idx] +y_test = labels_jnp[test_idx] + +print(f"Training edges: {len(train_idx)}, Test edges: {len(test_idx)}") + +# %% [markdown] +# ## Model definition +# %% +base_kernel = gpx.kernels.RBF() +graph_kernel = GraphEdgeKernel(feature_mat=node_feature_matrix, base_kernel=base_kernel) +meanf = gpx.mean_functions.Constant() +prior = gpx.gps.Prior(mean_function=meanf, kernel=graph_kernel) +likelihood = gpx.likelihoods.Bernoulli(num_datapoints=len(train_idx)) +posterior = prior * likelihood + +# %% [markdown] +# ## Train model +# %% +D_train = gpx.Dataset(X=jnp.array(edge_train), y=y_train) +D_test = gpx.Dataset(X=jnp.array(edge_test), y=y_test) + +optimiser = ox.adamw(learning_rate=0.1) +num_iters = 2000 + +opt_posterior, history = gpx.fit( + model=posterior, + objective=lambda p, d: -gpx.objectives.log_posterior_density(p, d), + train_data=D_train, + optim=optimiser, + num_iters=num_iters, + key=key, + trainable=Parameter, +) + +# %% [markdown] +# ## Predictions on test edges +# %% +pred_dist = opt_posterior.likelihood(opt_posterior(edge_test, D_train)) +pred_mean = pred_dist.mean + +y_prob_np = np.array(pred_mean) +y_test_np = np.array(y_test) + +pred_labels = jnp.where(y_prob_np > 0.5, 1, 0) +auc = roc_auc_score(y_test_np, y_prob_np) +acc = accuracy_score(y_test_np, pred_labels) +f1 = f1_score(y_test_np, pred_labels) +print(f"Test ROC-AUC: {auc:.4f}, Accuracy: {acc:.4f}, F1: {f1:.4f}") + +# %% [markdown] +# ## Plot ROC curve +# %% +fpr, tpr, _ = roc_curve(y_test_np, y_prob_np) +plt.figure(figsize=(5, 4)) +plt.plot(fpr, tpr) +plt.xlabel("False Positive Rate") +plt.ylabel("True Positive Rate") +plt.title("ROC Curve on Test Edges") +plt.grid(True) +plt.show() diff --git a/gpjax/kernels/__init__.py b/gpjax/kernels/__init__.py index 3844ebfda..dc399d5c7 100644 --- a/gpjax/kernels/__init__.py +++ b/gpjax/kernels/__init__.py @@ -30,7 +30,10 @@ DiagonalKernelComputation, EigenKernelComputation, ) -from gpjax.kernels.non_euclidean import GraphKernel +from gpjax.kernels.non_euclidean import ( + GraphEdgeKernel, + GraphKernel, +) from gpjax.kernels.nonstationary import ( ArcCosine, Linear, @@ -53,6 +56,7 @@ "Constant", "RBF", "GraphKernel", + "GraphEdgeKernel", "Matern12", "Matern32", "Matern52", diff --git a/gpjax/kernels/computations/__init__.py b/gpjax/kernels/computations/__init__.py index 880433bba..595c7a959 100644 --- a/gpjax/kernels/computations/__init__.py +++ b/gpjax/kernels/computations/__init__.py @@ -21,6 +21,7 @@ from gpjax.kernels.computations.dense import DenseKernelComputation from gpjax.kernels.computations.diagonal import DiagonalKernelComputation from gpjax.kernels.computations.eigen import EigenKernelComputation +from gpjax.kernels.computations.graph_edge import GraphEdgeKernelComputation __all__ = [ "AbstractKernelComputation", @@ -29,4 +30,5 @@ "DenseKernelComputation", "DiagonalKernelComputation", "EigenKernelComputation", + "GraphEdgeKernelComputation", ] diff --git a/gpjax/kernels/computations/graph_edge.py b/gpjax/kernels/computations/graph_edge.py new file mode 100644 index 000000000..72b58bcd6 --- /dev/null +++ b/gpjax/kernels/computations/graph_edge.py @@ -0,0 +1,38 @@ +# Copyright 2025 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import beartype.typing as tp +from jaxtyping import ( + Float, + Int, +) + +import gpjax # noqa: F401 +from gpjax.kernels.computations.base import AbstractKernelComputation +from gpjax.typing import Array + +K = tp.TypeVar("K", bound="gpjax.kernels.base.AbstractKernel") # noqa: F821 + + +class GraphEdgeKernelComputation(AbstractKernelComputation): + r"""Dense kernel computation class. Operations with the kernel assume + a dense gram matrix structure. + """ + + def cross_covariance( + self, kernel: K, x: Int[Array, "N D"], y: Int[Array, "M D"] + ) -> Float[Array, "N M"]: + cross_cov = kernel(x, y) + return cross_cov diff --git a/gpjax/kernels/non_euclidean/__init__.py b/gpjax/kernels/non_euclidean/__init__.py index d364bc71b..4dabf6a06 100644 --- a/gpjax/kernels/non_euclidean/__init__.py +++ b/gpjax/kernels/non_euclidean/__init__.py @@ -14,5 +14,6 @@ # ============================================================================== from gpjax.kernels.non_euclidean.graph import GraphKernel +from gpjax.kernels.non_euclidean.graph_edge import GraphEdgeKernel -__all__ = ["GraphKernel"] +__all__ = ["GraphKernel", "GraphEdgeKernel"] diff --git a/gpjax/kernels/non_euclidean/graph_edge.py b/gpjax/kernels/non_euclidean/graph_edge.py new file mode 100644 index 000000000..69e956343 --- /dev/null +++ b/gpjax/kernels/non_euclidean/graph_edge.py @@ -0,0 +1,88 @@ +import beartype.typing as tp +import jax.numpy as jnp +from jaxtyping import ( + Int, + Num, +) + +from gpjax.kernels.base import AbstractKernel +from gpjax.kernels.computations import ( + AbstractKernelComputation, + GraphEdgeKernelComputation, +) +from gpjax.typing import ( + Array, +) + + +# Stationary kernels are a class of kernels that are invariant to translations in the input space. +class GraphEdgeKernel(AbstractKernel): + r"""The Edge graph kernel defined on the edge set of a graph. + The kernel is an implementation of Kai Yu et al 2008 + https://papers.nips.cc/paper_files/paper/2007/hash/d045c59a90d7587d8d671b5f5aec4e7c-Abstract.html + + Directed Graphs: K ((i, j), (i', j')) = 〈 xi ⊗ xj, xi' ⊗ xj' 〉 + Undirected Graphs: K ((i, j), (i', j')) = 〈 xi ⊗ xj, xi' ⊗ xj' 〉 + 〈xi ⊗ xj, xj' ⊗ xi's 〉 + Bipartite Graphs: K ((i, j), (i′, j′)) = 〈 xi ⊗ zj, xi′ ⊗ zj′ 〉 + + """ + + name: str = "Graph Matérn" + + def __init__( + self, + base_kernel: AbstractKernel, + feature_mat: Num[Array, "N M"], + directed=False, + active_dims: tp.Union[list[int], slice, None] = None, + n_dims: tp.Union[int, None] = None, + compute_engine: AbstractKernelComputation = GraphEdgeKernelComputation(), + ): + """Initializes the kernel. + + Args: + base_kernel: the node feature matrix of size (number of nodes, number of features) + directed: True or false for directionality of graph edges + active_dims: The indices of the input dimensions that the kernel operates on. + compute_engine: The computation engine that the kernel uses to compute the + covariance matrix. + """ + + self.base_kernel = base_kernel + self.dense_feature_mat = feature_mat + self.directed = directed + + super().__init__(active_dims, n_dims, compute_engine) + + def __call__( # TODO not consistent with general kernel interface + self, + X: Int[Array, "N 2"], + y: Int[Array, "M 2"], + *, + S=None, + **kwargs, + ): + r""" + :param sender: Specifies the sending node indices for the edge in the batch. Shape + [B, 2]. + :param reciever: Specifies the recieving node indices for each edge in the batch. Shape + [B', 2]. + :return: + Kernel + """ + + sender, reciever = X[:, 0], X[:, 1] + sender_test, reciever_test = y[:, 0], y[:, 1] + + cov = self.base_kernel.gram(self.dense_feature_mat).to_dense() + + cov_edges = jnp.take( + jnp.take(cov, sender, axis=0), sender_test, axis=1 + ) * jnp.take(jnp.take(cov, reciever, axis=0), reciever_test, axis=1) + + if not self.directed: + cov_edges += jnp.take( + jnp.take(cov, sender, axis=0), reciever_test, axis=1 + ) * jnp.take(jnp.take(cov, reciever, axis=0), sender_test, axis=1) + + return cov_edges.squeeze() diff --git a/mkdocs.yml b/mkdocs.yml index 6629b1d1a..e30eef1f4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -24,6 +24,7 @@ nav: - Barycentres: _examples/barycentres.md - Deep kernel learning: _examples/deep_kernels.md - Graph kernels: _examples/graph_kernels.md + - Graph Edge kernels: _examples/graph_edge_kernels.md - Sparse GPs: _examples/collapsed_vi.md - Stochastic sparse GPs: _examples/uncollapsed_vi.md - Multi-output GPs for Ocean Modelling: _examples/oceanmodelling.md diff --git a/tests/test_kernels/test_non_euclidean.py b/tests/test_kernels/test_non_euclidean.py index a4f12e609..019d1d8ab 100644 --- a/tests/test_kernels/test_non_euclidean.py +++ b/tests/test_kernels/test_non_euclidean.py @@ -13,8 +13,13 @@ from jax import config import jax.numpy as jnp import networkx as nx +import numpy as np -from gpjax.kernels.non_euclidean import GraphKernel +from gpjax.kernels.non_euclidean import ( + GraphEdgeKernel, + GraphKernel, +) +from gpjax.kernels.stationary import RBF from gpjax.linalg.operators import Identity # # Enable Float64 for more stable matrix inversions. @@ -48,3 +53,29 @@ def test_graph_kernel(): Kxx += Identity(Kxx.shape[0]) * 1e-6 eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) assert all(eigen_values > 0) + + +def test_graph_edge_kernel(): + # Create a random graph, G, and verice labels, x, + n_verticies = 20 + n_edges = 40 + G = nx.gnm_random_graph(n_verticies, n_edges, seed=123) + edge_indices = jnp.array(G.edges).astype(jnp.int64) + + # Create graph kernel + kern = GraphEdgeKernel( + base_kernel=RBF(), + feature_mat=jnp.array( + np.random.uniform(low=0.5, high=13.3, size=(n_verticies, 5)) + ), + ) + assert isinstance(kern, GraphEdgeKernel) + + # Compute gram matrix + Kxx = kern.gram(edge_indices) + assert Kxx.shape == (n_edges, n_edges) + + # Check positive definiteness + Kxx += Identity(Kxx.shape[0]) * 1e-6 + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert all(eigen_values > 0)