Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
af584b4
Update gemma3_causal_lm_preprocessor.py
pctablet505 Apr 17, 2025
dc4ae8c
Update gemma3_causal_lm_preprocessor.py
pctablet505 Apr 17, 2025
07c5c77
Update gemma3_causal_lm_preprocessor_test.py
pctablet505 Apr 17, 2025
3fdc7fd
Update reversible_embedding.py
pctablet505 Jun 10, 2025
fa57e33
Merge branch 'master' of https://github.com/pctablet505/keras-hub
pctablet505 Jun 10, 2025
8da3303
upadated Gemma3InterleaveEmbeddings
pctablet505 Jun 19, 2025
adac2c6
Update gemma3_interleave_embeddings.py
pctablet505 Jun 19, 2025
bd27ec0
Revert "Update reversible_embedding.py"
pctablet505 Jun 19, 2025
f5163e8
Merge branch 'keras-team:master' into master
pctablet505 Jun 19, 2025
1904136
Update gemma3_interleave_embeddings.py
pctablet505 Jun 19, 2025
552fecb
Merge branch 'keras-team:master' into master
pctablet505 Jul 7, 2025
3aa11e9
Merge branch 'master' of https://github.com/pctablet505/keras-hub
pctablet505 Nov 17, 2025
63d529a
Merge branch 'keras-team:master' into master
pctablet505 Nov 27, 2025
fcada92
Merge branch 'keras-team:master' into master
pctablet505 Dec 8, 2025
2cfe17b
Merge branch 'keras-team:master' into master
pctablet505 Dec 19, 2025
f3f85cb
Merge branch 'keras-team:master' into master
pctablet505 Dec 23, 2025
ddf14d5
Merge branch 'keras-team:master' into master
pctablet505 Dec 23, 2025
2e176e7
Merge branch 'keras-team:master' into master
pctablet505 Jan 6, 2026
a69c99c
Ensure int32 type for indices in NMS layer
pctablet505 Jan 6, 2026
2fd457c
Merge branch 'master' of https://github.com/pctablet505/keras-hub
pctablet505 Jan 6, 2026
d39d485
Update mask assertion in embedding layer test
pctablet505 Jan 6, 2026
527c427
Revert "Update mask assertion in embedding layer test"
pctablet505 Jan 7, 2026
3eaa5f4
Merge branch 'keras-team:master' into master
pctablet505 Jan 8, 2026
9b03ed9
Merge branch 'keras-team:master' into master
pctablet505 Jan 23, 2026
2e84113
Merge branch 'keras-team:master' into master
pctablet505 Jan 27, 2026
ae39725
Merge branch 'keras-team:master' into master
pctablet505 Jan 29, 2026
dbdf2d2
Merge branch 'keras-team:master' into master
pctablet505 Feb 3, 2026
bff2ac9
Merge branch 'keras-team:master' into master
pctablet505 Feb 9, 2026
2b6b844
Merge branch 'keras-team:master' into master
pctablet505 Feb 16, 2026
c5c4c18
Add PyTorch backend support for LiteRT export tests
pctablet505 Feb 16, 2026
b6e5c7a
Reflow lines for consistent wrapping
pctablet505 Feb 16, 2026
d5424cf
Normalize attention masks and LiteRT test fixes
pctablet505 Feb 23, 2026
ca7a4dd
Fix LiteRT export bugs and update ai-edge-torch references
pctablet505 Feb 23, 2026
e7cdfdb
Fix Mermaid diagram rendering: replace emoji and arrows with text alt…
pctablet505 Feb 23, 2026
cf7b0cc
Fix diagrams, improve explanations: root cause analysis, test infrast…
pctablet505 Feb 23, 2026
f92653d
Update test_case.py
pctablet505 Feb 23, 2026
3e472d9
Delete PR_DESCRIPTION.md
pctablet505 Feb 23, 2026
a7f4ff5
deleted files
pctablet505 Feb 23, 2026
6db26fe
Update d_fine_loss.py
pctablet505 Feb 23, 2026
8cbc499
Use keras.ops.expand_dims and wrap test comment
pctablet505 Feb 23, 2026
7f48d52
Update requirements.txt
pctablet505 Feb 23, 2026
6f5cafb
Merge branch 'keras-team:master' into torch-backend-litert-support
pctablet505 Mar 11, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions keras_hub/src/models/d_fine/d_fine_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,12 +619,14 @@ def compute_ddf_loss_fn():
mask_flat = keras.ops.reshape(mask_expanded, (-1,))
loss_match_local1 = keras.ops.cond(
keras.ops.any(mask_flat),
lambda: keras.ops.sum(
loss_match_local
* keras.ops.cast(mask_flat, loss_match_local.dtype)
)
/ keras.ops.sum(
keras.ops.cast(mask_flat, loss_match_local.dtype)
lambda: (
keras.ops.sum(
loss_match_local
* keras.ops.cast(mask_flat, loss_match_local.dtype)
)
/ keras.ops.sum(
keras.ops.cast(mask_flat, loss_match_local.dtype)
)
),
lambda: keras.ops.convert_to_tensor(
0.0, dtype=loss_match_local.dtype
Expand All @@ -633,12 +635,14 @@ def compute_ddf_loss_fn():
neg_mask_flat = keras.ops.logical_not(mask_flat)
loss_match_local2 = keras.ops.cond(
keras.ops.any(neg_mask_flat),
lambda: keras.ops.sum(
loss_match_local
* keras.ops.cast(neg_mask_flat, loss_match_local.dtype)
)
/ keras.ops.sum(
keras.ops.cast(neg_mask_flat, loss_match_local.dtype)
lambda: (
keras.ops.sum(
loss_match_local
* keras.ops.cast(neg_mask_flat, loss_match_local.dtype)
)
/ keras.ops.sum(
keras.ops.cast(neg_mask_flat, loss_match_local.dtype)
)
),
lambda: keras.ops.convert_to_tensor(
0.0, dtype=loss_match_local.dtype
Expand Down
9 changes: 9 additions & 0 deletions keras_hub/src/models/d_fine/d_fine_object_detector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ def test_saved_model(self):
input_data=self.images,
)

@pytest.mark.xfail(
strict=False,
reason=(
"Upstream torch.export limitation: D-FINE's multi-scale feature "
"computation triggers a data-dependent shape guard "
"(Ne(Mod(u2, 16), 0)), preventing successful torch.export. "
"Will pass once torch.export supports this pattern."
),
)
def test_litert_export(self):
backbone = DFineBackbone(**self.base_backbone_kwargs)
init_kwargs = {
Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/deit/deit_image_classifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class DeiTImageClassifierTest(TestCase):
def setUp(self):
self.images = np.ones((2, 28, 28, 3))
self.images = np.ones((2, 28, 28, 3), dtype="float32")
self.labels = [0, 1]
self.backbone = DeiTBackbone(
image_shape=(28, 28, 3),
Expand Down
9 changes: 9 additions & 0 deletions keras_hub/src/models/f_net/f_net_text_classifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ def test_saved_model(self):
input_data=self.input_data,
)

@pytest.mark.xfail(
strict=False,
reason=(
"Upstream litert-torch limitation: FNet uses ops.fft2 which "
"produces aten.complex tensors. litert-torch has no lowering for "
"aten.complex.default. Will pass once complex tensor ops are "
"supported."
),
)
def test_litert_export(self):
# F-Net does NOT use padding_mask - it only uses token_ids and
# segment_ids. Don't add padding_mask to input_data.
Expand Down
8 changes: 8 additions & 0 deletions keras_hub/src/models/flux/flux_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ def test_saved_model(self):
input_data=self.input_data,
)

@pytest.mark.xfail(
strict=False,
reason=(
"Upstream torch.export limitation: Flux's attention reshape uses "
"a dynamic num_heads value, causing GuardOnDataDependentSymNode. "
"Will pass once torch.export supports data-dependent shapes here."
),
)
def test_litert_export(self):
self.run_litert_export_test(
cls=FluxBackbone,
Expand Down
11 changes: 7 additions & 4 deletions keras_hub/src/models/gemma/gemma_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ def _compute_attention(
)

if attention_mask is not None:
attention_mask = attention_mask[:, None, None, :, :]
# We add two dimensions at axis 1 and 2 to make it [B, 1, 1, S, S]
attention_mask = ops.expand_dims(attention_mask, axis=1)
attention_mask = ops.expand_dims(attention_mask, axis=1)
orig_dtype = attention_logits.dtype
attention_softmax = self.softmax(attention_logits, mask=attention_mask)
attention_softmax = ops.cast(attention_softmax, orig_dtype)
Expand Down Expand Up @@ -262,9 +264,10 @@ def call(
)

# Wipe attn vec if there are no attended tokens.
no_attended_tokens = ops.all(
ops.equal(attention_mask, 0), axis=-1, keepdims=True
)[..., None]
no_attended_tokens = ops.expand_dims(
ops.all(ops.equal(attention_mask, 0), axis=-1, keepdims=True),
axis=-1,
)
attention_vec = ops.where(
no_attended_tokens, ops.zeros_like(attention_vec), attention_vec
)
Expand Down
11 changes: 7 additions & 4 deletions keras_hub/src/models/gemma3/gemma3_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ def _compute_attention(
)

if attention_mask is not None:
attention_mask = attention_mask[:, None, None, :, :]
# We add two dimensions at axis 1 and 2 to make it [B, 1, 1, S, S]
attention_mask = ops.expand_dims(attention_mask, axis=1)
attention_mask = ops.expand_dims(attention_mask, axis=1)
orig_dtype = attention_logits.dtype
attention_softmax = self.softmax(attention_logits, mask=attention_mask)
attention_softmax = ops.cast(attention_softmax, orig_dtype)
Expand Down Expand Up @@ -399,9 +401,10 @@ def call(
)

# Wipe attn vec if there are no attended tokens.
no_attended_tokens = ops.all(
ops.equal(attention_mask, 0), axis=-1, keepdims=True
)[..., None]
no_attended_tokens = ops.expand_dims(
ops.all(ops.equal(attention_mask, 0), axis=-1, keepdims=True),
axis=-1,
)
attention_vec = ops.where(
no_attended_tokens, ops.zeros_like(attention_vec), attention_vec
)
Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/gpt_oss/gpt_oss_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def _compute_attention(
else:
adder = ops.cast(-1e4, self.compute_dtype)
attention_scores = ops.where(
attention_mask[:, None, :, :], attention_scores, adder
ops.expand_dims(attention_mask, axis=1), attention_scores, adder
)

# Handle sink tokens by concatenating them to the logits.
Expand Down
8 changes: 8 additions & 0 deletions keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ def test_saved_model(self):
input_data=self.input_data,
)

@pytest.mark.xfail(
strict=False,
reason=(
"Upstream litert-torch limitation: the NHWC layout rewriter does "
"not support aten.amax, causing 'NHWC node rewriter not found: "
"amax'. Will pass once litert-torch adds amax support."
),
)
def test_litert_export(self):
self.run_litert_export_test(
cls=GptOssCausalLM,
Expand Down
9 changes: 8 additions & 1 deletion keras_hub/src/models/llama/llama_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,15 @@ def _compute_key_value(x):

def _masked_softmax(self, attention_scores, attention_mask=None):
if attention_mask is not None:
# Use ops.expand_dims instead of Python None indexing
# (attention_mask[:, None, :, :]). Python None indexing traces
# as tf.StridedSlice(new_axis_mask) in the TF graph, which falls
# to the Flex delegate and is not supported by standalone
# ai_edge_litert (TF 2.20+). ops.expand_dims traces as the
# native TFLite ExpandDims op instead.
return self._softmax(
attention_scores, attention_mask[:, None, :, :]
attention_scores,
ops.expand_dims(attention_mask, axis=1),
)
return self._softmax(attention_scores)

Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/mistral/mistral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _compute_key_value(x):
def _masked_softmax(self, attention_scores, attention_mask=None):
if attention_mask is not None:
return self._softmax(
attention_scores, attention_mask[:, None, :, :]
attention_scores, ops.expand_dims(attention_mask, axis=1)
)
return self._softmax(attention_scores)

Expand Down
4 changes: 3 additions & 1 deletion keras_hub/src/models/mixtral/mixtral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ def _compute_key_value(x):

def _masked_softmax(self, attention_scores, attention_mask=None):
if attention_mask is not None:
return self.softmax(attention_scores, attention_mask[:, None, :, :])
return self.softmax(
attention_scores, ops.expand_dims(attention_mask, axis=1)
)
return self.softmax(attention_scores)

def _use_fused_attention_op(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,10 @@ def call(
if final_mask is not None:
mask_shape = keras.ops.shape(final_mask)
if len(mask_shape) == 2:
final_mask = final_mask[:, None, None, :]
final_mask = keras.ops.expand_dims(final_mask, axis=1)
final_mask = keras.ops.expand_dims(final_mask, axis=1)
elif len(mask_shape) == 3:
final_mask = final_mask[:, None, :, :]
final_mask = keras.ops.expand_dims(final_mask, axis=1)

attention_kwargs = {
k: v for k, v in kwargs.items() if k != "padding_mask"
Expand Down
4 changes: 3 additions & 1 deletion keras_hub/src/models/phi3/phi3_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ def call(

def _masked_softmax(self, attention_scores, attention_mask=None):
if attention_mask is not None:
return self.softmax(attention_scores, attention_mask[:, None, :, :])
return self.softmax(
attention_scores, ops.expand_dims(attention_mask, axis=1)
)
return self.softmax(attention_scores)

def _compute_attention(self, query, key, value, attention_mask=None):
Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/qwen/qwen_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
"""
if attention_mask is not None:
return self._softmax(
attention_scores, attention_mask[:, None, :, :]
attention_scores, ops.expand_dims(attention_mask, axis=1)
)
return self._softmax(attention_scores)

Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/qwen3/qwen3_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
"""
if attention_mask is not None:
return self._softmax(
attention_scores, attention_mask[:, None, :, :]
attention_scores, ops.expand_dims(attention_mask, axis=1)
)
return self._softmax(attention_scores)

Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
"""
if attention_mask is not None:
return self._softmax(
attention_scores, attention_mask[:, None, :, :]
attention_scores, ops.expand_dims(attention_mask, axis=1)
)
return self._softmax(attention_scores)

Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/qwen_moe/qwen_moe_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
"""
if attention_mask is not None:
return self._softmax(
attention_scores, attention_mask[:, None, :, :]
attention_scores, ops.expand_dims(attention_mask, axis=1)
)
return self._softmax(attention_scores)

Expand Down
8 changes: 8 additions & 0 deletions keras_hub/src/models/sam3/sam3_pc_image_segmenter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,14 @@ def test_all_presets(self):
},
)

@pytest.mark.xfail(
strict=False,
reason=(
"Upstream litert-torch limitation: SAM3 uses torchvision::nms "
"which is not registered in the torch.export op set and cannot "
"be lowered by litert-torch."
),
)
def test_litert_export(self):
self.run_litert_export_test(
cls=SAM3PromptableConceptImageSegmenter,
Expand Down
5 changes: 4 additions & 1 deletion keras_hub/src/models/siglip/siglip_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,10 @@ def build(self, inputs_shape):

def call(self, inputs, training=None):
batch_size = ops.shape(inputs)[0]
probes = ops.repeat(self.probe, repeats=batch_size, axis=0)
# Use expand_dims + broadcast_to instead of ops.repeat to avoid
# SymInt issues during torch.export (repeat_interleave produces
# unbacked symbolic dimensions).
probes = ops.broadcast_to(self.probe, (batch_size, 1, self.hidden_dim))
hidden_states = self.attention(
probes, inputs, inputs, training=training
)
Expand Down
8 changes: 8 additions & 0 deletions keras_hub/src/models/vae/vae_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def test_saved_model(self):
input_data=self.input_data,
)

@pytest.mark.xfail(
strict=False,
reason=(
"Upstream litert-torch limitation: VAE uses pow ops which fail "
"TFLite legalization ('failed to legalize operation tfl.pow'). "
"Will pass once TFLite built-ins cover tfl.pow."
),
)
def test_litert_export(self):
self.run_litert_export_test(
cls=VAEBackbone,
Expand Down
6 changes: 5 additions & 1 deletion keras_hub/src/models/vit/vit_image_classifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class ViTImageClassifierTest(TestCase):
def setUp(self):
self.images = np.ones((2, 28, 28, 3))
self.images = np.ones((2, 28, 28, 3), dtype="float32")
self.labels = [0, 1]
self.backbone = ViTBackbone(
image_shape=(28, 28, 3),
Expand Down Expand Up @@ -61,4 +61,8 @@ def test_litert_export(self):
cls=ViTImageClassifier,
init_kwargs=self.init_kwargs,
input_data=self.images,
# Small numeric drift can exceed strict 1e-6 atol after
# quantization-style fp32 pipeline; use statistical mode.
comparison_mode="statistical",
output_thresholds={"*": {"max": 1e-5, "mean": 1e-6}},
)
2 changes: 2 additions & 0 deletions keras_hub/src/models/vit_det/vit_det_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,6 @@ def test_litert_export(self):
cls=ViTDetBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
comparison_mode="statistical",
output_thresholds={"*": {"max": 1e-3, "mean": 1e-4}},
Comment on lines +46 to +47
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The numerical tolerance for this test seems quite high ("max": 1e-3). This could potentially mask subtle numerical regressions in the future. Could you investigate if these thresholds can be tightened? If this level of tolerance is unavoidable, please add a comment explaining the source of the large numerical difference for future reference.

)
2 changes: 2 additions & 0 deletions keras_hub/src/models/whisper/whisper_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def test_litert_export(self):
cls=WhisperBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
comparison_mode="statistical",
output_thresholds={"*": {"max": 1e-4, "mean": 1e-5}},
)

@pytest.mark.extra_large
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class XceptionImageClassifierTest(TestCase):
def setUp(self):
self.images = np.ones((2, 299, 299, 3))
self.images = np.ones((2, 299, 299, 3), dtype="float32")
self.labels = [0, 1]
self.backbone = XceptionBackbone(
stackwise_conv_filters=[[32, 64], [128, 128], [256, 256]],
Expand Down
Loading
Loading