|
1 | 1 | #!/usr/bin/env python3 |
2 | 2 | # |
3 | | -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang, |
4 | | -# Zengwei Yao) |
| 3 | +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, |
| 4 | +# Zengwei Yao) |
5 | 5 | # |
6 | 6 | # See ../../../../LICENSE for clarification regarding multiple authors |
7 | 7 | # |
@@ -540,23 +540,52 @@ def main(): |
540 | 540 | model.to(device) |
541 | 541 | model.load_state_dict(average_checkpoints(filenames, device=device)) |
542 | 542 | else: |
543 | | - assert params.iter == 0 and params.avg > 0 |
544 | | - start = params.epoch - params.avg |
545 | | - assert start >= 1 |
546 | | - filename_start = f"{params.exp_dir}/epoch-{start}.pt" |
547 | | - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" |
548 | | - logging.info( |
549 | | - f"Calculating the averaged model over epoch range from " |
550 | | - f"{start} (excluded) to {params.epoch}" |
551 | | - ) |
552 | | - model.to(device) |
553 | | - model.load_state_dict( |
554 | | - average_checkpoints_with_averaged_model( |
555 | | - filename_start=filename_start, |
556 | | - filename_end=filename_end, |
557 | | - device=device, |
| 543 | + if params.iter > 0: |
| 544 | + filenames = find_checkpoints( |
| 545 | + params.exp_dir, iteration=-params.iter |
| 546 | + )[: params.avg + 1] |
| 547 | + if len(filenames) == 0: |
| 548 | + raise ValueError( |
| 549 | + f"No checkpoints found for" |
| 550 | + f" --iter {params.iter}, --avg {params.avg}" |
| 551 | + ) |
| 552 | + elif len(filenames) < params.avg + 1: |
| 553 | + raise ValueError( |
| 554 | + f"Not enough checkpoints ({len(filenames)}) found for" |
| 555 | + f" --iter {params.iter}, --avg {params.avg}" |
| 556 | + ) |
| 557 | + filename_start = filenames[-1] |
| 558 | + filename_end = filenames[0] |
| 559 | + logging.info( |
| 560 | + "Calculating the averaged model over iteration checkpoints" |
| 561 | + f" from {filename_start} (excluded) to {filename_end}" |
| 562 | + ) |
| 563 | + model.to(device) |
| 564 | + model.load_state_dict( |
| 565 | + average_checkpoints_with_averaged_model( |
| 566 | + filename_start=filename_start, |
| 567 | + filename_end=filename_end, |
| 568 | + device=device, |
| 569 | + ) |
| 570 | + ) |
| 571 | + else: |
| 572 | + assert params.avg > 0 |
| 573 | + start = params.epoch - params.avg |
| 574 | + assert start >= 1 |
| 575 | + filename_start = f"{params.exp_dir}/epoch-{start}.pt" |
| 576 | + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" |
| 577 | + logging.info( |
| 578 | + f"Calculating the averaged model over epoch range from " |
| 579 | + f"{start} (excluded) to {params.epoch}" |
| 580 | + ) |
| 581 | + model.to(device) |
| 582 | + model.load_state_dict( |
| 583 | + average_checkpoints_with_averaged_model( |
| 584 | + filename_start=filename_start, |
| 585 | + filename_end=filename_end, |
| 586 | + device=device, |
| 587 | + ) |
558 | 588 | ) |
559 | | - ) |
560 | 589 |
|
561 | 590 | model.to(device) |
562 | 591 | model.eval() |
|
0 commit comments