Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 51 additions & 12 deletions metaworld/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def make_mt_envs(

def _make_ml_envs_inner(
benchmark: Benchmark,
meta_batch_size: int,
meta_batch_size: int | None = None,
seed: int | None = None,
total_tasks_per_cls: int | None = None,
split: Literal["train", "test"] = "train",
Expand All @@ -527,10 +527,20 @@ def _make_ml_envs_inner(
benchmark.train_classes if split == "train" else benchmark.test_classes
)
all_tasks = benchmark.train_tasks if split == "train" else benchmark.test_tasks
assert (
meta_batch_size % len(all_classes) == 0
), "meta_batch_size must be divisible by envs_per_task"
tasks_per_env = meta_batch_size // len(all_classes)

num_classes = len(all_classes)
if meta_batch_size is None:
meta_batch_size = num_classes

assert meta_batch_size >= num_classes, (
f"meta_batch_size ({meta_batch_size}) must be >= the number of environment "
f"classes ({num_classes}). Each class needs at least one sub-environment."
)
assert meta_batch_size % num_classes == 0, (
f"meta_batch_size ({meta_batch_size}) must be divisible by the number of "
f"environment classes ({num_classes})."
)
tasks_per_env = meta_batch_size // num_classes

env_tuples = []
for env_name, env_cls in all_classes.items():
Expand Down Expand Up @@ -565,7 +575,7 @@ def _make_ml_envs_inner(
def make_ml_envs(
name: str,
seed: int | None = None,
meta_batch_size: int = 20,
meta_batch_size: int | None = None,
total_tasks_per_cls: int | None = None,
split: Literal["train", "test"] = "train",
vector_strategy: Literal["sync", "async"] = "sync",
Expand Down Expand Up @@ -636,7 +646,7 @@ def _ml_bench_vector_entry_point(
| str = gym.vector.AutoresetMode.SAME_STEP,
total_tasks_per_cls: int | None = None,
seed: int | None = None,
meta_batch_size: int = 20,
meta_batch_size: int | None = None,
num_envs=None,
**lamb_kwargs,
):
Expand Down Expand Up @@ -666,12 +676,41 @@ def _ml_bench_vector_entry_point(
kwargs={},
)

def _ml1_entry_point(
env_name: str,
split: Literal["train", "test"],
seed: int | None = None,
total_tasks_per_cls: int | None = None,
**lamb_kwargs,
):
benchmark = ML1(env_name, seed=seed)
all_tasks = benchmark.train_tasks if split == "train" else benchmark.test_tasks
tasks = [task for task in all_tasks if task.env_name == env_name]
if total_tasks_per_cls is not None:
tasks = tasks[:total_tasks_per_cls]
terminate_on_success = split == "test"
return _init_each_env(
env_cls=benchmark.train_classes[env_name],
tasks=tasks,
seed=seed,
terminate_on_success=terminate_on_success,
task_select="pseudorandom",
**lamb_kwargs,
)

for split in ["train", "test"]:
register(
id=f"Meta-World/ML1-{split}",
vector_entry_point=lambda env_name, vector_strategy="sync", autoreset_mode=gym.vector.AutoresetMode.SAME_STEP, total_tasks_per_cls=None, meta_batch_size=20, seed=None, num_envs=None, **kwargs: _ml_bench_vector_entry_point(
entry_point=lambda env_name, _split=split, seed=None, total_tasks_per_cls=None, num_envs=None, **kwargs: _ml1_entry_point(
env_name,
_split,
seed,
total_tasks_per_cls,
**kwargs,
),
vector_entry_point=lambda env_name, _split=split, vector_strategy="sync", autoreset_mode=gym.vector.AutoresetMode.SAME_STEP, total_tasks_per_cls=None, meta_batch_size=None, seed=None, num_envs=None, **kwargs: _ml_bench_vector_entry_point(
env_name,
split, # type: ignore[arg-type]
_split,
vector_strategy,
autoreset_mode,
total_tasks_per_cls,
Expand Down Expand Up @@ -724,7 +763,7 @@ def _ml_bench_vector_entry_point(
for split in ["train", "test"]:
register(
id=f"Meta-World/{ml_bench}-{split}",
vector_entry_point=lambda _ml_bench=ml_bench, _split=split, vector_strategy="sync", autoreset_mode=gym.vector.AutoresetMode.SAME_STEP, total_tasks_per_cls=None, seed=None, meta_batch_size=20, num_envs=None, **kwargs: _ml_bench_vector_entry_point(
vector_entry_point=lambda _ml_bench=ml_bench, _split=split, vector_strategy="sync", autoreset_mode=gym.vector.AutoresetMode.SAME_STEP, total_tasks_per_cls=None, seed=None, meta_batch_size=None, num_envs=None, **kwargs: _ml_bench_vector_entry_point(
_ml_bench,
_split,
vector_strategy,
Expand Down Expand Up @@ -788,7 +827,7 @@ def _custom_ml_vector_entry_point(
autoreset_mode: gym.vector.AutoresetMode
| str = gym.vector.AutoresetMode.SAME_STEP,
total_tasks_per_cls: int | None = None,
meta_batch_size: int = 20,
meta_batch_size: int | None = None,
seed=None,
num_envs=None,
**lamb_kwargs,
Expand All @@ -805,7 +844,7 @@ def _custom_ml_vector_entry_point(

register(
id="Meta-World/custom-ml-envs",
vector_entry_point=lambda vector_strategy, train_envs, test_envs, autoreset_mode=gym.vector.AutoresetMode.SAME_STEP, total_tasks_per_cls=None, meta_batch_size=20, seed=None, num_envs=None, **kwargs: _custom_ml_vector_entry_point(
vector_entry_point=lambda vector_strategy, train_envs, test_envs, autoreset_mode=gym.vector.AutoresetMode.SAME_STEP, total_tasks_per_cls=None, meta_batch_size=None, seed=None, num_envs=None, **kwargs: _custom_ml_vector_entry_point(
vector_strategy,
train_envs,
test_envs,
Expand Down
2 changes: 1 addition & 1 deletion metaworld/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def __init__(
self,
env: Env,
tasks: list[Task],
sample_tasks_on_reset: bool = False,
sample_tasks_on_reset: bool = True,
):
super().__init__(env)
self.sample_tasks_on_reset = sample_tasks_on_reset
Expand Down