@@ -946,24 +946,18 @@ def learn_on_processed_samples(self) -> ResultDict:
946946 self .batches_to_place_on_learner .clear ()
947947 # If there are no learner workers and learning is directly on the driver
948948 # Then we can't do async updates, so we need to block.
949- blocking = self .config .num_learner_workers == 0
949+ async_update = self .config .num_learner_workers > 0
950950 results = []
951951 for batch in batches :
952- if blocking :
953- result = self .learner_group .update (
954- batch ,
955- reduce_fn = _reduce_impala_results ,
956- num_iters = self .config .num_sgd_iter ,
957- minibatch_size = self .config .minibatch_size ,
958- )
952+ result = self .learner_group .update_from_batch (
953+ batch = batch ,
954+ async_update = async_update ,
955+ reduce_fn = _reduce_impala_results ,
956+ num_iters = self .config .num_sgd_iter ,
957+ minibatch_size = self .config .minibatch_size ,
958+ )
959+ if not async_update :
959960 results = [result ]
960- else :
961- results = self .learner_group .async_update (
962- batch ,
963- reduce_fn = _reduce_impala_results ,
964- num_iters = self .config .num_sgd_iter ,
965- minibatch_size = self .config .minibatch_size ,
966- )
967961
968962 for r in results :
969963 self ._counters [NUM_ENV_STEPS_TRAINED ] += r [ALL_MODULES ].pop (
@@ -973,14 +967,14 @@ def learn_on_processed_samples(self) -> ResultDict:
973967 NUM_AGENT_STEPS_TRAINED
974968 )
975969
976- self ._counters .update (self .learner_group .get_in_queue_stats ())
970+ self ._counters .update (self .learner_group .get_stats ())
977971 # If there are results, reduce-mean over each individual value and return.
978972 if results :
979973 return tree .map_structure (lambda * x : np .mean (x ), * results )
980974
981975 # Nothing on the queue -> Don't send requests to learner group
982- # or no results ready (from previous `self.learner_group.update ()` calls) for
983- # reducing.
976+ # or no results ready (from previous `self.learner_group.update_from_batch ()`
977+ # calls) for reducing.
984978 return {}
985979
986980 def place_processed_samples_on_learner_thread_queue (self ) -> None :
0 commit comments