Skip to content

[MJX] jax_enable_x64 not work in mujoco/mjx/viewer.py with put_data #2565

Open
@rivergold

Description

@rivergold

Intro

Hi!

I am a student and use MuJoCo for my research on robot control. I want to use MJX and madrona to do RL.

My setup

MuJoco version: 3.3.0
OS: Ubuntu 22.04
GPU: NVIDIA GeForce RTX 3090
Driver: 565.77
Cuda version: 12.5

What's happening? What did you expect?

When debug nan issue occurs in MJX with jax.jit, jax.config.update("jax_enable_x64", True) is a way to show if it's a numerical precision issue.
But in mujoco/mjx/viewer.py, if set jax.config.update("jax_enable_x64", True) will cause errors like:

TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'd.contact.geom1' compiled with int32[496] and called with int64[496]
Argument 'd.contact.geom2' compiled with int32[496] and called with int64[496]
Argument 'd.contact.geom' compiled with int32[496,2] and called with int64[496,2]

This error is caused by put_data, now mujoco mjx only support make_data with jax.config.update("jax_enable_x64", True) in io.py. So we need to change mujoco/mjx/viewer.py with make_data.

And mujoco=3.3.0 not support tendon to int64 (just try robotiq_2f85). This bug has fixed in latest main branch with this issue and commit 4e0a4f4), very good job.

Steps for reproduction

  1. Edit mujoco/mjx/viewer.py and add jax.config.update("jax_enable_x64", True)
  2. python viewer.py --mjcf xxx.xml

Minimal model for reproduction

report_issue.zip

Code required for reproduction

Confirmations

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions