Skip to content

Commit 0904e49

Browse files
authored
Fix gigaspeech dataset iterator. (k2-fsa#2045)
Previously, it was reset after every epoch, which may cause it to always use the first part of the gigaspeech dataset if you choose a small --giga-prob.
1 parent 693f069 commit 0904e49

File tree

4 files changed

+32
-27
lines changed
  • egs/librispeech/ASR

4 files changed

+32
-27
lines changed

egs/librispeech/ASR/lstm_transducer_stateless2/train.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
import warnings
5454
from pathlib import Path
5555
from shutil import copyfile
56-
from typing import Any, Dict, Optional, Tuple, Union
56+
from typing import Any, Dict, Iterator, Optional, Tuple, Union
5757

5858
import k2
5959
import optim
@@ -770,7 +770,7 @@ def train_one_epoch(
770770
scheduler: LRSchedulerType,
771771
sp: spm.SentencePieceProcessor,
772772
train_dl: torch.utils.data.DataLoader,
773-
giga_train_dl: torch.utils.data.DataLoader,
773+
iter_giga: Iterator,
774774
valid_dl: torch.utils.data.DataLoader,
775775
rng: random.Random,
776776
scaler: "GradScaler",
@@ -826,7 +826,6 @@ def train_one_epoch(
826826
dl_weights = [1 - params.giga_prob, params.giga_prob]
827827

828828
iter_libri = iter(train_dl)
829-
iter_giga = iter(giga_train_dl)
830829

831830
batch_idx = 0
832831

@@ -1177,6 +1176,8 @@ def run(rank, world_size, args):
11771176
else:
11781177
logging.info("Skip scan_pessimistic_batches_for_oom")
11791178

1179+
iter_giga = iter(giga_train_dl)
1180+
11801181
scaler = create_grad_scaler(enabled=params.use_fp16)
11811182
if checkpoints and "grad_scaler" in checkpoints:
11821183
logging.info("Loading grad scaler state dict")
@@ -1200,7 +1201,7 @@ def run(rank, world_size, args):
12001201
scheduler=scheduler,
12011202
sp=sp,
12021203
train_dl=train_dl,
1203-
giga_train_dl=giga_train_dl,
1204+
iter_giga=iter_giga,
12041205
valid_dl=valid_dl,
12051206
rng=rng,
12061207
scaler=scaler,

egs/librispeech/ASR/pruned_transducer_stateless3/train.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
import warnings
5454
from pathlib import Path
5555
from shutil import copyfile
56-
from typing import Any, Dict, Optional, Tuple, Union
56+
from typing import Any, Dict, Iterator, Optional, Tuple, Union
5757

5858
import k2
5959
import optim
@@ -753,7 +753,7 @@ def train_one_epoch(
753753
scheduler: LRSchedulerType,
754754
sp: spm.SentencePieceProcessor,
755755
train_dl: torch.utils.data.DataLoader,
756-
giga_train_dl: torch.utils.data.DataLoader,
756+
iter_giga: Iterator,
757757
valid_dl: torch.utils.data.DataLoader,
758758
rng: random.Random,
759759
scaler: "GradScaler",
@@ -806,7 +806,6 @@ def train_one_epoch(
806806
dl_weights = [1 - params.giga_prob, params.giga_prob]
807807

808808
iter_libri = iter(train_dl)
809-
iter_giga = iter(giga_train_dl)
810809

811810
batch_idx = 0
812811

@@ -950,9 +949,9 @@ def remove_short_and_long_utt(c: Cut):
950949
# an utterance duration distribution for your dataset to select
951950
# the threshold
952951
if c.duration < 1.0 or c.duration > 20.0:
953-
logging.warning(
954-
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
955-
)
952+
# logging.warning(
953+
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
954+
# )
956955
return False
957956

958957
# In pruned RNN-T, we require that T >= S
@@ -965,14 +964,14 @@ def remove_short_and_long_utt(c: Cut):
965964
tokens = sp.encode(c.supervisions[0].text, out_type=str)
966965

967966
if T < len(tokens):
968-
logging.warning(
969-
f"Exclude cut with ID {c.id} from training. "
970-
f"Number of frames (before subsampling): {c.num_frames}. "
971-
f"Number of frames (after subsampling): {T}. "
972-
f"Text: {c.supervisions[0].text}. "
973-
f"Tokens: {tokens}. "
974-
f"Number of tokens: {len(tokens)}"
975-
)
967+
# logging.warning(
968+
# f"Exclude cut with ID {c.id} from training. "
969+
# f"Number of frames (before subsampling): {c.num_frames}. "
970+
# f"Number of frames (after subsampling): {T}. "
971+
# f"Text: {c.supervisions[0].text}. "
972+
# f"Tokens: {tokens}. "
973+
# f"Number of tokens: {len(tokens)}"
974+
# )
976975
return False
977976

978977
return True
@@ -1117,6 +1116,8 @@ def run(rank, world_size, args):
11171116
# It's time consuming to include `giga_train_dl` here
11181117
# for dl in [train_dl, giga_train_dl]:
11191118
for dl in [train_dl]:
1119+
# You can skip scan_pessimistic_batches_for_oom() if you are sure
1120+
# your selected params won't cause OOM
11201121
if params.start_batch <= 0:
11211122
scan_pessimistic_batches_for_oom(
11221123
model=model,
@@ -1127,6 +1128,8 @@ def run(rank, world_size, args):
11271128
warmup=0.0 if params.start_epoch == 0 else 1.0,
11281129
)
11291130

1131+
iter_giga = iter(giga_train_dl)
1132+
11301133
scaler = create_grad_scaler(enabled=params.use_fp16)
11311134
if checkpoints and "grad_scaler" in checkpoints:
11321135
logging.info("Loading grad scaler state dict")
@@ -1149,7 +1152,7 @@ def run(rank, world_size, args):
11491152
scheduler=scheduler,
11501153
sp=sp,
11511154
train_dl=train_dl,
1152-
giga_train_dl=giga_train_dl,
1155+
iter_giga=iter_giga,
11531156
valid_dl=valid_dl,
11541157
rng=rng,
11551158
scaler=scaler,

egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
import warnings
5151
from pathlib import Path
5252
from shutil import copyfile
53-
from typing import Any, Dict, Optional, Tuple, Union
53+
from typing import Any, Dict, Iterator, Optional, Tuple, Union
5454

5555
import k2
5656
import optim
@@ -798,7 +798,7 @@ def train_one_epoch(
798798
scheduler: LRSchedulerType,
799799
sp: spm.SentencePieceProcessor,
800800
train_dl: torch.utils.data.DataLoader,
801-
giga_train_dl: torch.utils.data.DataLoader,
801+
iter_giga: Iterator,
802802
valid_dl: torch.utils.data.DataLoader,
803803
rng: random.Random,
804804
scaler: "GradScaler",
@@ -849,7 +849,6 @@ def train_one_epoch(
849849
# This sets the probabilities for choosing which datasets
850850
dl_weights = [1 - params.giga_prob, params.giga_prob]
851851
iter_libri = iter(train_dl)
852-
iter_giga = iter(giga_train_dl)
853852

854853
batch_idx = 0
855854

@@ -1223,6 +1222,7 @@ def run(rank, world_size, args):
12231222
# sp=sp,
12241223
# params=params,
12251224
# )
1225+
iter_giga = iter(giga_train_dl)
12261226

12271227
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
12281228
if checkpoints and "grad_scaler" in checkpoints:
@@ -1247,7 +1247,7 @@ def run(rank, world_size, args):
12471247
scheduler=scheduler,
12481248
sp=sp,
12491249
train_dl=train_dl,
1250-
giga_train_dl=giga_train_dl,
1250+
iter_giga=iter_giga,
12511251
valid_dl=valid_dl,
12521252
rng=rng,
12531253
scaler=scaler,

egs/librispeech/ASR/pruned_transducer_stateless8/train.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
import warnings
5656
from pathlib import Path
5757
from shutil import copyfile
58-
from typing import Any, Dict, Optional, Tuple, Union
58+
from typing import Any, Dict, Iterator, Optional, Tuple, Union
5959

6060
import k2
6161
import optim
@@ -793,7 +793,7 @@ def train_one_epoch(
793793
scheduler: LRSchedulerType,
794794
sp: spm.SentencePieceProcessor,
795795
train_dl: torch.utils.data.DataLoader,
796-
giga_train_dl: torch.utils.data.DataLoader,
796+
iter_giga: Iterator,
797797
valid_dl: torch.utils.data.DataLoader,
798798
rng: random.Random,
799799
scaler: "GradScaler",
@@ -849,7 +849,6 @@ def train_one_epoch(
849849
dl_weights = [1 - params.giga_prob, params.giga_prob]
850850

851851
iter_libri = iter(train_dl)
852-
iter_giga = iter(giga_train_dl)
853852

854853
batch_idx = 0
855854

@@ -1225,6 +1224,8 @@ def run(rank, world_size, args):
12251224
params=params,
12261225
)
12271226

1227+
iter_giga = iter(giga_train_dl)
1228+
12281229
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
12291230
if checkpoints and "grad_scaler" in checkpoints:
12301231
logging.info("Loading grad scaler state dict")
@@ -1248,7 +1249,7 @@ def run(rank, world_size, args):
12481249
scheduler=scheduler,
12491250
sp=sp,
12501251
train_dl=train_dl,
1251-
giga_train_dl=giga_train_dl,
1252+
iter_giga=iter_giga,
12521253
valid_dl=valid_dl,
12531254
rng=rng,
12541255
scaler=scaler,

0 commit comments

Comments
 (0)