Skip to content

Accessing DataPipe state with MultiProcessingReadingService #1033

Open
@jhoareau

Description

Hi TorchData team,

I'm wondering how to access the state of the datapipe in the multi-processing context with DataLoader2 + MultiProcessingReadingService. When using no reading service, we can simply access the graph using dataloader.datapipe, then I can easily access the state of my datapipe using the code shown below.

However, in the multi processing case, the datapipe graph is replaced with QueueWrapper instances, and I cannot find any way to communicate with the workers to get access to the state of the data pipe (and I get the error that my StatefulIterator cannot be found on the datapipe). If I access dl2._datapipe_before_reading_service_adapt I do get the initial state only which makes sense since there is no state sync between the main and worker processes.

As far as I understand, this will also be a blocker for state capturing for proper DataLoader checkpointing when the MultiProcessingReadingService is being used.

Potentially, could we add a getstate communication primitive in communication.messages in order to capture the state (via getstate) of a datapipe in a worker process?
We're also open to using sharding_round_robin_dispatch in order to keep more information in the main process but I'm a bit confused on how to use it, if you have some sample code for me for the following case?

Running against today's master (commit a3b34a0):

import torchdata.datapipes as dp
from torch.utils.data.graph_settings import get_all_graph_pipes, traverse_dps
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService


class StatefulIterator(dp.iter.IterDataPipe):
    def __init__(self, datapipe):
        self.datapipe = datapipe
        self.custom_index = 0

    def __iter__(self):
        self.custom_index = 0
        for item in self.datapipe:
            self.custom_index += 1
            yield item
        self.custom_index = 0


def get_datapipe():
    initial_data = dp.iter.IterableWrapper([1, 2, 3, 4])
    stateful_data = StatefulIterator(initial_data)
    sharded_data = stateful_data.sharding_filter()
    return sharded_data


def get_datapipe_state(datapipe):
    graph = traverse_dps(datapipe)
    all_pipes = get_all_graph_pipes(graph)
    for pipe in all_pipes:
        if hasattr(pipe, "custom_index"):
            return pipe.custom_index

    raise ValueError("This datapipe does not contain a StatefulIterator.")


def main_no_multiprocessing():
    dp = get_datapipe()
    dl2 = DataLoader2(dp)
    for item in dl2:
        print("Custom index", get_datapipe_state(dl2.datapipe))
        print("Item", item)


def main_multiprocessing():
    dp = get_datapipe()
    dl2 = DataLoader2(dp, reading_service=MultiProcessingReadingService(num_workers=4))
    for item in dl2:
        print("Custom index", get_datapipe_state(dl2.datapipe))
        print("Item", item)


if __name__ == "__main__":
    main_no_multiprocessing()
    main_multiprocessing()

cc: @ejguan @VitalyFedyunin @NivekT

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions