Skip to content

Commit f82f373

Browse files
committed
backward with spmd issue
1 parent 8e6ca60 commit f82f373

File tree

5 files changed

+586
-253
lines changed

5 files changed

+586
-253
lines changed

test/test_pallas.py

+76-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
import numpy as np
1313

1414
if xr.device_type() == 'TPU':
15-
from torch_xla.experimental.custom_kernel import jax_import_guard
16-
jax_import_guard()
15+
# from torch_xla.experimental.custom_kernel import jax_import_guard
16+
# jax_import_guard()
17+
torch_xla._XLAC._init_computation_client()
1718
import jax
1819
import jax.numpy as jnp
1920
from jax.experimental import pallas as pl
@@ -488,6 +489,79 @@ def test_flash_attention_backward(self):
488489
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
489490
jax.config.update("jax_default_matmul_precision", "default")
490491

492+
493+
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
494+
"This test only works on TPUv4+.")
495+
def test_flash_attention_backward_aot_autograd_traceable(self):
496+
from functorch.compile import aot_function, make_boxed_func
497+
from torch_xla.experimental.custom_kernel import flash_attention, FlashAttention, flash_attention_compilable
498+
import torch_xla.core.xla_model as xm
499+
jax.config.update("jax_default_matmul_precision", "highest")
500+
def compiler(gm, _):
501+
print("Got graph:")
502+
print(gm.code)
503+
return make_boxed_func(gm)
504+
505+
torch.manual_seed(42)
506+
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
507+
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
508+
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
509+
q.retain_grad()
510+
k.retain_grad()
511+
v.retain_grad()
512+
B, N, SEQ, H = q.size()
513+
causal = True
514+
q_segment_ids = None
515+
kv_segment_ids = None
516+
sm_scale = 1.0
517+
mask = (torch.rand(4, 2, 128, 128) > 0.5).to("xla")
518+
# ab = torch.ones(4, 2, 128, 128).to("xla")
519+
# ab = ab.masked_fill(mask, torch.finfo(ab.dtype).min).requires_grad_(True)
520+
# ab.retain_grad()
521+
ab = None
522+
partition_spec = ('fsdp', 'tensor', None, None)
523+
# partition_spec = None
524+
import torch_xla.runtime as xr
525+
from torch_xla.distributed.spmd import Mesh
526+
xr.use_spmd()
527+
num_devices = xr.global_runtime_device_count()
528+
mesh_shape = (num_devices // 2, 2)
529+
device_ids = np.array(range(num_devices))
530+
mesh = Mesh(device_ids, mesh_shape, ('fsdp', 'tensor'))
531+
532+
def flash_attention_wrapper(q, k, v, casual, q_segment_ids, kv_segment_ids, sm_scale, ab, partition_spec, mesh):
533+
return flash_attention_compilable(q, k, v, casual, q_segment_ids, kv_segment_ids, sm_scale, ab=ab, partition_spec=partition_spec, mesh=mesh)
534+
535+
536+
# AOT compatiable funtion only accepts argument types listed https://github.com/pytorch/pytorch/blob/82859f61857ef39898b34a5cdf0ae56ec25704d9/torch/_functorch/_aot_autograd/utils.py#L23-L34, so we serliaze partition_spec and mesh into string.
537+
# compiled_flash_attention = aot_function(
538+
# flash_attention_wrapper, fw_compiler=compiler)
539+
# o_actual = compiled_flash_attention(q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, str(partition_spec), str(mesh))
540+
o_actual = flash_attention(q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab=ab, partition_spec=partition_spec, mesh=mesh)
541+
542+
print(o_actual.sum())
543+
o_actual.sum().backward()
544+
print(q.grad)
545+
546+
# if causal:
547+
# attention_mask = torch.triu(torch.ones(SEQ, SEQ), diagonal=1).to("xla")
548+
# # attention_mask = attention_mask.view(1, 1, SEQ, SEQ)
549+
# # attention_mask = attention_mask.expand(q.size(0), q.size(1), -1, -1)
550+
# else:
551+
# attention_mask = None
552+
# print(attention_mask)
553+
# assert False
554+
# import torch_xla.distributed.spmd as xs
555+
# expected_output = self._attention(q, k, v, attn_mask = attention_mask)
556+
# print(expected_output)
557+
# self.assertTrue(
558+
# torch.allclose(
559+
# expected_output.cpu(),
560+
# o_actual.cpu(),
561+
# atol=1e-1,
562+
# rtol=1e-1))
563+
564+
491565
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
492566
"This test only works on TPUv4+.")
493567
def test_paged_attention_wrapper(self):

third_party/xla

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 6e91ff19dad528ab7d2025a9bb46150618a3bc7d

torch_xla/distributed/spmd/xla_sharding.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,27 @@ def get_op_sharding(self,
124124
partition_spec)
125125
return torch_xla._XLAC.OpSharding(tile_assignment, group_assignment,
126126
replication_groups, sharding_type)
127-
128-
127+
128+
def __str__(self):
129+
"""Convert Mesh to string representation."""
130+
return (f"{{'device_ids': {self.device_ids.tolist()}, "
131+
f"'mesh_shape': {self.mesh_shape}, "
132+
f"'axis_names': {self.axis_names}}}")
133+
134+
@classmethod
135+
def from_str(cls, mesh_str: str):
136+
"""Create Mesh from string representation."""
137+
import ast
138+
import numpy as np
139+
# Remove 'Mesh' and parse dict
140+
dict_str = mesh_str.replace('Mesh', '')
141+
mesh_dict = ast.literal_eval(dict_str)
142+
# Convert list back to numpy array for device_ids
143+
return cls(
144+
device_ids=np.array(mesh_dict['device_ids']),
145+
mesh_shape=mesh_dict['mesh_shape'],
146+
axis_names=mesh_dict['axis_names']
147+
)
129148
_GLOBAL_MESH: Mesh = None
130149

131150

0 commit comments

Comments
 (0)