Skip to content

Commit ed752fb

Browse files
authored
Merge pull request #12 from audiohacking/copilot/fix-mps-device-type-issue
Fix autocast for MPS devices in music generation pipeline
2 parents 5f7dafc + a1ebee3 commit ed752fb

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

backend/app/services/music_service.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch._dynamo
77
import torchaudio
88
import logging
9+
from contextlib import nullcontext
910
from typing import Optional, Callable, Union, Dict
1011
from tqdm import tqdm
1112
from backend.app.models import GenerationRequest, Job, JobStatus
@@ -93,6 +94,30 @@ def is_mps_available() -> bool:
9394
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
9495

9596

97+
def get_autocast_context(device_type: str, dtype: torch.dtype):
98+
"""
99+
Get the appropriate autocast context manager for the given device type.
100+
101+
PyTorch's autocast only supports 'cuda', 'cpu', and 'xpu' device types.
102+
For MPS (Apple Metal), autocast is not supported, so we use a nullcontext (no-op).
103+
Since MPS pipelines already use float32, no autocast is needed.
104+
105+
Args:
106+
device_type: Device type string ('cuda', 'cpu', 'mps', 'xpu')
107+
dtype: Data type for autocast
108+
109+
Returns:
110+
Context manager for autocast or nullcontext for unsupported devices
111+
"""
112+
# torch.autocast doesn't support MPS device type
113+
# MPS pipelines already use float32, so autocast is not needed
114+
if device_type == 'mps':
115+
return nullcontext()
116+
117+
# For supported devices (cuda, cpu, xpu), use autocast
118+
return torch.autocast(device_type=device_type, dtype=dtype)
119+
120+
96121
def detect_optimal_gpu_config() -> dict:
97122
"""
98123
Auto-detect the optimal GPU configuration based on available VRAM.
@@ -705,7 +730,7 @@ def generate_with_callback(inputs, callback=None, **kwargs):
705730
bs_size = 2 if cfg_scale != 1.0 else 1
706731
pipeline.mula.setup_caches(bs_size)
707732

708-
with torch.autocast(device_type=pipeline.mula_device.type, dtype=pipeline.mula_dtype):
733+
with get_autocast_context(pipeline.mula_device.type, pipeline.mula_dtype):
709734
curr_token = pipeline.mula.generate_frame(
710735
tokens=prompt_tokens,
711736
tokens_mask=prompt_tokens_mask,
@@ -739,7 +764,7 @@ def _pad_audio_token(token):
739764

740765
for i in tqdm(range(max_audio_frames), desc="Generating audio"):
741766
curr_token, curr_token_mask = _pad_audio_token(curr_token)
742-
with torch.autocast(device_type=pipeline.mula_device.type, dtype=pipeline.mula_dtype):
767+
with get_autocast_context(pipeline.mula_device.type, pipeline.mula_dtype):
743768
curr_token = pipeline.mula.generate_frame(
744769
tokens=curr_token,
745770
tokens_mask=curr_token_mask,

0 commit comments

Comments
 (0)