Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions test/loader/test_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,3 +1024,53 @@ def test_temporal_neighbor_loader_single_link():
assert batch['a'].num_nodes == 10
assert batch['b'].num_nodes == 10
assert batch['c'].num_nodes == 0


@onlyNeighborSampler
def test_time_window_homo_neighbor_loader():
r"""Test that :obj:`time_window` correctly restricts temporal neighbor
sampling by filtering out neighbors outside the window.
"""
# Chain graph: 0-1-2-3-4 with node timestamps matching node index.
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 3, 3, 4],
[1, 0, 2, 1, 3, 2, 4, 3],
])
# Node times: node i has time i.
node_time = torch.arange(5, dtype=torch.long)

data = Data(edge_index=edge_index, time=node_time, num_nodes=5)

# Sample with input_time=4 and a window of 2: only nodes with
# time in [2, 4] should be reachable neighbors.
loader = NeighborLoader(
data,
num_neighbors=[-1, -1],
input_nodes=torch.tensor([4]),
input_time=torch.tensor([4]),
time_attr='time',
time_window=2,
batch_size=1,
)

batch = next(iter(loader))
# All sampled nodes must have time >= input_time - time_window = 2.
sampled_times = node_time[batch.n_id]
assert (sampled_times >= 4 - 2).all()
assert (sampled_times <= 4).all()


def test_time_window_requires_time_attr():
r"""Test that setting :obj:`time_window` without :obj:`time_attr` raises
a :class:`ValueError`.
"""
edge_index = torch.tensor([[0, 1], [1, 0]])
data = Data(edge_index=edge_index, num_nodes=2)

with pytest.raises(ValueError, match="time_window"):
NeighborLoader(
data,
num_neighbors=[-1],
time_window=5,
batch_size=1,
)
17 changes: 17 additions & 0 deletions torch_geometric/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,16 @@ class NeighborLoader(NodeLoader):
guaranteed to fulfill temporal constraints, *i.e.* neighbors have
an earlier or equal timestamp than the center node.
(default: :obj:`None`)
time_window (int, optional): The size of the time window to restrict
temporal neighbor sampling. If set, only neighbors whose timestamps
satisfy :obj:`input_time - time_window <= neighbor_time <=
input_time` will be considered for sampling.
Requires :obj:`time_attr` to be set to identify the timestamp
attribute, and works alongside :obj:`temporal_strategy` to
determine how neighbors within the window are selected.
:obj:`input_time` (or the node timestamps from :obj:`time_attr`
if :obj:`input_time` is not provided) is used as the upper bound.
(default: :obj:`None`)
weight_attr (str, optional): The name of the attribute that denotes
edge weights in the graph.
If set, weighted/biased sampling will be used such that neighbors
Expand Down Expand Up @@ -211,6 +221,7 @@ def __init__(
disjoint: bool = False,
temporal_strategy: str = 'uniform',
time_attr: Optional[str] = None,
time_window: Optional[int] = None,
weight_attr: Optional[str] = None,
transform: Optional[Callable] = None,
transform_sampler_output: Optional[Callable] = None,
Expand All @@ -225,6 +236,11 @@ def __init__(
"'time_attr' arguments: 'input_time' is set "
"while 'time_attr' is not set.")

if time_window is not None and time_attr is None:
raise ValueError("Received conflicting 'time_window' and "
"'time_attr' arguments: 'time_window' is set "
"while 'time_attr' is not set.")

if neighbor_sampler is None:
neighbor_sampler = NeighborSampler(
data,
Expand All @@ -234,6 +250,7 @@ def __init__(
disjoint=disjoint,
temporal_strategy=temporal_strategy,
time_attr=time_attr,
time_window=time_window,
weight_attr=weight_attr,
is_sorted=is_sorted,
share_memory=kwargs.get('num_workers', 0) > 0,
Expand Down
148 changes: 144 additions & 4 deletions torch_geometric/sampler/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,48 @@
class NeighborSampler(BaseSampler):
r"""An implementation of an in-memory (heterogeneous) neighbor sampler used
by :class:`~torch_geometric.loader.NeighborLoader`.

Args:
data (Data or HeteroData or Tuple[FeatureStore, GraphStore]): The
graph data object.
num_neighbors (List[int] or Dict[EdgeType, List[int]]): The number of
neighbors to sample per hop.
subgraph_type (SubgraphType or str, optional): The type of the
returned subgraph. (default: :obj:`"directional"`)
replace (bool, optional): If :obj:`True`, sample with replacement.
(default: :obj:`False`)
disjoint (bool, optional): If :obj:`True`, each seed node produces
its own disjoint subgraph. (default: :obj:`False`)
temporal_strategy (str, optional): The sampling strategy for temporal
sampling (:obj:`"uniform"` or :obj:`"last"`). Only has effect
when :obj:`time_attr` is set and :obj:`pyg-lib` is installed.
(default: :obj:`"uniform"`)
time_attr (str, optional): The name of the node- or edge-level
timestamp attribute. When set, enables temporal sampling so that
only neighbors with timestamps :obj:`<= seed_time` are considered.
(default: :obj:`None`)
time_window (int, optional): Restricts temporal sampling to a window
:obj:`[seed_time - time_window, seed_time]`. Applied as a
Python-side pre-filter before the sampling kernel, so it works
regardless of whether native kernel support is available.
Requires :obj:`time_attr` to be set. Works in combination with
:obj:`temporal_strategy` — after filtering to the window,
:obj:`temporal_strategy` governs how :obj:`num_neighbors`
neighbors are selected from the remaining candidates.
Note that the current implementation uses a conservative
batch-level floor (:obj:`min(seed_time) - time_window`), so the
effective window may be wider than intended when seed times vary
significantly within a batch. A warning is issued when the seed
time spread exceeds :obj:`time_window`.
(default: :obj:`None`)
weight_attr (str, optional): The name of the edge weight attribute
for weighted/biased sampling. (default: :obj:`None`)
is_sorted (bool, optional): If :obj:`True`, assumes edge indices are
sorted by column (and by time within neighborhoods when
:obj:`time_attr` is set). (default: :obj:`False`)
share_memory (bool, optional): If :obj:`True`, moves tensors to
shared memory for use with multiple worker processes.
(default: :obj:`False`)
"""
def __init__(
self,
Expand All @@ -50,6 +92,7 @@ def __init__(
disjoint: bool = False,
temporal_strategy: str = 'uniform',
time_attr: Optional[str] = None,
time_window: Optional[int] = None,
weight_attr: Optional[str] = None,
is_sorted: bool = False,
share_memory: bool = False,
Expand Down Expand Up @@ -335,6 +378,7 @@ def __init__(
self.subgraph_type = SubgraphType(subgraph_type)
self.disjoint = disjoint
self.temporal_strategy = temporal_strategy
self.time_window = time_window
self.keep_orig_edges = False

@property
Expand Down Expand Up @@ -426,6 +470,95 @@ def edge_permutation(self) -> Union[OptTensor, Dict[EdgeType, OptTensor]]:

# Helper functions ########################################################

def _get_time_window_masked_times(
self,
seed_time: Union[Tensor, Dict[NodeType, Tensor]],
):
r"""Returns masked copies of :obj:`node_time` and :obj:`edge_time`
where timestamps below the time window floor are replaced with a
sentinel value (:obj:`torch.iinfo(torch.int64).max`) so that the
kernel's upper-bound temporal filter (:obj:`<= seed_time`) will
always exclude them.

.. warning::

This is a conservative batch-level approximation. The floor is
computed as :obj:`min(seed_time) - time_window` across the entire
batch, meaning the effective window for seeds with higher
timestamps is wider than :obj:`time_window`. The approximation
is exact when all seed times in the batch are equal, and
degrades as the spread of seed times grows relative to
:obj:`time_window`. For per-seed accuracy, native kernel support
in :obj:`pyg-lib` is required (see
:obj:`WITH_TIME_WINDOW_NEIGHBOR_SAMPLE`).

Uses :obj:`seed_time.min() - time_window` as a conservative global
floor. Edges within :obj:`[min(seed_time) - time_window, seed_time]`
are always included; edges below this floor are always excluded.
"""
import torch_geometric.typing as tgt
sentinel = tgt.MAX_INT64

if isinstance(seed_time, dict):
# Heterogeneous: compute global floor across all node types.
all_times = [v for v in seed_time.values() if v is not None]
if len(all_times) == 0:
return self.node_time, self.edge_time
global_min = torch.stack([t.min() for t in all_times]).min().item()
global_max = torch.stack([t.max() for t in all_times]).max().item()
floor = global_min - self.time_window
if global_max - global_min > self.time_window:
warnings.warn(
f"The seed time spread ({global_max - global_min}) in "
f"this batch exceeds 'time_window' ({self.time_window}). "
f"The time window filter is a batch-level approximation "
f"and will be wider than intended for seeds with higher "
f"timestamps. For per-seed accuracy, native 'pyg-lib' "
f"support is required.", stacklevel=3)

node_time = None
if self.node_time is not None:
node_time = {
k: torch.where(v >= floor, v, torch.full_like(v, sentinel))
for k, v in self.node_time.items()
}
edge_time = None
if self.edge_time is not None:
edge_time = {
k: torch.where(v >= floor, v, torch.full_like(v, sentinel))
for k, v in self.edge_time.items()
}
else:
# Homogeneous: compute global floor across the batch.
seed_min = seed_time.min().item()
seed_max = seed_time.max().item()
floor = seed_min - self.time_window
if seed_max - seed_min > self.time_window:
warnings.warn(
f"The seed time spread ({seed_max - seed_min}) in this "
f"batch exceeds 'time_window' ({self.time_window}). "
f"The time window filter is a batch-level approximation "
f"and will be wider than intended for seeds with higher "
f"timestamps. For per-seed accuracy, native 'pyg-lib' "
f"support is required.", stacklevel=3)

node_time = None
if self.node_time is not None:
node_time = torch.where(
self.node_time >= floor,
self.node_time,
torch.full_like(self.node_time, sentinel),
)
edge_time = None
if self.edge_time is not None:
edge_time = torch.where(
self.edge_time >= floor,
self.edge_time,
torch.full_like(self.edge_time, sentinel),
)

return node_time, edge_time

def _sample(
self,
seed: Union[Tensor, Dict[NodeType, Tensor]],
Expand All @@ -435,6 +568,13 @@ def _sample(
r"""Implements neighbor sampling by calling either :obj:`pyg-lib` (if
installed) or :obj:`torch-sparse` (if installed) sampling routines.
"""
# Apply time window pre-filtering when requested:
if self.time_window is not None and seed_time is not None:
node_time, edge_time = self._get_time_window_masked_times(
seed_time)
else:
node_time, edge_time = self.node_time, self.edge_time

if isinstance(seed, dict): # Heterogeneous sampling:
# TODO Support induced subgraph sampling in `pyg-lib`.
if (torch_geometric.typing.WITH_PYG_LIB
Expand All @@ -451,10 +591,10 @@ def _sample(
self.row_dict,
seed,
self.num_neighbors.get_mapped_values(self.edge_types),
self.node_time,
node_time,
)
if torch_geometric.typing.WITH_EDGE_TIME_NEIGHBOR_SAMPLE:
args += (self.edge_time, )
args += (edge_time, )
args += (seed_time, )
if torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE:
args += (self.edge_weight, )
Expand Down Expand Up @@ -556,10 +696,10 @@ def _sample(
# TODO (matthias) `seed` should inherit dtype from `colptr`
seed.to(self.colptr.dtype),
self.num_neighbors.get_mapped_values(),
self.node_time,
node_time,
)
if torch_geometric.typing.WITH_EDGE_TIME_NEIGHBOR_SAMPLE:
args += (self.edge_time, )
args += (edge_time, )
args += (seed_time, )
if torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE:
args += (self.edge_weight, )
Expand Down
3 changes: 3 additions & 0 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
pyg_lib.sampler.neighbor_sample).parameters)
WITH_WEIGHTED_NEIGHBOR_SAMPLE = ('edge_weight' in inspect.signature(
pyg_lib.sampler.neighbor_sample).parameters)
WITH_TIME_WINDOW_NEIGHBOR_SAMPLE = ('time_window' in inspect.signature(
pyg_lib.sampler.neighbor_sample).parameters)
try:
torch.classes.pyg.CPUHashMap # noqa: B018
WITH_CPU_HASH_MAP = True
Expand All @@ -96,6 +98,7 @@
WITH_METIS = False
WITH_EDGE_TIME_NEIGHBOR_SAMPLE = False
WITH_WEIGHTED_NEIGHBOR_SAMPLE = False
WITH_TIME_WINDOW_NEIGHBOR_SAMPLE = False
WITH_CPU_HASH_MAP = False
WITH_CUDA_HASH_MAP = False

Expand Down