Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions keras_hub/src/models/moonshine/moonshine_audio_to_text_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import keras
import numpy as np
import pytest
from keras import ops

from keras_hub.src.models.moonshine.moonshine_audio_converter import (
MoonshineAudioConverter,
Expand Down Expand Up @@ -145,14 +146,31 @@ def test_saved_model(self):
input_data=self.input_data,
)

@pytest.mark.skip(
reason="TODO: Bug with MoonshineAudioToText liteRT export"
)
def test_litert_export(self):
# LiteRT inputs are strict about types.
# The model expects boolean masks, but the test data provides int32.

# 1. Convert ALL inputs to numpy first to avoid "mixing tensors" error.
input_data = {}
for k, v in self.input_data.items():
input_data[k] = ops.convert_to_numpy(v)

# 2. Force masks to boolean
if "encoder_padding_mask" in input_data:
input_data["encoder_padding_mask"] = np.array(
input_data["encoder_padding_mask"], dtype=bool
)

if "decoder_padding_mask" in input_data:
input_data["decoder_padding_mask"] = np.array(
input_data["decoder_padding_mask"], dtype=bool
)
Comment on lines +154 to +167
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This data preparation logic can be made more concise and less repetitive by using a dictionary comprehension for the initial conversion and a loop to handle the type casting for all relevant mask keys. This improves readability and maintainability.

Suggested change
input_data = {}
for k, v in self.input_data.items():
input_data[k] = ops.convert_to_numpy(v)
# 2. Force masks to boolean
if "encoder_padding_mask" in input_data:
input_data["encoder_padding_mask"] = np.array(
input_data["encoder_padding_mask"], dtype=bool
)
if "decoder_padding_mask" in input_data:
input_data["decoder_padding_mask"] = np.array(
input_data["decoder_padding_mask"], dtype=bool
)
input_data = {k: ops.convert_to_numpy(v) for k, v in self.input_data.items()}
# 2. Force masks to boolean
for mask_key in ("encoder_padding_mask", "decoder_padding_mask"):
if mask_key in input_data:
input_data[mask_key] = input_data[mask_key].astype(bool)


self.run_litert_export_test(
cls=MoonshineAudioToText,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
input_data=input_data,
strict_input_types=True,
)

@pytest.mark.extra_large
Expand Down
10 changes: 7 additions & 3 deletions keras_hub/src/tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,8 +561,8 @@ def run_litert_export_test(
expected_output_shape=None,
model=None,
verify_numerics=True,
# No LiteRT output in model saving test; remove undefined return
output_thresholds=None,
strict_input_types=False, # Defaults to False to preserve legacy behavior
**export_kwargs,
):
"""Export model to LiteRT format and verify outputs.
Expand All @@ -587,6 +587,9 @@ def run_litert_export_test(
with "max" and "mean" keys. Use "*" as wildcard for defaults.
Example: {"output1": {"max": 1e-4, "mean": 1e-5},
"*": {"max": 1e-3, "mean": 1e-4}}
strict_input_types: bool. If True, input data types (specifically bools)
are preserved as-is. If False (default), bools are converted to
int32 for compatibility with older models.
**export_kwargs: Additional keyword arguments to pass to
model.export(), such as allow_custom_ops=True or
enable_select_tf_ops=True.
Expand Down Expand Up @@ -705,14 +708,15 @@ def convert_for_tflite(x):
if hasattr(x, "dtype"):
if isinstance(x, np.ndarray):
if x.dtype == bool:
return x.astype(np.int32)
# Use strict mode if requested, otherwise legacy int32 conversion
return x if strict_input_types else x.astype(np.int32)
elif x.dtype == np.float64:
return x.astype(np.float32)
elif x.dtype == np.int64:
return x.astype(np.int32)
else: # TensorFlow tensor
if x.dtype == tf.bool:
return ops.cast(x, "int32").numpy()
return x.numpy() if strict_input_types else ops.cast(x, "int32").numpy()
elif x.dtype == tf.float64:
return ops.cast(x, "float32").numpy()
elif x.dtype == tf.int64:
Expand Down
Loading