@@ -492,7 +492,7 @@ def compute_progress(dir, iter, run_opts):
492492def combine_models (dir , num_iters , models_to_combine , num_chunk_per_minibatch_str ,
493493 egs_dir , leaky_hmm_coefficient , l2_regularize ,
494494 xent_regularize , run_opts ,
495- sum_to_one_penalty = 0.0 ):
495+ max_objective_evaluations = 30 ):
496496 """ Function to do model combination
497497
498498 In the nnet3 setup, the logic
@@ -505,9 +505,6 @@ def combine_models(dir, num_iters, models_to_combine, num_chunk_per_minibatch_st
505505
506506 models_to_combine .add (num_iters )
507507
508- # TODO: if it turns out the sum-to-one-penalty code is not useful,
509- # remove support for it.
510-
511508 for iter in sorted (models_to_combine ):
512509 model_file = '{0}/{1}.mdl' .format (dir , iter )
513510 if os .path .exists (model_file ):
@@ -528,12 +525,9 @@ def combine_models(dir, num_iters, models_to_combine, num_chunk_per_minibatch_st
528525
529526 common_lib .execute_command (
530527 """{command} {combine_queue_opt} {dir}/log/combine.log \
531- nnet3-chain-combine --num-iters={opt_iters} \
528+ nnet3-chain-combine \
529+ --max-objective-evaluations={max_objective_evaluations} \
532530 --l2-regularize={l2} --leaky-hmm-coefficient={leaky} \
533- --separate-weights-per-component={separate_weights} \
534- --enforce-sum-to-one={hard_enforce} \
535- --sum-to-one-penalty={penalty} \
536- --enforce-positive-weights=true \
537531 --verbose=3 {dir}/den.fst {raw_models} \
538532 "ark,bg:nnet3-chain-copy-egs ark:{egs_dir}/combine.cegs ark:- | \
539533 nnet3-chain-merge-egs --minibatch-size={num_chunk_per_mb} \
@@ -542,12 +536,9 @@ def combine_models(dir, num_iters, models_to_combine, num_chunk_per_minibatch_st
542536 {dir}/final.mdl""" .format (
543537 command = run_opts .command ,
544538 combine_queue_opt = run_opts .combine_queue_opt ,
545- opt_iters = (20 if sum_to_one_penalty <= 0 else 80 ),
546- separate_weights = (sum_to_one_penalty > 0 ),
539+ max_objective_evaluations = max_objective_evaluations ,
547540 l2 = l2_regularize , leaky = leaky_hmm_coefficient ,
548541 dir = dir , raw_models = " " .join (raw_model_strings ),
549- hard_enforce = (sum_to_one_penalty <= 0 ),
550- penalty = sum_to_one_penalty ,
551542 num_chunk_per_mb = num_chunk_per_minibatch_str ,
552543 num_iters = num_iters ,
553544 egs_dir = egs_dir ))
0 commit comments