Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 5D support for flash_attention #8693

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions torch_xla/experimental/custom_kernel.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't find a test associated with this file. Do we have a unit test that covers it? If yes, please add a case for 5 dimension fowarding.

If not, perhaps creating a bug for this would be appropriate.

Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ def fa_custom_forward(
ab = xs.enable_manual_sharding(
ab, partition_spec, mesh=mesh).global_tensor



Comment on lines +260 to +261
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Remove unnecessary additional spaces

# It computes the shape and type of o, l, m.
shapes = [q.shape]
dtypes = [q.dtype]
Expand All @@ -279,6 +281,14 @@ def fa_custom_forward(
segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids(
q_segment_ids, kv_segment_ids)

# support 5D inputs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment is unnecessary as the rest of the code makes this self-evident. What might be more useful as a comment is to answer why the following is necessary for 5d inputs.

Applicable to following instances.

if len(q_full_shape) == 5:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this reshaping only applicable for dimensions of 5, or can it be generalized? Is it the case that higher than 4 dimensions require reshaping?

q = q.reshape(-1, *q_full_shape[2:])
k = k.reshape(-1, *q_full_shape[2:])
v = v.reshape(-1, *q_full_shape[2:])
q_segment_ids = q_segment_ids.reshape(-1, *q_segment_ids.shape[2:])
kv_segment_ids = kv_segment_ids.reshape(-1, *kv_segment_ids.shape[2:])

# We can't directly use flash_attention as we need to override the save_residuals flag which returns
# l and m that is needed for the backward. Then we lose all the shape checks.
# TODO: replicate the shape checks on flash_attention.
Expand Down Expand Up @@ -322,6 +332,8 @@ def fa_custom_forward(
o, *aux = o
l, m = (v[..., 0] for v in aux[-2:])

if len(q_full_shape) == 5:
o = o.reshape(q_full_shape)
# SPMD integration
if partition_spec is not None:
o = xs.disable_manual_sharding(
Expand Down Expand Up @@ -397,6 +409,17 @@ def fa_custom_backward(
if ab is not None:
ab = xs.enable_manual_sharding(
ab, partition_spec, mesh=mesh).global_tensor

# support 5D input
if len(q.shape) == 5:
q = q.reshape(-1, *q.shape[2:])
k = k.reshape(-1, *k.shape[2:])
v = v.reshape(-1, *v.shape[2:])
expanded_l = expanded_l.reshape(-1, *expanded_l.shape[2:])
expanded_m = expanded_m.reshape(-1, *expanded_m.shape[2:])
grad_output = grad_output.reshape(-1, *grad_output.shape[2:])
expanded_grad_i = expanded_grad_i.reshape(-1, *expanded_grad_i.shape[2:])

if q_segment_ids is not None and kv_segment_ids is not None:
segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids(
q_segment_ids, kv_segment_ids)
Expand Down Expand Up @@ -490,6 +513,16 @@ def fa_custom_backward(
if require_grad_v:
grad_v = grads[1]


# support 5D input
if len(q.shape) == 5:
grad_q = grad_q.reshape(q_full_shape)
grad_k = grad_k.reshape(kv_full_shape)
grad_v = grad_v.reshape(kv_full_shape)
grad_v = grad_v.reshape(kv_full_shape)
if ab is not None:
grad_ab = grad_ab.reshape(ab_full_shape)

# SPMD integration
if partition_spec is not None:
grad_q = xs.disable_manual_sharding(
Expand Down
Loading