From 0cb904d328198b498c139f52a06799defbca5e9b Mon Sep 17 00:00:00 2001 From: CoreLeader Date: Tue, 5 Apr 2022 22:37:00 +0800 Subject: [PATCH] Initial proposal of unified samplers for PyG backend --- autogl/module/sampling/__init__.py | 0 .../module/sampling/graph_sampler/__init__.py | 25 ++ .../sampling/graph_sampler/_graph_sampler.py | 42 +++ .../sampling/graph_sampler/_pyg/__init__.py | 16 + .../_pyg/_pyg_homogeneous_graph_sampler.py | 311 ++++++++++++++++++ .../graph_sampler/_sampler_utility.py | 78 +++++ .../test_pyg_graph_sampler_pipeline.py | 131 ++++++++ 7 files changed, 603 insertions(+) create mode 100644 autogl/module/sampling/__init__.py create mode 100644 autogl/module/sampling/graph_sampler/__init__.py create mode 100644 autogl/module/sampling/graph_sampler/_graph_sampler.py create mode 100644 autogl/module/sampling/graph_sampler/_pyg/__init__.py create mode 100644 autogl/module/sampling/graph_sampler/_pyg/_pyg_homogeneous_graph_sampler.py create mode 100644 autogl/module/sampling/graph_sampler/_sampler_utility.py create mode 100644 test/module/sampling/test_pyg_graph_sampler/test_pyg_graph_sampler_pipeline.py diff --git a/autogl/module/sampling/__init__.py b/autogl/module/sampling/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/autogl/module/sampling/graph_sampler/__init__.py b/autogl/module/sampling/graph_sampler/__init__.py new file mode 100644 index 00000000..108efed2 --- /dev/null +++ b/autogl/module/sampling/graph_sampler/__init__.py @@ -0,0 +1,25 @@ +import autogl +from ._graph_sampler import GraphSampler, SampledSubgraph, GraphSamplerUniversalRegistry, instantiate_graph_sampler + + +if autogl.backend.DependentBackend.is_pyg(): + from ._pyg import ( + PyGGraphSampler, PyGHomogeneousGraphSampler, PyGSampledSubgraph, + PyGClusterSampler, PyGNeighborSampler, + PyGGraphSAINTNodeSampler, PyGGraphSAINTEdgeSampler, PyGGraphSAINTRandomWalkSampler + ) + + __all__ = [ + 'GraphSampler', + 'SampledSubgraph', + 'GraphSamplerUniversalRegistry', + 'instantiate_graph_sampler', + 'PyGGraphSampler', + 'PyGHomogeneousGraphSampler', + 'PyGSampledSubgraph', + 'PyGClusterSampler', + 'PyGNeighborSampler', + 'PyGGraphSAINTNodeSampler', + 'PyGGraphSAINTEdgeSampler', + 'PyGGraphSAINTRandomWalkSampler' + ] diff --git a/autogl/module/sampling/graph_sampler/_graph_sampler.py b/autogl/module/sampling/graph_sampler/_graph_sampler.py new file mode 100644 index 00000000..9a72e711 --- /dev/null +++ b/autogl/module/sampling/graph_sampler/_graph_sampler.py @@ -0,0 +1,42 @@ +import torch +import typing +from autogl.utils import universal_registry + + +class GraphSampler(torch.nn.Module, typing.Iterable): + def __iter__(self): + raise NotImplementedError + + +class SampledSubgraph: + ... + + +class GraphSamplerUniversalRegistry(universal_registry.UniversalRegistryBase): + @classmethod + def register_graph_sampler(cls, name: str) -> typing.Callable[ + [typing.Type[GraphSampler]], typing.Type[GraphSampler] + ]: + def register_sampler( + graph_sampler: typing.Type[GraphSampler] + ) -> typing.Type[GraphSampler]: + if not issubclass(graph_sampler, GraphSampler): + raise TypeError + else: + cls[name] = graph_sampler + return graph_sampler + + return register_sampler + + @classmethod + def get_graph_sampler(cls, name: str) -> typing.Type[GraphSampler]: + if name not in cls: + raise ValueError(f"Graph Sampler with name \"{name}\" not exist") + else: + return cls[name] + + +def instantiate_graph_sampler( + graph_sampler_name: str, data, sampler_configurations: typing.Mapping[str, typing.Any], **kwargs +) -> GraphSampler: + return GraphSamplerUniversalRegistry[graph_sampler_name](data, sampler_configurations, **kwargs) diff --git a/autogl/module/sampling/graph_sampler/_pyg/__init__.py b/autogl/module/sampling/graph_sampler/_pyg/__init__.py new file mode 100644 index 00000000..f72d98b7 --- /dev/null +++ b/autogl/module/sampling/graph_sampler/_pyg/__init__.py @@ -0,0 +1,16 @@ +from ._pyg_homogeneous_graph_sampler import ( + PyGGraphSampler, PyGHomogeneousGraphSampler, PyGSampledSubgraph, + PyGClusterSampler, PyGNeighborSampler, + PyGGraphSAINTNodeSampler, PyGGraphSAINTEdgeSampler, PyGGraphSAINTRandomWalkSampler +) + +__all__ = [ + 'PyGGraphSampler', + 'PyGHomogeneousGraphSampler', + 'PyGSampledSubgraph', + 'PyGClusterSampler', + 'PyGNeighborSampler', + 'PyGGraphSAINTNodeSampler', + 'PyGGraphSAINTEdgeSampler', + 'PyGGraphSAINTRandomWalkSampler' +] diff --git a/autogl/module/sampling/graph_sampler/_pyg/_pyg_homogeneous_graph_sampler.py b/autogl/module/sampling/graph_sampler/_pyg/_pyg_homogeneous_graph_sampler.py new file mode 100644 index 00000000..6627a860 --- /dev/null +++ b/autogl/module/sampling/graph_sampler/_pyg/_pyg_homogeneous_graph_sampler.py @@ -0,0 +1,311 @@ +import torch +import typing +import torch_geometric.loader +from .. import _graph_sampler, _sampler_utility + + +class PyGGraphSampler(_graph_sampler.GraphSampler): + def __iter__(self): + raise NotImplementedError + + +class PyGHomogeneousGraphSampler(PyGGraphSampler): + def __iter__(self): + raise NotImplementedError + + +class PyGSampledSubgraph(_graph_sampler.SampledSubgraph): + @property + def data(self) -> torch_geometric.data.Data: + raise NotImplementedError + + +class _PyGSampledHomogeneousSubgraph(PyGSampledSubgraph): + @property + def data(self) -> torch_geometric.data.Data: + return self._data + + def __init__(self, data: torch_geometric.data.Data, *_args, **_kwargs): + if not isinstance(data, torch_geometric.data.Data): + raise TypeError + self._data: torch_geometric.data.Data = data + + +class _PyGHomogeneousGraphSamplerIterator(typing.Iterator): + def __init__( + self, iterable: typing.Iterable[torch_geometric.data.Data], + transform: typing.Optional[typing.Callable[[torch_geometric.data.Data], typing.Any]] = ... + ): + self.__iterator: typing.Iterator[torch_geometric.data.Data] = iter(iterable) + self._transform: typing.Optional[typing.Callable[[torch_geometric.data.Data], typing.Any]] = ( + transform if transform is not None and transform is not Ellipsis and callable(transform) else None + ) + + def __iter__(self) -> '_PyGHomogeneousGraphSamplerIterator': + return self + + def __next__(self): + __data: torch_geometric.data.Data = next(self.__iterator) + return self._transform(__data) if self._transform is not None and callable(self._transform) else __data + + +@_graph_sampler.GraphSamplerUniversalRegistry.register_graph_sampler('neighbor_sampler') +@_graph_sampler.GraphSamplerUniversalRegistry.register_graph_sampler('pyg_neighbor_sampler') +class PyGNeighborSampler(PyGHomogeneousGraphSampler): + def __init__( + self, data: torch_geometric.data.Data, + sampler_configurations: typing.Mapping[str, typing.Any], **kwargs + ): + super(PyGNeighborSampler, self).__init__() + __filtered_configurations, remaining_configurations = _sampler_utility.ConfigurationsFilter( + ( + ( + ('num_neighbors', 'sizes', 'FanOuts'.lower()), + lambda num_neighbors: isinstance(num_neighbors, typing.Iterable) and all( + (isinstance(_num_neighbors, int) and (_num_neighbors == -1 or _num_neighbors > 0)) + for _num_neighbors in num_neighbors + ), + lambda num_neighbors: list(num_neighbors), + None, f"specified num_neighbors/sizes/{'FanOuts'.lower()} argument must be list of integer" + ), + ( + ('input_nodes', 'node_idx', 'target_nodes'), + lambda input_nodes: input_nodes is None or isinstance(input_nodes, torch.Tensor), None, + ..., "specified input_nodes/node_idx/target_nodes argument must be either None or Tensor" + ), + (('replace',), ..., lambda replace: bool(replace), ..., None), + (('directed',), ..., lambda directed: bool(directed), ..., None), + ( + ('batch_size',), lambda batch_size: isinstance(batch_size, int) and batch_size > 0, + lambda batch_size: int(batch_size), ..., None + ), + (('shuffle',), ..., lambda shuffle: bool(shuffle), ..., None), + ( + ('transform',), + lambda _transform: _transform is None or _transform is Ellipsis or callable(_transform), + lambda _transform: _transform if callable(_transform) else None, + ..., 'specified transform argument must be either None or callable transform function' + ) + ) + ).filter({**sampler_configurations, **kwargs}) + _filtered_configurations: typing.MutableMapping[str, typing.Any] = dict(__filtered_configurations) + _transform: typing.Optional[typing.Callable[[torch_geometric.data.Data], torch_geometric.data.Data]] = ( + _filtered_configurations.pop('transform', None) + ) + + def transform(__data: torch_geometric.data.Data) -> torch_geometric.data.Data: + if not hasattr(__data, 'batch_size'): + raise ValueError + if not isinstance(__data.batch_size, int) and __data.batch_size > 0: + raise ValueError + __data.target_nodes_index = torch.arange(0, __data.batch_size, device=__data.edge_index.device) + return _transform(__data) if _transform is not None and callable(_transform) else __data + + self._neighbor_loader: torch_geometric.loader.NeighborLoader = torch_geometric.loader.NeighborLoader( + data, **{**_filtered_configurations, **remaining_configurations}, transform=transform + ) + + def __iter__(self) -> typing.Iterator[_PyGSampledHomogeneousSubgraph]: + return _PyGHomogeneousGraphSamplerIterator( + self._neighbor_loader, lambda data: _PyGSampledHomogeneousSubgraph(data) + ) + + +@_graph_sampler.GraphSamplerUniversalRegistry.register_graph_sampler('graph_saint_node_sampler') +@_graph_sampler.GraphSamplerUniversalRegistry.register_graph_sampler('pyg_graph_saint_node_sampler') +class PyGGraphSAINTNodeSampler(PyGHomogeneousGraphSampler): + def __init__( + self, data: torch_geometric.data.Data, + sampler_configurations: typing.Mapping[str, typing.Any], **kwargs + ): + super(PyGGraphSAINTNodeSampler, self).__init__() + _filtered_configurations, _remaining_configurations = _sampler_utility.ConfigurationsFilter( + ( + ( + ('batch_size',), lambda batch_size: isinstance(batch_size, int) and batch_size > 0, ..., None, + "specified batch_size argument MUST be a positive integer " + "representing the approximate number of samples per batch" + ), + ( + ('num_steps', 'num_iterations'), lambda num_steps: isinstance(num_steps, int) and num_steps > 0, + ..., ..., + "specified num_steps/num_iterations argument MUST be a positive integer " + "representing the number of iterations per epoch" + ), + ( + ('sample_coverage',), lambda sample_coverage: isinstance(sample_coverage, int) and sample_coverage >= 0, + ..., ..., + "specified sample_coverage argument MUST be a non-negative argument " + "representing the coverage factor should be used to compute normalization statistics" + ), + ( + ('save_dir',), lambda save_dir: save_dir in (Ellipsis, None) or isinstance(save_dir, str), + lambda save_dir: save_dir if isinstance(save_dir, str) else None, ..., + 'specified save_dir argument must be None or str representing the path of directory ' + 'to save the normalization statistics for faster re-use' + ), + ( + ('log',), lambda _log: isinstance(_log, bool), lambda _log: bool(_log), ..., + "specified log argument MUST be a bool representing whether logging any pre-processing progress" + ) + ) + ).filter({**sampler_configurations, **kwargs}) + self._graph_saint_sampler: torch_geometric.loader.GraphSAINTSampler = ( + torch_geometric.loader.GraphSAINTNodeSampler( + data, **{**_filtered_configurations, **_remaining_configurations} + ) + ) + + def __iter__(self) -> typing.Iterator[_PyGSampledHomogeneousSubgraph]: + return _PyGHomogeneousGraphSamplerIterator( + self._graph_saint_sampler, lambda data: _PyGSampledHomogeneousSubgraph(data) + ) + + +@_graph_sampler.GraphSamplerUniversalRegistry.register_graph_sampler('graph_saint_edge_sampler') +@_graph_sampler.GraphSamplerUniversalRegistry.register_graph_sampler('pyg_graph_saint_edge_sampler') +class PyGGraphSAINTEdgeSampler(PyGHomogeneousGraphSampler): + def __init__( + self, data: torch_geometric.data.Data, + sampler_configurations: typing.Mapping[str, typing.Any], **kwargs + ): + super(PyGGraphSAINTEdgeSampler, self).__init__() + _filtered_configurations, _remaining_configurations = _sampler_utility.ConfigurationsFilter( + ( + ( + ('batch_size',), lambda batch_size: isinstance(batch_size, int) and batch_size > 0, ..., None, + "specified batch_size argument MUST be a positive integer " + "representing the approximate number of samples per batch" + ), + ( + ('num_steps', 'num_iterations'), lambda num_steps: isinstance(num_steps, int) and num_steps > 0, + ..., ..., + "specified num_steps/num_iterations argument MUST be a positive integer " + "representing the number of iterations per epoch" + ), + ( + ('sample_coverage',), lambda sample_coverage: isinstance(sample_coverage, int) and sample_coverage >= 0, + ..., ..., + "specified sample_coverage argument MUST be a non-negative argument " + "representing the coverage factor should be used to compute normalization statistics" + ), + ( + ('save_dir',), lambda save_dir: save_dir in (Ellipsis, None) or isinstance(save_dir, str), + lambda save_dir: save_dir if isinstance(save_dir, str) else None, ..., + 'specified save_dir argument must be None or str representing the path of directory ' + 'to save the normalization statistics for faster re-use' + ), + ( + ('log',), lambda _log: isinstance(_log, bool), lambda _log: bool(_log), ..., + "specified log argument MUST be a bool representing whether logging any pre-processing progress" + ) + ) + ).filter({**sampler_configurations, **kwargs}) + self._graph_saint_sampler: torch_geometric.loader.GraphSAINTSampler = ( + torch_geometric.loader.GraphSAINTEdgeSampler( + data, **{**_filtered_configurations, **_remaining_configurations} + ) + ) + + def __iter__(self) -> typing.Iterator[_PyGSampledHomogeneousSubgraph]: + return _PyGHomogeneousGraphSamplerIterator( + self._graph_saint_sampler, lambda data: _PyGSampledHomogeneousSubgraph(data) + ) + + +@_graph_sampler.GraphSamplerUniversalRegistry.register_graph_sampler('graph_saint_random_walk_sampler') +@_graph_sampler.GraphSamplerUniversalRegistry.register_graph_sampler('pyg_graph_saint_random_walk_sampler') +class PyGGraphSAINTRandomWalkSampler(PyGHomogeneousGraphSampler): + def __init__( + self, data: torch_geometric.data.Data, + sampler_configurations: typing.Mapping[str, typing.Any], **kwargs + ): + super(PyGGraphSAINTRandomWalkSampler, self).__init__() + _filtered_configurations, _remaining_configurations = _sampler_utility.ConfigurationsFilter( + ( + ( + ('batch_size',), lambda batch_size: isinstance(batch_size, int) and batch_size > 0, ..., None, + "specified batch_size argument MUST be a positive integer " + "representing the approximate number of samples per batch" + ), + ( + ('walk_length',), lambda walk_length: isinstance(walk_length, int) and walk_length > 0, ..., None, + "specified walk_length argument MUST be a positive integer " + "representing the length of each random walk" + ), + ( + ('num_steps', 'num_iterations'), lambda num_steps: isinstance(num_steps, int) and num_steps > 0, + ..., ..., + "specified num_steps/num_iterations argument MUST be a positive integer " + "representing the number of iterations per epoch" + ), + ( + ('sample_coverage',), lambda s_coverage: isinstance(s_coverage, int) and s_coverage >= 0, ..., ..., + "specified sample_coverage argument MUST be a non-negative argument " + "representing the coverage factor should be used to compute normalization statistics" + ), + ( + ('save_dir',), lambda save_dir: save_dir in (Ellipsis, None) or isinstance(save_dir, str), + lambda save_dir: save_dir if isinstance(save_dir, str) else None, ..., + 'specified save_dir argument must be None or str representing the path of directory ' + 'to save the normalization statistics for faster re-use' + ), + ( + ('log',), lambda _log: isinstance(_log, bool), lambda _log: bool(_log), ..., + "specified log argument MUST be a bool representing whether logging any pre-processing progress" + ) + ) + ).filter({**sampler_configurations, **kwargs}) + self._graph_saint_sampler: torch_geometric.loader.GraphSAINTSampler = ( + torch_geometric.loader.GraphSAINTRandomWalkSampler( + data, **{**_filtered_configurations, **_remaining_configurations} + ) + ) + + def __iter__(self) -> typing.Iterator[_PyGSampledHomogeneousSubgraph]: + return _PyGHomogeneousGraphSamplerIterator( + self._graph_saint_sampler, lambda data: _PyGSampledHomogeneousSubgraph(data) + ) + + +@_graph_sampler.GraphSamplerUniversalRegistry.register_graph_sampler('cluster_sampler') +@_graph_sampler.GraphSamplerUniversalRegistry.register_graph_sampler('pyg_cluster_sampler') +class PyGClusterSampler(PyGHomogeneousGraphSampler): + def __init__( + self, data: torch_geometric.data.Data, + sampler_configurations: typing.Mapping[str, typing.Any], **kwargs + ): + super(PyGClusterSampler, self).__init__() + _filtered_configurations, _remaining_configurations = _sampler_utility.ConfigurationsFilter( + ( + ( + ('num_parts',), + lambda num_parts: isinstance(num_parts, int) and num_parts > 0, lambda num_parts: int(num_parts), + None, 'specified num_parts argument be positive integer representing the number of partitions' + ), + ( + ('recursive',), lambda recursive: isinstance(recursive, bool), lambda recursive: bool(recursive), + ..., + 'specified recursive argument must be bool ' + 'indicating whether to use multilevel recursive bisection instead of multilevel k-way partitioning' + ), + ( + ('save_dir',), lambda save_dir: save_dir in (Ellipsis, None) or isinstance(save_dir, str), + lambda save_dir: save_dir if isinstance(save_dir, str) else None, ..., + 'specified save_dir argument must be None or str representing the path of directory ' + 'to save the partitioned data for faster re-use' + ), + ( + ('log',), lambda _log: isinstance(_log, bool), lambda _log: bool(_log), ..., + "specified log argument MUST be a bool representing whether logging any pre-processing progress" + ) + ) + ).filter({**sampler_configurations, **kwargs}) + self.__cluster_loader: torch_geometric.loader.ClusterLoader = torch_geometric.loader.ClusterLoader( + torch_geometric.loader.ClusterData(data, **_filtered_configurations), **_remaining_configurations + ) + + def __iter__(self) -> typing.Iterator[_PyGSampledHomogeneousSubgraph]: + return _PyGHomogeneousGraphSamplerIterator( + self.__cluster_loader, lambda data: _PyGSampledHomogeneousSubgraph(data) + ) diff --git a/autogl/module/sampling/graph_sampler/_sampler_utility.py b/autogl/module/sampling/graph_sampler/_sampler_utility.py new file mode 100644 index 00000000..c384e923 --- /dev/null +++ b/autogl/module/sampling/graph_sampler/_sampler_utility.py @@ -0,0 +1,78 @@ +import typing + + +class ConfigurationsFilter: + def __init__( + self, + rules: typing.Iterable[ + typing.Tuple[ + typing.Sequence[str], + typing.Optional[typing.Callable[[typing.Any], bool]], + typing.Optional[typing.Callable[[typing.Any], typing.Any]], + typing.Any, typing.Optional[str] + ] + ] + ): + if not isinstance(rules, typing.Iterable): + raise TypeError + for rule in rules: + if not (isinstance(rule, typing.Sequence) and len(rule) == 5): + raise TypeError + self.__filter_rules: typing.Iterable[ + typing.Tuple[ + typing.Sequence[str], + typing.Optional[typing.Callable[[typing.Any], bool]], + typing.Optional[typing.Callable[[typing.Any], typing.Any]], + typing.Any, typing.Optional[str] + ] + ] = rules + + def filter( + self, configurations: typing.Mapping[str, typing.Any] + ) -> typing.Tuple[typing.Mapping[str, typing.Any], typing.Mapping[str, typing.Any]]: + remaining_configurations: typing.MutableMapping[str, typing.Any] = dict(configurations) + filtered_configurations: typing.MutableMapping[str, typing.Any] = dict() + for rule in self.__filter_rules: + if len(rule[0]) == 0: + continue + _matched: bool = False + for matching_key in rule[0][::-1]: + if matching_key in remaining_configurations: + _matched = True + __configuration_item = remaining_configurations.pop(matching_key) + if rule[1] not in (Ellipsis, None) and callable(rule[1]): + if rule[1](__configuration_item): + filtered_configurations[rule[0][0]] = ( + rule[2](__configuration_item) if rule[2] is not None and callable(rule[2]) + else __configuration_item + ) + else: + filtered_configurations[rule[0][0]] = ( + rule[2](__configuration_item) if rule[2] is not None and callable(rule[2]) + else __configuration_item + ) + if _matched: + if rule[0][0] not in filtered_configurations: + if rule[4] is Ellipsis: + continue + if rule[4] is None: + raise ValueError( + f"One of the following keys {rule[0]} exists in provided configurations " + f"but none of the matched values satisfies certain requirement" + ) + if isinstance(rule[4], str) and len(rule[4].strip()) > 0: + raise ValueError( + f"One of the following keys {rule[0]} exists in provided configurations " + f"but none of the matched values satisfies certain requirement, " + f"the auxiliary information: {rule[4].strip()}" + ) + else: + if rule[3] not in (Ellipsis, None): + filtered_configurations[rule[0][0]] = ( + rule[2](rule[3]) if rule[2] is not None and callable(rule[2]) else rule[3] + ) + if rule[3] is Ellipsis: + continue + if rule[3] is None: + raise KeyError + return filtered_configurations, remaining_configurations diff --git a/test/module/sampling/test_pyg_graph_sampler/test_pyg_graph_sampler_pipeline.py b/test/module/sampling/test_pyg_graph_sampler/test_pyg_graph_sampler_pipeline.py new file mode 100644 index 00000000..41f167bb --- /dev/null +++ b/test/module/sampling/test_pyg_graph_sampler/test_pyg_graph_sampler_pipeline.py @@ -0,0 +1,131 @@ +import os +import argparse + +import tqdm + +os.environ["AUTOGL_BACKEND"] = 'pyg' + +import torch.nn.functional +import autogl.datasets.utils.conversion +import autogl.module.sampling.graph_sampler +import torch_geometric +from torch_geometric.nn import GCNConv + +mock_sampler_configurations = { + 'neighbor_sampler': { + 'num_neighbors': [25, 5], + 'batch_size': 256, + 'shuffle': True + }, + 'graph_saint_node_sampler': { + 'batch_size': 512, + 'num_steps': 10, + 'sample_coverage': 100 + }, + 'graph_saint_edge_sampler': { + 'batch_size': 512, + 'num_steps': 10, + 'sample_coverage': 100 + }, + 'graph_saint_random_walk_sampler': { + 'batch_size': 512, + 'walk_length': 2, + 'num_steps': 10, + 'sample_coverage': 100 + }, + 'cluster_sampler': { + 'num_parts': 50, + 'recursive': False, + 'batch_size': 10, + 'shuffle': True, + 'num_workers': 8 + } +} + + +class GNN(torch.nn.Module): + def __init__(self, input_dimension: int, output_dimension: int): + super(GNN, self).__init__() + self._gcn = GCNConv(input_dimension, output_dimension) + + def forward(self, data: torch_geometric.data.Data) -> torch.Tensor: + """ This model is a trivial 1-layer GCN """ + return torch.log_softmax(self._gcn.forward(data.x, data.edge_index, data.edge_weight), dim=-1) + + +__device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +if __name__ == '__main__': + argument_parser = argparse.ArgumentParser() + argument_parser.add_argument( + '--sampler', default='cluster_sampler', + choices=[ + 'neighbor_sampler', + 'graph_saint_node_sampler', + 'graph_saint_edge_sampler', + 'graph_saint_random_walk_sampler', + 'cluster_sampler' + ] + ) + arguments = argument_parser.parse_args() + + sampler_name = arguments.sampler + sampler_configurations = mock_sampler_configurations[sampler_name] + + cora_dataset = autogl.datasets.utils.conversion.to_pyg_dataset( + autogl.datasets.build_dataset_from_name('cora') + ) + cora_data = cora_dataset[0] + row, col = cora_data.edge_index + ''' Normalized edge_weight by in-degree ''' + cora_data.edge_weight = 1. / torch_geometric.utils.degree(col, cora_data.num_nodes)[col] + + gnn_model = GNN(cora_data.x.size(1), int(cora_data.y.max()) + 1).to(__device) + + graph_sampler = autogl.module.sampling.graph_sampler.instantiate_graph_sampler( + sampler_name, cora_data, sampler_configurations + ) + + optimizer = torch.optim.Adam(gnn_model.parameters(), lr=0.001) + + ''' train ''' + for _epoch in tqdm.tqdm(range(10)): + for sampled_subgraph in graph_sampler: + sampled_batch_data = sampled_subgraph.data + assert isinstance(sampled_batch_data, torch_geometric.data.Data) + sampled_batch_data.to(__device) + + if ( + hasattr(sampled_batch_data, 'node_norm') and + hasattr(sampled_batch_data, 'edge_norm') and + isinstance(sampled_batch_data.node_norm, torch.Tensor) and + isinstance(sampled_batch_data.edge_norm, torch.Tensor) and + torch.is_tensor(sampled_batch_data.node_norm) and + torch.is_tensor(sampled_batch_data.edge_norm) and + sampled_batch_data.node_norm.dim() == sampled_batch_data.edge_norm.dim() == 1 and + sampled_batch_data.node_norm.size(0) == sampled_batch_data.x.size(0) and + sampled_batch_data.edge_norm.size(0) == sampled_batch_data.edge_index.size(1) + ): + sampled_batch_data.edge_weight *= sampled_batch_data.edge_norm + out = gnn_model(sampled_batch_data) + loss = torch.nn.functional.nll_loss(out, sampled_batch_data.y, reduction='none') + loss = (loss * sampled_batch_data.node_norm)[sampled_batch_data.train_mask].sum() + else: + out = gnn_model(sampled_batch_data) + if ( + hasattr(sampled_batch_data, 'target_nodes_index') and + isinstance(sampled_batch_data.target_nodes_index, torch.Tensor) and + torch.is_tensor(sampled_batch_data.target_nodes_index) + ): + loss = torch.nn.functional.nll_loss( + out[sampled_batch_data.target_nodes_index], + sampled_batch_data.y[sampled_batch_data.target_nodes_index] + ) + else: + loss = torch.nn.functional.nll_loss( + out[sampled_batch_data.train_mask], + sampled_batch_data.y[sampled_batch_data.train_mask] + ) + + loss.backward() + optimizer.step()