Skip to content

Commit cf43a31

Browse files
authored
Support decoding with averaged model when using --iter (k2-fsa#353)
* support decoding with averaged model when using --iter * minor fix * monir fix of copyright date
1 parent d791e40 commit cf43a31

File tree

2 files changed

+50
-21
lines changed

2 files changed

+50
-21
lines changed

egs/librispeech/ASR/pruned_transducer_stateless4/decode.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22
#
3-
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang,
4-
# Zengwei Yao)
3+
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
4+
# Zengwei Yao)
55
#
66
# See ../../../../LICENSE for clarification regarding multiple authors
77
#
@@ -540,23 +540,52 @@ def main():
540540
model.to(device)
541541
model.load_state_dict(average_checkpoints(filenames, device=device))
542542
else:
543-
assert params.iter == 0 and params.avg > 0
544-
start = params.epoch - params.avg
545-
assert start >= 1
546-
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
547-
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
548-
logging.info(
549-
f"Calculating the averaged model over epoch range from "
550-
f"{start} (excluded) to {params.epoch}"
551-
)
552-
model.to(device)
553-
model.load_state_dict(
554-
average_checkpoints_with_averaged_model(
555-
filename_start=filename_start,
556-
filename_end=filename_end,
557-
device=device,
543+
if params.iter > 0:
544+
filenames = find_checkpoints(
545+
params.exp_dir, iteration=-params.iter
546+
)[: params.avg + 1]
547+
if len(filenames) == 0:
548+
raise ValueError(
549+
f"No checkpoints found for"
550+
f" --iter {params.iter}, --avg {params.avg}"
551+
)
552+
elif len(filenames) < params.avg + 1:
553+
raise ValueError(
554+
f"Not enough checkpoints ({len(filenames)}) found for"
555+
f" --iter {params.iter}, --avg {params.avg}"
556+
)
557+
filename_start = filenames[-1]
558+
filename_end = filenames[0]
559+
logging.info(
560+
"Calculating the averaged model over iteration checkpoints"
561+
f" from {filename_start} (excluded) to {filename_end}"
562+
)
563+
model.to(device)
564+
model.load_state_dict(
565+
average_checkpoints_with_averaged_model(
566+
filename_start=filename_start,
567+
filename_end=filename_end,
568+
device=device,
569+
)
570+
)
571+
else:
572+
assert params.avg > 0
573+
start = params.epoch - params.avg
574+
assert start >= 1
575+
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
576+
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
577+
logging.info(
578+
f"Calculating the averaged model over epoch range from "
579+
f"{start} (excluded) to {params.epoch}"
580+
)
581+
model.to(device)
582+
model.load_state_dict(
583+
average_checkpoints_with_averaged_model(
584+
filename_start=filename_start,
585+
filename_end=filename_end,
586+
device=device,
587+
)
558588
)
559-
)
560589

561590
model.to(device)
562591
model.eval()

icefall/checkpoint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
2-
# Zengwei Yao)
1+
# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang,
2+
# Zengwei Yao)
33
#
44
# See ../../LICENSE for clarification regarding multiple authors
55
#
@@ -405,7 +405,7 @@ def average_checkpoints_with_averaged_model(
405405
(3) avg = (model_end + model_start * (weight_start / weight_end))
406406
* weight_end
407407
408-
The model index could be epoch number or checkpoint number.
408+
The model index could be epoch number or iteration number.
409409
410410
Args:
411411
filename_start:

0 commit comments

Comments
 (0)