diff --git a/keras_hub/src/models/d_fine/d_fine_loss.py b/keras_hub/src/models/d_fine/d_fine_loss.py index d53e722a77..843228d5d9 100644 --- a/keras_hub/src/models/d_fine/d_fine_loss.py +++ b/keras_hub/src/models/d_fine/d_fine_loss.py @@ -619,12 +619,14 @@ def compute_ddf_loss_fn(): mask_flat = keras.ops.reshape(mask_expanded, (-1,)) loss_match_local1 = keras.ops.cond( keras.ops.any(mask_flat), - lambda: keras.ops.sum( - loss_match_local - * keras.ops.cast(mask_flat, loss_match_local.dtype) - ) - / keras.ops.sum( - keras.ops.cast(mask_flat, loss_match_local.dtype) + lambda: ( + keras.ops.sum( + loss_match_local + * keras.ops.cast(mask_flat, loss_match_local.dtype) + ) + / keras.ops.sum( + keras.ops.cast(mask_flat, loss_match_local.dtype) + ) ), lambda: keras.ops.convert_to_tensor( 0.0, dtype=loss_match_local.dtype @@ -633,12 +635,14 @@ def compute_ddf_loss_fn(): neg_mask_flat = keras.ops.logical_not(mask_flat) loss_match_local2 = keras.ops.cond( keras.ops.any(neg_mask_flat), - lambda: keras.ops.sum( - loss_match_local - * keras.ops.cast(neg_mask_flat, loss_match_local.dtype) - ) - / keras.ops.sum( - keras.ops.cast(neg_mask_flat, loss_match_local.dtype) + lambda: ( + keras.ops.sum( + loss_match_local + * keras.ops.cast(neg_mask_flat, loss_match_local.dtype) + ) + / keras.ops.sum( + keras.ops.cast(neg_mask_flat, loss_match_local.dtype) + ) ), lambda: keras.ops.convert_to_tensor( 0.0, dtype=loss_match_local.dtype diff --git a/keras_hub/src/models/d_fine/d_fine_object_detector_test.py b/keras_hub/src/models/d_fine/d_fine_object_detector_test.py index 7d022957d2..13fae90c66 100644 --- a/keras_hub/src/models/d_fine/d_fine_object_detector_test.py +++ b/keras_hub/src/models/d_fine/d_fine_object_detector_test.py @@ -153,6 +153,15 @@ def test_saved_model(self): input_data=self.images, ) + @pytest.mark.xfail( + strict=False, + reason=( + "Upstream torch.export limitation: D-FINE's multi-scale feature " + "computation triggers a data-dependent shape guard " + "(Ne(Mod(u2, 16), 0)), preventing successful torch.export. " + "Will pass once torch.export supports this pattern." + ), + ) def test_litert_export(self): backbone = DFineBackbone(**self.base_backbone_kwargs) init_kwargs = { diff --git a/keras_hub/src/models/deit/deit_image_classifier_test.py b/keras_hub/src/models/deit/deit_image_classifier_test.py index b112d3a400..3d90ae3cc2 100644 --- a/keras_hub/src/models/deit/deit_image_classifier_test.py +++ b/keras_hub/src/models/deit/deit_image_classifier_test.py @@ -12,7 +12,7 @@ class DeiTImageClassifierTest(TestCase): def setUp(self): - self.images = np.ones((2, 28, 28, 3)) + self.images = np.ones((2, 28, 28, 3), dtype="float32") self.labels = [0, 1] self.backbone = DeiTBackbone( image_shape=(28, 28, 3), diff --git a/keras_hub/src/models/f_net/f_net_text_classifier_test.py b/keras_hub/src/models/f_net/f_net_text_classifier_test.py index a45c50e2f0..99c2fae96a 100644 --- a/keras_hub/src/models/f_net/f_net_text_classifier_test.py +++ b/keras_hub/src/models/f_net/f_net_text_classifier_test.py @@ -57,6 +57,15 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.xfail( + strict=False, + reason=( + "Upstream litert-torch limitation: FNet uses ops.fft2 which " + "produces aten.complex tensors. litert-torch has no lowering for " + "aten.complex.default. Will pass once complex tensor ops are " + "supported." + ), + ) def test_litert_export(self): # F-Net does NOT use padding_mask - it only uses token_ids and # segment_ids. Don't add padding_mask to input_data. diff --git a/keras_hub/src/models/flux/flux_backbone_test.py b/keras_hub/src/models/flux/flux_backbone_test.py index 17bd5ad6f2..0e14715084 100644 --- a/keras_hub/src/models/flux/flux_backbone_test.py +++ b/keras_hub/src/models/flux/flux_backbone_test.py @@ -84,6 +84,14 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.xfail( + strict=False, + reason=( + "Upstream torch.export limitation: Flux's attention reshape uses " + "a dynamic num_heads value, causing GuardOnDataDependentSymNode. " + "Will pass once torch.export supports data-dependent shapes here." + ), + ) def test_litert_export(self): self.run_litert_export_test( cls=FluxBackbone, diff --git a/keras_hub/src/models/gemma/gemma_attention.py b/keras_hub/src/models/gemma/gemma_attention.py index f66a4506ce..884ded38b8 100644 --- a/keras_hub/src/models/gemma/gemma_attention.py +++ b/keras_hub/src/models/gemma/gemma_attention.py @@ -187,7 +187,9 @@ def _compute_attention( ) if attention_mask is not None: - attention_mask = attention_mask[:, None, None, :, :] + # We add two dimensions at axis 1 and 2 to make it [B, 1, 1, S, S] + attention_mask = ops.expand_dims(attention_mask, axis=1) + attention_mask = ops.expand_dims(attention_mask, axis=1) orig_dtype = attention_logits.dtype attention_softmax = self.softmax(attention_logits, mask=attention_mask) attention_softmax = ops.cast(attention_softmax, orig_dtype) @@ -262,9 +264,10 @@ def call( ) # Wipe attn vec if there are no attended tokens. - no_attended_tokens = ops.all( - ops.equal(attention_mask, 0), axis=-1, keepdims=True - )[..., None] + no_attended_tokens = ops.expand_dims( + ops.all(ops.equal(attention_mask, 0), axis=-1, keepdims=True), + axis=-1, + ) attention_vec = ops.where( no_attended_tokens, ops.zeros_like(attention_vec), attention_vec ) diff --git a/keras_hub/src/models/gemma3/gemma3_attention.py b/keras_hub/src/models/gemma3/gemma3_attention.py index 39244db680..208bbd9765 100644 --- a/keras_hub/src/models/gemma3/gemma3_attention.py +++ b/keras_hub/src/models/gemma3/gemma3_attention.py @@ -229,7 +229,9 @@ def _compute_attention( ) if attention_mask is not None: - attention_mask = attention_mask[:, None, None, :, :] + # We add two dimensions at axis 1 and 2 to make it [B, 1, 1, S, S] + attention_mask = ops.expand_dims(attention_mask, axis=1) + attention_mask = ops.expand_dims(attention_mask, axis=1) orig_dtype = attention_logits.dtype attention_softmax = self.softmax(attention_logits, mask=attention_mask) attention_softmax = ops.cast(attention_softmax, orig_dtype) @@ -399,9 +401,10 @@ def call( ) # Wipe attn vec if there are no attended tokens. - no_attended_tokens = ops.all( - ops.equal(attention_mask, 0), axis=-1, keepdims=True - )[..., None] + no_attended_tokens = ops.expand_dims( + ops.all(ops.equal(attention_mask, 0), axis=-1, keepdims=True), + axis=-1, + ) attention_vec = ops.where( no_attended_tokens, ops.zeros_like(attention_vec), attention_vec ) diff --git a/keras_hub/src/models/gpt_oss/gpt_oss_attention.py b/keras_hub/src/models/gpt_oss/gpt_oss_attention.py index 01a7e9d69c..2de020d061 100644 --- a/keras_hub/src/models/gpt_oss/gpt_oss_attention.py +++ b/keras_hub/src/models/gpt_oss/gpt_oss_attention.py @@ -280,7 +280,7 @@ def _compute_attention( else: adder = ops.cast(-1e4, self.compute_dtype) attention_scores = ops.where( - attention_mask[:, None, :, :], attention_scores, adder + ops.expand_dims(attention_mask, axis=1), attention_scores, adder ) # Handle sink tokens by concatenating them to the logits. diff --git a/keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_test.py b/keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_test.py index 3968af58d3..8decf4a496 100644 --- a/keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_test.py +++ b/keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_test.py @@ -108,6 +108,14 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.xfail( + strict=False, + reason=( + "Upstream litert-torch limitation: the NHWC layout rewriter does " + "not support aten.amax, causing 'NHWC node rewriter not found: " + "amax'. Will pass once litert-torch adds amax support." + ), + ) def test_litert_export(self): self.run_litert_export_test( cls=GptOssCausalLM, diff --git a/keras_hub/src/models/llama/llama_attention.py b/keras_hub/src/models/llama/llama_attention.py index fd1364ae7b..cda44ba50e 100644 --- a/keras_hub/src/models/llama/llama_attention.py +++ b/keras_hub/src/models/llama/llama_attention.py @@ -192,8 +192,15 @@ def _compute_key_value(x): def _masked_softmax(self, attention_scores, attention_mask=None): if attention_mask is not None: + # Use ops.expand_dims instead of Python None indexing + # (attention_mask[:, None, :, :]). Python None indexing traces + # as tf.StridedSlice(new_axis_mask) in the TF graph, which falls + # to the Flex delegate and is not supported by standalone + # ai_edge_litert (TF 2.20+). ops.expand_dims traces as the + # native TFLite ExpandDims op instead. return self._softmax( - attention_scores, attention_mask[:, None, :, :] + attention_scores, + ops.expand_dims(attention_mask, axis=1), ) return self._softmax(attention_scores) diff --git a/keras_hub/src/models/mistral/mistral_attention.py b/keras_hub/src/models/mistral/mistral_attention.py index 6916133b78..85de0e349c 100644 --- a/keras_hub/src/models/mistral/mistral_attention.py +++ b/keras_hub/src/models/mistral/mistral_attention.py @@ -191,7 +191,7 @@ def _compute_key_value(x): def _masked_softmax(self, attention_scores, attention_mask=None): if attention_mask is not None: return self._softmax( - attention_scores, attention_mask[:, None, :, :] + attention_scores, ops.expand_dims(attention_mask, axis=1) ) return self._softmax(attention_scores) diff --git a/keras_hub/src/models/mixtral/mixtral_attention.py b/keras_hub/src/models/mixtral/mixtral_attention.py index 0cae75a21c..31a159c62c 100644 --- a/keras_hub/src/models/mixtral/mixtral_attention.py +++ b/keras_hub/src/models/mixtral/mixtral_attention.py @@ -187,7 +187,9 @@ def _compute_key_value(x): def _masked_softmax(self, attention_scores, attention_mask=None): if attention_mask is not None: - return self.softmax(attention_scores, attention_mask[:, None, :, :]) + return self.softmax( + attention_scores, ops.expand_dims(attention_mask, axis=1) + ) return self.softmax(attention_scores) def _use_fused_attention_op(self): diff --git a/keras_hub/src/models/moonshine/moonshine_multi_head_attention.py b/keras_hub/src/models/moonshine/moonshine_multi_head_attention.py index 9fbc5948f6..ff772bc14b 100644 --- a/keras_hub/src/models/moonshine/moonshine_multi_head_attention.py +++ b/keras_hub/src/models/moonshine/moonshine_multi_head_attention.py @@ -328,9 +328,10 @@ def call( if final_mask is not None: mask_shape = keras.ops.shape(final_mask) if len(mask_shape) == 2: - final_mask = final_mask[:, None, None, :] + final_mask = keras.ops.expand_dims(final_mask, axis=1) + final_mask = keras.ops.expand_dims(final_mask, axis=1) elif len(mask_shape) == 3: - final_mask = final_mask[:, None, :, :] + final_mask = keras.ops.expand_dims(final_mask, axis=1) attention_kwargs = { k: v for k, v in kwargs.items() if k != "padding_mask" diff --git a/keras_hub/src/models/phi3/phi3_attention.py b/keras_hub/src/models/phi3/phi3_attention.py index a298d37211..c3ca6dd120 100644 --- a/keras_hub/src/models/phi3/phi3_attention.py +++ b/keras_hub/src/models/phi3/phi3_attention.py @@ -213,7 +213,9 @@ def call( def _masked_softmax(self, attention_scores, attention_mask=None): if attention_mask is not None: - return self.softmax(attention_scores, attention_mask[:, None, :, :]) + return self.softmax( + attention_scores, ops.expand_dims(attention_mask, axis=1) + ) return self.softmax(attention_scores) def _compute_attention(self, query, key, value, attention_mask=None): diff --git a/keras_hub/src/models/qwen/qwen_attention.py b/keras_hub/src/models/qwen/qwen_attention.py index 4b685956de..1ae55e9150 100644 --- a/keras_hub/src/models/qwen/qwen_attention.py +++ b/keras_hub/src/models/qwen/qwen_attention.py @@ -242,7 +242,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None): """ if attention_mask is not None: return self._softmax( - attention_scores, attention_mask[:, None, :, :] + attention_scores, ops.expand_dims(attention_mask, axis=1) ) return self._softmax(attention_scores) diff --git a/keras_hub/src/models/qwen3/qwen3_attention.py b/keras_hub/src/models/qwen3/qwen3_attention.py index a53e4ac501..d5545a2187 100644 --- a/keras_hub/src/models/qwen3/qwen3_attention.py +++ b/keras_hub/src/models/qwen3/qwen3_attention.py @@ -257,7 +257,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None): """ if attention_mask is not None: return self._softmax( - attention_scores, attention_mask[:, None, :, :] + attention_scores, ops.expand_dims(attention_mask, axis=1) ) return self._softmax(attention_scores) diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py index a5442e8da0..d6c14a473e 100644 --- a/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py +++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py @@ -258,7 +258,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None): """ if attention_mask is not None: return self._softmax( - attention_scores, attention_mask[:, None, :, :] + attention_scores, ops.expand_dims(attention_mask, axis=1) ) return self._softmax(attention_scores) diff --git a/keras_hub/src/models/qwen_moe/qwen_moe_attention.py b/keras_hub/src/models/qwen_moe/qwen_moe_attention.py index 30c4466de0..2e8a01d4c2 100644 --- a/keras_hub/src/models/qwen_moe/qwen_moe_attention.py +++ b/keras_hub/src/models/qwen_moe/qwen_moe_attention.py @@ -247,7 +247,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None): """ if attention_mask is not None: return self._softmax( - attention_scores, attention_mask[:, None, :, :] + attention_scores, ops.expand_dims(attention_mask, axis=1) ) return self._softmax(attention_scores) diff --git a/keras_hub/src/models/sam3/sam3_pc_image_segmenter_test.py b/keras_hub/src/models/sam3/sam3_pc_image_segmenter_test.py index 4bfa3041f8..7a5b8ee519 100644 --- a/keras_hub/src/models/sam3/sam3_pc_image_segmenter_test.py +++ b/keras_hub/src/models/sam3/sam3_pc_image_segmenter_test.py @@ -168,6 +168,14 @@ def test_all_presets(self): }, ) + @pytest.mark.xfail( + strict=False, + reason=( + "Upstream litert-torch limitation: SAM3 uses torchvision::nms " + "which is not registered in the torch.export op set and cannot " + "be lowered by litert-torch." + ), + ) def test_litert_export(self): self.run_litert_export_test( cls=SAM3PromptableConceptImageSegmenter, diff --git a/keras_hub/src/models/siglip/siglip_layers.py b/keras_hub/src/models/siglip/siglip_layers.py index 4aabde2ca4..9cc4537a56 100644 --- a/keras_hub/src/models/siglip/siglip_layers.py +++ b/keras_hub/src/models/siglip/siglip_layers.py @@ -463,7 +463,10 @@ def build(self, inputs_shape): def call(self, inputs, training=None): batch_size = ops.shape(inputs)[0] - probes = ops.repeat(self.probe, repeats=batch_size, axis=0) + # Use expand_dims + broadcast_to instead of ops.repeat to avoid + # SymInt issues during torch.export (repeat_interleave produces + # unbacked symbolic dimensions). + probes = ops.broadcast_to(self.probe, (batch_size, 1, self.hidden_dim)) hidden_states = self.attention( probes, inputs, inputs, training=training ) diff --git a/keras_hub/src/models/vae/vae_backbone_test.py b/keras_hub/src/models/vae/vae_backbone_test.py index cdb2d7b894..6b7d1fcfd1 100644 --- a/keras_hub/src/models/vae/vae_backbone_test.py +++ b/keras_hub/src/models/vae/vae_backbone_test.py @@ -34,6 +34,14 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.xfail( + strict=False, + reason=( + "Upstream litert-torch limitation: VAE uses pow ops which fail " + "TFLite legalization ('failed to legalize operation tfl.pow'). " + "Will pass once TFLite built-ins cover tfl.pow." + ), + ) def test_litert_export(self): self.run_litert_export_test( cls=VAEBackbone, diff --git a/keras_hub/src/models/vit/vit_image_classifier_test.py b/keras_hub/src/models/vit/vit_image_classifier_test.py index 2bd6a089ef..1acad8c320 100644 --- a/keras_hub/src/models/vit/vit_image_classifier_test.py +++ b/keras_hub/src/models/vit/vit_image_classifier_test.py @@ -12,7 +12,7 @@ class ViTImageClassifierTest(TestCase): def setUp(self): - self.images = np.ones((2, 28, 28, 3)) + self.images = np.ones((2, 28, 28, 3), dtype="float32") self.labels = [0, 1] self.backbone = ViTBackbone( image_shape=(28, 28, 3), @@ -61,4 +61,8 @@ def test_litert_export(self): cls=ViTImageClassifier, init_kwargs=self.init_kwargs, input_data=self.images, + # Small numeric drift can exceed strict 1e-6 atol after + # quantization-style fp32 pipeline; use statistical mode. + comparison_mode="statistical", + output_thresholds={"*": {"max": 1e-5, "mean": 1e-6}}, ) diff --git a/keras_hub/src/models/vit_det/vit_det_backbone_test.py b/keras_hub/src/models/vit_det/vit_det_backbone_test.py index ed5c9a3efc..751caed5ee 100644 --- a/keras_hub/src/models/vit_det/vit_det_backbone_test.py +++ b/keras_hub/src/models/vit_det/vit_det_backbone_test.py @@ -43,4 +43,6 @@ def test_litert_export(self): cls=ViTDetBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, + comparison_mode="statistical", + output_thresholds={"*": {"max": 1e-3, "mean": 1e-4}}, ) diff --git a/keras_hub/src/models/whisper/whisper_backbone_test.py b/keras_hub/src/models/whisper/whisper_backbone_test.py index b869dfd970..34b7d41385 100644 --- a/keras_hub/src/models/whisper/whisper_backbone_test.py +++ b/keras_hub/src/models/whisper/whisper_backbone_test.py @@ -65,6 +65,8 @@ def test_litert_export(self): cls=WhisperBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, + comparison_mode="statistical", + output_thresholds={"*": {"max": 1e-4, "mean": 1e-5}}, ) @pytest.mark.extra_large diff --git a/keras_hub/src/models/xception/xception_image_classifier_test.py b/keras_hub/src/models/xception/xception_image_classifier_test.py index 1ed8113073..03203eedbb 100644 --- a/keras_hub/src/models/xception/xception_image_classifier_test.py +++ b/keras_hub/src/models/xception/xception_image_classifier_test.py @@ -16,7 +16,7 @@ class XceptionImageClassifierTest(TestCase): def setUp(self): - self.images = np.ones((2, 299, 299, 3)) + self.images = np.ones((2, 299, 299, 3), dtype="float32") self.labels = [0, 1] self.backbone = XceptionBackbone( stackwise_conv_filters=[[32, 64], [128, 128], [256, 256]], diff --git a/keras_hub/src/tests/test_case.py b/keras_hub/src/tests/test_case.py index a7ce5acfb1..492cd24ef0 100644 --- a/keras_hub/src/tests/test_case.py +++ b/keras_hub/src/tests/test_case.py @@ -434,6 +434,70 @@ def run_model_saving_test( restored_output = restored_model(input_data) self.assertAllClose(model_output, restored_output, atol=atol, rtol=rtol) + @staticmethod + def _build_input_signature(input_data, is_torch_backend=False): + """Build a concrete ``input_signature`` from actual data. + + Returns a structure compatible with + ``keras.Model.export(input_signature=...)``: a single-element + list wrapping the mapped input structure, where each leaf has + fully concrete shapes (no ``None`` dims). Concrete shapes allow + the TFLite converter to fully optimize operations statically, + avoiding dynamic shape ops that require the Flex delegate + (e.g. FlexStridedSlice). + + For the TF backend, ``tf.TensorSpec`` objects with proper names + are used so that ``ExportArchive.add_endpoint`` preserves the + dict key names in the SavedModel SignatureDef. + For the torch backend, ``keras.InputSpec`` objects are used as + required by ``torch.export``. + """ + + def _to_numpy(x): + if hasattr(x, "detach"): + return x.detach().cpu().numpy() + elif hasattr(x, "numpy") and not isinstance(x, np.ndarray): + return x.numpy() + return x + + if is_torch_backend: + + def _to_spec(x): + x = _to_numpy(x) + # Normalize dtypes: TFLite/torch export doesn't support + # float64 or int64. Always work with np.dtype instances + # (not type objects like np.float32) so that .name works. + dtype = np.dtype(x.dtype) + if dtype == np.dtype("float64"): + dtype = np.dtype("float32") + elif dtype == np.dtype("int64"): + dtype = np.dtype("int32") + return keras.InputSpec(shape=x.shape, dtype=dtype.name) + + return [tree.map_structure(_to_spec, input_data)] + else: + # For TF backend: use tf.TensorSpec with names so that + # ExportArchive preserves dict key names in the SignatureDef. + def _to_tf_spec(x, name=None): + x = _to_numpy(x) + dtype = tf.as_dtype(x.dtype) + # TFLite doesn't support float64; match convert_for_tflite. + if dtype == tf.float64: + dtype = tf.float32 + # Normalize int64 to int32 for compatibility; test inputs + # are int32. + elif dtype == tf.int64: + dtype = tf.int32 + return tf.TensorSpec(shape=x.shape, dtype=dtype, name=name) + + if isinstance(input_data, dict): + spec_dict = { + k: _to_tf_spec(v, name=k) for k, v in input_data.items() + } + return [spec_dict] + else: + return [tree.map_structure(_to_tf_spec, input_data)] + def _verify_litert_outputs( self, keras_output, @@ -597,15 +661,27 @@ def run_litert_export_test( ) < packaging.version.Version("3.13.0"): self.skipTest("LiteRT export requires Keras >= 3.13") - self.skipTest( - "#TODO: [#2572] Re-enable LiteRT tests after a new tf release. " - "Can't test with tf 2.20 due to tf.lite module deprecation." - ) + is_torch_backend = keras.backend.backend() == "torch" + + if is_torch_backend: + try: + import litert_torch # noqa: F401 + except (ImportError, ModuleNotFoundError): + self.skipTest( + "litert-torch is required for LiteRT export " + "with the torch backend" + ) + else: + try: + from ai_edge_litert.interpreter import Interpreter # noqa: F401 + except (ImportError, ModuleNotFoundError): + self.skipTest( + "ai-edge-litert is required for LiteRT export " + "with the tensorflow backend" + ) # Extract comparison_mode from export_kwargs if provided comparison_mode = export_kwargs.pop("comparison_mode", "strict") - if keras.backend.backend() != "tensorflow": - self.skipTest("LiteRT export only supports TensorFlow backend") try: from ai_edge_litert.interpreter import Interpreter @@ -628,6 +704,17 @@ def run_litert_export_test( with tempfile.TemporaryDirectory() as temp_dir: export_path = os.path.join(temp_dir, "model.tflite") + # Build a concrete input_signature from the actual + # input_data shape (not reduced to batch=1) so the traced + # shapes match what the test provides. This is important + # for both torch and TF backends to avoid dynamic shape + # operations that require Flex delegates. + if "input_signature" not in export_kwargs: + input_sig = self._build_input_signature( + input_data, is_torch_backend=is_torch_backend + ) + export_kwargs.setdefault("input_signature", input_sig) + # Step 1: Export model and get Keras output model.export(export_path, format="litert", **export_kwargs) self.assertTrue(os.path.exists(export_path)) @@ -661,17 +748,28 @@ def run_litert_export_test( # Verify input signature if isinstance(input_data, dict): - expected_inputs = set(input_data.keys()) - actual_inputs = set(sig_inputs) - # Check that all expected inputs are in the signature - # (allow signature to have additional optional inputs) - missing_inputs = expected_inputs - actual_inputs - if missing_inputs: - self.fail( - f"Missing inputs in SignatureDef: " - f"{sorted(missing_inputs)}. " - f"Expected: {sorted(expected_inputs)}, " - f"SignatureDef has: {sorted(actual_inputs)}" + if not is_torch_backend: + # TF path: signature names match Keras names + expected_inputs = set(input_data.keys()) + actual_inputs = set(sig_inputs) + missing_inputs = expected_inputs - actual_inputs + if missing_inputs: + self.fail( + f"Missing inputs in SignatureDef: " + f"{sorted(missing_inputs)}. " + f"Expected: {sorted(expected_inputs)}, " + f"SignatureDef has: " + f"{sorted(actual_inputs)}" + ) + else: + # Torch path: inputs are named args_0, args_1, … + # Just verify counts match + self.assertEqual( + len(input_data), + len(sig_inputs), + f"Input count mismatch: model has " + f"{len(input_data)} inputs but SignatureDef " + f"has {len(sig_inputs)}: {sig_inputs}", ) else: # For numpy arrays, just verify we have exactly one input @@ -683,8 +781,13 @@ def run_litert_export_test( f"{sig_inputs}" ) - # Verify output signature - if verify_numerics and isinstance(keras_output, dict): + # Verify output signature (skip for torch: names are + # output_0, output_1, not Keras names) + if ( + verify_numerics + and isinstance(keras_output, dict) + and not is_torch_backend + ): expected_outputs = set(keras_output.keys()) actual_outputs = set(sig_outputs) if expected_outputs != actual_outputs: @@ -702,32 +805,47 @@ def run_litert_export_test( # Convert input data dtypes to match TFLite expectations def convert_for_tflite(x): """Convert tensor/array to TFLite-compatible dtypes.""" - if hasattr(x, "dtype"): - if isinstance(x, np.ndarray): - if x.dtype == bool: - return x.astype(np.int32) - elif x.dtype == np.float64: - return x.astype(np.float32) - elif x.dtype == np.int64: - return x.astype(np.int32) - else: # TensorFlow tensor - if x.dtype == tf.bool: - return ops.cast(x, "int32").numpy() - elif x.dtype == tf.float64: - return ops.cast(x, "float32").numpy() - elif x.dtype == tf.int64: - return ops.cast(x, "int32").numpy() - else: - return x.numpy() if hasattr(x, "numpy") else x - elif hasattr(x, "numpy"): - return x.numpy() + # Handle torch tensors + if hasattr(x, "detach"): + x = x.detach().cpu().numpy() + elif hasattr(x, "numpy") and not isinstance(x, np.ndarray): + x = x.numpy() + if isinstance(x, np.ndarray): + if x.dtype == np.float64: + return x.astype(np.float32) + elif x.dtype == np.int64: + return x.astype(np.int32) return x if isinstance(input_data, dict): converted_input_data = tree.map_structure( convert_for_tflite, input_data ) - litert_output = runner(**converted_input_data) + if is_torch_backend: + # Torch path: map dict values to args_N + # by position (sorted dict key order). + # Also cast each value to the dtype the + # TFLite model actually expects (e.g. bool + # padding_mask may have been fed as int32). + expected_dtypes = { + d["name"]: d["dtype"] + for d in interpreter.get_input_details() + } + sig_input_names = sorted(sig_inputs) + input_keys = list(input_data.keys()) + runner_kwargs = {} + for i, key in enumerate(input_keys): + sig_name = sig_input_names[i] + val = converted_input_data[key] + for dname, dt in expected_dtypes.items(): + if sig_name in dname: + if val.dtype != dt: + val = val.astype(dt) + break + runner_kwargs[sig_name] = val + litert_output = runner(**runner_kwargs) + else: + litert_output = runner(**converted_input_data) else: # For single tensor inputs, get the input name sig_inputs = serving_sig.get("inputs", []) diff --git a/requirements.txt b/requirements.txt index e499522558..61a42ea10a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ # Tensorflow. -tensorflow-cpu~=2.19.0;sys_platform != 'darwin' -tensorflow~=2.19.0;sys_platform == 'darwin' -tensorflow-text~=2.19;platform_system != 'Windows' +tensorflow-cpu~=2.20.0;sys_platform != 'darwin' +tensorflow~=2.20.0;sys_platform == 'darwin' +tensorflow-text>=2.20.0;platform_system != 'Windows' # Torch. --extra-index-url https://download.pytorch.org/whl/cpu