Skip to content

Commit 89722bd

Browse files
committed
Allow MpDeviceLoader to shard dictionaries of tensor with different shapes
1 parent 5dbdb8d commit 89722bd

File tree

4 files changed

+228
-15
lines changed

4 files changed

+228
-15
lines changed

docs/spmd_advanced.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,24 @@ train_loader = pl.MpDeviceLoader(
1414
input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None)))
1515
```
1616

17+
It is also possible to specify a different `input_sharding` for each element of the batch if they are different shapes:
18+
19+
```python
20+
# if batch = next(train_loader) looks like
21+
# {'x': <tensor of shape [s1, s2, s3, s4]>, 'y': <tensor for shape [s1, s2]>}
22+
23+
# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
24+
train_loader = pl.MpDeviceLoader(
25+
train_loader, # wraps PyTorch DataLoader
26+
device,
27+
# specify different sharding for each input of the batch.
28+
input_sharding={
29+
'x': xs.ShardingSpec(input_mesh, ('data', None, None, None)),
30+
'y': xs.ShardingSpec(input_mesh, ('data', None))
31+
}
32+
)
33+
```
34+
1735
### Virtual Device Optimization
1836

1937
PyTorch/XLA normally transfers tensor data asynchronously from host to device once the tensor is defined. This is to overlap the data transfer with the graph tracing time. However, because GSPMD allows the user to modify the tensor sharding _after _the tensor has been defined, we need an optimization to prevent unnecessary transfer of tensor data back and forth between host and device. We introduce Virtual Device Optimization, a technique to place the tensor data on a virtual device SPMD:0 first, before uploading to the physical devices when all the sharding decisions are finalized. Every tensor data in SPMD mode is placed on a virtual device, SPMD:0. The virtual device is exposed to the user as an XLA device XLA:0 with the actual shards on physical devices, like TPU:0, TPU:1, etc.

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ function run_xla_op_tests3 {
245245
run_test "$CDIR/spmd/test_dtensor_integration2.py"
246246
run_test "$CDIR/spmd/test_xla_auto_sharding.py"
247247
run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py"
248+
run_test "$CDIR/spmd/test_mp_input_sharding.py"
248249
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
249250
run_test "$CDIR/test_input_output_aliases.py"
250251
run_test "$CDIR/test_torch_distributed_xla_backend.py"

test/spmd/test_mp_input_sharding.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import sys
2+
import numpy as np
3+
import unittest
4+
5+
import torch
6+
import torch_xla
7+
from torch_xla import runtime as xr
8+
import torch_xla.core.xla_model as xm
9+
from torch_xla.distributed.spmd import Mesh
10+
import torch_xla.distributed.spmd as xs
11+
import torch_xla.distributed.parallel_loader as pl
12+
13+
xr.use_spmd()
14+
15+
16+
class MpInputShardingTest(unittest.TestCase):
17+
18+
class fake_dataloader:
19+
20+
def __init__(self, batch, size=1):
21+
self.batch = batch
22+
self.batch_size = size
23+
self.counter = 0
24+
25+
def __iter__(self):
26+
return self
27+
28+
def __next__(self):
29+
if self.counter < self.batch_size:
30+
self.counter += 1
31+
return self.batch
32+
raise StopIteration
33+
34+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
35+
"Multiple devices required for tupled partition spec")
36+
def test_multiple_inputs(self):
37+
device = xm.xla_device()
38+
batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))}
39+
train_loader = self.fake_dataloader(batch)
40+
num_devices = xr.global_runtime_device_count()
41+
mesh = xs.get_1d_mesh('x')
42+
43+
train_loader = pl.MpDeviceLoader(
44+
train_loader,
45+
device,
46+
input_sharding={
47+
'x': xs.ShardingSpec(mesh, ('x', None)),
48+
'y': xs.ShardingSpec(mesh, ('x', None, None))
49+
})
50+
train_loader = iter(train_loader)
51+
data = next(train_loader)
52+
annotation_x = '{devices=[%d,1]%s}' % (num_devices, ','.join(
53+
[str(i) for i in range(num_devices)]))
54+
annotation_y = '{devices=[%d,1,1]%s}' % (num_devices, ','.join(
55+
[str(i) for i in range(num_devices)]))
56+
self.assertEqual(annotation_x,
57+
torch_xla._XLAC._get_xla_sharding_spec(data['x']))
58+
self.assertEqual(annotation_y,
59+
torch_xla._XLAC._get_xla_sharding_spec(data['y']))
60+
61+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
62+
"Multiple devices required for tupled partition spec")
63+
def test_single_tensor(self):
64+
device = xm.xla_device()
65+
batch = torch.randn((16, 128))
66+
train_loader = self.fake_dataloader(batch)
67+
num_devices = xr.global_runtime_device_count()
68+
mesh = xs.get_1d_mesh('x')
69+
70+
train_loader = pl.MpDeviceLoader(
71+
train_loader, device, input_sharding=xs.ShardingSpec(mesh, ('x', None)))
72+
train_loader = iter(train_loader)
73+
data = next(train_loader)
74+
annotation = '{devices=[%d,1]%s}' % (num_devices, ','.join(
75+
[str(i) for i in range(num_devices)]))
76+
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(data))
77+
78+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
79+
"Multiple devices required for tupled partition spec")
80+
def test_error_single_tensor_with_input_sharding_dict(self):
81+
device = xm.xla_device()
82+
batch = torch.randn((16, 128))
83+
train_loader = self.fake_dataloader(batch)
84+
num_devices = xr.global_runtime_device_count()
85+
mesh = xs.get_1d_mesh('x')
86+
87+
train_loader = pl.MpDeviceLoader(
88+
train_loader, device, input_sharding={'x': xs.ShardingSpec(mesh, ('x', None))})
89+
train_loader = iter(train_loader)
90+
with self.assertRaises(ValueError):
91+
data = next(train_loader)
92+
93+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
94+
"Multiple devices required for tupled partition spec")
95+
def test_input_sharding_none(self):
96+
device = xm.xla_device()
97+
batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))}
98+
train_loader = self.fake_dataloader(batch)
99+
num_devices = xr.global_runtime_device_count()
100+
101+
train_loader = pl.MpDeviceLoader(train_loader, device, input_sharding=None)
102+
train_loader = iter(train_loader)
103+
data = next(train_loader)
104+
annotation = '{replicated}'
105+
self.assertEqual(annotation,
106+
torch_xla._XLAC._get_xla_sharding_spec(data['x']))
107+
self.assertEqual(annotation,
108+
torch_xla._XLAC._get_xla_sharding_spec(data['y']))
109+
110+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
111+
"Multiple devices required for tupled partition spec")
112+
def test_error_missing_keys(self):
113+
device = xm.xla_device()
114+
batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))}
115+
train_loader = self.fake_dataloader(batch)
116+
mesh = xs.get_1d_mesh('x')
117+
train_loader = pl.MpDeviceLoader(
118+
train_loader,
119+
device,
120+
input_sharding={'x': xs.ShardingSpec(mesh, ('x', None))})
121+
train_loader = iter(train_loader)
122+
with self.assertRaises(KeyError):
123+
data = next(train_loader)
124+
125+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
126+
"Multiple devices required for tupled partition spec")
127+
def test_input_sharding_not_dict(self):
128+
device = xm.xla_device()
129+
num_devices = xr.global_runtime_device_count()
130+
batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128))}
131+
train_loader = self.fake_dataloader(batch)
132+
mesh = xs.get_1d_mesh('x')
133+
train_loader = pl.MpDeviceLoader(
134+
train_loader, device, input_sharding=xs.ShardingSpec(mesh, ('x', None)))
135+
train_loader = iter(train_loader)
136+
data = next(train_loader)
137+
annotation_x = '{devices=[%d,1]%s}' % (num_devices, ','.join(
138+
[str(i) for i in range(num_devices)]))
139+
annotation_y = '{devices=[%d,1]%s}' % (num_devices, ','.join(
140+
[str(i) for i in range(num_devices)]))
141+
self.assertEqual(annotation_x,
142+
torch_xla._XLAC._get_xla_sharding_spec(data['x']))
143+
self.assertEqual(annotation_y,
144+
torch_xla._XLAC._get_xla_sharding_spec(data['y']))
145+
146+
147+
if __name__ == '__main__':
148+
test = unittest.main()
149+
sys.exit(0 if test.result.wasSuccessful() else 1)

torch_xla/distributed/parallel_loader.py

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class PerDeviceQueue(object):
1212

1313
def __init__(self, device, loader_prefetch_size, device_prefetch_size):
1414
self.device = device
15-
self.loader_queue = kq.Queue(maxsize=loader_prefetch_size)
15+
self.cpu_loader_queue = kq.Queue(maxsize=loader_prefetch_size)
1616
self.queue = kq.Queue(maxsize=device_prefetch_size)
1717
self.close_queue_count = itertools.count()
1818

@@ -46,6 +46,8 @@ def next(self):
4646
self._batches_yielded += 1
4747

4848
item = self._loader.next_item(self._device)
49+
if isinstance(item, Exception):
50+
raise item
4951
if item is None:
5052
xm.mark_step()
5153
raise StopIteration
@@ -56,7 +58,7 @@ class ParallelLoader(object):
5658
"""Wraps an existing PyTorch DataLoader with background data upload.
5759
5860
Args:
59-
loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be
61+
cpu_loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be
6062
wrapped.
6163
devices (`torch.device`...): The list of devices where the data has to be
6264
sent. The i-th sample returned by the `loader` will be sent to `devices[i
@@ -74,21 +76,20 @@ class ParallelLoader(object):
7476
host_to_device_transfer_threads (int, optional): The number of threads that
7577
work in parallel to transfer data from loader queue to device queue.
7678
Default: 1
77-
input_sharding (ShardingSpec, optional): Sharding spec to apply to
78-
compatible input tensors after loading.
79-
Default: None
79+
input_sharding (ShardingSpec, Dict(str, ShardingSpec), optional): Sharding
80+
spec to apply to compatible input tensors after loading.
8081
"""
8182

8283
def __init__(self,
83-
loader,
84+
cpu_loader,
8485
devices,
8586
batchdim=0,
8687
batches_per_execution=1,
8788
loader_prefetch_size=16,
8889
device_prefetch_size=8,
8990
host_to_device_transfer_threads=1,
9091
input_sharding=None):
91-
self._loader = loader
92+
self._cpu_loader = cpu_loader
9293
self._devices = [torch.device(x) for x in devices]
9394
self._batchdim = batchdim
9495
self._batches_per_execution = batches_per_execution
@@ -140,7 +141,7 @@ def close(self):
140141
self._done = True
141142
for dqueue in self._queues.values():
142143
dqueue.queue.close()
143-
dqueue.loader_queue.close()
144+
dqueue.cpu_loader_queue.close()
144145

145146
for thread in self._threads:
146147
thread.join()
@@ -151,7 +152,7 @@ def batches_per_execution(self):
151152

152153
def _loader_worker(self):
153154
queues = list(self._queues.values())
154-
data_iter = enumerate(self._loader)
155+
data_iter = enumerate(self._cpu_loader)
155156
batch = []
156157

157158
try:
@@ -163,21 +164,66 @@ def _loader_worker(self):
163164
batch.append(data)
164165
if len(batch) == len(self._devices):
165166
for queue_no, device_batch in enumerate(batch):
166-
queues[queue_no].loader_queue.put(device_batch)
167+
queues[queue_no].cpu_loader_queue.put(device_batch)
167168
batch = []
168169
finally:
169170
for dqueue in queues:
170-
dqueue.loader_queue.close_write()
171+
dqueue.cpu_loader_queue.close_write()
171172

172173
def _get_batch(self, dqueue):
173174
batch = []
174-
while dqueue.queue.max_size() > len(batch):
175-
item = dqueue.loader_queue.get()
175+
while len(batch) < dqueue.queue.max_size():
176+
item = dqueue.cpu_loader_queue.get()
176177
if item is None:
177178
break
178179
batch.append(item)
179180
return batch
180181

182+
def send_cpu_data_to_device(self, batches, device):
183+
"""Move batch to device.
184+
Args:
185+
batch -> List(torch.Tensor), List(Dict(str: torch.Tensor)): Input batch
186+
present in the cpu memory
187+
device: TPU device where the batch should be moved
188+
189+
Returns:
190+
result -> List(torch.Tensor), Dict(str: torch.Tensor): Returns a dict if the
191+
input batch is a dict. Otherwise, returns a list of torch.Tensor.
192+
"""
193+
result = None
194+
if isinstance(self._input_sharding, dict):
195+
if not isinstance(batches[0], dict):
196+
return [
197+
ValueError(
198+
f"input batch should be a dict when input sharding is a dict."
199+
)
200+
]
201+
result = []
202+
for batch in batches:
203+
xla_batch = {}
204+
missing_keys = []
205+
for key, tensor in batch.items():
206+
assert type(tensor) == torch.Tensor
207+
sharding_spec = None
208+
if self._input_sharding:
209+
if key not in self._input_sharding:
210+
missing_keys.append(key)
211+
continue
212+
sharding_spec = self._input_sharding[key]
213+
214+
# xla_tensor is a list of tensors.
215+
xla_tensor = xm.send_cpu_data_to_device(tensor, device, sharding_spec)
216+
xla_batch[key] = xla_tensor[0]
217+
if len(missing_keys) != 0:
218+
# Returning exception as raising in the dataloading thread doesn't surface the problem in the main thread.
219+
return [
220+
KeyError(f"Keys: {missing_keys} are missing from input_sharding.")
221+
]
222+
result.append(xla_batch)
223+
else:
224+
result = xm.send_cpu_data_to_device(batches, device, self._input_sharding)
225+
return result
226+
181227
def _worker(self, dqueue, host_to_device_transfer_threads):
182228
device = torch.device(dqueue.device)
183229

@@ -187,8 +233,7 @@ def _worker(self, dqueue, host_to_device_transfer_threads):
187233
if not batch:
188234
break
189235
with torch.no_grad():
190-
batch = xm.send_cpu_data_to_device(batch, device,
191-
self._input_sharding)
236+
batch = self.send_cpu_data_to_device(batch, device)
192237
for data in batch:
193238
dqueue.queue.put(data)
194239
finally:

0 commit comments

Comments
 (0)