|
33 | 33 | _td_to_device_mps_safe, |
34 | 34 | _to_device_mps_safe, |
35 | 35 | ) |
| 36 | +from torchrl.envs.env_creator import get_env_metadata |
36 | 37 | from torchrl.envs.transforms import StepCounter, TransformedEnv |
37 | 38 | from torchrl.envs.transforms.transforms import Tokenizer |
38 | 39 | from torchrl.envs.utils import check_env_specs |
@@ -999,3 +1000,160 @@ def test_parallel_env_no_buffers_mps_rollout(self): |
999 | 1000 | assert td["observation"].dtype == torch.float32 |
1000 | 1001 | finally: |
1001 | 1002 | env.close(raise_if_closed=False) |
| 1003 | + |
| 1004 | + |
| 1005 | +_MPS_USE_BUFFERS_WARNING = ( |
| 1006 | + "The environment specs have leaves on an MPS device, which cannot be placed " |
| 1007 | + "in shared memory" |
| 1008 | +) |
| 1009 | +_MPS_USE_BUFFERS_ERROR = ( |
| 1010 | + "use_buffers=True is incompatible with environments whose specs have leaves " |
| 1011 | + "on an MPS device" |
| 1012 | +) |
| 1013 | + |
| 1014 | + |
| 1015 | +class TestParallelEnvMPSBuffers: |
| 1016 | + """ParallelEnv use_buffers checks for MPS sub-envs (issue #3066). |
| 1017 | +
|
| 1018 | + These tests fake the device map reported by the env metadata, so they run |
| 1019 | + on CPU-only machines too. |
| 1020 | + """ |
| 1021 | + |
| 1022 | + @staticmethod |
| 1023 | + def _patch_device_map_to_mps(monkeypatch): |
| 1024 | + def get_env_metadata_mps(*args, **kwargs): |
| 1025 | + meta_data = get_env_metadata(*args, **kwargs) |
| 1026 | + meta_data.device_map = { |
| 1027 | + key: torch.device("mps") for key in meta_data.device_map |
| 1028 | + } |
| 1029 | + return meta_data |
| 1030 | + |
| 1031 | + monkeypatch.setattr( |
| 1032 | + "torchrl.envs.batched_envs.get_env_metadata", get_env_metadata_mps |
| 1033 | + ) |
| 1034 | + |
| 1035 | + def test_parallel_env_mps_leaves_default_use_buffers_false(self, monkeypatch): |
| 1036 | + self._patch_device_map_to_mps(monkeypatch) |
| 1037 | + with pytest.warns(UserWarning, match=_MPS_USE_BUFFERS_WARNING): |
| 1038 | + env = ParallelEnv(2, ContinuousActionVecMockEnv) |
| 1039 | + assert env._use_buffers is False |
| 1040 | + |
| 1041 | + def test_parallel_env_mps_leaves_use_buffers_true_raises(self, monkeypatch): |
| 1042 | + self._patch_device_map_to_mps(monkeypatch) |
| 1043 | + with pytest.raises(RuntimeError, match=_MPS_USE_BUFFERS_ERROR): |
| 1044 | + ParallelEnv(2, ContinuousActionVecMockEnv, use_buffers=True) |
| 1045 | + |
| 1046 | + def test_parallel_env_mps_leaves_configure_parallel_raises(self, monkeypatch): |
| 1047 | + self._patch_device_map_to_mps(monkeypatch) |
| 1048 | + with pytest.warns(UserWarning, match=_MPS_USE_BUFFERS_WARNING): |
| 1049 | + env = ParallelEnv(2, ContinuousActionVecMockEnv) |
| 1050 | + with pytest.raises(RuntimeError, match=_MPS_USE_BUFFERS_ERROR): |
| 1051 | + env.configure_parallel(use_buffers=True) |
| 1052 | + |
| 1053 | + def test_parallel_env_mps_leaves_explicit_use_buffers_false(self, monkeypatch): |
| 1054 | + self._patch_device_map_to_mps(monkeypatch) |
| 1055 | + env = ParallelEnv(2, ContinuousActionVecMockEnv, use_buffers=False) |
| 1056 | + assert env._use_buffers is False |
| 1057 | + |
| 1058 | + def test_serial_env_mps_leaves_keeps_buffers(self, monkeypatch): |
| 1059 | + # SerialEnv runs in-process, so MPS buffers are fine there |
| 1060 | + self._patch_device_map_to_mps(monkeypatch) |
| 1061 | + env = SerialEnv(2, ContinuousActionVecMockEnv) |
| 1062 | + assert env._use_buffers is True |
| 1063 | + |
| 1064 | + |
| 1065 | +@pytest.mark.skipif(not _has_mps(), reason="MPS device not available") |
| 1066 | +class TestMPSSubEnvs: |
| 1067 | + """ParallelEnv and collectors over sub-envs living on MPS (issue #3066).""" |
| 1068 | + |
| 1069 | + class _MPSObsEnv(EnvBase): |
| 1070 | + """Minimal env with all spec leaves on MPS. |
| 1071 | +
|
| 1072 | + The observation mirrors the last action so that the parent-worker |
| 1073 | + round-trip can be checked end-to-end. |
| 1074 | + """ |
| 1075 | + |
| 1076 | + def __init__(self, device="mps"): |
| 1077 | + super().__init__(device=device) |
| 1078 | + self.observation_spec = Composite( |
| 1079 | + observation=Unbounded(shape=(3,), device=device), device=device |
| 1080 | + ) |
| 1081 | + self.action_spec = Unbounded(shape=(1,), device=device) |
| 1082 | + self.reward_spec = Unbounded(shape=(1,), device=device) |
| 1083 | + |
| 1084 | + def _reset(self, tensordict): |
| 1085 | + return TensorDict( |
| 1086 | + {"observation": torch.zeros(3, device=self.device)}, |
| 1087 | + batch_size=[], |
| 1088 | + device=self.device, |
| 1089 | + ) |
| 1090 | + |
| 1091 | + def _step(self, tensordict): |
| 1092 | + return TensorDict( |
| 1093 | + { |
| 1094 | + "observation": tensordict["action"].expand(3).clone(), |
| 1095 | + "reward": torch.zeros(1, device=self.device), |
| 1096 | + "done": torch.zeros(1, dtype=torch.bool, device=self.device), |
| 1097 | + }, |
| 1098 | + batch_size=[], |
| 1099 | + device=self.device, |
| 1100 | + ) |
| 1101 | + |
| 1102 | + def _set_seed(self, seed): |
| 1103 | + return seed |
| 1104 | + |
| 1105 | + def test_parallel_env_mps_sub_envs_default_warns_and_runs(self): |
| 1106 | + with pytest.warns(UserWarning, match=_MPS_USE_BUFFERS_WARNING): |
| 1107 | + env = ParallelEnv(2, self._MPSObsEnv) |
| 1108 | + try: |
| 1109 | + assert env._use_buffers is False |
| 1110 | + td = env.reset() |
| 1111 | + assert td.device.type == "mps" |
| 1112 | + assert td["observation"].device.type == "mps" |
| 1113 | + policy = RandomPolicy(env.action_spec) |
| 1114 | + rollout = env.rollout(max_steps=3, policy=policy) |
| 1115 | + assert rollout.device.type == "mps" |
| 1116 | + # the worker must have seen the actions sampled in the parent |
| 1117 | + assert (rollout["next", "observation"] == rollout["action"]).all() |
| 1118 | + finally: |
| 1119 | + env.close(raise_if_closed=False) |
| 1120 | + |
| 1121 | + def test_parallel_env_mps_sub_envs_use_buffers_true_raises(self): |
| 1122 | + with pytest.raises(RuntimeError, match=_MPS_USE_BUFFERS_ERROR): |
| 1123 | + ParallelEnv(2, self._MPSObsEnv, use_buffers=True) |
| 1124 | + |
| 1125 | + @pytest.mark.parametrize("consolidate", [True, False]) |
| 1126 | + def test_parallel_env_mps_sub_envs_no_buffers_rollout(self, consolidate): |
| 1127 | + env = ParallelEnv( |
| 1128 | + 2, self._MPSObsEnv, use_buffers=False, consolidate=consolidate |
| 1129 | + ) |
| 1130 | + try: |
| 1131 | + policy = RandomPolicy(env.action_spec) |
| 1132 | + rollout = env.rollout(max_steps=3, policy=policy) |
| 1133 | + assert rollout.device.type == "mps" |
| 1134 | + assert (rollout["next", "observation"] == rollout["action"]).all() |
| 1135 | + finally: |
| 1136 | + env.close(raise_if_closed=False) |
| 1137 | + |
| 1138 | + def test_collector_parallel_env_mps_sub_envs(self): |
| 1139 | + # the setup reported in issue #3066 |
| 1140 | + with pytest.warns(UserWarning, match=_MPS_USE_BUFFERS_WARNING): |
| 1141 | + collector = Collector( |
| 1142 | + lambda: ParallelEnv(2, self._MPSObsEnv), |
| 1143 | + frames_per_batch=4, |
| 1144 | + total_frames=8, |
| 1145 | + ) |
| 1146 | + try: |
| 1147 | + for data in collector: |
| 1148 | + assert data.numel() == 4 |
| 1149 | + finally: |
| 1150 | + collector.shutdown() |
| 1151 | + |
| 1152 | + def test_serial_env_mps_sub_envs_buffers(self): |
| 1153 | + env = SerialEnv(2, self._MPSObsEnv) |
| 1154 | + try: |
| 1155 | + assert env._use_buffers is True |
| 1156 | + rollout = env.rollout(max_steps=3) |
| 1157 | + assert rollout.device.type == "mps" |
| 1158 | + finally: |
| 1159 | + env.close(raise_if_closed=False) |
0 commit comments