Context
On Neuron devices, StaticCache is required for correct generation — dynamic tensor shapes trigger per-step recompilations. Currently, users must explicitly pass past_key_values=StaticCache(...) or cache_implementation="static" to model.generate(). Ideally, model.generate() should auto-select StaticCache when running on Neuron, providing zero-config DX.
Depends on: #44742 (StaticCache-friendly _sample must exist first — otherwise auto-selecting StaticCache still hits the dynamic-shape _sample path)
Proposed change
In _prepare_cache_for_generation (around line 1879 of src/transformers/generation/utils.py), when no cache is specified:
if generation_config.cache_implementation is None and self.device.type == "neuron":
generation_config.cache_implementation = "static"
Precedent: The generation code already has device-aware logic:
- Line 2021:
self.device.type in ["cuda", "xpu"] for compile criteria
- Line 1870: Forces
"dynamic_full" for assisted generation
Device-aware cache selection fits the existing pattern.
Expected behavior
# On Neuron — auto-selects StaticCache, uses static-shape _sample path:
model.generate(input_ids)
# On CUDA — unchanged, uses DynamicCache + standard _sample:
model.generate(input_ids)
# Explicit override still works on any device:
model.generate(input_ids, cache_implementation="static")
model.generate(input_ids, past_key_values=my_static_cache)
Related