Skip to content

[Neuron] Improve transformers compatibility with AWS Neuron devices #44741

@dacorvo

Description

@dacorvo

Context

AWS Neuron devices (Trainium/Inferentia) compile a separate NEFF (Neuron Executable File Format) for every unique tensor shape. Any code path that changes tensor shapes between iterations — growing masks, torch.cat on outputs, variable padding — triggers expensive recompilations (2–60s per NEFF depending on model size).

This umbrella issue tracks non-generation-loop changes needed in transformers for Neuron compatibility. Generation loop changes (GenerationMixin._prefill, _sample, generate) are tracked separately in #44742.

Performance numbers below are measured on Llama-3.1-8B at TP=2 on trn2.3xlarge.


Tracked issues

Kernel mask function dispatch

Status: PR open — #44680 (issue #44679)

load_and_register_attn_kernel() hardcodes the mask function to flash_attention_2 for all custom kernels. Kernels that need SDPA-style 4D boolean masks get the wrong mask format. This is needed for an upcoming NKI (Neuron Kernel Interface) SDPA kernel that will be published as a HuggingFace Hub kernel. Fix: let kernel modules declare MASK_FUNCTION = "sdpa" (or any valid mask type), falling back to flash_attention_2 for backward compatibility.

index_select instead of fancy indexing in batched_mm_experts_forward

Status: PR open — #44669 (issue #44678)

batched_mm_experts_forward uses fancy indexing (self.gate_up_proj[expert_ids]) which is ambiguous for non-CUDA compiler backends. torch.index_select is the explicit, unambiguous API. Semantically identical, no behavioral change.

Auto-select StaticCache on Neuron device

Status: Issue open — #44748

When no cache is specified and device.type == "neuron", auto-select StaticCache in _prepare_cache_for_generation. Depends on #44742 (StaticCache-friendly _sample).

Auto-dispatch NKI SDPA kernel on Neuron device

Status: Open — needs design

infer_device() already returns "neuron" for Neuron devices, but _check_and_adjust_attn_implementation() has no auto-dispatch path for Neuron — unlike CUDA which has FLASH_ATTN_KERNEL_FALLBACK. When device.type == "neuron" and no explicit attn_implementation is set, the model falls back to eager, missing the NKI SDPA kernel that will be published as a Hub kernel.

Proposed: add a Neuron-specific fallback in _check_and_adjust_attn_implementation() (analogous to FLASH_ATTN_KERNEL_FALLBACK) that auto-calls load_and_register_attn_kernel() with the NKI SDPA Hub repo when running on a Neuron device. Depends on the mask dispatch fix (#44679) since the NKI kernel uses SDPA-style masks.


Resolved issues (no upstream changes needed)

  • NKI flash attention backend — resolved via HuggingFace kernels library (kernels PR #285). A dedicated NKI kernel for Neuron SDPA will be published as a HuggingFace Hub kernel.
  • StaticCache allocation — resolved by passing pre-built StaticCache as past_key_values= (existing API)

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions