We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 8ba9ff4 + 6921dc1 commit bde3b78Copy full SHA for bde3b78
utmosv2/_core/model/_common.py
@@ -198,7 +198,7 @@ def _prepare_data(
198
data = torch.from_numpy(data) if isinstance(data, np.ndarray) else data
199
data = torchaudio.transforms.Resample(
200
orig_freq=sr, new_freq=self._cfg.sr
201
- )(data)
+ ).to(data.device)(data)
202
assert data is not None # for mypy
203
return InMemoryData(
204
data=data if isinstance(data, np.ndarray) else data.cpu().numpy(),
@@ -290,6 +290,6 @@ def _predict_impl(
290
with autocast():
291
output = self._model(*x).squeeze(1)
292
pred.append(output.cpu().numpy())
293
- res += np.concatenate(pred) / num_repetitions
+ res += np.concatenate(pred) / num_repetitions # type: ignore
294
assert isinstance(res, np.ndarray)
295
return res
0 commit comments