Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 98 additions & 8 deletions olive/passes/onnx/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,83 @@ def forward(self, *input_data, **input_dict):
return self.model(*input_data, **input_dict)


def _patch_dynamic_layer_for_export():
"""Patch DynamicLayer.lazy_initialization for torch.export compatibility (transformers >= 5.0).

The original uses torch.tensor([]) which creates a 1D empty tensor (shape [0]).
torch.export needs consistent tensor ranks, so we use torch.narrow + torch.empty_like
to preserve the full shape (e.g. [batch, heads, 0, head_dim]).
"""
from transformers.cache_utils import DynamicLayer

if not hasattr(DynamicLayer, "lazy_initialization"):
return

def patched_lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor = None):
self.dtype, self.device = key_states.dtype, key_states.device
like = torch.narrow(key_states, dim=-2, start=0, length=0)
if hasattr(key_states, "fake_mode"):
with key_states.fake_mode:
self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device)
self.values = torch.empty_like(like, dtype=self.dtype, device=self.device)
else:
self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device)
self.values = torch.empty_like(like, dtype=self.dtype, device=self.device)
self.is_initialized = True

DynamicLayer.lazy_initialization = patched_lazy_initialization
logger.debug("Patched DynamicLayer.lazy_initialization for torch.export compatibility.")


def _convert_past_key_values_to_dynamic_cache(dummy_kwargs: dict, config=None) -> dict:
"""Convert legacy list-format past_key_values to DynamicCache (transformers >= 5.0).

Transformers 5.0 models expect DynamicCache objects, not lists of (key, value) tensors.
When config is provided, the DynamicCache will create correct layer types (e.g.
DynamicSlidingWindowLayer for models using sliding window attention).
"""
pkv = dummy_kwargs.get("past_key_values")
if pkv is None or not isinstance(pkv, (list, tuple)):
return dummy_kwargs

# Check if it's legacy format: list of [key, value] pairs (each with exactly 2 elements)
if not pkv or not isinstance(pkv[0], (list, tuple)) or len(pkv[0]) != 2:
return dummy_kwargs

from transformers.cache_utils import DynamicCache

dc = DynamicCache(config=config)
for layer_idx, kv in enumerate(pkv):
dc.update(kv[0], kv[1], layer_idx=layer_idx)
dummy_kwargs["past_key_values"] = dc
logger.debug("Converted past_key_values from legacy list format to DynamicCache.")
return dummy_kwargs


def _convert_dynamic_shapes_for_dynamic_cache(dynamic_shapes: dict) -> dict:
"""Convert dynamic_shapes for past_key_values from nested list to DynamicCache pytree format.

The old format is: [[key_shape, val_shape], ...] (one pair per layer)
The DynamicCache pytree expects a flat list: [key0, val0, key1, val1, ...]
matching the flattened order from register_dynamic_cache_export_support().
"""
pkv_shapes = dynamic_shapes.get("past_key_values")
if pkv_shapes is None or not isinstance(pkv_shapes, (list, tuple)):
return dynamic_shapes

if not pkv_shapes or not isinstance(pkv_shapes[0], (list, tuple)) or len(pkv_shapes[0]) != 2:
return dynamic_shapes

# Convert [[key0, val0], [key1, val1], ...] -> [[key0, key1, ...], [val0, val1, ...]]
# matching DynamicCache pytree: _dict_flatten({"key_cache": [...], "value_cache": [...]})
dynamic_shapes["past_key_values"] = [
[layer[0] for layer in pkv_shapes],
[layer[1] for layer in pkv_shapes],
]
logger.debug("Converted dynamic_shapes for past_key_values to DynamicCache pytree format.")
return dynamic_shapes


def _patch_model_if_necessary(pytorch_model: torch.nn.Module):
if not isinstance(pytorch_model, PreTrainedModel):
return
Expand Down Expand Up @@ -179,9 +256,6 @@ def _export_pytorch_model(
if torch_dtype:
pytorch_model = pytorch_model.to(torch_dtype)

# Apply any necessary patches
_patch_model_if_necessary(pytorch_model)

# get input and output names, and dynamic axes
assert io_config is not None, "Cannot get io_config for the model."
io_config = validate_config(io_config, IoConfig)
Expand Down Expand Up @@ -212,24 +286,40 @@ def _export_pytorch_model(
"Please upgrade PyTorch to 2.6.0 or above."
)

# Register DynamicCache export support
from transformers.integrations.executorch import register_dynamic_cache_export_support

register_dynamic_cache_export_support()

if isinstance(dummy_inputs, dict):
dummy_kwargs = dummy_inputs
dummy_inputs = ()
else:
dummy_kwargs = {}
dummy_inputs = tuple(dummy_inputs)

# Apply patches for DynamicCache / past_key_values compatibility
if version.parse(transformers.__version__) >= version.parse("5.0"):
# transformers >= 5.0: DynamicCache refactored to use DynamicLayer
from transformers.integrations.executorch import register_dynamic_cache_export_support
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

register_dynamic_cache_export_support Does not have the right ordering for kvcaches. I would just use the same patch regardless of the transformers version

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean _patch_model_if_necessary? transformers 5.0 updated DynamicCache and i got error "AttributeError: 'DynamicCache' object has no attribute 'to_legacy_cache'"

Copy link
Contributor

@justinchuby justinchuby Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can update the patch code (_patch_model_if_necessary) so that it works universally. There is no need to call to_legacy_cache. The executorch integration is not reliable for our usages.

Copy link
Contributor

@justinchuby justinchuby Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@titaiwangms suggestions on what the patch logic should be?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So can we get rid of register_dynamic_cache_export_support now?


register_dynamic_cache_export_support()
_patch_dynamic_layer_for_export()
model_config = getattr(pytorch_model, "config", None)
dummy_kwargs = _convert_past_key_values_to_dynamic_cache(dummy_kwargs, config=model_config)
if io_config.dynamic_shapes:
io_config.dynamic_shapes = _convert_dynamic_shapes_for_dynamic_cache(io_config.dynamic_shapes)
else:
# transformers < 5.0: patch forward to convert list <-> DynamicCache
_patch_model_if_necessary(pytorch_model)

# NOTE: Usually validation is done in io_config.py, but because
# dynamic_shapes has nested complexity, and it can't be validated multiple
# times like others, we validate it here.
io_config.dynamic_shapes, dummy_inputs, dummy_kwargs = _validate_dynamic_shapes(
io_config.dynamic_shapes, dummy_inputs, dummy_kwargs, pytorch_model
)
# torch.export requires strict type match between inputs and dynamic_shapes;
# _validate_dynamic_shapes may return OrderedDict, so convert back to plain dict
if isinstance(io_config.dynamic_shapes, collections.OrderedDict):
io_config.dynamic_shapes = dict(io_config.dynamic_shapes)
if isinstance(dummy_kwargs, collections.OrderedDict):
dummy_kwargs = dict(dummy_kwargs)

# When dynamo=True, PyTorch prefers dynamic_shapes over dynamic_axes.
# If dynamic_shapes is None and fallback is enabled, don't pass dynamic_axes
Expand Down
5 changes: 5 additions & 0 deletions olive/passes/pytorch/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def create_training_args(self) -> transformers.TrainingArguments:
if version.parse(transformers_version) < version.parse("4.41") and "eval_strategy" in args:
args["evaluation_strategy"] = args.pop("eval_strategy")
extra_args = args.pop("extra_args")
# Filter out fields that are not valid TrainingArguments parameters (e.g. overwrite_output_dir
# was removed in transformers 5.0 but is still used by Olive's own logic) and None values
# so that transformers uses its own defaults
training_args_fields = {f.name for f in dataclasses.fields(transformers.TrainingArguments) if f.init}
args = {k: v for k, v in args.items() if k in training_args_fields and v is not None}
return transformers.TrainingArguments(**args, **extra_args)


Expand Down
30 changes: 11 additions & 19 deletions test/model/test_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,16 @@ def setup(self):
self.local_path = huggingface_hub.snapshot_download(self.model_name, revision=self.revision)

@pytest.mark.parametrize("local", [True, False])
@pytest.mark.parametrize("trust_remote_code", [True, False])
def test_load_model(self, local, trust_remote_code):
def test_load_model(self, local):
olive_model = HfModelHandler(
model_path=self.local_path if local else self.model_name,
task=self.task,
load_kwargs={"trust_remote_code": trust_remote_code, "revision": self.revision},
load_kwargs={"revision": self.revision},
)

pytorch_model = olive_model.load_model()
actual_class_path = f"{pytorch_model.__module__}.{pytorch_model.__class__.__name__}"
if trust_remote_code:
# When using remote code, the model is loaded from transformers_modules
assert actual_class_path.startswith("transformers_modules.")
assert actual_class_path.endswith(".modeling_phi3.Phi3ForCausalLM")
else:
# When not using remote code, the model is loaded from transformers
assert actual_class_path == "transformers.models.phi3.modeling_phi3.Phi3ForCausalLM"
assert actual_class_path == "transformers.models.phi3.modeling_phi3.Phi3ForCausalLM"

@pytest.mark.parametrize("local", [True, False])
def test_load_model_with_kwargs(self, local):
Expand Down Expand Up @@ -73,19 +66,18 @@ def test_save_metadata(self, local, trust_remote_code, tokenizer_exists, tmp_pat
if tokenizer_exists:
olive_model.get_hf_tokenizer().save_pretrained(tmp_path)
saved_filepaths = olive_model.save_metadata(tmp_path)
# transformers>=4.53.x
assert len(saved_filepaths) == (4 if tokenizer_exists else 10)
# transformers>=5.0.0
assert len(saved_filepaths) == (4 if tokenizer_exists else 7)
assert all(Path(fp).exists() for fp in saved_filepaths)
assert isinstance(transformers.AutoConfig.from_pretrained(tmp_path), transformers.Phi3Config)
assert isinstance(transformers.AutoTokenizer.from_pretrained(tmp_path), transformers.LlamaTokenizerFast)
assert isinstance(transformers.AutoTokenizer.from_pretrained(tmp_path), transformers.PreTrainedTokenizerBase)

@pytest.mark.parametrize("local", [True, False])
@pytest.mark.parametrize("trust_remote_code", [True, False])
def test_save_pretrained_metadata(self, local, trust_remote_code, tmp_path):
def test_save_pretrained_metadata(self, local, tmp_path):
olive_model = HfModelHandler(
model_path=self.local_path if local else self.model_name,
task=self.task,
load_kwargs={"trust_remote_code": trust_remote_code, "revision": self.revision},
load_kwargs={"revision": self.revision},
)

# modify the config and save the model
Expand All @@ -94,8 +86,8 @@ def test_save_pretrained_metadata(self, local, trust_remote_code, tmp_path):
loaded_model.save_pretrained(tmp_path)

saved_filepaths = olive_model.save_metadata(tmp_path)
# generation config is also saved, transformers>=4.53.x
assert len(saved_filepaths) == 9
# generation config is also saved, transformers>=5.0.0
assert len(saved_filepaths) == 6

with open(tmp_path / "config.json") as f:
config = json.load(f)
Expand Down Expand Up @@ -126,7 +118,7 @@ def test_save_metadata_with_module_files(trust_remote_code, tmp_path):
assert f"{config.__module__}.{config.__class__.__name__}" == expected_class_name
assert isinstance(
transformers.AutoTokenizer.from_pretrained(tmp_path, **load_kwargs),
transformers.LlamaTokenizerFast,
transformers.PreTrainedTokenizerBase,
)


Expand Down
3 changes: 2 additions & 1 deletion test/passes/pytorch/test_rotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def common_test_rotate(rotate_pass, tmp_path, model_path, rotate_mode, atol, **c
with torch.no_grad():
original_output = original_model(i)
rotated_output = rotated_model(i)
assert torch.allclose(original_output.logits, rotated_output.logits, atol=atol)
# Cast to same dtype before comparison since rotated model may be saved/loaded in a different dtype
assert torch.allclose(original_output.logits.float(), rotated_output.logits.float(), atol=atol)


@pytest.mark.parametrize("model_path", ["tiny-phi3", "tiny-llama"])
Expand Down
2 changes: 0 additions & 2 deletions test/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,3 @@ sentencepiece
soundfile
tabulate
torchvision
# Remove version pin when the tests are fixed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should you add transformers>=5 unless there are some tests checking multiple versions of transformers.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of transformers>=5, i want to support both cases by adding transformers version check

transformers<5.0.0
Loading