Skip to content

Supporting PyTorch GPU compatibility on Apple Silicon chips #914

Open
deathcoder/stable-baselines3
#1
@ryanrudes

Description

@ryanrudes

🚀 Feature

PyTorch recently released support for GPU acceleration using the Apple Silicon chips. This should be supported in stable-baselines3 by the "mps" device (I believe).

Minimal Example

from stable_baselines3 import PPO
import gym

env = gym.make("QbertNoFrameskip-v0")
ppo = PPO("CnnPolicy", env, device = "mps")

ppo.learn(total_timesteps = 1000000)

The Mac Silicon GPU device is not automatically recognized by stable-baselines at the moment, so it defaults to "cpu". If you try to force it to use the "mps" device, this stack trace appears.

A.L.E: Arcade Learning Environment (version 0.7.4+069f8bd)
[Powered by Stella]
Traceback (most recent call last):
  File "/Users/ryanrudes/Desktop/pydt/train_min.py", line 7, in <module>
    ppo.learn(total_timesteps = 1000000)
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/stable_baselines3-1.5.1a6-py3.9.egg/stable_baselines3/ppo/ppo.py", line 310, in learn
    return super().learn(
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/stable_baselines3-1.5.1a6-py3.9.egg/stable_baselines3/common/on_policy_algorithm.py", line 247, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/stable_baselines3-1.5.1a6-py3.9.egg/stable_baselines3/common/on_policy_algorithm.py", line 166, in collect_rollouts
    actions, values, log_probs = self.policy(obs_tensor)
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/stable_baselines3-1.5.1a6-py3.9.egg/stable_baselines3/common/policies.py", line 592, in forward
    distribution = self._get_action_dist_from_latent(latent_pi)
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/stable_baselines3-1.5.1a6-py3.9.egg/stable_baselines3/common/policies.py", line 610, in _get_action_dist_from_latent
    return self.action_dist.proba_distribution(action_logits=mean_actions)
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/stable_baselines3-1.5.1a6-py3.9.egg/stable_baselines3/common/distributions.py", line 274, in proba_distribution
    self.distribution = Categorical(logits=action_logits)
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/torch/distributions/categorical.py", line 60, in __init__
    self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
NotImplementedError: Could not run 'aten::amax.out' with arguments from the 'MPS' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::amax.out' is only available for these backends: [Dense, Conjugate, UNKNOWN_TENSOR_TYPE_ID, QuantizedXPU, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, SparseCPU, SparseCUDA, SparseHIP, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, SparseXPU, UNKNOWN_TENSOR_TYPE_ID, SparseVE, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, NestedTensorCUDA, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID].

CPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/build/aten/src/ATen/RegisterCPU.cpp:37386 [kernel]
Meta: registered at /Users/runner/work/pytorch/pytorch/pytorch/build/aten/src/ATen/RegisterMeta.cpp:31637 [kernel]
BackendSelect: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:133 [backend fallback]
Named: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/ConjugateFallback.cpp:18 [backend fallback]
Negative: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
ZeroTensor: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/ADInplaceOrViewType_1.cpp:3288 [kernel]
AutogradOther: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradCPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradCUDA: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
UNKNOWN_TENSOR_TYPE_ID: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradXLA: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradMPS: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradIPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradXPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradHPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
UNKNOWN_TENSOR_TYPE_ID: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradLazy: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradPrivateUse1: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradPrivateUse2: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradPrivateUse3: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
Tracer: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/TraceType_1.cpp:11951 [kernel]
AutocastCPU: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/autocast_mode.cpp:481 [backend fallback]
Autocast: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/autocast_mode.cpp:324 [backend fallback]
Batched: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/BatchingRegistrations.cpp:1064 [backend fallback]
VmapMode: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
Functionalize: registered at /Users/runner/work/pytorch/pytorch/pytorch/build/aten/src/ATen/RegisterFunctionalization_3.cpp:12118 [kernel]
PythonTLSSnapshot: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:137 [backend fallback

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions