Skip to content

'jax.experimental.maps' import error #962

@MikeMpapa

Description

@MikeMpapa

Hi I am trying to use the levanter image but I get the following error: ModuleNotFoundError: No module named 'jax.experimental.maps'.

Was the model renamed? It worked fine yesterday

Thanks!

The complete error log:

  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
    levanter.config.main(main)()
  File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/levanter/src/levanter/main/train_lm.py", line 119, in main
    Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping)
  File "/opt/haliax/src/haliax/partitioning.py", line 597, in round_axis_for_partitioning
    size = physical_axis_size(axis, mapping)
  File "/opt/haliax/src/haliax/partitioning.py", line 566, in physical_axis_size
    mesh = _get_mesh()
  File "/opt/haliax/src/haliax/partitioning.py", line 606, in _get_mesh
    from jax.experimental.maps import thread_resources
ModuleNotFoundError: No module named 'jax.experimental.maps'
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
    levanter.config.main(main)()
  File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/levanter/src/levanter/main/train_lm.py", line 119, in main
    Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping)
  File "/opt/haliax/src/haliax/partitioning.py", line 597, in round_axis_for_partitioning
    size = physical_axis_size(axis, mapping)
  File "/opt/haliax/src/haliax/partitioning.py", line 566, in physical_axis_size
    mesh = _get_mesh()
  File "/opt/haliax/src/haliax/partitioning.py", line 606, in _get_mesh
    from jax.experimental.maps import thread_resources
ModuleNotFoundError: No module named 'jax.experimental.maps'```

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions