Skip to content

Implementation for Pytorch-geometric dataset #43

@avarbella

Description

@avarbella

I have added a few lines that allow to work with pytorch-geometric dataset. Since Pytorch-geometric data is saved as a list before being loaded by a Pytorch-geometric Dataloader, the modification is pretty simple.
Hope this could be helpful to someone.

Best,

Anna

`from typing import Callable

import pandas as pd
import torch
import torch.utils.data
import torchvision

class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
"""Samples elements randomly from a given list of indices for imbalanced dataset

Arguments:
    indices: a list of indices
    num_samples: number of samples to draw
    callback_get_label: a callback-like function which takes two arguments - dataset and index
"""

def __init__(self, dataset, indices: list = None, num_samples: int = None, callback_get_label: Callable = None):
    # if indices is not provided, all elements in the dataset will be considered
    self.indices = list(range(len(dataset))) if indices is None else indices

    # define custom callback
    self.callback_get_label = callback_get_label

    # if num_samples is not provided, draw `len(indices)` samples in each iteration
    self.num_samples = len(self.indices) if num_samples is None else num_samples

    # distribution of classes in the dataset
    df = pd.DataFrame()
    df["label"] = self._get_labels(dataset)
    df.index = self.indices
    df = df.sort_index()

    label_to_count = df["label"].value_counts()

    weights = 1.0 / label_to_count[df["label"]]

    self.weights = torch.DoubleTensor(weights.to_list())

def _get_labels(self, dataset):
    if self.callback_get_label:
        return self.callback_get_label(dataset)
    elif isinstance(dataset, torchvision.datasets.MNIST):
        return dataset.train_labels.tolist()
    elif isinstance(dataset, torchvision.datasets.ImageFolder):
        return [x[1] for x in dataset.imgs]
    elif isinstance(dataset, torchvision.datasets.DatasetFolder):
        return dataset.samples[:][1]
    elif isinstance(dataset, torch.utils.data.Subset):
        return dataset.dataset.imgs[:][1]
    elif isinstance(dataset, torch.utils.data.Dataset):
        return dataset.get_labels()
    elif isinstance(dataset, list):
        return [dataset[i].y.item() for i in range(len(dataset))]  #here the modification
    else:
        raise NotImplementedError

def __iter__(self):
    return (self.indices[i] for i in torch.multinomial(self.weights, self.num_samples, replacement=True))

def __len__(self):
    return self.num_samples

`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions