Skip to content

Missing sharding specs when annotating sharding over views #8662

Open
@rpsilva-aws

Description

🐛 Bug

The HLO instruction for the custom sharding call is missing the sharding specs, leading to has_sharding failures on XLA:

RuntimeError: Bad StatusOr access: INVALID_ARGUMENT: HloOptimization: error condition !(status.ok()): 13RET_CHECK failure (external/xla/xla/service/sharding_propagation.cc:1464) instruction->has_sharding() Sharding instruction must have a sharding attribute

This issue was earlier identified in #8427, but with manual sharding. @JackCaoG did some investigation, but we didn't entirely RCA the issue yet. The issue can be minimally reproduce with the mark sharding as well, and we observe the same problem when adding any custom sharding prior to the input layer normalization for Llama3.

To Reproduce

  1. Similar underlying behavior as the embedding:
device_ids = list(range(32))
mesh = xs.Mesh(device_ids, (1, 1, 32), ('data', 'other', 'model'))
device = xm.xla_device()

indices = torch.zeros((1, 4096), dtype=torch.int64).to(device)  # p0.1 shape
weight = torch.randn((128256, 4096), dtype=torch.float32).to(device)  # p1.3 shape
xs.mark_sharding(weight, mesh, ("model", None))

r0 = torch.index_select(weight, 0, indices.view(-1)).view(1, 4096, 4096)  # or reshape

xs.mark_sharding(r0, mesh, ("data", "model", None))

r0 = r0.view(1, 4096, 4096)  # or reshape

print(r0)
HloModule IrToHlo.11, entry_computation_layout={(s64[1,4096]{1,0}, f32[128256,4096]{1,0})->(f32[1,4096,4096]{2,1,0})}

ENTRY %IrToHlo.11 (p0.1: s64[1,4096], p1.3: f32[128256,4096]) -> (f32[1,4096,4096]) {
  %p1.3 = f32[128256,4096]{1,0} parameter(1), sharding={devices=[32,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}
  %p0.1 = s64[1,4096]{1,0} parameter(0), sharding={replicated}
  %reshape.2 = s64[4096]{0} reshape(s64[1,4096]{1,0} %p0.1)
  %convert.4 = u32[4096]{0} convert(s64[4096]{0} %reshape.2)
  %gather.5 = f32[4096,4096]{1,0} gather(f32[128256,4096]{1,0} %p1.3, u32[4096]{0} %convert.4), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,4096}
  %reshape.6 = f32[1,4096,4096]{2,1,0} reshape(f32[4096,4096]{1,0} %gather.5)
  %custom-call.7 = f32[1,4096,4096]{2,1,0} custom-call(f32[1,4096,4096]{2,1,0} %reshape.6), custom_call_target="Sharding"
  %reshape.8 = f32[4096,4096]{1,0} reshape(f32[1,4096,4096]{2,1,0} %custom-call.7)
  %reshape.9 = f32[1,4096,4096]{2,1,0} reshape(f32[4096,4096]{1,0} %reshape.8)
  ROOT %tuple.10 = (f32[1,4096,4096]{2,1,0}) tuple(f32[1,4096,4096]{2,1,0} %reshape.9)
}
  1. flash_attention: support also cross attention. #8427

Expected behavior

We expect the appropriate sharding spec to be present in the custom sharding call, namely (for 1) above), to include:

sharding={devices=[1,32,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CPU
  • torch_xla version: 2.6

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions