Skip to content

Commit f2b78ef

Browse files
young01ai周志洋
and
周志洋
authored
fix roformer models run on cuda (#84)
* fix roformer models run on cuda * fix roformer models run on cpu/cuda/mps --------- Co-authored-by: 周志洋 <[email protected]>
1 parent 588a82f commit f2b78ef

File tree

4 files changed

+53
-45
lines changed

4 files changed

+53
-45
lines changed

audio_separator/separator/uvr_lib_v5/attend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,11 @@ def flash_attn(self, q, k, v):
7070

7171
config = self.cuda_config if is_cuda else self.cpu_config
7272

73-
# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
73+
# sdpa_flash kernel only supports float16 on sm80+ architecture gpu
74+
if is_cuda and q.dtype != torch.float16:
75+
config = FlashAttentionConfig(False, True, True)
7476

77+
# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
7578
with torch.backends.cuda.sdp_kernel(**config._asdict()):
7679
out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout if self.training else 0.0)
7780

audio_separator/separator/uvr_lib_v5/bs_roformer.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -457,11 +457,10 @@ def forward(
457457
"""
458458

459459
original_device = raw_audio.device
460-
461460
x_is_mps = True if original_device.type == 'mps' else False
462461

463-
if x_is_mps:
464-
raw_audio = raw_audio.cpu()
462+
# if x_is_mps:
463+
# raw_audio = raw_audio.cpu()
465464

466465
device = raw_audio.device
467466

@@ -517,13 +516,11 @@ def forward(
517516

518517
x = self.final_norm(x)
519518

520-
num_stems = len(self.mask_estimators)
521-
522519
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
523520
mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
524521

525-
if x_is_mps:
526-
mask = mask.to('cpu')
522+
# if x_is_mps:
523+
# mask = mask.to('cpu')
527524

528525
# modulate frequency representation
529526

@@ -540,11 +537,14 @@ def forward(
540537

541538
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
542539

543-
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False)
540+
recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr,
541+
**self.stft_kwargs,
542+
window=stft_window.cpu() if x_is_mps else stft_window,
543+
return_complex=False).to(device)
544544

545-
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
545+
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=self.num_stems)
546546

547-
if num_stems == 1:
547+
if self.num_stems == 1:
548548
recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
549549

550550
# if a target is passed in, calculate loss for learning
@@ -585,15 +585,15 @@ def forward(
585585

586586
if not return_loss_breakdown:
587587
# Move the result back to the original device if it was moved to CPU for MPS compatibility
588-
if x_is_mps:
589-
total_loss = total_loss.to(original_device)
588+
# if x_is_mps:
589+
# total_loss = total_loss.to(original_device)
590590
return total_loss
591591

592592
# For detailed loss breakdown, ensure all components are moved back to the original device for MPS
593-
if x_is_mps:
594-
loss = loss.to(original_device)
595-
multi_stft_resolution_loss = multi_stft_resolution_loss.to(original_device)
596-
weighted_multi_resolution_loss = weighted_multi_resolution_loss.to(original_device)
593+
# if x_is_mps:
594+
# loss = loss.to(original_device)
595+
# multi_stft_resolution_loss = multi_stft_resolution_loss.to(original_device)
596+
# weighted_multi_resolution_loss = weighted_multi_resolution_loss.to(original_device)
597597

598598
return total_loss, (loss, multi_stft_resolution_loss)
599599

audio_separator/separator/uvr_lib_v5/mel_band_roformer.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -390,8 +390,8 @@ def forward(
390390
original_device = raw_audio.device
391391
x_is_mps = True if original_device.type == 'mps' else False
392392

393-
if x_is_mps:
394-
raw_audio = raw_audio.cpu()
393+
# if x_is_mps:
394+
# raw_audio = raw_audio.cpu()
395395

396396
device = raw_audio.device
397397

@@ -418,7 +418,8 @@ def forward(
418418

419419
batch_arange = torch.arange(batch, device=device)[..., None]
420420

421-
x = stft_repr[batch_arange, self.freq_indices.cpu()] if x_is_mps else stft_repr[batch_arange, self.freq_indices]
421+
# x = stft_repr[batch_arange, self.freq_indices.cpu()] if x_is_mps else stft_repr[batch_arange, self.freq_indices]
422+
x = stft_repr[batch_arange, self.freq_indices]
422423

423424
x = rearrange(x, 'b f t c -> b t (f c)')
424425

@@ -438,12 +439,10 @@ def forward(
438439

439440
x, = unpack(x, ps, '* f d')
440441

441-
num_stems = len(self.mask_estimators)
442-
443442
masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
444443
masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
445-
if x_is_mps:
446-
masks = masks.cpu()
444+
# if x_is_mps:
445+
# masks = masks.cpu()
447446

448447
stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
449448

@@ -452,29 +451,35 @@ def forward(
452451

453452
masks = masks.type(stft_repr.dtype)
454453

455-
if x_is_mps:
456-
scatter_indices = repeat(self.freq_indices.cpu(), 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
457-
else:
458-
scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
459-
stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
460-
masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
454+
# if x_is_mps:
455+
# scatter_indices = repeat(self.freq_indices.cpu(), 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
456+
# else:
457+
# scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
458+
scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=self.num_stems, t=stft_repr.shape[-1])
459+
stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=self.num_stems)
460+
masks_summed = torch.zeros_like(stft_repr_expanded_stems.cpu() if x_is_mps else stft_repr_expanded_stems
461+
).scatter_add_(2, scatter_indices.cpu() if x_is_mps else scatter_indices,
462+
masks.cpu() if x_is_mps else masks).to(device)
461463

462464
denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
463-
if x_is_mps:
464-
denom = denom.cpu()
465+
# if x_is_mps:
466+
# denom = denom.cpu()
465467

466468
masks_averaged = masks_summed / denom.clamp(min=1e-8)
467469

468470
stft_repr = stft_repr * masks_averaged
469471

470472
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
471473

472-
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
474+
recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr,
475+
**self.stft_kwargs,
476+
window=stft_window.cpu() if x_is_mps else stft_window,
477+
return_complex=False,
473478
length=istft_length)
474479

475-
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
480+
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=self.num_stems)
476481

477-
if num_stems == 1:
482+
if self.num_stems == 1:
478483
recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
479484

480485
if not exists(target):
@@ -512,17 +517,17 @@ def forward(
512517

513518

514519
# Move the total loss back to the original device if necessary
515-
if x_is_mps:
516-
total_loss = total_loss.to(original_device)
520+
# if x_is_mps:
521+
# total_loss = total_loss.to(original_device)
517522

518-
if not return_loss_breakdown:
519-
return total_loss
523+
# if not return_loss_breakdown:
524+
# return total_loss
520525

521526
# If detailed loss breakdown is requested, ensure all components are on the original device
522-
return total_loss, (loss.to(original_device) if x_is_mps else loss,
523-
multi_stft_resolution_loss.to(original_device) if x_is_mps else multi_stft_resolution_loss)
527+
# return total_loss, (loss.to(original_device) if x_is_mps else loss,
528+
# multi_stft_resolution_loss.to(original_device) if x_is_mps else multi_stft_resolution_loss)
524529

525-
# if not return_loss_breakdown:
526-
# return total_loss
530+
if not return_loss_breakdown:
531+
return total_loss
527532

528-
# return total_loss, (loss, multi_stft_resolution_loss)
533+
return total_loss, (loss, multi_stft_resolution_loss)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
44

55
[tool.poetry]
66
name = "audio-separator"
7-
version = "0.17.4"
7+
version = "0.17.5"
88
description = "Easy to use audio stem separation, using various models from UVR trained primarily by @Anjok07"
99
authors = ["Andrew Beveridge <[email protected]>"]
1010
license = "MIT"

0 commit comments

Comments
 (0)