-
Notifications
You must be signed in to change notification settings - Fork 66
Closed
Description
Hi I am trying to use the levanter image but I get the following error: ModuleNotFoundError: No module named 'jax.experimental.maps'.
Was the model renamed? It worked fine yesterday
Thanks!
The complete error log:
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
levanter.config.main(main)()
File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
response = fn(cfg, *args, **kwargs)
File "/levanter/src/levanter/main/train_lm.py", line 119, in main
Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping)
File "/opt/haliax/src/haliax/partitioning.py", line 597, in round_axis_for_partitioning
size = physical_axis_size(axis, mapping)
File "/opt/haliax/src/haliax/partitioning.py", line 566, in physical_axis_size
mesh = _get_mesh()
File "/opt/haliax/src/haliax/partitioning.py", line 606, in _get_mesh
from jax.experimental.maps import thread_resources
ModuleNotFoundError: No module named 'jax.experimental.maps'
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
levanter.config.main(main)()
File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
response = fn(cfg, *args, **kwargs)
File "/levanter/src/levanter/main/train_lm.py", line 119, in main
Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping)
File "/opt/haliax/src/haliax/partitioning.py", line 597, in round_axis_for_partitioning
size = physical_axis_size(axis, mapping)
File "/opt/haliax/src/haliax/partitioning.py", line 566, in physical_axis_size
mesh = _get_mesh()
File "/opt/haliax/src/haliax/partitioning.py", line 606, in _get_mesh
from jax.experimental.maps import thread_resources
ModuleNotFoundError: No module named 'jax.experimental.maps'```
Metadata
Metadata
Assignees
Labels
No labels