diff --git a/openhgnn/config.ini b/openhgnn/config.ini index 8c4ec3a2..2b9a5262 100644 --- a/openhgnn/config.ini +++ b/openhgnn/config.ini @@ -1559,3 +1559,10 @@ mini_batch_flag = True emb_dim=20 hid_dim=64 batch_size=128 + + +[HGSketch] +K = 2 +R = 3 +D = 128 +seed = 0 diff --git a/openhgnn/config.py b/openhgnn/config.py index 4d80af3f..4b70b20f 100644 --- a/openhgnn/config.py +++ b/openhgnn/config.py @@ -41,6 +41,13 @@ def __init__(self, file_path, model, dataset, task, gpu): self.patience = conf.getint("General", "patience") self.mini_batch_flag = conf.getboolean("General", "mini_batch_flag") ############## add config.py ################# + elif self.model_name == 'HGSketch': + self.K = conf.getint("HGSketch", "K") + self.R = conf.getint("HGSketch", "R") + self.D = conf.getint("HGSketch", "D") + self.seed = conf.getint("HGSketch", "seed") + self.max_epoch = 1 # non-parametric, no iterative training + elif self.model_name == 'MHGCN': self.lr = conf.getfloat("MHGCN", "lr") self.weight_decay = conf.getfloat("MHGCN", "weight_decay") diff --git a/openhgnn/experiment.py b/openhgnn/experiment.py index 1b920c85..fc95b43f 100644 --- a/openhgnn/experiment.py +++ b/openhgnn/experiment.py @@ -80,6 +80,7 @@ class Experiment(object): 'Ingram': 'Ingram_trainer', 'DisenKGAT': 'DisenKGAT_trainer', ###################### add trainer_flow here。 【model name】:【register name】 + 'HGSketch': 'HGSketch_trainer', 'BPHGNN':'BPHGNN_trainer', 'HGPrompt':'HGPrompt_trainer', 'HGMAE':'HGMAE_trainer', diff --git a/openhgnn/models/HGSketch.py b/openhgnn/models/HGSketch.py new file mode 100644 index 00000000..ee207338 --- /dev/null +++ b/openhgnn/models/HGSketch.py @@ -0,0 +1,341 @@ +""" +HGSketch Model +============== +HGSketch maps heterogeneous graphs to low-dimensional Hamming space by extracting +simplicial complexes to capture higher-order structures and using Locality-Sensitive +Hashing (LSH) for ultra-fast dimensionality reduction. + +Steps: +1. Extract k-simplices and build Hodge Laplacian matrices L_k +2. Initialize heterogeneous features via one-hot encoding of node types +3. Build local information amplification operator M^(k) = L_k ⊙ L_k +4. Build global structure enhancement operator N^(k) = (M^(k))^2 +5. Iterated LSH: UPDATE -> TRANSFORM -> sgn binarization +6. Graph-level feature concatenation +7. Linearization for downstream linear classifiers +""" + +import torch +import numpy as np +import networkx as nx +from itertools import combinations +from scipy import sparse as sp + +from . import BaseModel, register_model + + +@register_model('HGSketch') +class HGSketch(BaseModel): + r""" + HGSketch model for heterogeneous graph-level representation. + + Parameters + ---------- + K : int + Maximum simplex dimension. + R : int + Number of LSH iterations. + D : int + Hash dimension (output dimension per iteration). + num_node_types : int + Number of node types in the heterogeneous graph. + seed : int + Random seed for reproducibility. + """ + + @classmethod + def build_model_from_args(cls, args, hg): + return cls( + K=args.K, + R=args.R, + D=args.D, + num_node_types=len(hg.ntypes), + seed=args.seed, + ) + + def __init__(self, K=2, R=3, D=128, num_node_types=1, seed=0): + super(HGSketch, self).__init__() + self.K = K + self.R = R + self.D = D + self.num_node_types = num_node_types + self.seed = seed + # Dummy parameter so PyTorch recognizes this as a module + self._dummy = torch.nn.Parameter(torch.empty(0), requires_grad=False) + + def forward(self, hg): + """ + Generate graph-level binary hash code for a single heterogeneous graph. + + Parameters + ---------- + hg : dgl.DGLHeteroGraph + A heterogeneous graph. + + Returns + ------- + x_g : np.ndarray + Graph-level binary feature vector. + """ + return self.compute_sketch(hg) + + @torch.no_grad() + def compute_sketch(self, hg): + """Core HGSketch pipeline for a single graph.""" + rng = np.random.RandomState(self.seed) + + # Convert to undirected NetworkX graph (homogeneous view) + nx_g = self._hg_to_nx(hg) + num_nodes = nx_g.number_of_nodes() + + if num_nodes == 0: + return np.zeros(0, dtype=np.float32) + + # Build node-type mapping: node_id -> type_index + node_type_map = self._build_node_type_map(hg) + + # Step 1: Extract simplices of dimension 0..K + simplices_by_dim = self._extract_simplices(nx_g, self.K) + + # Steps 1-6: For each dimension k, compute features + all_features = [] + for k in range(self.K + 1): + simplices_k = simplices_by_dim.get(k, []) + if len(simplices_k) == 0: + all_features.append(np.array([], dtype=np.float32)) + continue + + # Step 1: Build Hodge Laplacian L_k + L_k = self._build_hodge_laplacian(simplices_by_dim, k) + + # Step 2: Initialize heterogeneous features + H_in = self._init_hetero_features(simplices_k, node_type_map) + + # Step 3: Local information amplification M^(k) = L_k ⊙ L_k + M_k = L_k.multiply(L_k) # Hadamard product + + # Step 4: Global structure enhancement N^(k) = (M^(k))^2 + N_k = M_k.dot(M_k) + + # Step 5: Iterated LSH + for r in range(self.R): + # UPDATE: feature propagation with N^(k) + H_temp = self._update(H_in, N_k) + # TRANSFORM: random projection + W = rng.randn(H_temp.shape[1], self.D) + H_in = H_temp @ W + # Binarize with sign function + H_in = np.sign(H_in) + # Replace 0s with -1 (edge case when value is exactly 0) + H_in[H_in == 0] = 1.0 + + all_features.append(H_in.flatten()) + + # Step 6: Concatenate all dimensions + x_g = np.concatenate([f for f in all_features if f.size > 0]) + + return x_g + + def linearize(self, x_g): + """ + Step 7: Linearize binary features for linear classifiers. + Maps binary vector of length L to sparse vector of length 2L. + + Parameters + ---------- + x_g : np.ndarray + Binary feature vector with values in {-1, 1}. + + Returns + ------- + x_lin : np.ndarray + Linearized feature vector of length 2L. + """ + L = len(x_g) + if L == 0: + return np.zeros(0, dtype=np.float32) + # Indicator for +1 and -1 + pos = (x_g == 1).astype(np.float64) + neg = (x_g == -1).astype(np.float64) + x_lin = np.concatenate([pos, neg]) / np.sqrt(L) + return x_lin + + # ==================== Helper Methods ==================== + + def _hg_to_nx(self, hg): + """Convert DGL heterogeneous graph to undirected NetworkX graph.""" + nx_g = nx.Graph() + # Add all nodes + for ntype in hg.ntypes: + num = hg.num_nodes(ntype) + # Use global node IDs + start = self._get_node_offset(hg, ntype) + for i in range(num): + nx_g.add_node(start + i) + + # Add all edges (undirected) + for etype in hg.canonical_etypes: + src_type, _, dst_type = etype + src, dst = hg.edges(etype=etype) + src_offset = self._get_node_offset(hg, src_type) + dst_offset = self._get_node_offset(hg, dst_type) + for s, d in zip(src.numpy(), dst.numpy()): + u = src_offset + s + v = dst_offset + d + if u != v: + nx_g.add_edge(u, v) + return nx_g + + def _get_node_offset(self, hg, ntype): + """Get the global node ID offset for a given node type.""" + offset = 0 + for nt in hg.ntypes: + if nt == ntype: + return offset + offset += hg.num_nodes(nt) + return offset + + def _build_node_type_map(self, hg): + """Build a mapping from global node ID to node type index.""" + node_type_map = {} + type_idx = 0 + offset = 0 + for ntype in hg.ntypes: + num = hg.num_nodes(ntype) + for i in range(num): + node_type_map[offset + i] = type_idx + type_idx += 1 + offset += num + return node_type_map + + def _extract_simplices(self, nx_g, K): + """ + Extract k-simplices (k=0..K) from the graph. + A k-simplex is a (k+1)-clique. + + Returns + ------- + simplices_by_dim : dict + {k: list of tuples}, each tuple is a sorted k-simplex. + """ + simplices_by_dim = {k: [] for k in range(K + 1)} + + # 0-simplices are just nodes + for node in nx_g.nodes(): + simplices_by_dim[0].append((node,)) + + if K >= 1: + # Find all cliques up to size K+1 + all_cliques = list(nx.enumerate_all_cliques(nx_g)) + for clique in all_cliques: + dim = len(clique) - 1 # k-simplex has k+1 nodes + if dim > K: + break + if dim >= 1: + simplices_by_dim[dim].append(tuple(sorted(clique))) + + return simplices_by_dim + + def _build_boundary_matrix(self, simplices_k, simplices_k_minus_1): + """ + Build boundary matrix B_k mapping k-simplices to (k-1)-simplices. + + B_k has shape (num_(k-1)-simplices, num_k-simplices). + """ + if len(simplices_k) == 0 or len(simplices_k_minus_1) == 0: + return sp.csr_matrix((len(simplices_k_minus_1), len(simplices_k))) + + # Index lookup for (k-1)-simplices + face_to_idx = {s: i for i, s in enumerate(simplices_k_minus_1)} + + rows, cols, vals = [], [], [] + for j, simplex in enumerate(simplices_k): + # Each k-simplex has (k+1) faces of dimension (k-1) + for i_face, _ in enumerate(simplex): + face = tuple(simplex[:i_face] + simplex[i_face + 1:]) + if face in face_to_idx: + rows.append(face_to_idx[face]) + cols.append(j) + vals.append((-1) ** i_face) + + B_k = sp.csr_matrix( + (vals, (rows, cols)), + shape=(len(simplices_k_minus_1), len(simplices_k)) + ) + return B_k + + def _build_hodge_laplacian(self, simplices_by_dim, k): + """ + Build the Hodge Laplacian L_k. + + L_0 = B_1 @ B_1^T + L_k = B_k^T @ B_k + B_{k+1} @ B_{k+1}^T (for 1 <= k < K) + L_K = B_K^T @ B_K + """ + n = len(simplices_by_dim.get(k, [])) + if n == 0: + return sp.csr_matrix((0, 0)) + + K = self.K + + if k == 0: + # L_0 = B_1 @ B_1^T + simplices_1 = simplices_by_dim.get(1, []) + if len(simplices_1) > 0: + B_1 = self._build_boundary_matrix(simplices_1, simplices_by_dim[0]) + L_k = B_1.dot(B_1.T) + else: + L_k = sp.csr_matrix((n, n)) + elif k == K: + # L_K = B_K^T @ B_K + B_k = self._build_boundary_matrix(simplices_by_dim[k], simplices_by_dim.get(k - 1, [])) + L_k = B_k.T.dot(B_k) + else: + # L_k = B_k^T @ B_k + B_{k+1} @ B_{k+1}^T + B_k = self._build_boundary_matrix(simplices_by_dim[k], simplices_by_dim.get(k - 1, [])) + down = B_k.T.dot(B_k) + + simplices_k_plus_1 = simplices_by_dim.get(k + 1, []) + if len(simplices_k_plus_1) > 0: + B_k_plus_1 = self._build_boundary_matrix(simplices_k_plus_1, simplices_by_dim[k]) + up = B_k_plus_1.dot(B_k_plus_1.T) + else: + up = sp.csr_matrix((n, n)) + + L_k = down + up + + return L_k.tocsr().astype(np.float64) + + def _init_hetero_features(self, simplices_k, node_type_map): + """ + Initialize features for k-simplices using one-hot encoding of node types. + + For each k-simplex, aggregate the one-hot vectors of its constituent nodes. + + Returns + ------- + H_in : np.ndarray + Shape (num_simplices, num_node_types). + """ + num_types = self.num_node_types + n = len(simplices_k) + H_in = np.zeros((n, num_types), dtype=np.float64) + + for i, simplex in enumerate(simplices_k): + for node in simplex: + t = node_type_map.get(node, 0) + H_in[i, t] = 1.0 # one-hot aggregation (union) + + return H_in + + def _update(self, H_in, N_k): + """ + UPDATE step: propagate features using global operator N^(k). + + H_temp = N^(k) @ H_in + H_in (with residual connection) + """ + if sp.issparse(N_k): + H_temp = N_k.dot(H_in) + H_in + else: + H_temp = N_k @ H_in + H_in + return H_temp diff --git a/openhgnn/models/__init__.py b/openhgnn/models/__init__.py index 99578e65..d3758479 100644 --- a/openhgnn/models/__init__.py +++ b/openhgnn/models/__init__.py @@ -64,6 +64,7 @@ def build_model_from_args(args, hg): SUPPORTED_MODELS = { ##### add models here + 'HGSketch': 'openhgnn.models.HGSketch', 'MHGCN':'openhgnn.models.MHGCN', 'BPHGNN' : 'openhgnn.models.BPHGNN', "MetaHIN": "openhgnn.models.MetaHIN", @@ -141,6 +142,7 @@ def build_model_from_args(args, hg): } ##### add model here +from .HGSketch import HGSketch from .BPHGNN import BPHGNN from .RHINE import RHINE from .FedHGNN import FedHGNN @@ -253,5 +255,6 @@ def build_model_from_args(args, hg): 'ExpressGNN', 'Ingram', 'RHINE', + 'HGSketch', ] classes = __all__ diff --git a/openhgnn/trainerflow/HGSketch_trainer.py b/openhgnn/trainerflow/HGSketch_trainer.py new file mode 100644 index 00000000..afe8550c --- /dev/null +++ b/openhgnn/trainerflow/HGSketch_trainer.py @@ -0,0 +1,165 @@ +""" +HGSketch Trainer +================ +Training flow for HGSketch model. + +Since HGSketch is a non-parametric method (no gradient-based training), +the trainer computes graph-level embeddings for all graphs in the dataset, +applies linearization, and uses a linear classifier (Logistic Regression) +for graph classification. +""" + +import numpy as np +from tqdm import tqdm +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import accuracy_score, f1_score +from sklearn.preprocessing import LabelEncoder + +from . import register_flow +from .base_flow import BaseFlow +from ..models import build_model + + +@register_flow('HGSketch_trainer') +class HGSketchTrainer(BaseFlow): + """Trainer flow for HGSketch graph classification.""" + + def __init__(self, args): + # HGSketch is non-parametric, skip standard BaseFlow graph loading + self.args = args + self.logger = args.logger + self.model_name = args.model_name + self.device = args.device + + # Build task to get dataset + from ..tasks import build_task + self.task = build_task(args) + + def train(self): + """ + Main training pipeline: + 1. Build model from the first graph in dataset + 2. Compute HGSketch embeddings for all graphs + 3. Linearize embeddings + 4. Train LogisticRegression and evaluate + """ + dataset = self.task.dataset + graphs, labels = self._load_graph_dataset(dataset) + + if len(graphs) == 0: + self.logger.train_info("No graphs found in dataset.") + return {'accuracy': 0.0} + + # Build model using the first graph as reference + model_cls = build_model(self.args.model_name) + model = model_cls.build_model_from_args(self.args, graphs[0]) + + self.logger.train_info(f"Computing HGSketch embeddings for {len(graphs)} graphs...") + self.logger.train_info(f"Parameters: K={model.K}, R={model.R}, D={model.D}") + + # Compute embeddings for all graphs + embeddings = [] + for i, g in enumerate(tqdm(graphs, desc="HGSketch")): + x_g = model.compute_sketch(g) + x_lin = model.linearize(x_g) + embeddings.append(x_lin) + + # Pad embeddings to the same length (graphs may have different sizes) + max_len = max(len(e) for e in embeddings) + X = np.zeros((len(embeddings), max_len), dtype=np.float32) + for i, e in enumerate(embeddings): + X[i, :len(e)] = e + + y = np.array(labels) + + # Split into train/test + train_mask, test_mask = self._get_split(dataset, len(graphs)) + + X_train, y_train = X[train_mask], y[train_mask] + X_test, y_test = X[test_mask], y[test_mask] + + self.logger.train_info(f"Train size: {len(X_train)}, Test size: {len(X_test)}") + self.logger.train_info(f"Feature dimension: {max_len}") + + # Train linear classifier + clf = LogisticRegression(max_iter=1000, solver='lbfgs', multi_class='auto', C=1.0) + clf.fit(X_train, y_train) + + # Evaluate + y_pred_train = clf.predict(X_train) + y_pred_test = clf.predict(X_test) + + train_acc = accuracy_score(y_train, y_pred_train) + test_acc = accuracy_score(y_test, y_pred_test) + test_f1_macro = f1_score(y_test, y_pred_test, average='macro') + test_f1_micro = f1_score(y_test, y_pred_test, average='micro') + + self.logger.train_info(f"Train Accuracy: {train_acc:.4f}") + self.logger.train_info(f"Test Accuracy: {test_acc:.4f}") + self.logger.train_info(f"Test F1-Macro: {test_f1_macro:.4f}") + self.logger.train_info(f"Test F1-Micro: {test_f1_micro:.4f}") + + return { + 'train_acc': train_acc, + 'test_acc': test_acc, + 'test_f1_macro': test_f1_macro, + 'test_f1_micro': test_f1_micro, + } + + def _load_graph_dataset(self, dataset): + """ + Load graphs and labels from the dataset. + + Returns + ------- + graphs : list of DGLHeteroGraph + labels : list of int + """ + graphs = [] + labels = [] + + if hasattr(dataset, 'graphs') and hasattr(dataset, 'labels'): + graphs = dataset.graphs + labels = dataset.labels + if hasattr(labels, 'numpy'): + labels = labels.numpy().tolist() + elif isinstance(labels, np.ndarray): + labels = labels.tolist() + elif hasattr(dataset, '__len__') and hasattr(dataset, '__getitem__'): + for i in range(len(dataset)): + item = dataset[i] + if isinstance(item, tuple) and len(item) == 2: + g, l = item + graphs.append(g) + labels.append(l.item() if hasattr(l, 'item') else int(l)) + + return graphs, labels + + def _get_split(self, dataset, n): + """ + Get train/test split masks. + + Returns + ------- + train_mask : np.ndarray of bool + test_mask : np.ndarray of bool + """ + if hasattr(dataset, 'train_mask') and hasattr(dataset, 'test_mask'): + train_mask = np.array(dataset.train_mask, dtype=bool) + test_mask = np.array(dataset.test_mask, dtype=bool) + elif hasattr(dataset, 'train_idx') and hasattr(dataset, 'test_idx'): + train_mask = np.zeros(n, dtype=bool) + test_mask = np.zeros(n, dtype=bool) + train_mask[dataset.train_idx] = True + test_mask[dataset.test_idx] = True + else: + # Default 80/20 split + rng = np.random.RandomState(self.args.seed) + indices = rng.permutation(n) + split = int(0.8 * n) + train_mask = np.zeros(n, dtype=bool) + test_mask = np.zeros(n, dtype=bool) + train_mask[indices[:split]] = True + test_mask[indices[split:]] = True + + return train_mask, test_mask diff --git a/openhgnn/trainerflow/__init__.py b/openhgnn/trainerflow/__init__.py index 3dd96083..8a8e0609 100644 --- a/openhgnn/trainerflow/__init__.py +++ b/openhgnn/trainerflow/__init__.py @@ -50,6 +50,7 @@ def build_flow(args, flow_name): SUPPORTED_FLOWS = { ########### add trainer_flow here. 【register name】 : 【class name】 + 'HGSketch_trainer': 'openhgnn.trainerflow.HGSketch_trainer', 'MHGCN_trainer':'openhgnn.trainerflow.MHGCN_trainer', 'BPHGNN_trainer':'openhgnn.trainerflow.BPHGNN_trainer', 'HGMAE':'openhgnn.trainerflow.HGMAE_trainer', @@ -113,6 +114,7 @@ def build_flow(args, flow_name): } ###### add trainer_flow here +from .HGSketch_trainer import HGSketchTrainer from .BPHGNN_trainer import BPHGNN_trainer from .HGMAE_trainer import HGMAE_trainer from .hga_trainer import HGATrainer diff --git a/tests/test_hgsketch.py b/tests/test_hgsketch.py new file mode 100644 index 00000000..7e8ec487 --- /dev/null +++ b/tests/test_hgsketch.py @@ -0,0 +1,322 @@ +""" +Tests for HGSketch model component. +Covers: registration, config, core algorithm, and linearization. +""" +import os +import sys +import numpy as np +import pytest + +# Ensure project root is on path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +import dgl +import torch + + +# ============================================================ +# 1. Registration Tests +# ============================================================ + +def test_model_registered(): + """HGSketch should be in MODEL_REGISTRY after import.""" + from openhgnn.models import MODEL_REGISTRY + from openhgnn.models.HGSketch import HGSketch # trigger registration + assert 'HGSketch' in MODEL_REGISTRY, "HGSketch not found in MODEL_REGISTRY" + + +def test_flow_registered(): + """HGSketch_trainer should be in FLOW_REGISTRY after import.""" + from openhgnn.trainerflow import FLOW_REGISTRY + from openhgnn.trainerflow.HGSketch_trainer import HGSketchTrainer # trigger registration + assert 'HGSketch_trainer' in FLOW_REGISTRY, "HGSketch_trainer not found in FLOW_REGISTRY" + + +def test_experiment_binding(): + """HGSketch should be mapped in Experiment.specific_trainerflow.""" + from openhgnn.experiment import Experiment + assert 'HGSketch' in Experiment.specific_trainerflow + assert Experiment.specific_trainerflow['HGSketch'] == 'HGSketch_trainer' + + +def test_supported_models_entry(): + """HGSketch should be in SUPPORTED_MODELS dict.""" + from openhgnn.models import SUPPORTED_MODELS + assert 'HGSketch' in SUPPORTED_MODELS + assert SUPPORTED_MODELS['HGSketch'] == 'openhgnn.models.HGSketch' + + +def test_supported_flows_entry(): + """HGSketch_trainer should be in SUPPORTED_FLOWS dict.""" + from openhgnn.trainerflow import SUPPORTED_FLOWS + assert 'HGSketch_trainer' in SUPPORTED_FLOWS + + +# ============================================================ +# 2. Config Tests +# ============================================================ + +def test_config_reads_hgsketch_params(): + """Config should correctly read HGSketch params from config.ini.""" + from openhgnn.config import Config + conf_path = os.path.join(os.path.dirname(__file__), '..', 'openhgnn', 'config.ini') + config = Config(file_path=conf_path, model='HGSketch', dataset='test_ds', task='graph_classification', gpu=-1) + assert config.K == 2 + assert config.R == 3 + assert config.D == 128 + assert config.seed == 0 + assert config.max_epoch == 1 # non-parametric + + +# ============================================================ +# 3. Helper: create a small heterogeneous graph +# ============================================================ + +def _make_small_hg(): + """Create a small heterogeneous graph with 2 node types and 2 edge types.""" + # 4 'user' nodes, 3 'item' nodes + # edges: user->item (likes), item->user (liked_by) + data = { + ('user', 'likes', 'item'): (torch.tensor([0, 1, 2, 0, 3]), torch.tensor([0, 1, 2, 1, 0])), + ('item', 'liked_by', 'user'): (torch.tensor([0, 1, 2, 1, 0]), torch.tensor([0, 1, 2, 0, 3])), + } + hg = dgl.heterograph(data) + return hg + + +def _make_triangle_hg(): + """Create a graph that contains a triangle (3-clique) to test 2-simplices.""" + # 3 nodes of type 'a', fully connected -> forms a triangle + data = { + ('a', 'e1', 'a'): (torch.tensor([0, 1, 0, 2, 1, 2]), torch.tensor([1, 0, 2, 0, 2, 1])), + } + hg = dgl.heterograph(data) + return hg + + +# ============================================================ +# 4. Core Algorithm Tests +# ============================================================ + +def test_hg_to_nx_conversion(): + """Test heterogeneous graph to NetworkX conversion.""" + from openhgnn.models.HGSketch import HGSketch + model = HGSketch(K=1, R=1, D=8, num_node_types=2, seed=42) + hg = _make_small_hg() + nx_g = model._hg_to_nx(hg) + # Total nodes: 4 users + 3 items = 7 + assert nx_g.number_of_nodes() == 7 + # Should have edges (undirected, no self-loops) + assert nx_g.number_of_edges() > 0 + + +def test_node_type_map(): + """Test node type mapping.""" + from openhgnn.models.HGSketch import HGSketch + model = HGSketch(K=1, R=1, D=8, num_node_types=2, seed=42) + hg = _make_small_hg() + type_map = model._build_node_type_map(hg) + # DGL sorts ntypes alphabetically: ['item', 'user'] + # First 3 nodes (item) should be type 0, next 4 (user) should be type 1 + assert type_map[0] == 0 + assert type_map[2] == 0 + assert type_map[3] == 1 + assert type_map[6] == 1 + + +def test_simplex_extraction(): + """Test simplex extraction from a triangle graph.""" + from openhgnn.models.HGSketch import HGSketch + model = HGSketch(K=2, R=1, D=8, num_node_types=1, seed=42) + hg = _make_triangle_hg() + nx_g = model._hg_to_nx(hg) + simplices = model._extract_simplices(nx_g, K=2) + + # 0-simplices: 3 nodes + assert len(simplices[0]) == 3 + # 1-simplices: 3 edges + assert len(simplices[1]) == 3 + # 2-simplices: 1 triangle + assert len(simplices[2]) == 1 + + +def test_compute_sketch_output_shape(): + """Test that compute_sketch returns a non-empty array with values in {-1, 1}.""" + from openhgnn.models.HGSketch import HGSketch + model = HGSketch(K=1, R=2, D=16, num_node_types=2, seed=42) + hg = _make_small_hg() + x_g = model.compute_sketch(hg) + + assert isinstance(x_g, np.ndarray) + assert x_g.size > 0 + # All values should be -1 or 1 + unique_vals = set(np.unique(x_g)) + assert unique_vals.issubset({-1.0, 1.0}), f"Unexpected values: {unique_vals}" + + +def test_compute_sketch_with_triangle(): + """Test compute_sketch on a graph with higher-order simplices (K=2).""" + from openhgnn.models.HGSketch import HGSketch + model = HGSketch(K=2, R=2, D=16, num_node_types=1, seed=42) + hg = _make_triangle_hg() + x_g = model.compute_sketch(hg) + + assert isinstance(x_g, np.ndarray) + assert x_g.size > 0 + unique_vals = set(np.unique(x_g)) + assert unique_vals.issubset({-1.0, 1.0}) + + +def test_compute_sketch_deterministic(): + """Same graph + same seed should produce identical sketches.""" + from openhgnn.models.HGSketch import HGSketch + model = HGSketch(K=1, R=2, D=16, num_node_types=2, seed=123) + hg = _make_small_hg() + x1 = model.compute_sketch(hg) + x2 = model.compute_sketch(hg) + np.testing.assert_array_equal(x1, x2) + + +# ============================================================ +# 5. Linearization Tests +# ============================================================ + +def test_linearize_output_length(): + """Linearized vector should have length 2L.""" + from openhgnn.models.HGSketch import HGSketch + model = HGSketch(K=1, R=1, D=8, num_node_types=2, seed=42) + hg = _make_small_hg() + x_g = model.compute_sketch(hg) + L = len(x_g) + x_lin = model.linearize(x_g) + assert len(x_lin) == 2 * L + + +def test_linearize_values(): + """Linearized values should be 0 or 1/sqrt(L).""" + from openhgnn.models.HGSketch import HGSketch + model = HGSketch(K=1, R=1, D=8, num_node_types=2, seed=42) + hg = _make_small_hg() + x_g = model.compute_sketch(hg) + L = len(x_g) + x_lin = model.linearize(x_g) + + expected_nonzero = 1.0 / np.sqrt(L) + for v in np.unique(x_lin): + assert np.isclose(v, 0.0, atol=1e-8) or np.isclose(v, expected_nonzero, atol=1e-8), \ + f"Unexpected value: {v}, expected 0 or {expected_nonzero}" + + +def test_linearize_kernel_property(): + """ + Verify the Hamming kernel linearization property: + should equal the normalized Hamming agreement. + """ + from openhgnn.models.HGSketch import HGSketch + model = HGSketch(K=1, R=2, D=16, num_node_types=1, seed=42) + + hg1 = _make_triangle_hg() + hg2 = _make_small_hg() + # Use same num_node_types for both + model2 = HGSketch(K=1, R=2, D=16, num_node_types=2, seed=42) + + x1 = model.compute_sketch(hg1) + x2 = model.compute_sketch(hg1) # same graph + + x1_lin = model.linearize(x1) + x2_lin = model.linearize(x2) + + # For identical graphs, inner product should equal 1.0 (perfect agreement) + L = len(x1) + if L > 0: + inner = np.dot(x1_lin, x2_lin) + # Since x1 == x2, all bits agree, so kernel = L/L = 1.0 + np.testing.assert_almost_equal(inner, 1.0, decimal=5) + + +# ============================================================ +# 6. Hodge Laplacian Tests +# ============================================================ + +def test_hodge_laplacian_shape(): + """L_k should be square with size = number of k-simplices.""" + from openhgnn.models.HGSketch import HGSketch + model = HGSketch(K=2, R=1, D=8, num_node_types=1, seed=42) + hg = _make_triangle_hg() + nx_g = model._hg_to_nx(hg) + simplices = model._extract_simplices(nx_g, K=2) + + for k in range(3): + n_k = len(simplices[k]) + L_k = model._build_hodge_laplacian(simplices, k) + assert L_k.shape == (n_k, n_k), f"L_{k} shape mismatch: {L_k.shape} vs ({n_k}, {n_k})" + + +def test_hodge_laplacian_symmetric(): + """Hodge Laplacian should be symmetric.""" + from openhgnn.models.HGSketch import HGSketch + from scipy import sparse as sp + model = HGSketch(K=2, R=1, D=8, num_node_types=1, seed=42) + hg = _make_triangle_hg() + nx_g = model._hg_to_nx(hg) + simplices = model._extract_simplices(nx_g, K=2) + + for k in range(3): + L_k = model._build_hodge_laplacian(simplices, k) + if L_k.shape[0] > 0: + diff = L_k - L_k.T + assert abs(diff).max() < 1e-10, f"L_{k} is not symmetric" + + +# ============================================================ +# 7. Build Model from Args Test +# ============================================================ + +def test_build_model_from_args(): + """Test build_model_from_args class method.""" + from openhgnn.models.HGSketch import HGSketch + + class MockArgs: + K = 2 + R = 3 + D = 64 + seed = 42 + + hg = _make_small_hg() + model = HGSketch.build_model_from_args(MockArgs(), hg) + assert model.K == 2 + assert model.R == 3 + assert model.D == 64 + assert model.num_node_types == 2 # user, item + assert model.seed == 42 + + +# ============================================================ +# 8. Edge Cases +# ============================================================ + +def test_empty_graph(): + """Model should handle a graph with no edges gracefully.""" + from openhgnn.models.HGSketch import HGSketch + model = HGSketch(K=1, R=1, D=8, num_node_types=1, seed=42) + # Graph with 3 nodes, no edges + hg = dgl.heterograph({('a', 'e', 'a'): (torch.zeros(0, dtype=torch.int64), torch.zeros(0, dtype=torch.int64))}, + num_nodes_dict={'a': 3}) + x_g = model.compute_sketch(hg) + assert isinstance(x_g, np.ndarray) + # Should still produce output (at least 0-simplex features) + assert x_g.size > 0 + + +def test_single_node_graph(): + """Model should handle a single-node graph.""" + from openhgnn.models.HGSketch import HGSketch + model = HGSketch(K=1, R=1, D=8, num_node_types=1, seed=42) + hg = dgl.heterograph({('a', 'e', 'a'): (torch.zeros(0, dtype=torch.int64), torch.zeros(0, dtype=torch.int64))}, + num_nodes_dict={'a': 1}) + x_g = model.compute_sketch(hg) + assert isinstance(x_g, np.ndarray) + + +if __name__ == '__main__': + pytest.main([__file__, '-v'])