Skip to content

Commit 42d3f63

Browse files
authored
[OpenVINO] Connect beam_idx input to Linear Attention Layers (CausalConv, SSM, GDN) (huggingface#1619)
* [OpenVINO] Support beam search for Linear Attention Layers (CausalConv, SSM, GDN) * Fix formatting issue * Remove unneeded import * Apply correct patching for attention mask * Deprecate arguments in fuse in a correct way * Apply suggestions from code review * Correct elements in not_ov_cache_inputs * Check that cache input names provided * Fix documentation for fuse_cache_reorder Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Apply suggestions from code review * Apply suggestion from @rkazants * Apply suggestions from code review * Apply suggestions from code review * Update optimum/exporters/openvino/stateful.py * Apply suggestion from @rkazants * Apply suggestion from @rkazants * Apply suggestion from @rkazants * Fix internal function get_kv_ssm_tensor_names Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
1 parent d2b21bf commit 42d3f63

2 files changed

Lines changed: 17 additions & 26 deletions

File tree

optimum/exporters/openvino/model_patcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7353,7 +7353,7 @@ def granite_moe_hybrid_update_causal_mask(
73537353
return causal_mask
73547354

73557355

7356-
class GraniteMoeHybridModelPatcher(ModelPatcher):
7356+
class GraniteMoeHybridModelPatcher(OVDecoderModelPatcher):
73577357
def __init__(
73587358
self,
73597359
config: "OnnxConfig",

optimum/exporters/openvino/stateful.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -53,22 +53,23 @@ def fuse_cache_reorder(
5353
gather_dim: int,
5454
):
5555
"""
56-
Fuses reored_cache during generate cycle into ov.Model. Used with stateful models, because we can not modify model state directly.
56+
Fuses reordered_cache during generate cycle into ov.Model.
57+
Used with stateful models, because we can not modify model state directly.
5758
58-
Adds a new beam_idx parameter and Gather op per each kv-cache input in a given model.
59-
Should be run before make_stateful. Implements optimumum's _reorder_cache
59+
Adds a new beam_idx parameter and Gather op per each cache input in a given model.
60+
Should be run before make_stateful. Implements optimum's _reorder_cache
6061
inside the model in the beginning of each iteration.
6162
Gather works along given gather_dim dimension that may vary from model to model.
62-
KV-cache inputs are identified based on names in key_value_input_names.
63-
Append the new beam_idx parameter to not_kv_inputs.
63+
Inputs with cache states (with key, value, and fixed-sized cache) are identified based on names in `key_value_input_names`.
64+
Append the new beam_idx parameter to `not_kv_inputs`.
6465
6566
Parameters:
6667
ov_model (`ov.Model`):
6768
openvino model for processing
6869
not_kv_inputs (`List[str]`):
69-
list of input nodes in model that not related to past key values
70+
list of input nodes in model that not related to cache states
7071
key_value_input_names (`List[str]`):
71-
list of names for key value input layers
72+
list of names for input layers with key, value, and fixed-sized cache states
7273
gather_dim (int):
7374
dimension for gathering cache during reorder pass
7475
"""
@@ -262,13 +263,11 @@ def insert_state_for_nodes(model: ov.Model, nodes):
262263

263264

264265
def patch_stateful_hybrid_ssm(ov_model: ov.Model):
265-
from openvino._offline_transformations import apply_make_stateful_transformation
266-
267266
def get_kv_ssm_tensor_names(ssm_prefix_names: list, kv_prefix_names: list, ov_tensors):
268267
# return tensor names of model inputs/outputs tensors with KV and SSM states
269268
kv_names = []
270269
ssm_names = []
271-
other_names = []
270+
other_tensors = []
272271
for ov_tensor in ov_tensors:
273272
ov_tensor_names = ov_tensor.get_names()
274273
is_kv_or_ssm = False
@@ -282,36 +281,28 @@ def get_kv_ssm_tensor_names(ssm_prefix_names: list, kv_prefix_names: list, ov_te
282281
is_kv_or_ssm = True
283282
break
284283
if not is_kv_or_ssm:
285-
other_names.append(ov_tensor_name)
286-
return kv_names, ssm_names, other_names
284+
other_tensors.append(ov_tensor)
285+
return kv_names, ssm_names, other_tensors
287286

288287
ssm_prefix_input_names = ["cache_params.past.ssm", "cache_params.past.conv"]
289288
kv_prefix_input_names = ["cache_params.past.key", "cache_params.past.value"]
290-
kv_input_names, ssm_input_names, other_input_names = get_kv_ssm_tensor_names(
289+
kv_input_names, ssm_input_names, not_cache_inputs = get_kv_ssm_tensor_names(
291290
ssm_prefix_input_names, kv_prefix_input_names, ov_model.inputs
292291
)
293-
not_kv_inputs = ssm_input_names + other_input_names
292+
cache_inputs = kv_input_names + ssm_input_names
294293

295294
ssm_prefix_output_names = ["cache_params.present.ssm", "cache_params.present.conv"]
296295
kv_prefix_output_names = ["cache_params.present.key", "cache_params.present.value"]
297296
kv_output_names, ssm_output_names, _ = get_kv_ssm_tensor_names(
298297
ssm_prefix_output_names, kv_prefix_output_names, ov_model.outputs
299298
)
299+
cache_outputs = kv_output_names + ssm_output_names
300300

301301
# hybrid models can contain transformer blocks as well
302302
# so KV tensors must be handled properly
303303
batch_dim = 0
304-
if kv_input_names is not None and len(kv_input_names) > 0:
305-
fuse_cache_reorder(ov_model, not_kv_inputs, kv_input_names, batch_dim)
306-
make_stateful(ov_model, not_kv_inputs, kv_input_names, kv_output_names, batch_dim)
307-
308-
# create states for SSM cache
309-
input_output_map = {}
310-
for cache_name_pair in zip(ssm_input_names, ssm_output_names):
311-
input_output_map[cache_name_pair[0]] = cache_name_pair[1]
312-
313-
apply_make_stateful_transformation(ov_model, input_output_map)
314-
build_state_initializer(ov_model, batch_dim)
304+
fuse_cache_reorder(ov_model, not_cache_inputs, cache_inputs, batch_dim)
305+
make_stateful(ov_model, not_cache_inputs, cache_inputs, cache_outputs, batch_dim)
315306

316307

317308
def patch_stateful(config: PretrainedConfig, ov_model: ov.Model):

0 commit comments

Comments
 (0)