Skip to content

Commit d3106cf

Browse files
authored
Fix _get_strided_batch device (#1303)
1 parent 0673925 commit d3106cf

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

lhotse/features/kaldi/layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ def _get_strided_batch(
776776
if npad_right >= 0:
777777
pad_right = torch.flip(waveform[:, -npad_right:], (1,))
778778
else:
779-
pad_right = torch.zeros(0, dtype=waveform.dtype)
779+
pad_right = torch.zeros(0, dtype=waveform.dtype, device=waveform.device)
780780
waveform = torch.cat((pad_left, waveform, pad_right), dim=1)
781781

782782
strides = (

0 commit comments

Comments
 (0)