Skip to content

[Neuron] Auto-select StaticCache when device is Neuron #44748

@dacorvo

Description

@dacorvo

Context

On Neuron devices, StaticCache is required for correct generation — dynamic tensor shapes trigger per-step recompilations. Currently, users must explicitly pass past_key_values=StaticCache(...) or cache_implementation="static" to model.generate(). Ideally, model.generate() should auto-select StaticCache when running on Neuron, providing zero-config DX.

Depends on: #44742 (StaticCache-friendly _sample must exist first — otherwise auto-selecting StaticCache still hits the dynamic-shape _sample path)

Proposed change

In _prepare_cache_for_generation (around line 1879 of src/transformers/generation/utils.py), when no cache is specified:

if generation_config.cache_implementation is None and self.device.type == "neuron":
    generation_config.cache_implementation = "static"

Precedent: The generation code already has device-aware logic:

  • Line 2021: self.device.type in ["cuda", "xpu"] for compile criteria
  • Line 1870: Forces "dynamic_full" for assisted generation

Device-aware cache selection fits the existing pattern.

Expected behavior

# On Neuron — auto-selects StaticCache, uses static-shape _sample path:
model.generate(input_ids)

# On CUDA — unchanged, uses DynamicCache + standard _sample:
model.generate(input_ids)

# Explicit override still works on any device:
model.generate(input_ids, cache_implementation="static")
model.generate(input_ids, past_key_values=my_static_cache)

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions