diff --git a/axlearn/common/flash_attention/layer.py b/axlearn/common/flash_attention/layer.py index b22bd3281..c322b04df 100644 --- a/axlearn/common/flash_attention/layer.py +++ b/axlearn/common/flash_attention/layer.py @@ -140,6 +140,7 @@ def _compute_attention( cfg.mha_dim_to_partition_spec["btnh"], cfg.mha_dim_to_partition_spec["bsnh"], cfg.mha_dim_to_partition_spec["bsnh"], + #NOTE : Please resync this with upstream code. # Bias [batch_size, num_heads, seq_len, seq_len]. # cfg.mha_dim_to_partition_spec["bnts"], PartitionSpec(None, None, None, None) @@ -159,6 +160,8 @@ def _compute_attention( q_proj = with_sharding_constraint(q_proj, cfg.mha_dim_to_partition_spec["btnh"]) k_proj = with_sharding_constraint(k_proj, cfg.mha_dim_to_partition_spec["bsnh"]) v_proj = with_sharding_constraint(v_proj, cfg.mha_dim_to_partition_spec["bsnh"]) + + #NOTE : Please resync this with upstream code. # if attention_logit_biases is not None: # attention_logit_biases = with_sharding_constraint( # attention_logit_biases, cfg.mha_dim_to_partition_spec["bnts"] diff --git a/axlearn/common/flash_attention/neuron_attention.py b/axlearn/common/flash_attention/neuron_attention.py index 81dad4613..7310ecab6 100644 --- a/axlearn/common/flash_attention/neuron_attention.py +++ b/axlearn/common/flash_attention/neuron_attention.py @@ -31,12 +31,15 @@ from jax import custom_vjp -@partial(custom_vjp, nondiff_argnums=(3,4)) -def flash_attention(query, key, value, causal, softmax_scale): - out, _ = _mha_forward(query, key, value, causal, softmax_scale) +@partial(custom_vjp, nondiff_argnums=(4,5)) +def flash_attention(query, key, value, bias, causal, softmax_scale): + #NOTE : Merge with upstream. Old code supports both 2d and 4d bias but upstream code only supports 4d. + # We no longer need 2d logit_bias but should sync how we merge this check with upstream. + assert bias.ndim == 4, f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}" + out, _ = _mha_forward(query, key, value, bias, causal, softmax_scale) return out -def _mha_forward(query, key, value, causal, softmax_scale): +def _mha_forward(query, key, value, bias, causal, softmax_scale): # Get the batch size, sequence lengths, number of heads, and hidden dimension batch_size, q_seq_len, num_heads, d_model = query.shape _, kv_seq_len, _, _ = key.shape @@ -57,8 +60,10 @@ def _mha_forward(query, key, value, causal, softmax_scale): import neuronxcc.nki.language as nl assert (num_heads % 2) == 0 and (num_heads // 2 > 0), f'unexpect num_heads: {num_heads}' - attn_output, lse = flash_fwd[batch_size, nl.nc(2) * (num_heads//2)](q, k, v, seed, use_causal_mask=causal, softmax_scale=softmax_scale, mixed_precision=True, dropout_p=0.0) + attn_output, lse = flash_fwd[batch_size, nl.nc(2) * (num_heads//2)](q, k, v, seed, bias, use_causal_mask=causal, softmax_scale=softmax_scale, mixed_precision=True, dropout_p=0.0) else: + #NOTE : Please make a feature request to neuron compiler team if this is needed. + assert bias == None, f"logit_bias is not supported in legacy kernels. Set envvar ENABLED_NEW_UNSHARDED_ATTN_KERNEL to use new kernel" from neuronxcc.nki._private_kernels.legacy.attention import flash_fwd from neuronxcc.nki._private_kernels.attention import flash_fwd_shardable from neuronxcc.starfish.penguin.targets.nki.private_api import vnc @@ -71,10 +76,10 @@ def _mha_forward(query, key, value, causal, softmax_scale): # Transpose the output back to the original shape attn_output = attn_output.transpose(0, 2, 1, 3) # [batch_size, q_seq_len, num_heads, d_model] - return attn_output, (lse, attn_output, q, k, v) + return attn_output, (lse, attn_output, q, k, v, bias) def _mha_backward(causal, softmax_scale, res, d_attn_output): - lse, o, q, k, v = res + lse, o, q, k, v, bias = res batch_size, num_heads, d_model, seq_len = q.shape _, kv_seq_len, _, _ = k.shape @@ -95,8 +100,10 @@ def _mha_backward(causal, softmax_scale, res, d_attn_output): from neuronxcc.nki.kernels.attention import flash_attn_bwd import neuronxcc.nki.language as nl assert (num_heads % 2) == 0 and (num_heads // 2 > 0), f'unexpected num_heads: {num_heads}' - d_query, d_key, d_value = flash_attn_bwd[batch_size, nl.nc(2) * (num_heads//2)](q, k, v, o, dy, lse, seed, use_causal_mask=causal, mixed_precision=True, dropout_p=0.0, softmax_scale=softmax_scale) + d_query, d_key, d_value = flash_attn_bwd[batch_size, nl.nc(2) * (num_heads//2)](q, k, v, o, dy, lse, seed, bias, use_causal_mask=causal, mixed_precision=True, dropout_p=0.0, softmax_scale=softmax_scale) else: + #NOTE : Please make a feature request to neuron compiler team if this is needed. + assert bias == None, f"logit_bias is not supported in legacy kernels. Set envvar ENABLED_NEW_UNSHARDED_ATTN_KERNEL to use new kernel" from neuronxcc.nki._private_kernels.legacy.attention import flash_attn_bwd from neuronxcc.nki._private_kernels.attention import flash_attn_bwd_shardable from neuronxcc.starfish.penguin.targets.nki.private_api import vnc @@ -113,7 +120,7 @@ def _mha_backward(causal, softmax_scale, res, d_attn_output): d_key = d_key.transpose(0, 3, 1, 2) d_value = d_value.transpose(0, 3, 1, 2) - return d_query, d_key, d_value + return d_query, d_key, d_value, None flash_attention.defvjp(_mha_forward, _mha_backward) diff --git a/axlearn/common/flash_attention/utils.py b/axlearn/common/flash_attention/utils.py index 6fe7e0601..bbec444b2 100644 --- a/axlearn/common/flash_attention/utils.py +++ b/axlearn/common/flash_attention/utils.py @@ -155,7 +155,7 @@ def jit_attn(query, key, value, bias): @jax.jit def jit_attn(query, key, value, bias): return neuron_flash_attention( - query, key, value, causal, softmax_scale) + query, key, value, bias, causal, softmax_scale) return jit_attn