Skip to content

Commit 217399d

Browse files
fix bug
1 parent 6aeb1d8 commit 217399d

File tree

5 files changed

+90
-10
lines changed

5 files changed

+90
-10
lines changed

optimum/exporters/openvino/convert.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,20 +433,39 @@ def ts_patched_forward(*args, **kwargs):
433433
if patch_16bit_model:
434434
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable
435435

436+
import psutil
437+
proc = psutil.Process()
438+
print(f"[DEBUG] Before __make_16bit_traceable: RSS={proc.memory_info().rss / 1e9:.1f} GB", flush=True)
436439
__make_16bit_traceable(model)
437-
440+
print(f"[DEBUG] After __make_16bit_traceable: RSS={proc.memory_info().rss / 1e9:.1f} GB", flush=True)
441+
442+
# Allow patcher to free duplicated memory after 16-bit tracing setup.
443+
# __make_16bit_traceable calls module.float() on non-Linear/Embedding modules,
444+
# creating fp32 copies of parameters already captured as bf16 views by the patcher.
445+
# This hook frees those fp32 duplicates to avoid OOM on large MoE models.
446+
if hasattr(patcher, "post_make_16bit_traceable"):
447+
print(f"[DEBUG] Calling post_make_16bit_traceable hook", flush=True)
448+
patcher.post_make_16bit_traceable()
449+
print(f"[DEBUG] After post_make_16bit_traceable: RSS={proc.memory_info().rss / 1e9:.1f} GB", flush=True)
450+
else:
451+
print(f"[DEBUG] No post_make_16bit_traceable hook found on patcher={type(patcher).__name__}", flush=True)
438452
conversion_extensions = getattr(patcher, "conversion_extensions", [])
439453
module_extensions = getattr(patcher, "module_extensions", None)
440454
if module_extensions is not None:
441455
ts_decoder_kwargs["module_extensions"] = module_extensions
442456

443-
ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs)
457+
example_input = dummy_inputs
458+
459+
print(f"[DEBUG] Before TorchScriptPythonDecoder: RSS={proc.memory_info().rss / 1e9:.1f} GB" if 'proc' in dir() else "[DEBUG] Before TorchScriptPythonDecoder", flush=True)
460+
ts_decoder = TorchScriptPythonDecoder(model, example_input=example_input, **ts_decoder_kwargs)
461+
print(f"[DEBUG] After TorchScriptPythonDecoder: RSS={proc.memory_info().rss / 1e9:.1f} GB" if 'proc' in dir() else "[DEBUG] After TorchScriptPythonDecoder", flush=True)
444462
ov_model = convert_model(
445463
ts_decoder,
446-
example_input=dummy_inputs,
464+
example_input=example_input,
447465
input=[(item.shape, item.type) for item in input_info],
448466
extension=conversion_extensions,
449467
)
468+
print(f"[DEBUG] After convert_model: RSS={proc.memory_info().rss / 1e9:.1f} GB" if 'proc' in dir() else "[DEBUG] After convert_model", flush=True)
450469

451470
ov_model.validate_nodes_and_infer_types() # TODO: remove as unnecessary validation?
452471

optimum/exporters/openvino/model_configs.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1724,6 +1724,19 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
17241724
dummy_inputs["token_type_ids"] = self.orig_export_config.DUMMY_INPUT_GENERATOR_CLASSES[
17251725
0
17261726
].random_int_tensor(token_type_ids_shape, min_value=0, max_value=2)
1727+
1728+
# Generate dummy inputs for any extra entries from input_info_upd (e.g., position_ids)
1729+
if self.input_info_upd:
1730+
dummy_inputs_generators = self.orig_export_config._create_dummy_input_generator_classes(**kwargs)
1731+
for input_name in self.input_info_upd:
1732+
if input_name not in dummy_inputs:
1733+
for dummy_input_gen in dummy_inputs_generators:
1734+
if dummy_input_gen.supports_input(input_name):
1735+
dummy_inputs[input_name] = self.orig_export_config.overwrite_shape_and_generate_input(
1736+
dummy_input_gen, input_name, framework, input_shapes=kwargs,
1737+
)
1738+
break
1739+
17271740
return dummy_inputs
17281741

17291742

@@ -5510,6 +5523,8 @@ def __init__(
55105523
self.batch_size = batch_size
55115524
self.normalized_config = normalized_config
55125525
self.hidden_size = self.normalized_config.hidden_size
5526+
self.num_key_value_heads = self.normalized_config.num_key_value_heads
5527+
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.normalized_config.num_attention_heads)
55135528
self.linear_key_head_dim = config.linear_key_head_dim
55145529
self.linear_value_head_dim = config.linear_value_head_dim
55155530
self.linear_num_key_heads = config.linear_num_key_heads
@@ -5542,7 +5557,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
55425557
self.batch_size,
55435558
self.num_key_value_heads,
55445559
self.sequence_length,
5545-
self.hidden_size // self.num_attention_heads,
5560+
self.head_dim,
55465561
)
55475562
k = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype)
55485563
v = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype)

optimum/exporters/openvino/model_patcher.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8590,12 +8590,12 @@ def patched_qwen3_5_moe_sparse_moe_block(self, hidden_states: torch.Tensor) -> t
85908590
hidden_expanded = hidden_states_reshaped.unsqueeze(0).expand(num_experts, -1, -1)
85918591

85928592
# Vectorized expert computation using pre-transposed weights
8593-
gate_up = torch.bmm(hidden_expanded, self._gate_up_projs_t)
8593+
gate_up = torch.bmm(hidden_expanded, self._gate_up_projs_t.to(hidden_expanded.dtype))
85948594
intermediate_size = self.experts.intermediate_dim
85958595
gate = gate_up[:, :, :intermediate_size]
85968596
up = gate_up[:, :, intermediate_size:]
85978597
activated = self.experts.act_fn(gate) * up
8598-
next_states = torch.bmm(activated, self._down_projs_t)
8598+
next_states = torch.bmm(activated, self._down_projs_t.to(activated.dtype))
85998599

86008600
# Weight by routing and sum over experts
86018601
next_states = next_states * new_routing_weights.T.unsqueeze(-1)
@@ -8914,6 +8914,26 @@ def __enter__(self):
89148914
patched_qwen3_5_moe_sparse_moe_block, sparse_moe_block
89158915
)
89168916

8917+
def post_make_16bit_traceable(self):
8918+
"""Free duplicated expert weights after __make_16bit_traceable.
8919+
8920+
__make_16bit_traceable calls module.float() on Qwen3_5MoeExperts modules,
8921+
creating fp32 copies of gate_up_proj and down_proj parameters. Since patcher
8922+
already captured bf16 views (_gate_up_projs_t, _down_projs_t) for the patched
8923+
forward, the fp32 copies are unused waste. Free them to avoid OOM.
8924+
"""
8925+
import gc
8926+
8927+
import torch
8928+
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeSparseMoeBlock
8929+
8930+
for decoder_layer in self._model.model.layers:
8931+
if isinstance(decoder_layer.mlp, Qwen3_5MoeSparseMoeBlock):
8932+
experts = decoder_layer.mlp.experts
8933+
experts.gate_up_proj.data = torch.empty(0)
8934+
experts.down_proj.data = torch.empty(0)
8935+
gc.collect()
8936+
89178937
def __exit__(self, exc_type, exc_value, traceback):
89188938
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeSparseMoeBlock
89198939

@@ -9025,6 +9045,8 @@ def has_previous_state(self):
90259045
layer_idx = self.linear_attn_mapping[self.last_linear_layer]
90269046
return self.conv_states[layer_idx] is not None
90279047

9048+
_lm_head_weight = model.lm_head.weight
9049+
90289050
def patched_forward(
90299051
inputs_embeds,
90309052
attention_mask=None,
@@ -9063,7 +9085,7 @@ def patched_forward(
90639085
use_cache=use_cache,
90649086
)
90659087
hidden_states = outputs[0]
9066-
logits = model.lm_head(hidden_states)
9088+
logits = torch.nn.functional.linear(hidden_states, _lm_head_weight.to(hidden_states.dtype))
90679089

90689090
result = {"logits": logits}
90699091

@@ -9178,6 +9200,8 @@ def has_previous_state(self):
91789200
layer_idx = self.linear_attn_mapping[self.last_linear_layer]
91799201
return self.conv_states[layer_idx] is not None
91809202

9203+
_lm_head_weight = model.lm_head.weight
9204+
91819205
def patched_forward(
91829206
inputs_embeds,
91839207
attention_mask=None,
@@ -9216,7 +9240,7 @@ def patched_forward(
92169240
use_cache=use_cache,
92179241
)
92189242
hidden_states = outputs[0]
9219-
logits = model.lm_head(hidden_states)
9243+
logits = torch.nn.functional.linear(hidden_states, _lm_head_weight.to(hidden_states.dtype))
92209244

92219245
result = {"logits": logits}
92229246

@@ -9271,6 +9295,26 @@ def __enter__(self):
92719295
patched_qwen3_5_moe_sparse_moe_block, sparse_moe_block
92729296
)
92739297

9298+
def post_make_16bit_traceable(self):
9299+
"""Free duplicated expert weights after __make_16bit_traceable.
9300+
9301+
__make_16bit_traceable calls module.float() on Qwen3_5MoeExperts modules,
9302+
creating fp32 copies of gate_up_proj and down_proj parameters. Since patcher
9303+
already captured bf16 views (_gate_up_projs_t, _down_projs_t) for the patched
9304+
forward, the fp32 copies are unused waste. Free them to avoid OOM.
9305+
"""
9306+
import gc
9307+
9308+
import torch
9309+
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeSparseMoeBlock
9310+
9311+
for decoder_layer in self._model.model.language_model.layers:
9312+
if isinstance(decoder_layer.mlp, Qwen3_5MoeSparseMoeBlock):
9313+
experts = decoder_layer.mlp.experts
9314+
experts.gate_up_proj.data = torch.empty(0)
9315+
experts.down_proj.data = torch.empty(0)
9316+
gc.collect()
9317+
92749318
def __exit__(self, exc_type, exc_value, traceback):
92759319
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeSparseMoeBlock
92769320

optimum/exporters/openvino/stateful.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,14 +285,14 @@ def get_kv_ssm_tensor_names(ssm_prefix_names: list, kv_prefix_names: list, ov_te
285285
other_names.append(ov_tensor_name)
286286
return kv_names, ssm_names, other_names
287287

288-
ssm_prefix_input_names = ["cache_params.past.ssm", "cache_params.past.conv"]
288+
ssm_prefix_input_names = ["cache_params.past.ssm", "cache_params.past.conv", "cache_params.past.recurrent"]
289289
kv_prefix_input_names = ["cache_params.past.key", "cache_params.past.value"]
290290
kv_input_names, ssm_input_names, other_input_names = get_kv_ssm_tensor_names(
291291
ssm_prefix_input_names, kv_prefix_input_names, ov_model.inputs
292292
)
293293
not_kv_inputs = ssm_input_names + other_input_names
294294

295-
ssm_prefix_output_names = ["cache_params.present.ssm", "cache_params.present.conv"]
295+
ssm_prefix_output_names = ["cache_params.present.ssm", "cache_params.present.conv", "cache_params.present.recurrent"]
296296
kv_prefix_output_names = ["cache_params.present.key", "cache_params.present.value"]
297297
kv_output_names, ssm_output_names, _ = get_kv_ssm_tensor_names(
298298
ssm_prefix_output_names, kv_prefix_output_names, ov_model.outputs

optimum/exporters/openvino/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,11 @@ def _get_input_info(
9191
) -> List[InputInfo]:
9292
sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
9393
inputs = config.ordered_inputs(model)
94+
9495
ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs}
9596
if not ordered_dummy_inputs:
9697
ordered_dummy_inputs = dummy_inputs
98+
9799
ordered_input_names = list(inputs)
98100
flatten_inputs = flattenize_inputs(ordered_dummy_inputs.values())
99101
input_info = []

0 commit comments

Comments
 (0)