Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and make (final v3 pr) #533

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
new tests
reginald-mclean committed Mar 17, 2025
commit 067654ed90f74777409235d30a54684b585161be
12 changes: 6 additions & 6 deletions tests/integration/test_memory_usage.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import memory_profiler
import pytest

from metaworld.envs.mujoco.env_dict import ALL_V2_ENVIRONMENTS
from metaworld.env_dict import ALL_V3_ENVIRONMENTS
from tests.helpers import step_env


@@ -21,7 +21,7 @@ def build_and_step_all(classes):
@pytest.fixture(scope="module")
def mt50_usage():
profile = {}
for env_cls in ALL_V2_ENVIRONMENTS.values():
for env_cls in ALL_V3_ENVIRONMENTS.values():
target = (build_and_step, [env_cls], {})
memory_usage = memory_profiler.memory_usage(target)
profile[env_cls] = max(memory_usage)
@@ -30,7 +30,7 @@ def mt50_usage():


@pytest.mark.skip
@pytest.mark.parametrize("env_cls", ALL_V2_ENVIRONMENTS.values())
@pytest.mark.parametrize("env_cls", ALL_V3_ENVIRONMENTS.values())
def test_max_memory_usage(env_cls, mt50_usage):
# No env should use more than 250MB
#
@@ -43,14 +43,14 @@ def test_max_memory_usage(env_cls, mt50_usage):
@pytest.mark.skip
def test_avg_memory_usage():
# average usage no greater than 60MB/env
target = (build_and_step_all, [ALL_V2_ENVIRONMENTS.values()], {})
target = (build_and_step_all, [ALL_V3_ENVIRONMENTS.values()], {})
usage = memory_profiler.memory_usage(target)
average = max(usage) / len(ALL_V2_ENVIRONMENTS)
average = max(usage) / len(ALL_V3_ENVIRONMENTS)
assert average < 60


@pytest.mark.skip
def test_from_task_memory_usage():
target = (ALL_V2_ENVIRONMENTS["reach-v1"], (), {})
target = (ALL_V3_ENVIRONMENTS["reach-v1"], (), {})
usage = memory_profiler.memory_usage(target)
assert max(usage) < 250
18 changes: 9 additions & 9 deletions tests/integration/test_new_api.py
Original file line number Diff line number Diff line change
@@ -251,15 +251,15 @@ def check_target_poss_unique(env_instances, env_rand_vecs):
"""Verify that all the state_goals are unique for the different rand_vecs that are sampled.

Note: The following envs randomize object initial position but not state_goal.
['hammer-v2', 'sweep-into-v2', 'bin-picking-v2', 'basketball-v2']
['hammer-v3', 'sweep-into-v3', 'bin-picking-v3', 'basketball-v3']

"""
for env_name, rand_vecs in env_rand_vecs.items():
if env_name in {
"hammer-v2",
"sweep-into-v2",
"bin-picking-v2",
"basketball-v2",
"hammer-v3",
"sweep-into-v3",
"bin-picking-v3",
"basketball-v3",
}:
continue
env = env_instances[env_name]
@@ -289,13 +289,13 @@ def helper_neq(env, env_2):
assert not (rand_vec_1 == rand_vec_2).all()

# testing MT1
mt1_1 = metaworld.MT1("sweep-into-v2", seed=10)
mt1_2 = metaworld.MT1("sweep-into-v2", seed=10)
mt1_1 = metaworld.MT1("sweep-into-v3", seed=10)
mt1_2 = metaworld.MT1("sweep-into-v3", seed=10)
helper(mt1_1, mt1_2)

# testing ML1
ml1_1 = metaworld.ML1("sweep-into-v2", seed=10)
ml1_2 = metaworld.ML1("sweep-into-v2", seed=10)
ml1_1 = metaworld.ML1("sweep-into-v3", seed=10)
ml1_2 = metaworld.ML1("sweep-into-v3", seed=10)
helper(ml1_1, ml1_2)

# testing MT10
48 changes: 24 additions & 24 deletions tests/integration/test_single_goal_envs.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import numpy as np

from metaworld.envs import (
ALL_V2_ENVIRONMENTS_GOAL_HIDDEN,
ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE,
from metaworld.env_dict import (
ALL_V3_ENVIRONMENTS_GOAL_HIDDEN,
ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE,
)
from tests.helpers import step_env


def test_hidden_goal_envs():
for env_key, env_cls in ALL_V2_ENVIRONMENTS_GOAL_HIDDEN.items():
for env_key, env_cls in ALL_V3_ENVIRONMENTS_GOAL_HIDDEN.items():
assert "goal-hidden" in env_key
assert "GoalHidden" in env_cls.__name__
state_before = np.random.get_state()
env = env_cls(seed=5)
env2 = env_cls(seed=5)
enV3 = env_cls(seed=5)
step_env(env, max_path_length=3, iterations=3, render=False)

first_target = env._target_pos
@@ -22,8 +22,8 @@ def test_hidden_goal_envs():

assert (first_target == second_target).all()
env.reset()
env2.reset()
assert (env._target_pos == env2._target_pos).all()
enV3.reset()
assert (env._target_pos == enV3._target_pos).all()
state_after = np.random.get_state()
for idx, (state_before_idx, state_after_idx) in enumerate(
zip(state_before, state_after)
@@ -35,12 +35,12 @@ def test_hidden_goal_envs():


def test_observable_goal_envs():
for env_key, env_cls in ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.items():
for env_key, env_cls in ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE.items():
assert "goal-observable" in env_key
assert "GoalObservable" in env_cls.__name__
state_before = np.random.get_state()
env = env_cls(seed=10)
env2 = env_cls(seed=10)
enV3 = env_cls(seed=10)
step_env(env, max_path_length=3, iterations=3, render=False)

first_target = env._target_pos
@@ -49,8 +49,8 @@ def test_observable_goal_envs():

assert (first_target == second_target).all()
env.reset()
env2.reset()
assert (env._target_pos == env2._target_pos).all()
enV3.reset()
assert (env._target_pos == enV3._target_pos).all()
state_after = np.random.get_state()
for idx, (state_before_idx, state_after_idx) in enumerate(
zip(state_before, state_after)
@@ -62,21 +62,21 @@ def test_observable_goal_envs():


def test_seeding_observable():
door_open_goal_observable_cls = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[
"door-open-v2-goal-observable"
door_open_goal_observable_cls = ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE[
"door-open-v3-goal-observable"
]

env1 = door_open_goal_observable_cls(seed=5)
env2 = door_open_goal_observable_cls(seed=5)
enV3 = door_open_goal_observable_cls(seed=5)

env1.reset() # Reset environment
env2.reset()
enV3.reset()
a1 = env1.action_space.sample() # Sample an action
a2 = env2.action_space.sample()
a2 = enV3.action_space.sample()
next_obs1, _, _, _, _ = env1.step(
a1
) # Step the environoment with the sampled random action
next_obs2, _, _, _, _ = env2.step(a2)
next_obs2, _, _, _, _ = enV3.step(a2)
assert (
next_obs1[-3:] == next_obs2[-3:]
).all() # 2 envs initialized with the same seed will have the same goal
@@ -105,23 +105,23 @@ def test_seeding_observable():


def test_seeding_hidden():
door_open_goal_hidden_cls = ALL_V2_ENVIRONMENTS_GOAL_HIDDEN[
"door-open-v2-goal-hidden"
door_open_goal_hidden_cls = ALL_V3_ENVIRONMENTS_GOAL_HIDDEN[
"door-open-v3-goal-hidden"
]

env1 = door_open_goal_hidden_cls(seed=5)
env2 = door_open_goal_hidden_cls(seed=5)
enV3 = door_open_goal_hidden_cls(seed=5)

env1.reset() # Reset environment
env2.reset()
enV3.reset()
a1 = env1.action_space.sample() # Sample an action
a2 = env2.action_space.sample()
a2 = enV3.action_space.sample()
next_obs1, _, _, _, _ = env1.step(
a1
) # Step the environoment with the sampled random action
next_obs2, _, _, _, _ = env2.step(a2)
next_obs2, _, _, _, _ = enV3.step(a2)
assert (
env1._target_pos == env2._target_pos
env1._target_pos == enV3._target_pos
).all() # 2 envs initialized with the same seed will have the same goal
assert (next_obs2[-3:] == np.zeros(3)).all() and (
next_obs1[-3] == np.zeros(3)
6 changes: 3 additions & 3 deletions tests/metaworld/envs/mujoco/sawyer_xyz/test_obs_space_hand.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import numpy as np
import pytest

from metaworld.envs.mujoco.env_dict import ALL_V2_ENVIRONMENTS
from metaworld.envs.mujoco.sawyer_xyz import SawyerXYZEnv
from metaworld.env_dict import ALL_V3_ENVIRONMENTS
from metaworld.policies.action import Action
from metaworld.policies.policy import Policy, move
from metaworld.sawyer_xyz_env import SawyerXYZEnv


class SawyerRandomReachPolicy(Policy):
@@ -43,7 +43,7 @@ def sample_spherical(num_points, radius=1.0):

@pytest.mark.parametrize("target", sample_spherical(100, 10.0))
def test_reaching_limit(target):
env = ALL_V2_ENVIRONMENTS["reach-v2"]()
env = ALL_V3_ENVIRONMENTS["reach-v3"]()
env._partially_observable = False
env._freeze_rand_vec = False
env._set_task_called = True
123 changes: 12 additions & 111 deletions tests/metaworld/envs/mujoco/sawyer_xyz/test_scripted_policies.py
Original file line number Diff line number Diff line change
@@ -1,120 +1,22 @@
import random

import numpy as np
import pytest

from metaworld import MT1
from metaworld.policies import (
SawyerAssemblyV2Policy,
SawyerBasketballV2Policy,
SawyerBinPickingV2Policy,
SawyerBoxCloseV2Policy,
SawyerButtonPressTopdownV2Policy,
SawyerButtonPressTopdownWallV2Policy,
SawyerButtonPressV2Policy,
SawyerButtonPressWallV2Policy,
SawyerCoffeeButtonV2Policy,
SawyerCoffeePullV2Policy,
SawyerCoffeePushV2Policy,
SawyerDialTurnV2Policy,
SawyerDisassembleV2Policy,
SawyerDoorCloseV2Policy,
SawyerDoorLockV2Policy,
SawyerDoorOpenV2Policy,
SawyerDoorUnlockV2Policy,
SawyerDrawerCloseV2Policy,
SawyerDrawerOpenV2Policy,
SawyerFaucetCloseV2Policy,
SawyerFaucetOpenV2Policy,
SawyerHammerV2Policy,
SawyerHandInsertV2Policy,
SawyerHandlePressSideV2Policy,
SawyerHandlePressV2Policy,
SawyerHandlePullSideV2Policy,
SawyerHandlePullV2Policy,
SawyerLeverPullV2Policy,
SawyerPegInsertionSideV2Policy,
SawyerPegUnplugSideV2Policy,
SawyerPickOutOfHoleV2Policy,
SawyerPickPlaceV2Policy,
SawyerPickPlaceWallV2Policy,
SawyerPlateSlideBackSideV2Policy,
SawyerPlateSlideBackV2Policy,
SawyerPlateSlideSideV2Policy,
SawyerPlateSlideV2Policy,
SawyerPushBackV2Policy,
SawyerPushV2Policy,
SawyerPushWallV2Policy,
SawyerReachV2Policy,
SawyerReachWallV2Policy,
SawyerShelfPlaceV2Policy,
SawyerSoccerV2Policy,
SawyerStickPullV2Policy,
SawyerStickPushV2Policy,
SawyerSweepIntoV2Policy,
SawyerSweepV2Policy,
SawyerWindowCloseV2Policy,
SawyerWindowOpenV2Policy,
)

policies = dict(
{
"assembly-v2": SawyerAssemblyV2Policy,
"basketball-v2": SawyerBasketballV2Policy,
"bin-picking-v2": SawyerBinPickingV2Policy,
"box-close-v2": SawyerBoxCloseV2Policy,
"button-press-topdown-v2": SawyerButtonPressTopdownV2Policy,
"button-press-topdown-wall-v2": SawyerButtonPressTopdownWallV2Policy,
"button-press-v2": SawyerButtonPressV2Policy,
"button-press-wall-v2": SawyerButtonPressWallV2Policy,
"coffee-button-v2": SawyerCoffeeButtonV2Policy,
"coffee-pull-v2": SawyerCoffeePullV2Policy,
"coffee-push-v2": SawyerCoffeePushV2Policy,
"dial-turn-v2": SawyerDialTurnV2Policy,
"disassemble-v2": SawyerDisassembleV2Policy,
"door-close-v2": SawyerDoorCloseV2Policy,
"door-lock-v2": SawyerDoorLockV2Policy,
"door-open-v2": SawyerDoorOpenV2Policy,
"door-unlock-v2": SawyerDoorUnlockV2Policy,
"drawer-close-v2": SawyerDrawerCloseV2Policy,
"drawer-open-v2": SawyerDrawerOpenV2Policy,
"faucet-close-v2": SawyerFaucetCloseV2Policy,
"faucet-open-v2": SawyerFaucetOpenV2Policy,
"hammer-v2": SawyerHammerV2Policy,
"hand-insert-v2": SawyerHandInsertV2Policy,
"handle-press-side-v2": SawyerHandlePressSideV2Policy,
"handle-press-v2": SawyerHandlePressV2Policy,
"handle-pull-v2": SawyerHandlePullV2Policy,
"handle-pull-side-v2": SawyerHandlePullSideV2Policy,
"peg-insert-side-v2": SawyerPegInsertionSideV2Policy,
"lever-pull-v2": SawyerLeverPullV2Policy,
"peg-unplug-side-v2": SawyerPegUnplugSideV2Policy,
"pick-out-of-hole-v2": SawyerPickOutOfHoleV2Policy,
"pick-place-v2": SawyerPickPlaceV2Policy,
"pick-place-wall-v2": SawyerPickPlaceWallV2Policy,
"plate-slide-back-side-v2": SawyerPlateSlideBackSideV2Policy,
"plate-slide-back-v2": SawyerPlateSlideBackV2Policy,
"plate-slide-side-v2": SawyerPlateSlideSideV2Policy,
"plate-slide-v2": SawyerPlateSlideV2Policy,
"reach-v2": SawyerReachV2Policy,
"reach-wall-v2": SawyerReachWallV2Policy,
"push-back-v2": SawyerPushBackV2Policy,
"push-v2": SawyerPushV2Policy,
"push-wall-v2": SawyerPushWallV2Policy,
"shelf-place-v2": SawyerShelfPlaceV2Policy,
"soccer-v2": SawyerSoccerV2Policy,
"stick-pull-v2": SawyerStickPullV2Policy,
"stick-push-v2": SawyerStickPushV2Policy,
"sweep-into-v2": SawyerSweepIntoV2Policy,
"sweep-v2": SawyerSweepV2Policy,
"window-close-v2": SawyerWindowCloseV2Policy,
"window-open-v2": SawyerWindowOpenV2Policy,
}
)
from metaworld.policies import ENV_POLICY_MAP


@pytest.mark.parametrize("env_name", MT1.ENV_NAMES)
def test_policy(env_name):
mt1 = MT1(env_name)
SEED = 42
random.seed(SEED)
np.random.random(SEED)

mt1 = MT1(env_name, seed=SEED)
env = mt1.train_classes[env_name]()
p = policies[env_name]()
env.seed(SEED)
p = ENV_POLICY_MAP[env_name]()
completed = 0
for task in mt1.train_tasks:
env.set_task(task)
@@ -130,5 +32,4 @@ def test_policy(env_name):
if int(info["success"]) == 1:
completed += 1
break
print(float(completed) / 50)
assert (float(completed) / 50) > 0.80
assert (float(completed) / 50) >= 0.80
14 changes: 7 additions & 7 deletions tests/metaworld/envs/mujoco/sawyer_xyz/test_seeded_rand_vec.py
Original file line number Diff line number Diff line change
@@ -3,24 +3,24 @@
import numpy as np
import pytest

from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
from metaworld.env_dict import ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE


@pytest.mark.parametrize("env_name", sorted(ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.keys()))
@pytest.mark.parametrize("env_name", sorted(ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE.keys()))
def test_observations_match(env_name):
seed = random.randrange(1000)
env1 = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_name](seed=seed)
env1 = ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE[env_name](seed=seed)
env1.seeded_rand_vec = True
env2 = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_name](seed=seed)
env2.seeded_rand_vec = True
enV3 = ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE[env_name](seed=seed)
enV3.seeded_rand_vec = True

(obs1, _), (obs2, _) = env1.reset(), env2.reset()
(obs1, _), (obs2, _) = env1.reset(), enV3.reset()
assert (obs1 == obs2).all()

for i in range(env1.max_path_length):
a = np.random.uniform(low=-1, high=-1, size=4)
obs1, r1, done1, _, _ = env1.step(a)
obs2, r2, done2, _, _ = env2.step(a)
obs2, r2, done2, _, _ = enV3.step(a)
assert (obs1 == obs2).all()
assert r1 == r2
assert not done1
151 changes: 151 additions & 0 deletions tests/metaworld/test_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from __future__ import annotations

import random

import gymnasium as gym
import numpy as np
import numpy.typing as npt
import pytest

import metaworld # noqa: F401
from metaworld import evaluation
from metaworld.policies import ENV_POLICY_MAP


class ScriptedPolicyAgent(evaluation.MetaLearningAgent):
def __init__(
self,
envs: gym.vector.SyncVectorEnv | gym.vector.AsyncVectorEnv,
num_rollouts: int | None = None,
max_episode_steps: int | None = None,
):
env_task_names = evaluation._get_task_names(envs)
self.policies = [ENV_POLICY_MAP[task]() for task in env_task_names] # type: ignore
self.num_rollouts = num_rollouts
self.max_episode_steps = max_episode_steps
self.adapt_calls = 0

def adapt_action(
self, observations: npt.NDArray[np.float64]
) -> tuple[npt.NDArray[np.float64], dict[str, npt.NDArray]]:
actions: list[npt.NDArray[np.float32]] = []
num_envs = len(self.policies)
for env_idx in range(num_envs):
actions.append(self.policies[env_idx].get_action(observations[env_idx]))
stacked_actions = np.stack(actions, axis=0, dtype=np.float64)
return stacked_actions, {
"log_probs": np.ones((num_envs,)),
"means": stacked_actions,
"stds": np.zeros((num_envs,)),
}

def eval_action(
self, observations: npt.NDArray[np.float64]
) -> npt.NDArray[np.float64]:
actions: list[npt.NDArray[np.float32]] = []
num_envs = len(self.policies)
for env_idx in range(num_envs):
actions.append(self.policies[env_idx].get_action(observations[env_idx]))
stacked_actions = np.stack(actions, axis=0, dtype=np.float64)
return stacked_actions

def adapt(self, rollouts: evaluation.Rollout) -> None:
assert self.num_rollouts is not None

for key in [
"observations",
"rewards",
"actions",
"dones",
"log_probs",
"means",
"stds",
]:
assert len(getattr(rollouts, key).shape) >= 3
assert getattr(rollouts, key).shape[0] == len(self.policies)
assert getattr(rollouts, key).shape[1] == self.num_rollouts
assert getattr(rollouts, key).shape[2] == self.max_episode_steps

self.adapt_calls += 1


class RemovePartialObservabilityWrapper(gym.vector.VectorWrapper):
def get_attr(self, name):
return self.env.get_attr(name)

def set_attr(self, name, values):
return self.env.set_attr(name, values)

def call(self, name, *args, **kwargs):
return self.env.call(name, *args, **kwargs)

def step(self, actions):
self.env.set_attr("_partially_observable", False)
return super().step(actions)


def test_evaluation():
SEED = 42
max_episode_steps = 300 # To speed up the test
num_episodes = 50

random.seed(SEED)
np.random.seed(SEED)
envs = gym.make_vec(
"Meta-World/MT50",
seed=SEED,
max_episode_steps=max_episode_steps,
vector_strategy="async",
)
agent = ScriptedPolicyAgent(envs)
mean_success_rate, mean_returns, success_rate_per_task = evaluation.evaluation(
agent, envs, num_episodes=num_episodes
)
assert isinstance(mean_returns, float)
assert mean_success_rate >= 0.80
assert len(success_rate_per_task) == envs.num_envs
assert np.all(np.array(list(success_rate_per_task.values())) >= 0.80)


# @pytest.mark.skip
@pytest.mark.parametrize("benchmark", ("ML10", "ML45"))
def test_metalearning_evaluation(benchmark):
SEED = 42

max_episode_steps = 300
meta_batch_size = 10 # Number of parallel envs

adaptation_steps = 2 # Number of adaptation iterations
adaptation_episodes = 2 # Number of train episodes per task in meta_batch_size per adaptation iteration
num_evals = 50 # Number of different task vectors tested for each task
num_episodes = 1 # Number of test episodes per task vector

random.seed(SEED)
np.random.seed(SEED)
envs = gym.make_vec(
f"Meta-World/{benchmark}-test",
seed=SEED,
vector_strategy="async",
meta_batch_size=meta_batch_size,
max_episode_steps=max_episode_steps,
)
envs = RemovePartialObservabilityWrapper(envs)
agent = ScriptedPolicyAgent(envs, adaptation_episodes, max_episode_steps)
(
mean_success_rate,
mean_returns,
success_rate_per_task,
) = evaluation.metalearning_evaluation(
agent,
envs,
max_episode_steps=max_episode_steps,
num_episodes=num_episodes,
adaptation_episodes=adaptation_episodes,
adaptation_steps=adaptation_steps,
num_evals=num_evals,
)
assert isinstance(mean_returns, float)
assert mean_success_rate >= 0.80
assert len(success_rate_per_task) == len(set(evaluation._get_task_names(envs)))
assert np.all(np.array(list(success_rate_per_task.values())) >= 0.80)
assert agent.adapt_calls == num_evals * adaptation_steps
204 changes: 204 additions & 0 deletions tests/metaworld/test_gym_make.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
from __future__ import annotations

import random
from typing import Literal

import gymnasium as gym
import numpy as np
import pytest

import metaworld # noqa: F401
from metaworld import _N_GOALS, SawyerXYZEnv
from metaworld.env_dict import (
ALL_V3_ENVIRONMENTS,
ALL_V3_ENVIRONMENTS_GOAL_HIDDEN,
ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE,
ML10_V3,
ML45_V3,
MT10_V3,
MT50_V3,
EnvDict,
TrainTestEnvDict,
)


def _get_task_names(
envs: gym.vector.SyncVectorEnv | gym.vector.AsyncVectorEnv,
) -> list[str]:
metaworld_cls_to_task_name = {v.__name__: k for k, v in ALL_V3_ENVIRONMENTS.items()}
return [
metaworld_cls_to_task_name[task_name]
for task_name in envs.get_attr("task_name")
]


@pytest.mark.parametrize("benchmark,env_dict", (("MT10", MT10_V3), ("MT50", MT50_V3)))
@pytest.mark.parametrize("vector_strategy", ("sync", "async"))
def test_mt_benchmarks(benchmark: str, env_dict: EnvDict, vector_strategy: str):
SEED = 42
random.seed(SEED)
np.random.seed(SEED)

max_episode_steps = 10

envs = gym.make_vec(
f"Meta-World/{benchmark}",
vector_strategy=vector_strategy,
seed=SEED,
use_one_hot=True,
max_episode_steps=max_episode_steps,
)

# Assert vec is correct
expected_vectorisation = getattr(
gym.vector, f"{vector_strategy.capitalize()}VectorEnv"
)
assert isinstance(envs, expected_vectorisation)

# Assert envs are correct
task_names = _get_task_names(envs)
assert envs.num_envs == len(env_dict.keys())
assert set(task_names) == set(env_dict.keys())

# Assert every env has N_GOALS goals
envs_tasks = envs.get_attr("tasks")
for env_tasks in envs_tasks:
assert len(env_tasks) == _N_GOALS

# Test wrappers: one hot obs, task sampling, max path length
obs, _ = envs.reset()
original_vecs = envs.get_attr("_last_rand_vec")

has_truncated = False
for _ in range(max_episode_steps + 1):
obs, _, _, truncated, _ = envs.step(envs.action_space.sample())
print(obs)
env_one_hots = obs[:, -envs.num_envs :]
env_ids = np.argmax(env_one_hots, axis=1)
assert set(env_ids) == set(range(envs.num_envs))

if any(truncated):
has_truncated = True

assert has_truncated

new_vecs = envs.get_attr("_last_rand_vec")
task_has_changed = False
for og_vec, new_vec in zip(original_vecs, new_vecs):
if np.any(og_vec != new_vec):
task_has_changed = True
assert task_has_changed

partially_observable = all(envs.get_attr("_partially_observable"))
assert not partially_observable


@pytest.mark.parametrize("env_name", ALL_V3_ENVIRONMENTS.keys())
def test_mt1(env_name: str):
metaworld_cls_to_task_name = {v.__name__: k for k, v in ALL_V3_ENVIRONMENTS.items()}
env = gym.make("Meta-World/MT1", env_name=env_name)
assert isinstance(env.unwrapped, SawyerXYZEnv)
assert len(env.get_wrapper_attr("tasks")) == _N_GOALS
assert metaworld_cls_to_task_name[env.unwrapped.task_name] == env_name

env.reset()
assert not env.unwrapped._partially_observable


@pytest.mark.parametrize("env_name", ALL_V3_ENVIRONMENTS_GOAL_HIDDEN.keys())
def test_goal_hidden(env_name: str):
env = gym.make("Meta-World/goal_hidden", env_name=env_name, seed=None)
assert isinstance(env.unwrapped, SawyerXYZEnv)

env.reset()
assert env.unwrapped._partially_observable


@pytest.mark.parametrize("env_name", ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE.keys())
def test_goal_observable(env_name: str):
env = gym.make("Meta-World/goal_observable", env_name=env_name, seed=None)
assert isinstance(env.unwrapped, SawyerXYZEnv)

env.reset()
assert not env.unwrapped._partially_observable


@pytest.mark.parametrize("env_name", ALL_V3_ENVIRONMENTS.keys())
@pytest.mark.parametrize("split", ("train", "test"))
@pytest.mark.parametrize("vector_strategy", ("sync", "async"))
def test_ml1(env_name, split, vector_strategy):
meta_batch_size = 10
max_episode_steps = 10

envs = gym.make_vec(
f"Meta-World/ML1-{split}",
env_name=env_name,
vector_strategy=vector_strategy,
meta_batch_size=meta_batch_size,
max_episode_steps=max_episode_steps,
)
assert envs.num_envs == meta_batch_size
task_names = _get_task_names(envs)
assert all([task_name == env_name for task_name in task_names])

# Assert vec is correct
expected_vectorisation = getattr(
gym.vector, f"{vector_strategy.capitalize()}VectorEnv"
)
assert isinstance(envs, expected_vectorisation)

envs_tasks = envs.get_attr("tasks")
total_tasks = sum([len(env_tasks) for env_tasks in envs_tasks])
assert total_tasks == _N_GOALS

partially_observable = all(envs.get_attr("_partially_observable"))
assert partially_observable


@pytest.mark.parametrize("benchmark,env_dict", (("ML10", ML10_V3), ("ML45", ML45_V3)))
@pytest.mark.parametrize("split", ("train", "test"))
@pytest.mark.parametrize("vector_strategy", ("sync", "async"))
def test_ml_benchmarks(
benchmark: str,
env_dict: TrainTestEnvDict,
split: Literal["train", "test"],
vector_strategy: str,
):
meta_batch_size = 20 if benchmark != "ML45" else 45
total_tasks_per_cls = _N_GOALS
if benchmark == "ML45":
total_tasks_per_cls = 45
elif benchmark == "ML10" and split == "test":
total_tasks_per_cls = 40
max_episode_steps = 10

envs = gym.make_vec(
f"Meta-World/{benchmark}-{split}",
vector_strategy=vector_strategy,
meta_batch_size=meta_batch_size,
max_episode_steps=max_episode_steps,
total_tasks_per_cls=total_tasks_per_cls,
)
assert envs.num_envs == meta_batch_size
task_names = _get_task_names(envs) # type: ignore
assert set(task_names) == set(env_dict[split].keys())

# Assert vec is correct
expected_vectorisation = getattr(
gym.vector, f"{vector_strategy.capitalize()}VectorEnv"
)
assert isinstance(envs, expected_vectorisation)

envs_tasks = envs.get_attr("tasks")
tasks_per_env = {}
for task in env_dict[split].keys():
tasks_per_env[task] = 0

for env_tasks, env_name in zip(envs_tasks, task_names):
tasks_per_env[env_name] += len(env_tasks)

for task in env_dict[split].keys():
assert tasks_per_env[task] == total_tasks_per_cls

partially_observable = all(envs.get_attr("_partially_observable"))
assert partially_observable