Skip to content

How to minimize memory expansion due to padding during sharding #6674

Open
@mfatih7

Description

Hello

For a model that can be sharded in model parallelization in TPUv4 (4x32) device, I am getting the error below at the beginning of the training on TPUv3 (8x16) device. There is 4x expansion with respect to console message. Even if both both TPUv4 and TPUv3 devices have same total memory I cannot run the training on TPUv3 device.

Program hbm requirement 15.45G:
    global            2.36M
    scoped            3.88M
    HLO temp         15.45G (60.9% utilization: Unpadded (9.40G) Padded (15.44G), 0.0% fragmentation (5.52M))

  Largest program allocations in hbm:

  1. Size: 4.00G
     Shape: bf16[2048,1,2048,128]{0,1,3,2:T(4,128)(2,1)}
     Unpadded size: 1.00G
     Extra memory due to padding: 3.00G (4.0x expansion)
     XLA label: broadcast.6042.remat3 = broadcast(bitcast.26), dimensions={2,3}
     Allocation type: HLO temp
     ==========================

  2. Size: 4.00G
     Shape: bf16[2048,1,2048,128]{0,1,3,2:T(4,128)(2,1)}
     Unpadded size: 1.00G
     Extra memory due to padding: 3.00G (4.0x expansion)
     XLA label: broadcast.6043.remat3 = broadcast(bitcast.27), dimensions={0,3}
     Allocation type: HLO temp
     ==========================

The lines that causes 4x expansion is below:

def forward(self, x):   # Activation map volume = 1,128,2048,1
   ...
   ...
   x = torch.transpose(x, 1, 3)  # Activation map volume = 1,1,2048,128

   x_batch_0 = x.expand(2048, -1, -1, -1)  # Activation map volume = 2048,1,2048,128

   x_batch_1 = x.repeat_interleave(2048, dim=2).reshape(2048, 1, 2048, 128) # Activation map volume = 2048,1,2048,128

   x_batch = torch.cat((x_batch_0, x_batch_1), dim=1) # Activation map volume = 2048,2,2048,128

   ...
   ...

Here are the sharding properties that I set.

mesh_shape = (num_devices, 1, 1, 1)

mesh = xs.Mesh(device_ids, mesh_shape, ('w', 'x', 'y', 'z'))
partition_spec = (0, 1, 2, 3)  # Apply sharding along all axes

for name, layer in model.named_modules():
    if (  'conv2d' in name ):
       xs.mark_sharding(layer.weight, mesh, partition_spec)

How can I prevent 4x expansion?

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions