Skip to content

GraphAnnDataModule with masking option for spatial_node_loader #60

@FrancescaDr

Description

@FrancescaDr

Description of feature

When initialised the GraphAnnDataModule there is a graph_loader and spatial_node_loader. Currently, they load the entire dataset as batches. For self-supervised learning strategies it is necessary to mask some nodes during training. Because of imbalanced datasets the masking or sampling should consider the different e.i. cell type proportions. I made an initial draft for a spatial node loader that adds a .mask to the PyG data object.

def smallest_data_batch_length(self, data_list: List['BaseData']):
        """Returns the number of nodes in the smallest graph from the list of BaseData."""
        lengths = [data.num_nodes for data in data_list]
        return min(lengths)

    def _spatial_node_loader(self, 
                             data_list: List[BaseData], 
                             shuffle: bool = False, 
                             **kwargs) -> DataListLoader:
        """Adds a one-node mask to each Data object. TODO: load each graph multiple times with a different mask.

        Args:
        ----
        data: PyTorch geometric.Batch
        shuffle (bool, optional): whether to shuffle the data. Defaults to False.
        kwargs: arguments passed to the pyg.NeighborLoader

        Returns
        -------
            NeighborLoader: the node dataloader
        """
        smallest_length = self.smallest_data_batch_length(data_list)
        num_nodes_to_mask = int(smallest_length * self.pct_mask_nodes)
        if num_nodes_to_mask == 0: # must mask at least one node
            num_nodes_to_mask = 1
        
        for data in data_list:
            if data.num_nodes < num_nodes_to_mask:
                raise ValueError("Cannot sample more nodes than available in any graph.")

            # Randomly select a ndoe to mask
            mask_indices = random.sample(range(data.num_nodes), num_nodes_to_mask)
            data.mask = torch.zeros(data.num_nodes, dtype=torch.bool)
            data.mask[mask_indices] = True

        return DataLoader(
            dataset=data_list,
            shuffle=shuffle,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            #collate_fn=collate_fn,
            **kwargs,
        ) 

Some challenges:

  1. My current draft does not consider different e.i. cell type proportions which means that some cell types are never sampled and predicted. I can think of two solutions: 1. mask by considering stratification .obs or 2. remember the already sampled cells to next time mask cells that haven't been sampled yet. Personally, I would prefer the second option because then all information / cells are used for prediction or learning.
  2. Graphs are of different sizes. Currently I solve it by using the smallest graph as reference as getting random values for it. The problem is then that for some graphs it only includes 1 node while other graphs in the batch are of size >1k.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions