diff --git a/.vscode/settings.json b/.vscode/settings.json index f7ef33f9..9aa399fc 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -27,5 +27,7 @@ "editor.codeActionsOnSave": { "source.organizeImports": "explicit" } - } + }, + "python.testing.pytestArgs": ["packages/backend"], + "python.testing.pytestEnabled": true } diff --git a/packages/backend/embedding_atlas/cli.py b/packages/backend/embedding_atlas/cli.py index 568c2100..0c1dc4bb 100644 --- a/packages/backend/embedding_atlas/cli.py +++ b/packages/backend/embedding_atlas/cli.py @@ -227,6 +227,14 @@ def import_modules(names: list[str]): "neighbors_column", help='Column containing pre-computed nearest neighbors in format: {"ids": [n1, n2, ...], "distances": [d1, d2, ...]}. IDs should be zero-based row indices.', ) +@click.option( + "--pagerank", + "pagerank_column", + default=None, + is_flag=False, + flag_value="__compute__", + help="Compute PageRank scores from the neighbor graph, or specify a column containing pre-computed scores. Automatically computed when --image is specified.", +) @click.option( "--query", default=None, @@ -352,6 +360,7 @@ def main( x_column: str | None, y_column: str | None, neighbors_column: str | None, + pagerank_column: str | None, query: str | None, sample: int | None, umap_n_neighbors: int | None, @@ -475,12 +484,32 @@ def main( labels_df = load_pandas_data(labels) labels_resolved = labels_df.to_dict("records") + # Compute PageRank from neighbor graph when requested or when --image is specified + should_compute_pagerank = (pagerank_column == "__compute__") or ( + image is not None and pagerank_column is None + ) + if ( + should_compute_pagerank + and neighbors_column is not None + and neighbors_column in df.columns + ): + from embedding_atlas.pagerank import compute_pagerank_column + + logger.info("Computing PageRank scores from neighbor graph...") + pagerank_column = find_column_name(df.columns, "pagerank") + df[pagerank_column] = compute_pagerank_column(df, neighbors=neighbors_column) + elif pagerank_column == "__compute__": + logger.warning("Cannot compute PageRank: no neighbor data available.") + pagerank_column = None + props = make_embedding_atlas_props( row_id=id_column, x=x_column, y=y_column, neighbors=neighbors_column, + importance=pagerank_column, text=text, + image=image, point_size=point_size, stop_words=stop_words_resolved, labels=labels_resolved, diff --git a/packages/backend/embedding_atlas/options.py b/packages/backend/embedding_atlas/options.py index 45202d36..954a510a 100644 --- a/packages/backend/embedding_atlas/options.py +++ b/packages/backend/embedding_atlas/options.py @@ -17,6 +17,13 @@ class EmbeddingAtlasOptions(TypedDict, total=False): text: The column name for the textual data. + image: + The column name for image data. + + importance: + The column name for importance scores (e.g., PageRank). Used with ``image`` to select + representative images for cluster labels. Maps to ``importance`` in the frontend API. + neighbors: The column name containing precomputed K-nearest neighbors for each point. Each value in the column should be a dictionary with the format: @@ -60,6 +67,8 @@ class EmbeddingAtlasOptions(TypedDict, total=False): x: str | None y: str | None text: str | None + image: str | None + importance: str | None neighbors: str | None point_size: float | None @@ -109,6 +118,8 @@ def set_prop(key: str, value): if options.get("x") is not None and options.get("y") is not None: set_prop("data.projection", {"x": options.get("x"), "y": options.get("y")}) set_prop("data.text", options.get("text")) + set_prop("data.image", options.get("image")) + set_prop("data.importance", options.get("importance")) set_prop("data.neighbors", options.get("neighbors")) # Embedding View diff --git a/packages/backend/embedding_atlas/pagerank.py b/packages/backend/embedding_atlas/pagerank.py new file mode 100644 index 00000000..d54c64ee --- /dev/null +++ b/packages/backend/embedding_atlas/pagerank.py @@ -0,0 +1,319 @@ +from collections.abc import Sequence + +import numpy as np +import pandas as pd +import torch + + +def pagerank( + edges: Sequence[tuple[int, int] | tuple[int, int, float]], + *, + n: int, + damping: float = 0.85, + max_iterations: int = 100, + tolerance: float = 1e-9, +) -> np.ndarray: + """ + Compute PageRank scores from a list of edges of a graph using PyTorch + sparse matrix power iteration. The graph can be either unweighted (each + edge consists of source node ID and target node ID), or weighted (each + edge has an additional third element: edge weight). + + Args: + edges: List of tuples representing edges. Can be: + - Unweighted: [(source1, target1), (source2, target2), ...] + - Weighted: [(source1, target1, weight1), (source2, target2, weight2), ...] + Weighted vs unweighted is auto-detected based on tuple length. + n: Number of nodes in the graph. The returned array will have this length. + damping: PageRank damping factor (default: 0.85). + max_iterations: Maximum number of iterations (default: 100). + tolerance: Convergence tolerance (default: 1e-9). + + Returns: + np.ndarray of shape (n,) containing PageRank scores. + Scores are ordered by node index (scores[i] is the score for node i). + + Example: + >>> edges = [(0, 1, 0.5), (0, 2, 1.0), (1, 2, 0.8), (2, 0, 1.0)] + >>> scores = pagerank(edges, n=3) + >>> scores # scores[i] is the PageRank score for node i + array([0.32..., 0.21..., 0.46...]) + + # With KNN arrays: + >>> edges = knn_to_edges(knn_indices, knn_distances) + >>> scores = pagerank(edges, n=len(knn_indices)) + """ + if len(edges) == 0: + if n > 0: + return np.full(n, 1.0 / n) + return np.array([]) + + # Parse edges into source, target, weight arrays + sources = [] + targets = [] + weights = [] + for edge in edges: + sources.append(edge[0]) + targets.append(edge[1]) + weights.append(float(edge[2]) if len(edge) == 3 else 1.0) + + # Validate n covers all node IDs in the edge list + max_node_id = max(max(sources), max(targets)) + if n <= max_node_id: + raise ValueError( + f"n={n} but edges contain node ID {max_node_id} (n must be > max node ID)" + ) + + # Build sparse transition matrix M where M[j, i] = weight(i -> j) / out_degree(i) + # This means column i represents outgoing edges from node i. + # We need to normalize each column by its sum. + src = torch.tensor(sources, dtype=torch.long) + tgt = torch.tensor(targets, dtype=torch.long) + w = torch.tensor(weights, dtype=torch.float64) + + # Compute column sums (out-weight per source node) + col_sums = torch.zeros(n, dtype=torch.float64) + col_sums.scatter_add_(0, src, w) + + # Normalize weights by column sum to get transition probabilities + # Avoid division by zero for dangling nodes (handled separately) + safe_col_sums = col_sums[src] + safe_col_sums[safe_col_sums == 0] = 1.0 + normalized_w = w / safe_col_sums + + # Build sparse matrix: M[tgt, src] = normalized_w + # This is the column-stochastic transition matrix + indices = torch.stack([tgt, src]) + M = torch.sparse_coo_tensor(indices, normalized_w, size=(n, n), dtype=torch.float64) + M = M.coalesce() + + # Identify dangling nodes (no outgoing edges) + is_dangling = col_sums == 0 + + # Initialize rank vector uniformly + r = torch.full((n,), 1.0 / n, dtype=torch.float64) + + teleport = (1.0 - damping) / n + + for i in range(max_iterations): + # Dangling node contribution: their rank is distributed uniformly + dangling_sum = r[is_dangling].sum().item() + + # Power iteration step + r_new = damping * torch.mv(M, r) + teleport + damping * dangling_sum / n + + # Check convergence (L1 norm) + diff = torch.abs(r_new - r).sum().item() + r = r_new + + if diff < tolerance: + break + + return r.numpy() + + +def knn_to_edges( + knn_indices: np.ndarray, + knn_distances: np.ndarray, + local_connectivity: float = 1.0, +) -> list[tuple[int, int, float]]: + """ + Convert raw UMAP k-nearest-neighbor (KNN) arrays into a weighted edge + list, which can then be passed into pagerank(). + + Raw KNN distances are not directly usable as edge weights because higher + distance means weaker connection, and distances are not normalized across + points with varying local density. This method transforms raw distances + into UMAP-style membership strengths in [0, 1], where higher values + indicate stronger connections. The transformation is density-adaptive: + each point's distances are normalized relative to its local neighborhood + via per-point sigma and rho parameters. + + The raw arrays come from Projection.knn_indices and + Projection.knn_distances (see projection.py), which store raw distances + from umap.umap_.nearest_neighbors(). During UMAP's fit_transform(), + these raw distances are internally converted to membership strengths + via smooth_knn_dist() and compute_membership_strengths(), but those + intermediate results are not exposed. Since Projection only stores the + raw distances, this method re-derives the membership weights by calling + the same UMAP functions: + + 1. smooth_knn_dist() computes per-point sigma (bandwidth) and rho + (distance to nearest neighbor) values. rho ensures every point has + at least one neighbor with membership strength ~1.0. sigma controls + how fast the strength decays for farther neighbors. + + 2. compute_membership_strengths() transforms each raw distance into a + membership weight via exp(-(distance - rho) / sigma), producing + values in [0, 1]. Distances <= rho are clamped to weight 1.0. + + Args: + knn_indices: Array of shape (N, k) where knn_indices[i] contains + the 0-indexed row IDs of the k nearest neighbors of + row i (may include i itself). + knn_distances: Array of shape (N, k) where knn_distances[i] contains + the raw distances to the k nearest neighbors of row i, + aligned with knn_indices (distances[j] corresponds to + indices[j]). + local_connectivity: UMAP local_connectivity parameter (default: 1.0). + The default of 1.0 matches UMAP's own default, so this + does not need to be provided unless local_connectivity + was explicitly set to a non-default value in umap_args + when computing the projection (see projection.py). In + that case, the same value must be passed here to ensure + the membership weights are consistent with the projection. + + Returns: + List of (source, target, weight) tuples, with self-loops excluded. + + Example: + >>> indices = np.array([[1, 2], [0, 2], [0, 1]]) + >>> distances = np.array([[0.1, 0.2], [0.1, 0.3], [0.2, 0.3]]) + >>> edges = knn_to_edges(indices, distances) + """ + from umap.umap_ import compute_membership_strengths, smooth_knn_dist + + n_neighbors = knn_distances.shape[1] + + # Compute sigmas and rhos + sigmas, rhos = smooth_knn_dist( + knn_distances, + k=n_neighbors, + local_connectivity=local_connectivity, + ) + + # Compute membership strengths (edge weights) + result = compute_membership_strengths( + knn_indices.astype(np.int32), + knn_distances.astype(np.float32), + sigmas.astype(np.float32), + rhos.astype(np.float32), + return_dists=False, + ) + rows, cols, vals = result[0], result[1], result[2] + + # Convert to edge list, filtering out self-loops + edges = [(int(r), int(c), float(v)) for r, c, v in zip(rows, cols, vals) if r != c] + + return edges + + +def compute_pagerank_column( + dataframe: pd.DataFrame, + *, + neighbors: str = "__neighbors", + local_connectivity: float = 1.0, + damping: float = 0.85, +): + """ + Compute PageRank scores from a DataFrame that contains a neighbors column. + + The neighbors column contains one dict per row with two parallel arrays: + - 'ids': 0-indexed row IDs of the k nearest neighbors (int[]) + - 'distances': raw distances to those neighbors (float[]) + + The arrays are aligned: ids[j] is the neighbor and distances[j] is its + distance. A row's own ID typically appears in its own ids array (often + at position 0 with distance 0.0), but it is not guaranteed to be first + because other neighbors can also have distance 0.0. For example: + + Row 0: ids=[0, 110431, 61815, ...], distances=[0.0, 0.07, 0.11, ...] + Row 4: ids=[113494, 75640, 4, ...], distances=[0.0, 0.0, 0.0, ...] + + This is the format produced by compute_text_projection, + compute_vector_projection, and compute_image_projection in projection.py. + + Args: + dataframe: pandas DataFrame containing the neighbor data. + neighbors: Column name containing the neighbors dicts. + local_connectivity: UMAP local_connectivity parameter (default: 1.0). + See knn_to_edges() for when this needs to be changed. + damping: PageRank damping factor (default: 0.85). + + Returns: + np.ndarray of shape (len(dataframe),) containing PageRank scores. + """ + neighbors_col = dataframe[neighbors] + knn_indices = np.stack([np.array(row["ids"]) for row in neighbors_col]) + knn_distances = np.stack([np.array(row["distances"]) for row in neighbors_col]) + + edges = knn_to_edges( + knn_indices, knn_distances, local_connectivity=local_connectivity + ) + scores = pagerank(edges, n=len(dataframe), damping=damping) + + return scores + + +if __name__ == "__main__": + import argparse + import time + + import pyarrow.parquet as pq + + parser = argparse.ArgumentParser( + description="""\ + Compute PageRank scores from a parquet file containing KNN neighbor data. + + Input parquet file must contain a '__neighbors' column where each row is a + dict with two parallel arrays: + - 'ids': 0-indexed row IDs of the k nearest neighbors (int[]) + e.g. [0, 110431, 61815, ...] or [113494, 75640, 4, ...] + - 'distances': raw distances to those neighbors (float[]) + e.g. [0.0, 0.07, 0.11, ...] or [0.0, 0.0, 0.0, ...] + + The arrays are aligned: ids[j] is the neighbor and distances[j] is its + distance. A row's own ID typically appears in its own ids array (often at + position 0 with distance 0.0), but it is not guaranteed to be first because + other neighbors can also have distance 0.0. + + This is the format produced by projection.py (see compute_text_projection, + compute_vector_projection, compute_image_projection). + + Output parquet file contains all original columns plus a 'pagerank' column + with float scores that sum to 1.0, e.g. 0.000312 (higher = more central + in the KNN graph). + """, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--in", + dest="input_file", + type=str, + required=True, + help="Input parquet file with __neighbors column (see above for format)", + ) + parser.add_argument( + "--out", + dest="output_file", + type=str, + required=True, + help="Output parquet file: same as input with a 'pagerank' column added", + ) + args = parser.parse_args() + + # Load parquet and extract KNN arrays from the __neighbors column + print(f"Loading {args.input_file} ...") + df = pq.read_table(args.input_file).to_pandas() + print(f"Loaded {len(df)} rows") + + # Compute PageRank and add as a column + print("Computing PageRank...") + start_time = time.time() + df["pagerank"] = compute_pagerank_column(df) + print(f"PageRank completed in {time.time() - start_time:.4f} seconds") + + # Summary + scores = df["pagerank"].values + top_indices = np.argsort(scores)[::-1][:10] + print("\nTop 10 nodes by PageRank score:") + for idx in top_indices: + print(f" node {idx:>6d} score {scores[idx]:.10f}") + + print( + f"\nScore statistics: min={scores.min():.10f}, max={scores.max():.10f}, mean={scores.mean():.10f}" + ) + + # Write output + df.to_parquet(args.output_file, index=False) + print(f"\nSaved to {args.output_file} with 'pagerank' column added") diff --git a/packages/backend/tests/test_pagerank.py b/packages/backend/tests/test_pagerank.py new file mode 100644 index 00000000..fc633971 --- /dev/null +++ b/packages/backend/tests/test_pagerank.py @@ -0,0 +1,459 @@ +# Copyright (c) 2025 Apple Inc. Licensed under MIT License. + +"""Comprehensive tests for the PyTorch PageRank implementation. + +Covers edge cases, dtype boundaries, convergence behavior, and API contracts. +""" + +import numpy as np +import pandas as pd +import pytest +import torch +from embedding_atlas.pagerank import ( + compute_pagerank_column, + knn_to_edges, + pagerank, +) + + +class TestEdgeCases: + def test_empty_edges(self): + """Empty edge list with n=0 should return an empty array.""" + scores = pagerank([], n=0) + assert isinstance(scores, np.ndarray) + assert len(scores) == 0 + + def test_empty_edges_with_n(self): + """Empty edge list with explicit n should return uniform scores.""" + scores = pagerank([], n=3) + assert len(scores) == 3 + assert np.isclose(scores.sum(), 1.0) + assert np.allclose(scores, 1 / 3) + + def test_single_node_self_loop(self): + """A single node with a self-loop should get all the rank.""" + scores = pagerank([(0, 0, 1.0)], n=1) + assert len(scores) == 1 + assert np.isclose(scores[0], 1.0) + + def test_single_edge(self): + """Two nodes, one directed edge: target gets more rank.""" + scores = pagerank([(0, 1, 1.0)], n=2) + assert len(scores) == 2 + assert np.isclose(scores.sum(), 1.0) + # Node 1 receives a link from node 0; node 0 is a dangling node + # whose rank gets redistributed. Node 1 should have higher rank. + assert scores[1] > scores[0] + + def test_fully_disconnected_graph(self): + """Two nodes with self-loops only: equal rank (no cross-edges).""" + scores = pagerank([(0, 0, 1.0), (1, 1, 1.0)], n=2) + assert len(scores) == 2 + assert np.isclose(scores.sum(), 1.0) + assert np.allclose(scores[0], scores[1]) + + def test_self_loops_do_not_break_convergence(self): + """Graph with self-loops mixed with real edges should converge.""" + edges = [ + (0, 0, 1.0), # self-loop + (0, 1, 1.0), + (1, 2, 1.0), + (2, 0, 1.0), + ] + scores = pagerank(edges, n=3) + assert len(scores) == 3 + assert np.isclose(scores.sum(), 1.0) + assert all(s > 0 for s in scores) + + def test_dangling_node(self): + """A node with no outgoing edges (dangling) should still get rank.""" + # 0 -> 1, 0 -> 2; nodes 1 and 2 are dangling + edges = [(0, 1, 1.0), (0, 2, 1.0)] + scores = pagerank(edges, n=3) + assert len(scores) == 3 + assert np.isclose(scores.sum(), 1.0) + # Dangling nodes redistribute their rank uniformly, + # so all nodes should have positive scores. + assert all(s > 0 for s in scores) + + def test_gap_in_node_ids(self): + """Nodes with gaps in IDs: phantom nodes in the gap are dangling.""" + # Nodes 0 and 5 — IDs 1-4 don't appear in edges but exist in the array + edges = [(0, 5, 1.0), (5, 0, 1.0)] + scores = pagerank(edges, n=6) + assert len(scores) == 6 + assert np.isclose(scores.sum(), 1.0) + # Nodes 0 and 5 should have more rank than phantom nodes 1-4 + assert scores[0] > scores[1] + assert scores[5] > scores[1] + + def test_trailing_nodes_with_n(self): + """Explicit n includes trailing nodes that have no edges.""" + edges = [(0, 1, 1.0), (1, 0, 1.0)] + scores = pagerank(edges, n=5) + assert len(scores) == 5 + assert np.isclose(scores.sum(), 1.0) + # Trailing nodes 2-4 are dangling and get some rank + assert all(s > 0 for s in scores) + + def test_n_smaller_than_edges_raises(self): + """If n is too small for the edge node IDs, raise ValueError.""" + edges = [(0, 5, 1.0), (5, 0, 1.0)] + with pytest.raises(ValueError, match="n=3 but edges contain node ID 5"): + pagerank(edges, n=3) + + def test_duplicate_edges_are_summed(self): + """Duplicate edges should be coalesced (weights summed) by sparse tensor.""" + edges_single = [(0, 1, 2.0), (1, 0, 1.0)] + edges_dup = [(0, 1, 1.0), (0, 1, 1.0), (1, 0, 1.0)] + scores_single = pagerank(edges_single, n=2) + scores_dup = pagerank(edges_dup, n=2) + assert np.allclose(scores_single, scores_dup, atol=1e-6) + + def test_unweighted_edges(self): + """Unweighted (2-tuple) edges should work.""" + edges = [(0, 1), (1, 2), (2, 0)] + scores = pagerank(edges, n=3) + assert len(scores) == 3 + assert np.isclose(scores.sum(), 1.0) + # Cycle => all equal + assert np.allclose(scores, scores[0]) + + +class TestDtypeBoundaries: + def test_integer_edge_weights(self): + """Integer weights should be accepted and produce valid results.""" + edges = [(0, 1, 2), (1, 2, 3), (2, 0, 1)] + scores = pagerank(edges, n=3) + assert len(scores) == 3 + assert np.isclose(scores.sum(), 1.0) + + def test_very_small_weights(self): + """Very small weights should not cause division-by-zero.""" + edges = [(0, 1, 1e-30), (1, 0, 1e-30)] + scores = pagerank(edges, n=2) + assert len(scores) == 2 + assert np.isclose(scores.sum(), 1.0) + assert not np.any(np.isnan(scores)) + assert not np.any(np.isinf(scores)) + + def test_very_large_weights(self): + """Very large weights should not overflow.""" + edges = [(0, 1, 1e15), (1, 0, 1e15)] + scores = pagerank(edges, n=2) + assert len(scores) == 2 + assert np.isclose(scores.sum(), 1.0) + assert not np.any(np.isnan(scores)) + + def test_mixed_weight_magnitudes(self): + """Edges with vastly different weight magnitudes. + + In a 2-node graph each node has one outgoing edge, so after + column-normalization both transition probabilities become 1.0 + regardless of the raw weight. Need 3+ nodes to see weight effects. + """ + # Node 0 splits outgoing: heavy to 1, light to 2 + # Node 1 -> 0, Node 2 -> 0 + edges = [ + (0, 1, 1e10), + (0, 2, 1e-10), + (1, 0, 1.0), + (2, 0, 1.0), + ] + scores = pagerank(edges, n=3) + assert len(scores) == 3 + assert np.isclose(scores.sum(), 1.0) + # Node 1 should get much more rank than node 2 + assert scores[1] > scores[2] + + def test_output_dtype_is_float64(self): + """Output should be float64 (from PyTorch float64 computation).""" + edges = [(0, 1, 1.0), (1, 0, 1.0)] + scores = pagerank(edges, n=2) + assert scores.dtype == np.float64 + + +class TestConvergence: + def test_damping_zero(self): + """damping=0 means pure teleportation: uniform distribution.""" + edges = [(0, 1, 1.0), (1, 2, 1.0), (2, 0, 1.0)] + scores = pagerank(edges, n=3, damping=0.0) + assert np.isclose(scores.sum(), 1.0) + expected = np.array([1 / 3, 1 / 3, 1 / 3]) + assert np.allclose(scores, expected, atol=1e-6) + + def test_damping_one(self): + """damping=1.0 means no teleportation: pure link-following.""" + edges = [(0, 1, 1.0), (1, 2, 1.0), (2, 0, 1.0)] + scores = pagerank(edges, n=3, damping=1.0) + assert np.isclose(scores.sum(), 1.0) + # Cycle with damping=1 should still be uniform + assert np.allclose(scores, 1 / 3, atol=1e-6) + + def test_damping_one_asymmetric(self): + """damping=1.0 on asymmetric graph follows link structure only.""" + # 0 -> 2, 1 -> 2: two sources feeding into node 2 + edges = [(0, 2, 1.0), (1, 2, 1.0)] + scores = pagerank(edges, n=3, damping=1.0) + assert np.isclose(scores.sum(), 1.0) + # (nodes 0 and 1 are dangling, their rank gets redistributed uniformly, + # then flows to 2 again). In steady state: node 2 gets 60% of rank. + assert scores[2] > scores[0] + assert scores[2] > scores[1] + + def test_max_iterations_respected(self): + """With max_iterations=1, result should not fully converge on asymmetric graph.""" + # Asymmetric graph: 0->1, 0->2, 1->0 (node 2 is dangling) + edges = [(0, 1, 1.0), (0, 2, 1.0), (1, 0, 1.0)] + scores_1 = pagerank(edges, n=3, max_iterations=1, tolerance=0.0) + scores_100 = pagerank(edges, n=3, max_iterations=100) + assert np.isclose(scores_1.sum(), 1.0) + assert np.isclose(scores_100.sum(), 1.0) + # 1 iteration should not match the converged result + assert not np.allclose(scores_1, scores_100, atol=1e-6) + + def test_tight_tolerance_converges(self): + """Very tight tolerance should still produce valid results.""" + edges = [(0, 1, 1.0), (1, 2, 1.0), (2, 0, 1.0)] + scores = pagerank(edges, n=3, tolerance=1e-15, max_iterations=1000) + assert np.isclose(scores.sum(), 1.0) + + def test_large_chain_graph(self): + """A long chain requires many iterations; verify it converges.""" + n = 100 + edges = [(i, i + 1, 1.0) for i in range(n - 1)] + edges.append((n - 1, 0, 1.0)) # close the cycle + scores = pagerank(edges, n=n, max_iterations=500) + assert len(scores) == n + assert np.isclose(scores.sum(), 1.0) + # Cycle => uniform + assert np.allclose(scores, 1 / n, atol=1e-4) + + def test_bipartite_structure(self): + """Bipartite graph: group A links to group B, B links back to A.""" + edges = [ + (0, 2, 1.0), + (0, 3, 1.0), + (1, 2, 1.0), + (1, 3, 1.0), + (2, 0, 1.0), + (2, 1, 1.0), + (3, 0, 1.0), + (3, 1, 1.0), + ] + scores = pagerank(edges, n=4) + assert np.isclose(scores.sum(), 1.0) + # All nodes are symmetric => uniform + assert np.allclose(scores, 0.25, atol=1e-6) + + +class TestAPIContract: + def test_negative_weight_accepted(self): + """Negative weights are not rejected (no validation), but result should + still be a numpy array. This documents current behavior.""" + edges = [(0, 1, -1.0), (1, 0, 1.0)] + scores = pagerank(edges, n=2) + assert isinstance(scores, np.ndarray) + assert len(scores) == 2 + + def test_zero_weight_edges(self): + """Zero-weight edges should not cause errors.""" + edges = [(0, 1, 0.0), (1, 0, 1.0)] + scores = pagerank(edges, n=2) + assert len(scores) == 2 + assert np.isclose(scores.sum(), 1.0) + + def test_returns_numpy_array(self): + """Return type must be np.ndarray, not a torch tensor.""" + edges = [(0, 1, 1.0), (1, 0, 1.0)] + scores = pagerank(edges, n=2) + assert isinstance(scores, np.ndarray) + assert not isinstance(scores, torch.Tensor) + + def test_single_node_no_self_loop(self): + """A dangling target node (only incoming edges) should still get rank.""" + edges = [(0, 1, 1.0)] + scores = pagerank(edges, n=2) + # Node 1 is dangling, but should have rank + assert len(scores) == 2 + assert np.isclose(scores.sum(), 1.0) + + def test_large_node_ids(self): + """Large node IDs should work (array is sized max_id + 1).""" + edges = [(0, 999, 1.0), (999, 0, 1.0)] + scores = pagerank(edges, n=1000) + assert len(scores) == 1000 + assert np.isclose(scores.sum(), 1.0) + assert scores[0] > 0 + assert scores[999] > 0 + + +class TestAnalyticalResults: + def test_two_node_cycle(self): + """Two nodes in a cycle: equal rank of 0.5 each.""" + edges = [(0, 1, 1.0), (1, 0, 1.0)] + scores = pagerank(edges, n=2) + assert np.allclose(scores, [0.5, 0.5]) + + def test_three_node_cycle(self): + """Three nodes in a cycle: equal rank of 1/3 each.""" + edges = [(0, 1, 1.0), (1, 2, 1.0), (2, 0, 1.0)] + scores = pagerank(edges, n=3) + assert np.allclose(scores, [1 / 3, 1 / 3, 1 / 3]) + + def test_star_incoming(self): + """All nodes point to hub — hub gets most rank.""" + # 1->0, 2->0, 3->0, 4->0 + edges = [(i, 0, 1.0) for i in range(1, 5)] + scores = pagerank(edges, n=5) + assert np.isclose(scores.sum(), 1.0) + assert scores[0] == max(scores) + + def test_star_outgoing(self): + """Hub points to all — leaves get more rank than hub.""" + # 0->1, 0->2, 0->3, 0->4 + edges = [(0, i, 1.0) for i in range(1, 5)] + scores = pagerank(edges, n=5) + assert np.isclose(scores.sum(), 1.0) + # All leaves are equivalent + leaf_scores = scores[1:] + assert np.allclose(leaf_scores, leaf_scores[0], atol=1e-6) + # Leaves get more rank than hub (hub distributes all its rank out) + assert scores[1] > scores[0] + + def test_weight_sensitivity(self): + """Higher weight edge should direct more rank to its target.""" + edges_equal = [(0, 1, 1.0), (0, 2, 1.0), (1, 0, 1.0), (2, 0, 1.0)] + edges_biased = [(0, 1, 10.0), (0, 2, 1.0), (1, 0, 1.0), (2, 0, 1.0)] + + scores_equal = pagerank(edges_equal, n=3) + scores_biased = pagerank(edges_biased, n=3) + + # With equal weights, nodes 1 and 2 should have equal rank + assert np.isclose(scores_equal[1], scores_equal[2], atol=1e-6) + # With biased weights, node 1 should have more rank than node 2 + assert scores_biased[1] > scores_biased[2] + + +class TestKnnToEdges: + def test_basic_conversion(self): + """Should produce correct number of edges.""" + indices = np.array([[1, 2], [0, 2], [0, 1]]) + distances = np.array([[0.1, 0.2], [0.1, 0.3], [0.2, 0.3]]) + edges = knn_to_edges(indices, distances) + assert len(edges) == 6 # 3 nodes × 2 neighbors + + def test_skips_self_loops(self): + """Should skip edges where source == target.""" + indices = np.array([[0, 1], [1, 0]]) # Node 0 points to itself + distances = np.array([[0.0, 0.5], [0.0, 0.5]]) + edges = knn_to_edges(indices, distances) + # Should have 2 edges (0->1 and 1->0), not 4 + assert len(edges) == 2 + assert all(src != tgt for src, tgt, _ in edges) + + def test_edge_weights_are_positive(self): + """UMAP-style weights should be positive.""" + indices = np.array([[1, 2], [0, 2], [0, 1]]) + distances = np.array([[0.1, 0.2], [0.1, 0.3], [0.2, 0.3]]) + edges = knn_to_edges(indices, distances) + assert all(w > 0 for _, _, w in edges) + + +class TestComputePagerankColumn: + def test_compute_pagerank_column(self): + """Should add a 'pagerank' column to the DataFrame.""" + indices = np.array([[1, 2], [0, 2], [0, 1]]) + distances = np.array([[0.1, 0.2], [0.1, 0.3], [0.2, 0.3]]) + df = pd.DataFrame( + { + "__neighbors": [ + {"ids": indices[i], "distances": distances[i]} + for i in range(len(indices)) + ] + } + ) + + df["pagerank"] = compute_pagerank_column(df) + + assert len(df["pagerank"]) == 3 + assert np.isclose(df["pagerank"].sum(), 1.0) + assert all(s > 0 for s in df["pagerank"]) + + +class TestIntegration: + def test_knn_to_pagerank_pipeline(self): + """Full pipeline: KNN arrays -> edges -> PageRank scores.""" + # Simple 3-node graph + indices = np.array([[1, 2], [0, 2], [0, 1]]) + distances = np.array([[0.1, 0.2], [0.1, 0.3], [0.2, 0.3]]) + + edges = knn_to_edges(indices, distances) + scores = pagerank(edges, n=3) + + assert len(scores) == 3 + assert np.isclose(scores.sum(), 1.0) + assert all(s > 0 for s in scores) + + +class TestUmapWeightCompatibility: + """Verify that our edge weight computation matches UMAP's implementation.""" + + def test_edge_weights_match_umap(self): + """Our edge weights should match UMAP's compute_membership_strengths.""" + from umap.umap_ import compute_membership_strengths, smooth_knn_dist + + # Generate random KNN data + np.random.seed(123) + n_samples = 50 + n_neighbors = 10 + + # Random neighbor indices (ensuring no self-loops for simplicity) + knn_indices = np.zeros((n_samples, n_neighbors), dtype=np.int32) + for i in range(n_samples): + candidates = [j for j in range(n_samples) if j != i] + knn_indices[i] = np.random.choice(candidates, n_neighbors, replace=False) + + # Random distances (sorted) - use float32 for UMAP compatibility + knn_distances = np.sort( + np.random.rand(n_samples, n_neighbors).astype(np.float32), axis=1 + ) + + # Compute sigmas and rhos using UMAP + umap_sigmas, umap_rhos = smooth_knn_dist( + knn_distances, + k=n_neighbors, + local_connectivity=1.0, + bandwidth=1.0, + ) + # Ensure float32 for numba compatibility + umap_sigmas = umap_sigmas.astype(np.float32) + umap_rhos = umap_rhos.astype(np.float32) + + # Compute membership strengths using UMAP + result = compute_membership_strengths( + knn_indices, knn_distances, umap_sigmas, umap_rhos, return_dists=False + ) + rows, cols, vals = result[0], result[1], result[2] + + # Build a dict of UMAP weights: (source, target) -> weight + umap_weights = {} + for r, c, v in zip(rows, cols, vals): + umap_weights[(int(r), int(c))] = float(v) + + # Compute using our implementation + our_edges = knn_to_edges(knn_indices, knn_distances) + our_weights = {(src, tgt): w for src, tgt, w in our_edges} + + # Compare weights for each edge + mismatches = [] + for (src, tgt), our_w in our_weights.items(): + umap_w = umap_weights.get((src, tgt)) + if umap_w is None: + mismatches.append(f"Edge ({src}, {tgt}) not in UMAP output") + elif not np.isclose(our_w, umap_w, rtol=1e-7): + mismatches.append( + f"Edge ({src}, {tgt}): ours={our_w:.8f}, umap={umap_w:.8f}" + ) + + assert len(mismatches) == 0, "Weight mismatches:\n" + "\n".join(mismatches[:10]) diff --git a/packages/component/src/lib/embedding_view/EmbeddingViewImpl.svelte b/packages/component/src/lib/embedding_view/EmbeddingViewImpl.svelte index 6cbf3196..72731f0b 100644 --- a/packages/component/src/lib/embedding_view/EmbeddingViewImpl.svelte +++ b/packages/component/src/lib/embedding_view/EmbeddingViewImpl.svelte @@ -16,7 +16,7 @@ totalCount: number | null; maxDensity: number | null; labels?: Label[] | null; - queryClusterLabels: ((clusters: Rectangle[][]) => Promise<(string | null)[]>) | null; + queryClusterLabels: ((clusters: Rectangle[][]) => Promise<(LabelContent | null)[]>) | null; tooltip: Selection | null; selection: Selection[] | null; querySelection: ((x: number, y: number, unitDistance: number) => Promise) | null; @@ -38,7 +38,7 @@ sumDensity: number; rects: Rectangle[]; bandwidth: number; - label?: string | null; + content?: LabelContent | null; } function viewingParameters( @@ -118,7 +118,7 @@ import { layoutLabels, type LabelWithPlacement } from "./labels.js"; import { simplifyPolygon } from "./simplify_polygon.js"; import { resolveTheme, type ThemeConfig } from "./theme.js"; - import type { Cache, CustomComponent, Label, OverlayProxy } from "./types.js"; + import type { Cache, CustomComponent, Label, LabelContent, OverlayProxy } from "./types.js"; import { findClusters } from "./worker/index.js"; interface SelectionBase { @@ -619,7 +619,7 @@ let cacheKey = await cacheKeyForObject({ autoLabel: { - version: 1, + version: 3, viewport, stopWords: config?.autoLabelStopWords, densityThreshold: config?.autoLabelDensityThreshold, @@ -636,19 +636,24 @@ let newClusters = await generateClusters(renderer, 10, viewport, config?.autoLabelDensityThreshold ?? 0.005); newClusters = newClusters.concat(await generateClusters(renderer, 5, viewport)); - if (queryClusterLabels) { - let labels = await queryClusterLabels(newClusters.map((x) => x.rects)); - for (let i = 0; i < newClusters.length; i++) { - newClusters[i].label = labels[i]; + let labels = await queryClusterLabels(newClusters.map((x) => x.rects)); + for (let i = 0; i < newClusters.length; i++) { + let label = labels[i]; + newClusters[i].content = label; + if (typeof label == "object" && label != null && "x" in label && "y" in label) { + if (label.x != null && label.y != null) { + newClusters[i].x = label.x; + newClusters[i].y = label.y; + } } } let result: Label[] = newClusters - .filter((x) => x.label != null && x.label.length > 0) + .filter((x) => x.content != null && (typeof x.content !== "string" || x.content.length > 0)) .map((x) => ({ x: x.x, y: x.y, - text: x.label!, + content: x.content!, priority: x.sumDensity, level: x.bandwidth == 10 ? 0 : 1, })); @@ -669,9 +674,14 @@ clusterLabels = await layoutLabels(vp.scale(), labels, resolvedTheme.fontFamily); } else { statusMessage = "Generating labels..."; - let result = await generateLabels(viewport); - clusterLabels = await layoutLabels(vp.scale(), result, resolvedTheme.fontFamily); - statusMessage = null; + try { + let result = await generateLabels(viewport); + clusterLabels = await layoutLabels(vp.scale(), result, resolvedTheme.fontFamily); + } catch (e) { + console.error("Error while generating labels", e); + } finally { + statusMessage = null; + } } } @@ -767,36 +777,49 @@ {#if showClusterLabels} {#each clusterLabels as label} - {@const rows = label.text.split("\n")} {@const location = pointLocation(label.coordinate.x, label.coordinate.y)} {@const scale = resolvedViewport.scale()} {@const isVisible = label.placement != null && label.placement.minScale <= scale && scale <= label.placement.maxScale} {#if isVisible} - - {#each rows as row, index} - - {row} - - {/each} - + {#if typeof label.content !== "string"} + + {:else} + {@const rows = label.content.split("\n")} + + {#each rows as row, index} + + {row} + + {/each} + + {/if} {/if} {/each} diff --git a/packages/component/src/lib/embedding_view/EmbeddingViewMosaic.svelte b/packages/component/src/lib/embedding_view/EmbeddingViewMosaic.svelte index e5933b9a..86e2490a 100644 --- a/packages/component/src/lib/embedding_view/EmbeddingViewMosaic.svelte +++ b/packages/component/src/lib/embedding_view/EmbeddingViewMosaic.svelte @@ -1,5 +1,6 @@ Promise<(string | null)[]>) | null; + queryClusterLabels?: ((clusters: Rectangle[][]) => Promise<(LabelContent | null)[]>) | null; /** A custom renderer to draw the tooltip content. */ customTooltip?: CustomComponent | null; diff --git a/packages/component/src/lib/embedding_view/embedding_view_mosaic_api.ts b/packages/component/src/lib/embedding_view/embedding_view_mosaic_api.ts index d5fea4ad..6e8f33c4 100644 --- a/packages/component/src/lib/embedding_view/embedding_view_mosaic_api.ts +++ b/packages/component/src/lib/embedding_view/embedding_view_mosaic_api.ts @@ -34,6 +34,14 @@ export interface EmbeddingViewMosaicProps { * The text content is also used to generate labels automatically. */ text?: string | null; + /** The name of the image column. + * If specified along with `importance`, cluster labels will display the highest-importance image per region. */ + image?: string | null; + + /** The name of the importance score column (e.g., PageRank, centrality). + * Used together with `image` to select representative images for cluster labels. */ + importance?: string | null; + /** The name of the identifier (aka., id) column. * If specified, the `selection` object will contain an `identifier` property that you can use to identify the point. */ identifier?: string | null; diff --git a/packages/component/src/lib/embedding_view/labels.ts b/packages/component/src/lib/embedding_view/labels.ts index 932ea364..e4850144 100644 --- a/packages/component/src/lib/embedding_view/labels.ts +++ b/packages/component/src/lib/embedding_view/labels.ts @@ -2,11 +2,14 @@ import { measureText } from "../measure_text.js"; import { type Point, type Rectangle } from "../utils.js"; -import type { Label } from "./types.js"; +import type { Label, LabelContent } from "./types.js"; import { dynamicLabelPlacement } from "./worker/index.js"; +/** Maximum size of image labels in pixels. */ +export const IMAGE_LABEL_SIZE = 48; + export interface LabelWithPlacement { - text: string; + content: LabelContent; fontSize: number; bounds: Rectangle; locationAtZero: Point; @@ -28,15 +31,21 @@ export async function layoutLabels( let location = { x: cluster.x, y: cluster.y }; let level = cluster.level ?? 0; let fontSize = level == 0 ? 14 : 12; - let size = measureText({ - text: cluster.text, - fontSize: fontSize, - fontFamily: fontFamily, - }); - size.width += 4; - size.height += 4; + let size; + if (typeof cluster.content !== "string") { + // Use the pre-computed display dimensions for collision detection + size = { width: cluster.content.width + 4, height: cluster.content.height + 4 }; + } else { + size = measureText({ + text: cluster.content, + fontSize: fontSize, + fontFamily: fontFamily, + }); + size.width += 4; + size.height += 4; + } return { - text: cluster.text, + content: cluster.content, fontSize: fontSize, bounds: { xMin: location.x - size.width / 2, diff --git a/packages/component/src/lib/embedding_view/types.ts b/packages/component/src/lib/embedding_view/types.ts index 2d976f79..ef15333b 100644 --- a/packages/component/src/lib/embedding_view/types.ts +++ b/packages/component/src/lib/embedding_view/types.ts @@ -18,13 +18,16 @@ export interface Cache { set: (key: string, value: any) => Promise; } +/** The content of a label: either a text string or an image with display dimensions (and optionally x, y coordinates). */ +export type LabelContent = string | { x?: number; y?: number; image: string; width: number; height: number }; + export interface Label { /** X coordinate. */ x: number; /** Y coordinate. */ y: number; - /** Label text, use "\n" for a new line. */ - text: string; + /** Label content: a text string or an image reference. */ + content: LabelContent; /** Label level. The label will be shown around 2^level zoom factor. */ level?: number | null; /** Placement priority. */ diff --git a/packages/component/src/lib/index.ts b/packages/component/src/lib/index.ts index 2cdc5e77..3bdc0935 100644 --- a/packages/component/src/lib/index.ts +++ b/packages/component/src/lib/index.ts @@ -18,6 +18,7 @@ export type { DataPoint, DataPointID, Label, + LabelContent, OverlayProxy, } from "./embedding_view/types.js"; export type { Point, Rectangle, ViewportState } from "./utils.js"; diff --git a/packages/utils/src/index.ts b/packages/utils/src/index.ts index d59bf9ff..a1cf8ec6 100644 --- a/packages/utils/src/index.ts +++ b/packages/utils/src/index.ts @@ -4,4 +4,5 @@ export { base64Decode, base64Encode, compress, decompress } from "./compression. export { debounce } from "./debounce.js"; export { deepEquals, deepMemo } from "./equals.js"; export { interactionHandler, type CursorValue, type DragHandler } from "./interaction_handler.js"; +export { audioToDataUrl, imageToDataUrl } from "./media.js"; export { applyUpdatesForKeyIfNeeded, applyUpdatesIfNeeded, mergeUpdates } from "./merge_updates.js"; diff --git a/packages/viewer/src/utils/media.ts b/packages/utils/src/media.ts similarity index 100% rename from packages/viewer/src/utils/media.ts rename to packages/utils/src/media.ts diff --git a/packages/viewer/src/EmbeddingAtlas.svelte b/packages/viewer/src/EmbeddingAtlas.svelte index 85a345d8..adddd87f 100644 --- a/packages/viewer/src/EmbeddingAtlas.svelte +++ b/packages/viewer/src/EmbeddingAtlas.svelte @@ -264,6 +264,8 @@ ? { ...data.projection, text: data.text ?? undefined, + image: data.image ?? undefined, + importance: data.importance ?? undefined, } : undefined, config: defaultChartsConfig ?? undefined, diff --git a/packages/viewer/src/api.ts b/packages/viewer/src/api.ts index dcd46815..7db5cf93 100644 --- a/packages/viewer/src/api.ts +++ b/packages/viewer/src/api.ts @@ -40,6 +40,12 @@ export interface EmbeddingAtlasProps { /** The column for text. The text will be used as content for the tooltip and search features. */ text?: string | null; + + /** The column for image data. Used with `importance` to select representative images for cluster labels. */ + image?: string | null; + + /** The column for importance scores (e.g., PageRank, centrality). Used with `image` to select representative images for cluster labels. */ + importance?: string | null; }; /** The color scheme. */ diff --git a/packages/viewer/src/charts/default_charts.ts b/packages/viewer/src/charts/default_charts.ts index d2148bf8..4ea915e7 100644 --- a/packages/viewer/src/charts/default_charts.ts +++ b/packages/viewer/src/charts/default_charts.ts @@ -30,7 +30,7 @@ export async function defaultCharts(options: { coordinator: Coordinator; table: string; id: string; - projection?: { x: string; y: string; text?: string }; + projection?: { x: string; y: string; text?: string; image?: string; importance?: string }; config?: DefaultChartsConfig; }): Promise { let { coordinator, table, projection } = options; @@ -49,6 +49,8 @@ export async function defaultCharts(options: { x: projection.x, y: projection.y, text: projection.text, + image: projection.image, + importance: projection.importance, }, }; if (typeof config.embedding == "object") { diff --git a/packages/viewer/src/charts/embedding/Embedding.svelte b/packages/viewer/src/charts/embedding/Embedding.svelte index b1f72277..3fd41c46 100644 --- a/packages/viewer/src/charts/embedding/Embedding.svelte +++ b/packages/viewer/src/charts/embedding/Embedding.svelte @@ -217,6 +217,8 @@ x={spec.data.x} y={spec.data.y} text={spec.data.text} + image={spec.data.image} + importance={spec.data.importance} category={categoryLegend?.indexColumn} categoryColors={categoryLegend?.legend.map((x) => x.color) ?? [theme.embeddingColor]} config={{ diff --git a/packages/viewer/src/charts/embedding/types.ts b/packages/viewer/src/charts/embedding/types.ts index aecfd3ae..4fb5b037 100644 --- a/packages/viewer/src/charts/embedding/types.ts +++ b/packages/viewer/src/charts/embedding/types.ts @@ -10,6 +10,8 @@ export interface EmbeddingSpec { x: string; y: string; text?: string | null; + image?: string | null; + importance?: string | null; category?: string | null; }; diff --git a/packages/viewer/src/embedding/embedding.worker.ts b/packages/viewer/src/embedding/embedding.worker.ts index fd524f85..925dd8aa 100644 --- a/packages/viewer/src/embedding/embedding.worker.ts +++ b/packages/viewer/src/embedding/embedding.worker.ts @@ -1,9 +1,9 @@ // Copyright (c) 2025 Apple Inc. Licensed under MIT License. import { createUMAP } from "@embedding-atlas/umap-wasm"; +import { imageToDataUrl } from "@embedding-atlas/utils"; import { load_image, pipeline } from "@huggingface/transformers"; -import { imageToDataUrl } from "../utils/media.js"; import { WorkerRPC } from "./worker_helper.js"; let { handler, register } = WorkerRPC.runtime(); diff --git a/packages/viewer/src/renderers/AudioContent.svelte b/packages/viewer/src/renderers/AudioContent.svelte index 3864d3b1..e10179a9 100644 --- a/packages/viewer/src/renderers/AudioContent.svelte +++ b/packages/viewer/src/renderers/AudioContent.svelte @@ -1,8 +1,8 @@