Skip to content

Commit ba5d31b

Browse files
committed
Add DynamicSlidingWindowLayer support
1 parent 524ca2a commit ba5d31b

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

olive/passes/onnx/conversion.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,12 @@ def patched_lazy_initialization(self, key_states: torch.Tensor, value_states: to
8585
logger.debug("Patched DynamicLayer.lazy_initialization for torch.export compatibility.")
8686

8787

88-
def _convert_past_key_values_to_dynamic_cache(dummy_kwargs: dict) -> dict:
88+
def _convert_past_key_values_to_dynamic_cache(dummy_kwargs: dict, config=None) -> dict:
8989
"""Convert legacy list-format past_key_values to DynamicCache (transformers >= 5.0).
9090
9191
Transformers 5.0 models expect DynamicCache objects, not lists of (key, value) tensors.
92+
When config is provided, the DynamicCache will create correct layer types (e.g.
93+
DynamicSlidingWindowLayer for models using sliding window attention).
9294
"""
9395
pkv = dummy_kwargs.get("past_key_values")
9496
if pkv is None or not isinstance(pkv, (list, tuple)):
@@ -100,7 +102,7 @@ def _convert_past_key_values_to_dynamic_cache(dummy_kwargs: dict) -> dict:
100102

101103
from transformers.cache_utils import DynamicCache
102104

103-
dc = DynamicCache()
105+
dc = DynamicCache(config=config)
104106
for layer_idx, kv in enumerate(pkv):
105107
dc.update(kv[0], kv[1], layer_idx=layer_idx)
106108
dummy_kwargs["past_key_values"] = dc
@@ -298,7 +300,8 @@ def _export_pytorch_model(
298300

299301
register_dynamic_cache_export_support()
300302
_patch_dynamic_layer_for_export()
301-
dummy_kwargs = _convert_past_key_values_to_dynamic_cache(dummy_kwargs)
303+
model_config = getattr(pytorch_model, "config", None)
304+
dummy_kwargs = _convert_past_key_values_to_dynamic_cache(dummy_kwargs, config=model_config)
302305
if io_config.dynamic_shapes:
303306
io_config.dynamic_shapes = _convert_dynamic_shapes_for_dynamic_cache(io_config.dynamic_shapes)
304307
else:

0 commit comments

Comments
 (0)