@@ -390,8 +390,8 @@ def forward(
390
390
original_device = raw_audio .device
391
391
x_is_mps = True if original_device .type == 'mps' else False
392
392
393
- if x_is_mps :
394
- raw_audio = raw_audio .cpu ()
393
+ # if x_is_mps:
394
+ # raw_audio = raw_audio.cpu()
395
395
396
396
device = raw_audio .device
397
397
@@ -418,7 +418,8 @@ def forward(
418
418
419
419
batch_arange = torch .arange (batch , device = device )[..., None ]
420
420
421
- x = stft_repr [batch_arange , self .freq_indices .cpu ()] if x_is_mps else stft_repr [batch_arange , self .freq_indices ]
421
+ # x = stft_repr[batch_arange, self.freq_indices.cpu()] if x_is_mps else stft_repr[batch_arange, self.freq_indices]
422
+ x = stft_repr [batch_arange , self .freq_indices ]
422
423
423
424
x = rearrange (x , 'b f t c -> b t (f c)' )
424
425
@@ -438,12 +439,10 @@ def forward(
438
439
439
440
x , = unpack (x , ps , '* f d' )
440
441
441
- num_stems = len (self .mask_estimators )
442
-
443
442
masks = torch .stack ([fn (x ) for fn in self .mask_estimators ], dim = 1 )
444
443
masks = rearrange (masks , 'b n t (f c) -> b n f t c' , c = 2 )
445
- if x_is_mps :
446
- masks = masks .cpu ()
444
+ # if x_is_mps:
445
+ # masks = masks.cpu()
447
446
448
447
stft_repr = rearrange (stft_repr , 'b f t c -> b 1 f t c' )
449
448
@@ -452,29 +451,35 @@ def forward(
452
451
453
452
masks = masks .type (stft_repr .dtype )
454
453
455
- if x_is_mps :
456
- scatter_indices = repeat (self .freq_indices .cpu (), 'f -> b n f t' , b = batch , n = num_stems , t = stft_repr .shape [- 1 ])
457
- else :
458
- scatter_indices = repeat (self .freq_indices , 'f -> b n f t' , b = batch , n = num_stems , t = stft_repr .shape [- 1 ])
459
- stft_repr_expanded_stems = repeat (stft_repr , 'b 1 ... -> b n ...' , n = num_stems )
460
- masks_summed = torch .zeros_like (stft_repr_expanded_stems ).scatter_add_ (2 , scatter_indices , masks )
454
+ # if x_is_mps:
455
+ # scatter_indices = repeat(self.freq_indices.cpu(), 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
456
+ # else:
457
+ # scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
458
+ scatter_indices = repeat (self .freq_indices , 'f -> b n f t' , b = batch , n = self .num_stems , t = stft_repr .shape [- 1 ])
459
+ stft_repr_expanded_stems = repeat (stft_repr , 'b 1 ... -> b n ...' , n = self .num_stems )
460
+ masks_summed = torch .zeros_like (stft_repr_expanded_stems .cpu () if x_is_mps else stft_repr_expanded_stems
461
+ ).scatter_add_ (2 , scatter_indices .cpu () if x_is_mps else scatter_indices ,
462
+ masks .cpu () if x_is_mps else masks ).to (device )
461
463
462
464
denom = repeat (self .num_bands_per_freq , 'f -> (f r) 1' , r = channels )
463
- if x_is_mps :
464
- denom = denom .cpu ()
465
+ # if x_is_mps:
466
+ # denom = denom.cpu()
465
467
466
468
masks_averaged = masks_summed / denom .clamp (min = 1e-8 )
467
469
468
470
stft_repr = stft_repr * masks_averaged
469
471
470
472
stft_repr = rearrange (stft_repr , 'b n (f s) t -> (b n s) f t' , s = self .audio_channels )
471
473
472
- recon_audio = torch .istft (stft_repr , ** self .stft_kwargs , window = stft_window , return_complex = False ,
474
+ recon_audio = torch .istft (stft_repr .cpu () if x_is_mps else stft_repr ,
475
+ ** self .stft_kwargs ,
476
+ window = stft_window .cpu () if x_is_mps else stft_window ,
477
+ return_complex = False ,
473
478
length = istft_length )
474
479
475
- recon_audio = rearrange (recon_audio , '(b n s) t -> b n s t' , b = batch , s = self .audio_channels , n = num_stems )
480
+ recon_audio = rearrange (recon_audio , '(b n s) t -> b n s t' , b = batch , s = self .audio_channels , n = self . num_stems )
476
481
477
- if num_stems == 1 :
482
+ if self . num_stems == 1 :
478
483
recon_audio = rearrange (recon_audio , 'b 1 s t -> b s t' )
479
484
480
485
if not exists (target ):
@@ -512,17 +517,17 @@ def forward(
512
517
513
518
514
519
# Move the total loss back to the original device if necessary
515
- if x_is_mps :
516
- total_loss = total_loss .to (original_device )
520
+ # if x_is_mps:
521
+ # total_loss = total_loss.to(original_device)
517
522
518
- if not return_loss_breakdown :
519
- return total_loss
523
+ # if not return_loss_breakdown:
524
+ # return total_loss
520
525
521
526
# If detailed loss breakdown is requested, ensure all components are on the original device
522
- return total_loss , (loss .to (original_device ) if x_is_mps else loss ,
523
- multi_stft_resolution_loss .to (original_device ) if x_is_mps else multi_stft_resolution_loss )
527
+ # return total_loss, (loss.to(original_device) if x_is_mps else loss,
528
+ # multi_stft_resolution_loss.to(original_device) if x_is_mps else multi_stft_resolution_loss)
524
529
525
- # if not return_loss_breakdown:
526
- # return total_loss
530
+ if not return_loss_breakdown :
531
+ return total_loss
527
532
528
- # return total_loss, (loss, multi_stft_resolution_loss)
533
+ return total_loss , (loss , multi_stft_resolution_loss )
0 commit comments