-
Notifications
You must be signed in to change notification settings - Fork 32.5k
Description
Context
AWS Neuron devices (Trainium/Inferentia) compile a separate NEFF (Neuron Executable File Format) for every unique tensor shape. Any code path that changes tensor shapes between iterations — growing masks, torch.cat on outputs, variable padding — triggers expensive recompilations (2–60s per NEFF depending on model size).
This umbrella issue tracks non-generation-loop changes needed in transformers for Neuron compatibility. Generation loop changes (GenerationMixin._prefill, _sample, generate) are tracked separately in #44742.
Performance numbers below are measured on Llama-3.1-8B at TP=2 on trn2.3xlarge.
Tracked issues
Kernel mask function dispatch
Status: PR open — #44680 (issue #44679)
load_and_register_attn_kernel() hardcodes the mask function to flash_attention_2 for all custom kernels. Kernels that need SDPA-style 4D boolean masks get the wrong mask format. This is needed for an upcoming NKI (Neuron Kernel Interface) SDPA kernel that will be published as a HuggingFace Hub kernel. Fix: let kernel modules declare MASK_FUNCTION = "sdpa" (or any valid mask type), falling back to flash_attention_2 for backward compatibility.
index_select instead of fancy indexing in batched_mm_experts_forward
Status: PR open — #44669 (issue #44678)
batched_mm_experts_forward uses fancy indexing (self.gate_up_proj[expert_ids]) which is ambiguous for non-CUDA compiler backends. torch.index_select is the explicit, unambiguous API. Semantically identical, no behavioral change.
Auto-select StaticCache on Neuron device
Status: Issue open — #44748
When no cache is specified and device.type == "neuron", auto-select StaticCache in _prepare_cache_for_generation. Depends on #44742 (StaticCache-friendly _sample).
Auto-dispatch NKI SDPA kernel on Neuron device
Status: Open — needs design
infer_device() already returns "neuron" for Neuron devices, but _check_and_adjust_attn_implementation() has no auto-dispatch path for Neuron — unlike CUDA which has FLASH_ATTN_KERNEL_FALLBACK. When device.type == "neuron" and no explicit attn_implementation is set, the model falls back to eager, missing the NKI SDPA kernel that will be published as a Hub kernel.
Proposed: add a Neuron-specific fallback in _check_and_adjust_attn_implementation() (analogous to FLASH_ATTN_KERNEL_FALLBACK) that auto-calls load_and_register_attn_kernel() with the NKI SDPA Hub repo when running on a Neuron device. Depends on the mask dispatch fix (#44679) since the NKI kernel uses SDPA-style masks.
Resolved issues (no upstream changes needed)
- NKI flash attention backend — resolved via HuggingFace
kernelslibrary (kernels PR #285). A dedicated NKI kernel for Neuron SDPA will be published as a HuggingFace Hub kernel. - StaticCache allocation — resolved by passing pre-built
StaticCacheaspast_key_values=(existing API)
Related
- Generation loop changes: [Neuron] Static-shape generation loop for compilation-friendly inference #44742
- OLMoE TP plan: Add base_model_tp_plan to OlmoeConfig #44677 / PR Add
base_model_tp_plantoOlmoeConfig#44668 (general feature, not Neuron-specific)