Skip to content

SentencePieceTokenizer errors with jax-metal #2135

Open
@t-kalinowski

Description

@t-kalinowski

Describe the bug

On an Arm Mac with a Jax backend, if jax-metal is installed, SentencePieceTokenizer will throw exceptions.

To Reproduce

Given a file bug.py

# /// script
# dependencies = [
#   "keras",
#   "keras-hub",
#   "jax"
# ]
# ///

import os
os.environ["KERAS_BACKEND"] = "jax"

import keras
import keras_hub

vocabulary_file = keras.utils.get_file(
    origin="https://huggingface.co/mattdangerw/sentencepiece-example/resolve/main/vocabulary.proto"
)

tokenizer = keras_hub.tokenizers.SentencePieceTokenizer(vocabulary_file)
tokens = tokenizer.tokenize("The quick brown fox.")
print(tokens)

Running uv run --python 3.11 bug.py succeeds without error.
Running uv run --python 3.11 --with jax-metal bug.py produces:

tomasz@tomaszkalinows-WQVX deep_learning_with_r_3e % uv run --python 3.11 --with jax-metal bug2.py
Installed 9 packages in 28ms
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
WARNING:2025-03-12 07:33:22,783:jax._src.xla_bridge:997: Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1741779202.784024   79609 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M4 Max

systemMemory: 128.00 GB
maxCacheSize: 48.00 GB

I0000 00:00:1741779202.802302   79609 service.cc:145] XLA service 0x600003880600 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1741779202.802324   79609 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1741779202.803547   79609 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1741779202.803557   79609 mps_client.cc:384] XLA backend will use up to 103078739968 bytes on device 0 for SimpleAllocator.
Traceback (most recent call last):
  File "/Users/tomasz/github/t-kalinowski/deep_learning_with_r_3e/bug2.py", line 22, in <module>
    tokens = tokenizer.tokenize("The quick brown fox.")
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras_hub/src/utils/tensor_utils.py", line 50, in wrapper
    return convert_preprocessing_outputs(x)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras_hub/src/utils/tensor_utils.py", line 191, in convert_preprocessing_outputs
    return keras.tree.map_structure(convert, x)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras/src/tree/tree_api.py", line 192, in map_structure
    return tree_impl.map_structure(func, *structures)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras/src/tree/optree_impl.py", line 108, in map_structure
    return optree.tree_map(
           ^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/optree/ops.py", line 766, in tree_map
    return treespec.unflatten(map(func, *flat_args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras_hub/src/utils/tensor_utils.py", line 189, in convert
    return ops.convert_to_tensor(x, dtype=dtype)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras/src/ops/core.py", line 958, in convert_to_tensor
    return backend.core.convert_to_tensor(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras/src/backend/jax/core.py", line 80, in convert_to_tensor
    return jnp.asarray(x, dtype=dtype)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 5732, in asarray
    return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 5566, in array
    out_array: Array = lax_internal._convert_element_type(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 1414, in _convert_element_type
    return convert_element_type_p.bind(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/core.py", line 502, in bind
    return self._true_bind(*args, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/core.py", line 520, in _true_bind
    return self.bind_with_trace(prev_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 4371, in _convert_element_type_bind_with_trace
    operand = core.Primitive.bind_with_trace(convert_element_type_p, trace, args, params)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/core.py", line 525, in bind_with_trace
    return trace.process_primitive(self, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/core.py", line 1024, in process_primitive
    return primitive.impl(*args, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/dispatch.py", line 90, in apply_primitive
    outs = fun(*args)
           ^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: default_memory_space is not supported.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
I0000 00:00:1741779203.113335   79609 mps_client.h:209] MetalClient destroyed.

Metadata

Metadata

Labels

type:BugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions