-
Notifications
You must be signed in to change notification settings - Fork 19.7k
[OpenVINO backend] Support for LSTM #22313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
5f9c2df
ac9384e
d08e3db
d8b91b8
9683ce2
f4ab9bc
c2e50b1
107f5a8
c6cff76
4e5eee4
8c2fd6a
15d8b3b
7943a2c
5a2f450
261acd8
486ecb6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |||||
| from openvino import Model | ||||||
| from openvino import Type | ||||||
|
|
||||||
| from keras.src import activations | ||||||
| from keras.src import tree | ||||||
| from keras.src.backend.openvino.core import OpenVINOKerasTensor | ||||||
| from keras.src.backend.openvino.core import get_ov_output | ||||||
|
|
@@ -105,7 +106,7 @@ def _slice_at_0(x): | |||||
| x_ov = get_ov_output(x) | ||||||
| slice_0 = ov_opset.gather( | ||||||
| x_ov, | ||||||
| ov_opset.constant([0], dtype=Type.i32).output(0), | ||||||
| ov_opset.constant(0, dtype=Type.i32).output(0), | ||||||
| ov_opset.constant(0, dtype=Type.i32).output(0), | ||||||
| ).output(0) | ||||||
| return OpenVINOKerasTensor(slice_0) | ||||||
|
|
@@ -130,10 +131,10 @@ def _slice_at_0(x): | |||||
| inp_ov = get_ov_output(inp) | ||||||
| pshape = inp_ov.get_partial_shape() | ||||||
| if pshape.rank.is_static: | ||||||
| new_shape = list(pshape)[1:] | ||||||
| new_shape = [1] + list(pshape)[1:] | ||||||
| else: | ||||||
| new_shape = ( | ||||||
| [-1] * (pshape.rank.get_length() - 1) | ||||||
| [-1] * (pshape.rank.get_length()) | ||||||
| if pshape.rank.is_static | ||||||
| else None | ||||||
| ) | ||||||
|
|
@@ -144,7 +145,7 @@ def _slice_at_0(x): | |||||
| if mask is not None: | ||||||
| mask_ov = get_ov_output(mask) | ||||||
| pshape = mask_ov.get_partial_shape() | ||||||
| new_shape = list(pshape)[1:] if pshape.rank.is_static else None | ||||||
| new_shape = [1] + list(pshape)[1:] if pshape.rank.is_static else None | ||||||
| param = ov_opset.parameter(new_shape, mask_ov.get_element_type()) | ||||||
| sliced_mask_params.append(param) | ||||||
| params.append(param) | ||||||
|
|
@@ -172,9 +173,13 @@ def _slice_at_0(x): | |||||
| ) | ||||||
| constants_params.append(param) | ||||||
| params.append(param) | ||||||
| sliced_inputs_t = [ | ||||||
| OpenVINOKerasTensor(p.output(0)) for p in sliced_inputs_params | ||||||
| ] | ||||||
| sliced_inputs_t = [] | ||||||
| for p in sliced_inputs_params: | ||||||
| p_out = p.output(0) | ||||||
| p_out = ov_opset.squeeze( | ||||||
| p_out, ov_opset.constant([0], dtype=Type.i32).output(0) | ||||||
| ).output(0) | ||||||
| sliced_inputs_t.append(OpenVINOKerasTensor(p_out)) | ||||||
| merged_states_t = [ | ||||||
| OpenVINOKerasTensor(p.output(0)) for p in merged_states_params | ||||||
| ] | ||||||
|
|
@@ -192,6 +197,9 @@ def _slice_at_0(x): | |||||
| final_last_output_list = [] | ||||||
| if mask is not None: | ||||||
| mask_t = sliced_mask_params[0].output(0) | ||||||
| mask_t = ov_opset.squeeze( | ||||||
| mask_t, ov_opset.constant([0], dtype=Type.i32).output(0) | ||||||
| ).output(0) | ||||||
| for i, (new_st, old_st) in enumerate( | ||||||
| zip(flat_step_new_states, merged_states_t) | ||||||
| ): | ||||||
|
|
@@ -225,7 +233,15 @@ def _slice_at_0(x): | |||||
| final_states_list = [get_ov_output(x) for x in flat_step_new_states] | ||||||
| final_output_list = [get_ov_output(x) for x in flat_step_output] | ||||||
| final_last_output_list = [get_ov_output(x) for x in flat_step_output] | ||||||
| zero_const = ov_opset.constant([0], dtype=Type.i32).output(0) | ||||||
| final_output_list = [ | ||||||
| ov_opset.unsqueeze(x, zero_const).output(0) for x in final_output_list | ||||||
| ] | ||||||
| cond_const = ov_opset.constant(True, Type.boolean).output(0) | ||||||
| zero_const = ov_opset.constant([0], dtype=Type.i32).output(0) | ||||||
| final_output_list = [ | ||||||
| ov_opset.unsqueeze(x, zero_const).output(0) for x in final_output_list | ||||||
| ] | ||||||
| results = ( | ||||||
| [cond_const] | ||||||
| + final_states_list | ||||||
|
|
@@ -284,8 +300,162 @@ def _slice_at_0(x): | |||||
| return last_output, outputs, new_states | ||||||
|
|
||||||
|
|
||||||
| def lstm(*args, **kwargs): | ||||||
| raise NotImplementedError("`lstm` is not supported with openvino backend") | ||||||
| def lstm( | ||||||
| inputs, | ||||||
| initial_state_h, | ||||||
| initial_state_c, | ||||||
| mask, | ||||||
| kernel, | ||||||
| recurrent_kernel, | ||||||
| bias, | ||||||
| activation, | ||||||
| recurrent_activation, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||||||
| return_sequences=False, | ||||||
| go_backwards=False, | ||||||
| unroll=False, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| time_major=False, | ||||||
|
Comment on lines
+315
to
+316
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| ): | ||||||
| if mask is not None: | ||||||
| raise NotImplementedError("lstm sequence with mask is not supported") | ||||||
| mask = get_ov_output(mask) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| inputs = get_ov_output(inputs) | ||||||
| initial_state_h = get_ov_output(initial_state_h) | ||||||
| initial_state_c = get_ov_output(initial_state_c) | ||||||
| kernel = get_ov_output(kernel) | ||||||
| recurrent_kernel = get_ov_output(recurrent_kernel) | ||||||
| shape = ov_opset.shape_of(inputs, Type.i32).output(0) | ||||||
| if not time_major: | ||||||
| batch_size = ov_opset.gather( | ||||||
| shape, | ||||||
| ov_opset.constant([0], Type.i32), | ||||||
| ov_opset.constant(0, Type.i32), | ||||||
| ).output(0) | ||||||
| seq_length = ov_opset.gather( | ||||||
| shape, | ||||||
| ov_opset.constant([1], Type.i32), | ||||||
| ov_opset.constant(0, Type.i32), | ||||||
| ).output(0) | ||||||
| else: | ||||||
| batch_size = ov_opset.gather( | ||||||
| shape, | ||||||
| ov_opset.constant([1], Type.i32), | ||||||
| ov_opset.constant(0, Type.i32), | ||||||
| ).output(0) | ||||||
| seq_length = ov_opset.gather( | ||||||
| shape, | ||||||
| ov_opset.constant([0], Type.i32), | ||||||
| ov_opset.constant(0, Type.i32), | ||||||
| ).output(0) | ||||||
| seq_length_tensor = ov_opset.broadcast(seq_length, batch_size).output(0) | ||||||
| if time_major: | ||||||
| axes = ov_opset.constant([1, 0, 2], Type.i32).output(0) | ||||||
| inputs = ov_opset.transpose(inputs, axes).output(0) | ||||||
| kernel_T = ov_opset.transpose( | ||||||
| kernel, ov_opset.constant([1, 0], Type.i32) | ||||||
| ).output(0) | ||||||
| k_i, k_f, k_c, k_o = ov_opset.split( | ||||||
| kernel_T, ov_opset.constant(0, Type.i32), 4 | ||||||
| ).outputs() | ||||||
| W = ov_opset.concat([k_f, k_i, k_c, k_o], axis=0).output(0) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The concatenation order for weights
Suggested change
|
||||||
| W = ov_opset.unsqueeze(W, ov_opset.constant([0], Type.i32)).output(0) | ||||||
| recurrent_kernel_T = ov_opset.transpose( | ||||||
| recurrent_kernel, ov_opset.constant([1, 0], Type.i32) | ||||||
| ).output(0) | ||||||
| r_i, r_f, r_c, r_o = ov_opset.split( | ||||||
| recurrent_kernel_T, ov_opset.constant(0, Type.i32), 4 | ||||||
| ).outputs() | ||||||
| R = ov_opset.concat([r_f, r_i, r_c, r_o], axis=0).output(0) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the kernel weights
Suggested change
|
||||||
| R = ov_opset.unsqueeze(R, ov_opset.constant([0], Type.i32)).output(0) | ||||||
| if bias is not None: | ||||||
| bias = get_ov_output(bias) | ||||||
| b_i, b_f, b_c, b_o = ov_opset.split( | ||||||
| bias, ov_opset.constant(0, Type.i32), 4 | ||||||
| ).outputs() | ||||||
| B = ov_opset.concat([b_f, b_i, b_c, b_o], axis=0).output(0) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the weights, the concatenation order for biases
Suggested change
|
||||||
| else: | ||||||
| W_shape = ov_opset.shape_of(W, Type.i32).output(0) | ||||||
| B_size = ov_opset.gather( | ||||||
| W_shape, | ||||||
| ov_opset.constant([1], Type.i32), | ||||||
| ov_opset.constant(0, Type.i32), | ||||||
| ).output(0) | ||||||
| B = ov_opset.broadcast( | ||||||
| ov_opset.constant(0, inputs.get_element_type()), B_size | ||||||
| ).output(0) | ||||||
| B = ov_opset.unsqueeze(B, ov_opset.constant([0], Type.i32)).output(0) | ||||||
| initial_state_h = ov_opset.unsqueeze( | ||||||
| initial_state_h, ov_opset.constant([1], Type.i32) | ||||||
| ).output(0) | ||||||
| initial_state_c = ov_opset.unsqueeze( | ||||||
| initial_state_c, ov_opset.constant([1], Type.i32) | ||||||
| ).output(0) | ||||||
| direction = "forward" | ||||||
| if go_backwards: | ||||||
| direction = "reverse" | ||||||
|
|
||||||
| def get_activation_name(act): | ||||||
|
|
||||||
| if act == activations.tanh: | ||||||
| return "tanh" | ||||||
| if act == activations.sigmoid: | ||||||
| return "sigmoid" | ||||||
| if act == activations.relu: | ||||||
| return "relu" | ||||||
| if act == activations.linear: | ||||||
| return "linear" | ||||||
| if hasattr(act, "__name__"): | ||||||
| return act.__name__ | ||||||
| return "tanh" | ||||||
|
Comment on lines
+396
to
+408
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function can be made more concise and maintainable by using a dictionary to map activation functions to their string names. This is a common pattern for this type of logic and would improve readability. For example: ACTIVATION_MAP = {
activations.tanh: "tanh",
activations.sigmoid: "sigmoid",
activations.relu: "relu",
activations.linear: "linear",
}
def get_activation_name(act):
name = ACTIVATION_MAP.get(act)
if name:
return name
if hasattr(act, "__name__"):
return act.__name__
return "tanh" |
||||||
|
|
||||||
| act_names = [ | ||||||
| get_activation_name(recurrent_activation), | ||||||
| get_activation_name(activation), | ||||||
| get_activation_name(activation), | ||||||
| ] | ||||||
| hidden_size = ( | ||||||
| R.get_partial_shape()[2].get_length() | ||||||
| if R.get_partial_shape().rank.is_static | ||||||
| and R.get_partial_shape()[2].is_static | ||||||
| else -1 | ||||||
| ) | ||||||
| lstm_node = ov_opset.lstm_sequence( | ||||||
| X=inputs, | ||||||
| initial_hidden_state=initial_state_h, | ||||||
| initial_cell_state=initial_state_c, | ||||||
| sequence_lengths=seq_length_tensor, | ||||||
| W=W, | ||||||
| R=R, | ||||||
| B=B, | ||||||
| hidden_size=hidden_size, | ||||||
| direction=direction, | ||||||
| activations=act_names, | ||||||
| ) | ||||||
| Y = lstm_node.output(0) | ||||||
| Ho = lstm_node.output(1) | ||||||
| Co = lstm_node.output(2) | ||||||
| Y = ov_opset.squeeze(Y, ov_opset.constant([1], Type.i32)).output(0) | ||||||
| Ho = ov_opset.squeeze(Ho, ov_opset.constant([1], Type.i32)).output(0) | ||||||
| Co = ov_opset.squeeze(Co, ov_opset.constant([1], Type.i32)).output(0) | ||||||
| if not time_major: | ||||||
| outputs = Y | ||||||
| else: | ||||||
| axes = ov_opset.constant([1, 0, 2], Type.i32).output(0) | ||||||
| outputs = ov_opset.transpose(Y, axes).output(0) | ||||||
| if not return_sequences: | ||||||
| if time_major: | ||||||
| outputs = ov_opset.unsqueeze( | ||||||
| Ho, ov_opset.constant([0], Type.i32) | ||||||
| ).output(0) | ||||||
| else: | ||||||
| outputs = ov_opset.unsqueeze( | ||||||
| Ho, ov_opset.constant([1], Type.i32) | ||||||
| ).output(0) | ||||||
|
Comment on lines
+444
to
+452
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| last_output = Ho | ||||||
| return ( | ||||||
| OpenVINOKerasTensor(last_output), | ||||||
| OpenVINOKerasTensor(outputs), | ||||||
| [OpenVINOKerasTensor(Ho), OpenVINOKerasTensor(Co)], | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| def gru(*args, **kwargs): | ||||||
|
|
@@ -305,4 +475,4 @@ def numpy_scan(f, init, xs, reverse=False, mask=None): | |||||
|
|
||||||
|
|
||||||
| def cudnn_ok(*args, **kwargs): | ||||||
| return False | ||||||
| return False | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code, which redefines
zero_constand re-assignsfinal_output_listby unsqueezing its elements, is a duplicate of the block on lines 236-239. This is redundant and should be removed to avoid confusion and potential bugs.