Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions learning/train_rsl_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@
f"{', '.join(mujoco_playground.registry.ALL_ENVS)}"
),
)
_IMPL = flags.DEFINE_enum("impl", "jax", ["jax", "warp"], "MJX implementation")
_NJMAX = flags.DEFINE_integer(
"njmax", None, "The maximum number of constraints per world."
)
_LOAD_RUN_NAME = flags.DEFINE_string(
"load_run_name", None, "Run name to load from (for checkpoint restoration)."
)
Expand Down Expand Up @@ -108,6 +112,9 @@ def main(argv):

# Load default config from registry
env_cfg = registry.get_default_config(_ENV_NAME.value)
env_cfg.impl = _IMPL.value
if _NJMAX.present:
env_cfg.njmax = _NJMAX.value
print(f"Environment config:\n{env_cfg}")

# Generate unique experiment name
Expand Down