-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Description
Describe the bug
Using KerasHub’s CLIP preset clip_vit_base_patch32, the text_encoder produces different outputs between:
- a direct graph call wrapped in
@tf.function(jit_compile=False) predict_on_batch()
for the same model weights and same input token IDs.
This mismatch is reproducible on GPU (RTX 3090) with:
- TensorFlow 2.20.0
- Keras 3.12.0
- KerasHub 0.25.1
Observed max absolute difference is ~9.86e-03 with threshold 1e-3.
Setting TF_XLA_FLAGS=--tf_xla_auto_jit=0 does not change the mismatch.
On CPU-only (CUDA_VISIBLE_DEVICES="") the difference becomes 0.0.
This suggests predict_on_batch() and an explicit graph call are not numerically consistent on GPU for this model/path.
To Reproduce
import os
import numpy as np
def max_abs(a, b):
a = np.asarray(a, np.float32).reshape(-1)
b = np.asarray(b, np.float32).reshape(-1)
m = np.isfinite(a) & np.isfinite(b)
if not m.any():
return float("nan")
return float(np.max(np.abs(a[m] - b[m])))
def to_np(x):
if isinstance(x, dict):
for v in x.values():
y = to_np(v)
if y is not None:
return y
return None
if hasattr(x, "numpy"):
return x.numpy()
try:
return np.asarray(x)
except Exception:
return None
def main():
os.environ.setdefault("KERAS_BACKEND", "tensorflow")
import tensorflow as tf
import keras_hub
np.random.seed(2021)
try:
tf.random.set_seed(2021)
except Exception:
pass
# reduce VRAM grab
try:
for gpu in tf.config.list_physical_devices("GPU"):
try:
tf.config.experimental.set_memory_growth(gpu, True)
except Exception:
pass
except Exception:
pass
preset = os.environ.get("CLIP_PRESET", "clip_vit_base_patch32")
thresh = float(os.environ.get("GCFL_ABS_THRESH", "1e-3"))
seq_len = int(os.environ.get("SEQ_LEN", "77"))
clip = keras_hub.models.CLIPBackbone.from_preset(preset)
text_enc = clip.text_encoder
vocab = 49408
token_ids = np.random.randint(0, vocab, size=(1, seq_len), dtype=np.int32)
@tf.function(jit_compile=False)
def graph_call(ids):
return text_enc({"token_ids": ids}, training=False)
y_graph = to_np(graph_call(token_ids))
y_pred = to_np(text_enc.predict_on_batch({"token_ids": token_ids}))
if y_graph is None or y_pred is None:
raise RuntimeError("Could not materialize outputs")
d = max_abs(y_graph, y_pred)
print(f"max_abs_graph_pred={d:.6e} thresh={thresh}")
if d > thresh:
raise AssertionError(
f"predict_on_batch != tf.function: max_abs={d:.6e} > {thresh}"
)
if __name__ == "__main__":
main()
Triggering commands
GPU repro:
conda activate keras_venv
export CUDA_VISIBLE_DEVICES=0
export KERAS_BACKEND=tensorflow
export CLIP_PRESET=clip_vit_base_patch32
export GCFL_ABS_THRESH=1e-3
set -o pipefail
python repro_min_khub_clip_predict_mismatch.py 2>&1 | tee repro_gpu.log
echo "exit_code=$?"
GPU repro with XLA auto-jit disabled (still fails):
export TF_XLA_FLAGS=--tf_xla_auto_jit=0
set -o pipefail
python repro_min_khub_clip_predict_mismatch.py 2>&1 | tee repro_gpu_no_xla.log
echo "exit_code=$?"
CPU-only control (passes):
export CUDA_VISIBLE_DEVICES=""
unset TF_XLA_FLAGS
set -o pipefail
python repro_min_khub_clip_predict_mismatch.py 2>&1 | tee repro_cpu.log
echo "exit_code=$?"
Actual behavior
GPU (CUDA_VISIBLE_DEVICES=0):
max_abs_graph_pred=9.863973e-03 thresh=0.001
AssertionError: predict_on_batch != tf.function: max_abs=9.863973e-03 > 0.001
exit_code=1
TF_XLA_FLAGS=--tf_xla_auto_jit=0:
max_abs_graph_pred=9.863973e-03 thresh=0.001
AssertionError: predict_on_batch != tf.function: max_abs=9.863973e-03 > 0.001
CPU-only (CUDA_VISIBLE_DEVICES=""):
max_abs_graph_pred=0.000000e+00 thresh=0.001
exit_code=0
Expected behavior
predict_on_batch() should match an explicit graph call (tf.function) for the same model/input within a small tolerance. A consistent max abs error around 1e-2 on GPU is unexpected.
Traceback (full)
Traceback (most recent call last):
File "/home/talha/dl_testing/repro_min_khub_clip_predict_mismatch.py", line 77, in <module>
main()
File "/home/talha/dl_testing/repro_min_khub_clip_predict_mismatch.py", line 74, in main
raise AssertionError(f"predict_on_batch != tf.function: max_abs={d:.6e} > {thresh}")
AssertionError: predict_on_batch != tf.function: max_abs=9.863973e-03 > 0.001
Environment
- Python: 3.10.19
- TensorFlow: 2.20.0
- Keras: 3.12.0
- KerasHub: 0.25.1
- tensorflow-text: 2.20.0 (installed as dependency of keras-hub)
- GPU: NVIDIA GeForce RTX 3090
- Driver: 550.78
- CUDA Version: 12.4
- KERAS_BACKEND: tensorflow
- OS: Linux (server)
- NVIDIA-SMI 550.78, CUDA 12.4
- 4× RTX 3090 present
Extra note:
This repro uses a cached preset download via kagglehub under:
~/.cache/kagglehub/models/keras/clip/keras/clip_vit_base_patch32/...