Accessing DataPipe state with MultiProcessingReadingService #1033
Description
Hi TorchData team,
I'm wondering how to access the state of the datapipe in the multi-processing context with DataLoader2 + MultiProcessingReadingService. When using no reading service, we can simply access the graph using dataloader.datapipe
, then I can easily access the state of my datapipe using the code shown below.
However, in the multi processing case, the datapipe graph is replaced with QueueWrapper instances, and I cannot find any way to communicate with the workers to get access to the state of the data pipe (and I get the error that my StatefulIterator cannot be found on the datapipe). If I access dl2._datapipe_before_reading_service_adapt
I do get the initial state only which makes sense since there is no state sync between the main and worker processes.
As far as I understand, this will also be a blocker for state capturing for proper DataLoader checkpointing when the MultiProcessingReadingService is being used.
Potentially, could we add a getstate
communication primitive in communication.messages
in order to capture the state (via getstate) of a datapipe in a worker process?
We're also open to using sharding_round_robin_dispatch
in order to keep more information in the main process but I'm a bit confused on how to use it, if you have some sample code for me for the following case?
Running against today's master (commit a3b34a0):
import torchdata.datapipes as dp
from torch.utils.data.graph_settings import get_all_graph_pipes, traverse_dps
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
class StatefulIterator(dp.iter.IterDataPipe):
def __init__(self, datapipe):
self.datapipe = datapipe
self.custom_index = 0
def __iter__(self):
self.custom_index = 0
for item in self.datapipe:
self.custom_index += 1
yield item
self.custom_index = 0
def get_datapipe():
initial_data = dp.iter.IterableWrapper([1, 2, 3, 4])
stateful_data = StatefulIterator(initial_data)
sharded_data = stateful_data.sharding_filter()
return sharded_data
def get_datapipe_state(datapipe):
graph = traverse_dps(datapipe)
all_pipes = get_all_graph_pipes(graph)
for pipe in all_pipes:
if hasattr(pipe, "custom_index"):
return pipe.custom_index
raise ValueError("This datapipe does not contain a StatefulIterator.")
def main_no_multiprocessing():
dp = get_datapipe()
dl2 = DataLoader2(dp)
for item in dl2:
print("Custom index", get_datapipe_state(dl2.datapipe))
print("Item", item)
def main_multiprocessing():
dp = get_datapipe()
dl2 = DataLoader2(dp, reading_service=MultiProcessingReadingService(num_workers=4))
for item in dl2:
print("Custom index", get_datapipe_state(dl2.datapipe))
print("Item", item)
if __name__ == "__main__":
main_no_multiprocessing()
main_multiprocessing()