Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow MpDeviceLoader to shard dictionaries of tensor #8202

Merged
merged 4 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/spmd_advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,24 @@ train_loader = pl.MpDeviceLoader(
input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None)))
```

It is also possible to specify a different `input_sharding` for each element of the batch if they are different shapes:

```python
# if batch = next(train_loader) looks like
# {'x': <tensor of shape [s1, s2, s3, s4]>, 'y': <tensor for shape [s1, s2]>}

# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
train_loader, # wraps PyTorch DataLoader
device,
# specify different sharding for each input of the batch.
bhavya01 marked this conversation as resolved.
Show resolved Hide resolved
input_sharding={
'x': xs.ShardingSpec(input_mesh, ('data', None, None, None)),
'y': xs.ShardingSpec(input_mesh, ('data', None))
}
)
```

### Virtual Device Optimization

PyTorch/XLA normally transfers tensor data asynchronously from host to device once the tensor is defined. This is to overlap the data transfer with the graph tracing time. However, because GSPMD allows the user to modify the tensor sharding _after _the tensor has been defined, we need an optimization to prevent unnecessary transfer of tensor data back and forth between host and device. We introduce Virtual Device Optimization, a technique to place the tensor data on a virtual device SPMD:0 first, before uploading to the physical devices when all the sharding decisions are finalized. Every tensor data in SPMD mode is placed on a virtual device, SPMD:0. The virtual device is exposed to the user as an XLA device XLA:0 with the actual shards on physical devices, like TPU:0, TPU:1, etc.
Expand Down
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ function run_xla_op_tests3 {
run_test "$CDIR/spmd/test_dtensor_integration2.py"
run_test "$CDIR/spmd/test_xla_auto_sharding.py"
run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py"
run_test "$CDIR/spmd/test_mp_input_sharding.py"
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
run_test "$CDIR/test_input_output_aliases.py"
run_test "$CDIR/test_torch_distributed_xla_backend.py"
Expand Down
151 changes: 151 additions & 0 deletions test/spmd/test_mp_input_sharding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import sys
import numpy as np
import unittest

import torch
import torch_xla
from torch_xla import runtime as xr
import torch_xla.core.xla_model as xm
from torch_xla.distributed.spmd import Mesh
import torch_xla.distributed.spmd as xs
import torch_xla.distributed.parallel_loader as pl

xr.use_spmd()


class MpInputShardingTest(unittest.TestCase):

class fake_dataloader:

def __init__(self, batch, size=1):
self.batch = batch
self.batch_size = size
self.counter = 0

def __iter__(self):
return self

def __next__(self):
if self.counter < self.batch_size:
self.counter += 1
return self.batch
raise StopIteration

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_multiple_inputs(self):
device = xm.xla_device()
batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))}
train_loader = self.fake_dataloader(batch)
num_devices = xr.global_runtime_device_count()
mesh = xs.get_1d_mesh('x')

train_loader = pl.MpDeviceLoader(
train_loader,
device,
input_sharding={
'x': xs.ShardingSpec(mesh, ('x', None)),
'y': xs.ShardingSpec(mesh, ('x', None, None))
})
train_loader = iter(train_loader)
data = next(train_loader)
annotation_x = '{devices=[%d,1]%s}' % (num_devices, ','.join(
[str(i) for i in range(num_devices)]))
annotation_y = '{devices=[%d,1,1]%s}' % (num_devices, ','.join(
[str(i) for i in range(num_devices)]))
self.assertEqual(annotation_x,
torch_xla._XLAC._get_xla_sharding_spec(data['x']))
self.assertEqual(annotation_y,
torch_xla._XLAC._get_xla_sharding_spec(data['y']))

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_single_tensor(self):
device = xm.xla_device()
batch = torch.randn((16, 128))
train_loader = self.fake_dataloader(batch)
num_devices = xr.global_runtime_device_count()
mesh = xs.get_1d_mesh('x')

train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, ('x', None)))
train_loader = iter(train_loader)
data = next(train_loader)
annotation = '{devices=[%d,1]%s}' % (num_devices, ','.join(
[str(i) for i in range(num_devices)]))
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(data))

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_error_single_tensor_with_input_sharding_dict(self):
device = xm.xla_device()
batch = torch.randn((16, 128))
train_loader = self.fake_dataloader(batch)
num_devices = xr.global_runtime_device_count()
mesh = xs.get_1d_mesh('x')

train_loader = pl.MpDeviceLoader(
train_loader,
device,
input_sharding={'x': xs.ShardingSpec(mesh, ('x', None))})
train_loader = iter(train_loader)
with self.assertRaises(ValueError):
data = next(train_loader)

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_input_sharding_none(self):
device = xm.xla_device()
batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))}
train_loader = self.fake_dataloader(batch)
num_devices = xr.global_runtime_device_count()

train_loader = pl.MpDeviceLoader(train_loader, device, input_sharding=None)
train_loader = iter(train_loader)
data = next(train_loader)
annotation = '{replicated}'
self.assertEqual(annotation,
torch_xla._XLAC._get_xla_sharding_spec(data['x']))
self.assertEqual(annotation,
torch_xla._XLAC._get_xla_sharding_spec(data['y']))

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_error_missing_keys(self):
device = xm.xla_device()
batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))}
train_loader = self.fake_dataloader(batch)
mesh = xs.get_1d_mesh('x')
train_loader = pl.MpDeviceLoader(
train_loader,
device,
input_sharding={'x': xs.ShardingSpec(mesh, ('x', None))})
train_loader = iter(train_loader)
with self.assertRaises(KeyError):
data = next(train_loader)

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_input_sharding_not_dict(self):
device = xm.xla_device()
num_devices = xr.global_runtime_device_count()
batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128))}
train_loader = self.fake_dataloader(batch)
mesh = xs.get_1d_mesh('x')
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, ('x', None)))
train_loader = iter(train_loader)
data = next(train_loader)
annotation_x = '{devices=[%d,1]%s}' % (num_devices, ','.join(
[str(i) for i in range(num_devices)]))
annotation_y = '{devices=[%d,1]%s}' % (num_devices, ','.join(
[str(i) for i in range(num_devices)]))
self.assertEqual(annotation_x,
torch_xla._XLAC._get_xla_sharding_spec(data['x']))
self.assertEqual(annotation_y,
torch_xla._XLAC._get_xla_sharding_spec(data['y']))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
74 changes: 59 additions & 15 deletions torch_xla/distributed/parallel_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class PerDeviceQueue(object):

def __init__(self, device, loader_prefetch_size, device_prefetch_size):
self.device = device
self.loader_queue = kq.Queue(maxsize=loader_prefetch_size)
self.cpu_loader_queue = kq.Queue(maxsize=loader_prefetch_size)
self.queue = kq.Queue(maxsize=device_prefetch_size)
self.close_queue_count = itertools.count()

Expand Down Expand Up @@ -46,6 +46,8 @@ def next(self):
self._batches_yielded += 1

item = self._loader.next_item(self._device)
if isinstance(item, Exception):
raise item
bhavya01 marked this conversation as resolved.
Show resolved Hide resolved
if item is None:
xm.mark_step()
raise StopIteration
Expand All @@ -56,7 +58,7 @@ class ParallelLoader(object):
"""Wraps an existing PyTorch DataLoader with background data upload.

Args:
loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be
cpu_loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be
wrapped.
devices (`torch.device`...): The list of devices where the data has to be
sent. The i-th sample returned by the `loader` will be sent to `devices[i
Expand All @@ -74,21 +76,20 @@ class ParallelLoader(object):
host_to_device_transfer_threads (int, optional): The number of threads that
work in parallel to transfer data from loader queue to device queue.
Default: 1
input_sharding (ShardingSpec, optional): Sharding spec to apply to
compatible input tensors after loading.
Default: None
input_sharding (ShardingSpec, Dict(str, ShardingSpec), optional): Sharding
spec to apply to compatible input tensors after loading.
"""

def __init__(self,
loader,
cpu_loader,
devices,
batchdim=0,
batches_per_execution=1,
loader_prefetch_size=16,
device_prefetch_size=8,
host_to_device_transfer_threads=1,
input_sharding=None):
self._loader = loader
self._cpu_loader = cpu_loader
self._devices = [torch.device(x) for x in devices]
self._batchdim = batchdim
self._batches_per_execution = batches_per_execution
Expand Down Expand Up @@ -140,7 +141,7 @@ def close(self):
self._done = True
for dqueue in self._queues.values():
dqueue.queue.close()
dqueue.loader_queue.close()
dqueue.cpu_loader_queue.close()

for thread in self._threads:
thread.join()
Expand All @@ -151,7 +152,7 @@ def batches_per_execution(self):

def _loader_worker(self):
queues = list(self._queues.values())
data_iter = enumerate(self._loader)
data_iter = enumerate(self._cpu_loader)
batch = []

try:
Expand All @@ -163,21 +164,65 @@ def _loader_worker(self):
batch.append(data)
if len(batch) == len(self._devices):
for queue_no, device_batch in enumerate(batch):
queues[queue_no].loader_queue.put(device_batch)
queues[queue_no].cpu_loader_queue.put(device_batch)
batch = []
finally:
for dqueue in queues:
dqueue.loader_queue.close_write()
dqueue.cpu_loader_queue.close_write()

def _get_batch(self, dqueue):
batch = []
while dqueue.queue.max_size() > len(batch):
item = dqueue.loader_queue.get()
while len(batch) < dqueue.queue.max_size():
item = dqueue.cpu_loader_queue.get()
if item is None:
break
batch.append(item)
return batch

def send_cpu_data_to_device(self, batches, device):
"""Move batch to device.
Args:
batch -> List(torch.Tensor), List(Dict(str: torch.Tensor)): Input batch
present in the cpu memory
device: TPU device where the batch should be moved

Returns:
result -> List(torch.Tensor), Dict(str: torch.Tensor): Returns a dict if the
input batch is a dict. Otherwise, returns a list of torch.Tensor.
"""
result = None
if isinstance(self._input_sharding, dict):
if not isinstance(batches[0], dict):
return [
ValueError(
f"input batch should be a dict when input sharding is a dict.")
]
result = []
for batch in batches:
xla_batch = {}
missing_keys = []
for key, tensor in batch.items():
assert type(tensor) == torch.Tensor
sharding_spec = None
if self._input_sharding:
if key not in self._input_sharding:
missing_keys.append(key)
continue
sharding_spec = self._input_sharding[key]

# xla_tensor is a list of tensors.
xla_tensor = xm.send_cpu_data_to_device(tensor, device, sharding_spec)
xla_batch[key] = xla_tensor[0]
if len(missing_keys) != 0:
# Returning exception as raising in the dataloading thread doesn't surface the problem in the main thread.
return [
KeyError(f"Keys: {missing_keys} are missing from input_sharding.")
]
result.append(xla_batch)
else:
result = xm.send_cpu_data_to_device(batches, device, self._input_sharding)
return result

def _worker(self, dqueue, host_to_device_transfer_threads):
device = torch.device(dqueue.device)

Expand All @@ -187,8 +232,7 @@ def _worker(self, dqueue, host_to_device_transfer_threads):
if not batch:
break
with torch.no_grad():
batch = xm.send_cpu_data_to_device(batch, device,
self._input_sharding)
batch = self.send_cpu_data_to_device(batch, device)
for data in batch:
dqueue.queue.put(data)
finally:
Expand Down
Loading