Skip to content

CLIPBackbone text_encoder: predict_on_batch() output differs from tf.function on GPU (TF backend) #22380

@griffinstalha

Description

@griffinstalha

Describe the bug

Using KerasHub’s CLIP preset clip_vit_base_patch32, the text_encoder produces different outputs between:

  • a direct graph call wrapped in@tf.function(jit_compile=False)
  • predict_on_batch()

for the same model weights and same input token IDs.
This mismatch is reproducible on GPU (RTX 3090) with:

  • TensorFlow 2.20.0
  • Keras 3.12.0
  • KerasHub 0.25.1

Observed max absolute difference is ~9.86e-03 with threshold 1e-3.
Setting TF_XLA_FLAGS=--tf_xla_auto_jit=0 does not change the mismatch.
On CPU-only (CUDA_VISIBLE_DEVICES="") the difference becomes 0.0.

This suggests predict_on_batch() and an explicit graph call are not numerically consistent on GPU for this model/path.

To Reproduce

import os
import numpy as np

def max_abs(a, b):
    a = np.asarray(a, np.float32).reshape(-1)
    b = np.asarray(b, np.float32).reshape(-1)
    m = np.isfinite(a) & np.isfinite(b)
    if not m.any():
        return float("nan")
    return float(np.max(np.abs(a[m] - b[m])))

def to_np(x):
    if isinstance(x, dict):
        for v in x.values():
            y = to_np(v)
            if y is not None:
                return y
        return None
    if hasattr(x, "numpy"):
        return x.numpy()
    try:
        return np.asarray(x)
    except Exception:
        return None

def main():
    os.environ.setdefault("KERAS_BACKEND", "tensorflow")

    import tensorflow as tf
    import keras_hub

    np.random.seed(2021)
    try:
        tf.random.set_seed(2021)
    except Exception:
        pass

    # reduce VRAM grab
    try:
        for gpu in tf.config.list_physical_devices("GPU"):
            try:
                tf.config.experimental.set_memory_growth(gpu, True)
            except Exception:
                pass
    except Exception:
        pass

    preset = os.environ.get("CLIP_PRESET", "clip_vit_base_patch32")
    thresh = float(os.environ.get("GCFL_ABS_THRESH", "1e-3"))
    seq_len = int(os.environ.get("SEQ_LEN", "77"))

    clip = keras_hub.models.CLIPBackbone.from_preset(preset)
    text_enc = clip.text_encoder

    vocab = 49408
    token_ids = np.random.randint(0, vocab, size=(1, seq_len), dtype=np.int32)

    @tf.function(jit_compile=False)
    def graph_call(ids):
        return text_enc({"token_ids": ids}, training=False)

    y_graph = to_np(graph_call(token_ids))
    y_pred = to_np(text_enc.predict_on_batch({"token_ids": token_ids}))

    if y_graph is None or y_pred is None:
        raise RuntimeError("Could not materialize outputs")

    d = max_abs(y_graph, y_pred)
    print(f"max_abs_graph_pred={d:.6e} thresh={thresh}")

    if d > thresh:
        raise AssertionError(
            f"predict_on_batch != tf.function: max_abs={d:.6e} > {thresh}"
        )

if __name__ == "__main__":
    main()

Triggering commands

GPU repro:


conda activate keras_venv
export CUDA_VISIBLE_DEVICES=0
export KERAS_BACKEND=tensorflow
export CLIP_PRESET=clip_vit_base_patch32
export GCFL_ABS_THRESH=1e-3

set -o pipefail
python repro_min_khub_clip_predict_mismatch.py 2>&1 | tee repro_gpu.log
echo "exit_code=$?"

GPU repro with XLA auto-jit disabled (still fails):

export TF_XLA_FLAGS=--tf_xla_auto_jit=0

set -o pipefail
python repro_min_khub_clip_predict_mismatch.py 2>&1 | tee repro_gpu_no_xla.log
echo "exit_code=$?"

CPU-only control (passes):

export CUDA_VISIBLE_DEVICES=""
unset TF_XLA_FLAGS

set -o pipefail
python repro_min_khub_clip_predict_mismatch.py 2>&1 | tee repro_cpu.log
echo "exit_code=$?"

Actual behavior

GPU (CUDA_VISIBLE_DEVICES=0):

max_abs_graph_pred=9.863973e-03 thresh=0.001
AssertionError: predict_on_batch != tf.function: max_abs=9.863973e-03 > 0.001
exit_code=1

TF_XLA_FLAGS=--tf_xla_auto_jit=0:

max_abs_graph_pred=9.863973e-03 thresh=0.001
AssertionError: predict_on_batch != tf.function: max_abs=9.863973e-03 > 0.001

CPU-only (CUDA_VISIBLE_DEVICES=""):

max_abs_graph_pred=0.000000e+00 thresh=0.001
exit_code=0

Expected behavior

predict_on_batch() should match an explicit graph call (tf.function) for the same model/input within a small tolerance. A consistent max abs error around 1e-2 on GPU is unexpected.

Traceback (full)

Traceback (most recent call last):
  File "/home/talha/dl_testing/repro_min_khub_clip_predict_mismatch.py", line 77, in <module>
    main()
  File "/home/talha/dl_testing/repro_min_khub_clip_predict_mismatch.py", line 74, in main
    raise AssertionError(f"predict_on_batch != tf.function: max_abs={d:.6e} > {thresh}")
AssertionError: predict_on_batch != tf.function: max_abs=9.863973e-03 > 0.001

Environment

- Python: 3.10.19
- TensorFlow: 2.20.0
- Keras: 3.12.0
- KerasHub: 0.25.1
- tensorflow-text: 2.20.0 (installed as dependency of keras-hub)
- GPU: NVIDIA GeForce RTX 3090
- Driver: 550.78
- CUDA Version: 12.4
- KERAS_BACKEND: tensorflow
- OS: Linux (server)

- NVIDIA-SMI 550.78, CUDA 12.4
- 4× RTX 3090 present

Extra note:
This repro uses a cached preset download via kagglehub under:
~/.cache/kagglehub/models/keras/clip/keras/clip_vit_base_patch32/...

repro_gpu.log
repro_gpu_no_xla.log

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions