Skip to content

Commit da6120e

Browse files
authored
[BugFix] ParallelEnv over MPS envs: default to use_buffers=False, stage pipe data on CPU (#3867)
1 parent 0b09545 commit da6120e

2 files changed

Lines changed: 271 additions & 9 deletions

File tree

test/envs/test_special.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
_td_to_device_mps_safe,
3434
_to_device_mps_safe,
3535
)
36+
from torchrl.envs.env_creator import get_env_metadata
3637
from torchrl.envs.transforms import StepCounter, TransformedEnv
3738
from torchrl.envs.transforms.transforms import Tokenizer
3839
from torchrl.envs.utils import check_env_specs
@@ -999,3 +1000,160 @@ def test_parallel_env_no_buffers_mps_rollout(self):
9991000
assert td["observation"].dtype == torch.float32
10001001
finally:
10011002
env.close(raise_if_closed=False)
1003+
1004+
1005+
_MPS_USE_BUFFERS_WARNING = (
1006+
"The environment specs have leaves on an MPS device, which cannot be placed "
1007+
"in shared memory"
1008+
)
1009+
_MPS_USE_BUFFERS_ERROR = (
1010+
"use_buffers=True is incompatible with environments whose specs have leaves "
1011+
"on an MPS device"
1012+
)
1013+
1014+
1015+
class TestParallelEnvMPSBuffers:
1016+
"""ParallelEnv use_buffers checks for MPS sub-envs (issue #3066).
1017+
1018+
These tests fake the device map reported by the env metadata, so they run
1019+
on CPU-only machines too.
1020+
"""
1021+
1022+
@staticmethod
1023+
def _patch_device_map_to_mps(monkeypatch):
1024+
def get_env_metadata_mps(*args, **kwargs):
1025+
meta_data = get_env_metadata(*args, **kwargs)
1026+
meta_data.device_map = {
1027+
key: torch.device("mps") for key in meta_data.device_map
1028+
}
1029+
return meta_data
1030+
1031+
monkeypatch.setattr(
1032+
"torchrl.envs.batched_envs.get_env_metadata", get_env_metadata_mps
1033+
)
1034+
1035+
def test_parallel_env_mps_leaves_default_use_buffers_false(self, monkeypatch):
1036+
self._patch_device_map_to_mps(monkeypatch)
1037+
with pytest.warns(UserWarning, match=_MPS_USE_BUFFERS_WARNING):
1038+
env = ParallelEnv(2, ContinuousActionVecMockEnv)
1039+
assert env._use_buffers is False
1040+
1041+
def test_parallel_env_mps_leaves_use_buffers_true_raises(self, monkeypatch):
1042+
self._patch_device_map_to_mps(monkeypatch)
1043+
with pytest.raises(RuntimeError, match=_MPS_USE_BUFFERS_ERROR):
1044+
ParallelEnv(2, ContinuousActionVecMockEnv, use_buffers=True)
1045+
1046+
def test_parallel_env_mps_leaves_configure_parallel_raises(self, monkeypatch):
1047+
self._patch_device_map_to_mps(monkeypatch)
1048+
with pytest.warns(UserWarning, match=_MPS_USE_BUFFERS_WARNING):
1049+
env = ParallelEnv(2, ContinuousActionVecMockEnv)
1050+
with pytest.raises(RuntimeError, match=_MPS_USE_BUFFERS_ERROR):
1051+
env.configure_parallel(use_buffers=True)
1052+
1053+
def test_parallel_env_mps_leaves_explicit_use_buffers_false(self, monkeypatch):
1054+
self._patch_device_map_to_mps(monkeypatch)
1055+
env = ParallelEnv(2, ContinuousActionVecMockEnv, use_buffers=False)
1056+
assert env._use_buffers is False
1057+
1058+
def test_serial_env_mps_leaves_keeps_buffers(self, monkeypatch):
1059+
# SerialEnv runs in-process, so MPS buffers are fine there
1060+
self._patch_device_map_to_mps(monkeypatch)
1061+
env = SerialEnv(2, ContinuousActionVecMockEnv)
1062+
assert env._use_buffers is True
1063+
1064+
1065+
@pytest.mark.skipif(not _has_mps(), reason="MPS device not available")
1066+
class TestMPSSubEnvs:
1067+
"""ParallelEnv and collectors over sub-envs living on MPS (issue #3066)."""
1068+
1069+
class _MPSObsEnv(EnvBase):
1070+
"""Minimal env with all spec leaves on MPS.
1071+
1072+
The observation mirrors the last action so that the parent-worker
1073+
round-trip can be checked end-to-end.
1074+
"""
1075+
1076+
def __init__(self, device="mps"):
1077+
super().__init__(device=device)
1078+
self.observation_spec = Composite(
1079+
observation=Unbounded(shape=(3,), device=device), device=device
1080+
)
1081+
self.action_spec = Unbounded(shape=(1,), device=device)
1082+
self.reward_spec = Unbounded(shape=(1,), device=device)
1083+
1084+
def _reset(self, tensordict):
1085+
return TensorDict(
1086+
{"observation": torch.zeros(3, device=self.device)},
1087+
batch_size=[],
1088+
device=self.device,
1089+
)
1090+
1091+
def _step(self, tensordict):
1092+
return TensorDict(
1093+
{
1094+
"observation": tensordict["action"].expand(3).clone(),
1095+
"reward": torch.zeros(1, device=self.device),
1096+
"done": torch.zeros(1, dtype=torch.bool, device=self.device),
1097+
},
1098+
batch_size=[],
1099+
device=self.device,
1100+
)
1101+
1102+
def _set_seed(self, seed):
1103+
return seed
1104+
1105+
def test_parallel_env_mps_sub_envs_default_warns_and_runs(self):
1106+
with pytest.warns(UserWarning, match=_MPS_USE_BUFFERS_WARNING):
1107+
env = ParallelEnv(2, self._MPSObsEnv)
1108+
try:
1109+
assert env._use_buffers is False
1110+
td = env.reset()
1111+
assert td.device.type == "mps"
1112+
assert td["observation"].device.type == "mps"
1113+
policy = RandomPolicy(env.action_spec)
1114+
rollout = env.rollout(max_steps=3, policy=policy)
1115+
assert rollout.device.type == "mps"
1116+
# the worker must have seen the actions sampled in the parent
1117+
assert (rollout["next", "observation"] == rollout["action"]).all()
1118+
finally:
1119+
env.close(raise_if_closed=False)
1120+
1121+
def test_parallel_env_mps_sub_envs_use_buffers_true_raises(self):
1122+
with pytest.raises(RuntimeError, match=_MPS_USE_BUFFERS_ERROR):
1123+
ParallelEnv(2, self._MPSObsEnv, use_buffers=True)
1124+
1125+
@pytest.mark.parametrize("consolidate", [True, False])
1126+
def test_parallel_env_mps_sub_envs_no_buffers_rollout(self, consolidate):
1127+
env = ParallelEnv(
1128+
2, self._MPSObsEnv, use_buffers=False, consolidate=consolidate
1129+
)
1130+
try:
1131+
policy = RandomPolicy(env.action_spec)
1132+
rollout = env.rollout(max_steps=3, policy=policy)
1133+
assert rollout.device.type == "mps"
1134+
assert (rollout["next", "observation"] == rollout["action"]).all()
1135+
finally:
1136+
env.close(raise_if_closed=False)
1137+
1138+
def test_collector_parallel_env_mps_sub_envs(self):
1139+
# the setup reported in issue #3066
1140+
with pytest.warns(UserWarning, match=_MPS_USE_BUFFERS_WARNING):
1141+
collector = Collector(
1142+
lambda: ParallelEnv(2, self._MPSObsEnv),
1143+
frames_per_batch=4,
1144+
total_frames=8,
1145+
)
1146+
try:
1147+
for data in collector:
1148+
assert data.numel() == 4
1149+
finally:
1150+
collector.shutdown()
1151+
1152+
def test_serial_env_mps_sub_envs_buffers(self):
1153+
env = SerialEnv(2, self._MPSObsEnv)
1154+
try:
1155+
assert env._use_buffers is True
1156+
rollout = env.rollout(max_steps=3)
1157+
assert rollout.device.type == "mps"
1158+
finally:
1159+
env.close(raise_if_closed=False)

0 commit comments

Comments
 (0)