-
Notifications
You must be signed in to change notification settings - Fork 169
Description
🚀 The feature
A torchdata.nodes.Stack node that would use multiple torch.utils.data.Sampler and maybe torchdata.nodes.Batcher instances to generate independent batches from the same dataset and stack them.
A naive implementation (Still working on the actual code):
from typing import Optional, Any, Dict
import torch
from torch import Tensor
from torchdata.nodes import BaseNode
class Stack(BaseNode[Tensor]):
def __init__(self, *sources : BaseNode[Tensor]):
super().__init__()
self.sources = sources
def next(self):
try:
return torch.stack([next(node) for node in self.sources])
except StopIteration:
raise StopIteration()
def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
if initial_state is None:
for node in self.sources:
node.reset()
else:
for inode, node in enumerate(self.sources):
node.reset(initial_state[f"source_{inode}"])
def get_state(self) -> Dict[str, Any]:
state_dict = {}
for inode, node in enumerate(self.sources):
state_dict[f"source_{inode}"] = node.state_dict()
return state_dictA very minimal example usage:
from typing import Optional, Callable
from collections.abc import Sequence
from torch.utils.data import RandomSampler
from torchdata.nodes import SamplerWrapper
from torchdata.nodes import Batcher
from torchdata.nodes import ParallelMapper
from torchdata.nodes import Loader, Header
def stacked_batch_loader(
dataset: Sequence,
ncopies: int,
batch_size: int,
collate_fn : Optional[Callable] = None,
num_workers: int =0,
load_only_first : Optional[int] = None,
):
nodes_to_stack = []
num_workers_per_copy = num_workers // ncopies
map_fn = collate_fn or getattr(dataset, "__getitems__", None)
if map_fn is None:
raise ValueError("Either give a collate_fn or define a __getitems__ to sample mini-batches!")
for icopy in range(ncopies):
g = torch.Generator()
g.manual_seed(icopy % 2**32)
sampler = RandomSampler(dataset, generator=g)
node = SamplerWrapper(sampler)
node = Batcher(node, batch_size=batch_size)
node = ParallelMapper(node,
map_fn=map_fn,
num_workers=num_workers_per_copy,
method="process",
in_order=True,
)
nodes_to_stack.append(node)
stacked_node = Stack(*nodes_to_stack)
if load_only_first is not None:
stacked_node = Header(stacked_node, load_only_first)
return Loader(stacked_node)I would appreciate any feedback on this. I can do a PR if y'all think its useful.
Motivation, pitch
I am trying to do ensemble training using independent instances of the model stacked using torch.func.stack_module_state and using torch.func.vmap to vectorize the forward-pass/gradient computation over bootstrap-sampled and stacked batches from data. Similar to this tutorial.
For my use-case, if my main dataset has 10000 samples, I use say 10 sub-datasets of 5000 sampled using different random seeds to train 10 instances of my model, following that, I use mean and standard-deviations of the 10 predictions from the models to quantify uncertainties in different regions of my feature space.
An iterator that gives stacked batches of different subsamples of data would be convenient for this use case.
Alternatives
No response
Additional context
Sorry if I am mixing up any terminologies, I am not a software engineer strictly speaking, just use torch for ML applications in my research.