|
6 | 6 | import torch._dynamo |
7 | 7 | import torchaudio |
8 | 8 | import logging |
| 9 | +from contextlib import nullcontext |
9 | 10 | from typing import Optional, Callable, Union, Dict |
10 | 11 | from tqdm import tqdm |
11 | 12 | from backend.app.models import GenerationRequest, Job, JobStatus |
@@ -93,6 +94,30 @@ def is_mps_available() -> bool: |
93 | 94 | return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() |
94 | 95 |
|
95 | 96 |
|
| 97 | +def get_autocast_context(device_type: str, dtype: torch.dtype): |
| 98 | + """ |
| 99 | + Get the appropriate autocast context manager for the given device type. |
| 100 | + |
| 101 | + PyTorch's autocast only supports 'cuda', 'cpu', and 'xpu' device types. |
| 102 | + For MPS (Apple Metal), autocast is not supported, so we use a nullcontext (no-op). |
| 103 | + Since MPS pipelines already use float32, no autocast is needed. |
| 104 | + |
| 105 | + Args: |
| 106 | + device_type: Device type string ('cuda', 'cpu', 'mps', 'xpu') |
| 107 | + dtype: Data type for autocast |
| 108 | + |
| 109 | + Returns: |
| 110 | + Context manager for autocast or nullcontext for unsupported devices |
| 111 | + """ |
| 112 | + # torch.autocast doesn't support MPS device type |
| 113 | + # MPS pipelines already use float32, so autocast is not needed |
| 114 | + if device_type == 'mps': |
| 115 | + return nullcontext() |
| 116 | + |
| 117 | + # For supported devices (cuda, cpu, xpu), use autocast |
| 118 | + return torch.autocast(device_type=device_type, dtype=dtype) |
| 119 | + |
| 120 | + |
96 | 121 | def detect_optimal_gpu_config() -> dict: |
97 | 122 | """ |
98 | 123 | Auto-detect the optimal GPU configuration based on available VRAM. |
@@ -705,7 +730,7 @@ def generate_with_callback(inputs, callback=None, **kwargs): |
705 | 730 | bs_size = 2 if cfg_scale != 1.0 else 1 |
706 | 731 | pipeline.mula.setup_caches(bs_size) |
707 | 732 |
|
708 | | - with torch.autocast(device_type=pipeline.mula_device.type, dtype=pipeline.mula_dtype): |
| 733 | + with get_autocast_context(pipeline.mula_device.type, pipeline.mula_dtype): |
709 | 734 | curr_token = pipeline.mula.generate_frame( |
710 | 735 | tokens=prompt_tokens, |
711 | 736 | tokens_mask=prompt_tokens_mask, |
@@ -739,7 +764,7 @@ def _pad_audio_token(token): |
739 | 764 |
|
740 | 765 | for i in tqdm(range(max_audio_frames), desc="Generating audio"): |
741 | 766 | curr_token, curr_token_mask = _pad_audio_token(curr_token) |
742 | | - with torch.autocast(device_type=pipeline.mula_device.type, dtype=pipeline.mula_dtype): |
| 767 | + with get_autocast_context(pipeline.mula_device.type, pipeline.mula_dtype): |
743 | 768 | curr_token = pipeline.mula.generate_frame( |
744 | 769 | tokens=curr_token, |
745 | 770 | tokens_mask=curr_token_mask, |
|
0 commit comments