Open
Description
Hello
For a model that can be sharded in model parallelization in TPUv4 (4x32) device, I am getting the error below at the beginning of the training on TPUv3 (8x16) device. There is 4x expansion
with respect to console message. Even if both both TPUv4 and TPUv3 devices have same total memory I cannot run the training on TPUv3 device.
Program hbm requirement 15.45G:
global 2.36M
scoped 3.88M
HLO temp 15.45G (60.9% utilization: Unpadded (9.40G) Padded (15.44G), 0.0% fragmentation (5.52M))
Largest program allocations in hbm:
1. Size: 4.00G
Shape: bf16[2048,1,2048,128]{0,1,3,2:T(4,128)(2,1)}
Unpadded size: 1.00G
Extra memory due to padding: 3.00G (4.0x expansion)
XLA label: broadcast.6042.remat3 = broadcast(bitcast.26), dimensions={2,3}
Allocation type: HLO temp
==========================
2. Size: 4.00G
Shape: bf16[2048,1,2048,128]{0,1,3,2:T(4,128)(2,1)}
Unpadded size: 1.00G
Extra memory due to padding: 3.00G (4.0x expansion)
XLA label: broadcast.6043.remat3 = broadcast(bitcast.27), dimensions={0,3}
Allocation type: HLO temp
==========================
The lines that causes 4x expansion
is below:
def forward(self, x): # Activation map volume = 1,128,2048,1
...
...
x = torch.transpose(x, 1, 3) # Activation map volume = 1,1,2048,128
x_batch_0 = x.expand(2048, -1, -1, -1) # Activation map volume = 2048,1,2048,128
x_batch_1 = x.repeat_interleave(2048, dim=2).reshape(2048, 1, 2048, 128) # Activation map volume = 2048,1,2048,128
x_batch = torch.cat((x_batch_0, x_batch_1), dim=1) # Activation map volume = 2048,2,2048,128
...
...
Here are the sharding properties that I set.
mesh_shape = (num_devices, 1, 1, 1)
mesh = xs.Mesh(device_ids, mesh_shape, ('w', 'x', 'y', 'z'))
partition_spec = (0, 1, 2, 3) # Apply sharding along all axes
for name, layer in model.named_modules():
if ( 'conv2d' in name ):
xs.mark_sharding(layer.weight, mesh, partition_spec)
How can I prevent 4x expansion
?