Open
Description
🐛 Bug
The HLO instruction for the custom sharding call is missing the sharding specs, leading to has_sharding
failures on XLA:
RuntimeError: Bad StatusOr access: INVALID_ARGUMENT: HloOptimization: error condition !(status.ok()): 13RET_CHECK failure (external/xla/xla/service/sharding_propagation.cc:1464) instruction->has_sharding() Sharding instruction must have a sharding attribute
This issue was earlier identified in #8427, but with manual sharding. @JackCaoG did some investigation, but we didn't entirely RCA the issue yet. The issue can be minimally reproduce with the mark sharding as well, and we observe the same problem when adding any custom sharding prior to the input layer normalization for Llama3.
To Reproduce
- Similar underlying behavior as the embedding:
device_ids = list(range(32))
mesh = xs.Mesh(device_ids, (1, 1, 32), ('data', 'other', 'model'))
device = xm.xla_device()
indices = torch.zeros((1, 4096), dtype=torch.int64).to(device) # p0.1 shape
weight = torch.randn((128256, 4096), dtype=torch.float32).to(device) # p1.3 shape
xs.mark_sharding(weight, mesh, ("model", None))
r0 = torch.index_select(weight, 0, indices.view(-1)).view(1, 4096, 4096) # or reshape
xs.mark_sharding(r0, mesh, ("data", "model", None))
r0 = r0.view(1, 4096, 4096) # or reshape
print(r0)
HloModule IrToHlo.11, entry_computation_layout={(s64[1,4096]{1,0}, f32[128256,4096]{1,0})->(f32[1,4096,4096]{2,1,0})}
ENTRY %IrToHlo.11 (p0.1: s64[1,4096], p1.3: f32[128256,4096]) -> (f32[1,4096,4096]) {
%p1.3 = f32[128256,4096]{1,0} parameter(1), sharding={devices=[32,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}
%p0.1 = s64[1,4096]{1,0} parameter(0), sharding={replicated}
%reshape.2 = s64[4096]{0} reshape(s64[1,4096]{1,0} %p0.1)
%convert.4 = u32[4096]{0} convert(s64[4096]{0} %reshape.2)
%gather.5 = f32[4096,4096]{1,0} gather(f32[128256,4096]{1,0} %p1.3, u32[4096]{0} %convert.4), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,4096}
%reshape.6 = f32[1,4096,4096]{2,1,0} reshape(f32[4096,4096]{1,0} %gather.5)
%custom-call.7 = f32[1,4096,4096]{2,1,0} custom-call(f32[1,4096,4096]{2,1,0} %reshape.6), custom_call_target="Sharding"
%reshape.8 = f32[4096,4096]{1,0} reshape(f32[1,4096,4096]{2,1,0} %custom-call.7)
%reshape.9 = f32[1,4096,4096]{2,1,0} reshape(f32[4096,4096]{1,0} %reshape.8)
ROOT %tuple.10 = (f32[1,4096,4096]{2,1,0}) tuple(f32[1,4096,4096]{2,1,0} %reshape.9)
}
Expected behavior
We expect the appropriate sharding spec to be present in the custom sharding call, namely (for 1) above), to include:
sharding={devices=[1,32,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: CPU
- torch_xla version: 2.6