Skip to content

Commit bde3b78

Browse files
authored
Merge pull request #90 from sarulab-speech/fix-resample-gpu-tensor
Fix device mismatch in torchaudio Resample when input is on CUDA
2 parents 8ba9ff4 + 6921dc1 commit bde3b78

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

utmosv2/_core/model/_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def _prepare_data(
198198
data = torch.from_numpy(data) if isinstance(data, np.ndarray) else data
199199
data = torchaudio.transforms.Resample(
200200
orig_freq=sr, new_freq=self._cfg.sr
201-
)(data)
201+
).to(data.device)(data)
202202
assert data is not None # for mypy
203203
return InMemoryData(
204204
data=data if isinstance(data, np.ndarray) else data.cpu().numpy(),
@@ -290,6 +290,6 @@ def _predict_impl(
290290
with autocast():
291291
output = self._model(*x).squeeze(1)
292292
pred.append(output.cpu().numpy())
293-
res += np.concatenate(pred) / num_repetitions
293+
res += np.concatenate(pred) / num_repetitions # type: ignore
294294
assert isinstance(res, np.ndarray)
295295
return res

0 commit comments

Comments
 (0)