@@ -363,6 +363,8 @@ def __init__(self, algo_class=None):
363363 self .grad_clip = None
364364 self .grad_clip_by = "global_norm"
365365 self .train_batch_size = 32
366+ # Simple logic for now: If None, use `train_batch_size`.
367+ self .train_batch_size_per_learner = None
366368 # TODO (sven): Unsolved problem with RLModules sometimes requiring settings from
367369 # the main AlgorithmConfig. We should not require the user to provide those
368370 # settings in both, the AlgorithmConfig (as property) AND the model config
@@ -871,6 +873,7 @@ def build_env_to_module_connector(self, env):
871873 return pipeline
872874
873875 def build_module_to_env_connector (self , env ):
876+
874877 from ray .rllib .connectors .module_to_env import (
875878 DefaultModuleToEnv ,
876879 ModuleToEnvPipeline ,
@@ -1333,11 +1336,11 @@ def environment(
13331336 Tuple[value1, value2]: Clip at value1 and value2.
13341337 normalize_actions: If True, RLlib will learn entirely inside a normalized
13351338 action space (0.0 centered with small stddev; only affecting Box
1336- components). We will unsquash actions (and clip, just in case) to the
1339+ components). RLlib will unsquash actions (and clip, just in case) to the
13371340 bounds of the env's action space before sending actions back to the env.
1338- clip_actions: If True, RLlib will clip actions according to the env's bounds
1339- before sending them back to the env.
1340- TODO: (sven) This option should be deprecated and always be False .
1341+ clip_actions: If True, the RLlib default ModuleToEnv connector will clip
1342+ actions according to the env's bounds ( before sending them into the
1343+ `env.step()` call) .
13411344 disable_env_checking: If True, disable the environment pre-checking module.
13421345 is_atari: This config can be used to explicitly specify whether the env is
13431346 an Atari env or not. If not specified, RLlib will try to auto-detect
@@ -1678,6 +1681,7 @@ def training(
16781681 grad_clip : Optional [float ] = NotProvided ,
16791682 grad_clip_by : Optional [str ] = NotProvided ,
16801683 train_batch_size : Optional [int ] = NotProvided ,
1684+ train_batch_size_per_learner : Optional [int ] = NotProvided ,
16811685 model : Optional [dict ] = NotProvided ,
16821686 optimizer : Optional [dict ] = NotProvided ,
16831687 max_requests_in_flight_per_sampler_worker : Optional [int ] = NotProvided ,
@@ -1726,7 +1730,16 @@ def training(
17261730 the shapes of these tensors are).
17271731 grad_clip_by: See `grad_clip` for the effect of this setting on gradient
17281732 clipping. Allowed values are `value`, `norm`, and `global_norm`.
1729- train_batch_size: Training batch size, if applicable.
1733+ train_batch_size_per_learner: Train batch size per individual Learner
1734+ worker. This setting only applies to the new API stack. The number
1735+ of Learner workers can be set via `config.resources(
1736+ num_learner_workers=...)`. The total effective batch size is then
1737+ `num_learner_workers` x `train_batch_size_per_learner` and can
1738+ be accessed via the property `AlgorithmConfig.total_train_batch_size`.
1739+ train_batch_size: Training batch size, if applicable. When on the new API
1740+ stack, this setting should no longer be used. Instead, use
1741+ `train_batch_size_per_learner` (in combination with
1742+ `num_learner_workers`).
17301743 model: Arguments passed into the policy model. See models/catalog.py for a
17311744 full list of the available model options.
17321745 TODO: Provide ModelConfig objects instead of dicts.
@@ -1766,6 +1779,8 @@ def training(
17661779 "or 'global_norm'!"
17671780 )
17681781 self .grad_clip_by = grad_clip_by
1782+ if train_batch_size_per_learner is not NotProvided :
1783+ self .train_batch_size_per_learner = train_batch_size_per_learner
17691784 if train_batch_size is not NotProvided :
17701785 self .train_batch_size = train_batch_size
17711786 if model is not NotProvided :
@@ -2716,20 +2731,29 @@ def uses_new_env_runners(self):
27162731 self .env_runner_cls , RolloutWorker
27172732 )
27182733
2734+ @property
2735+ def total_train_batch_size (self ):
2736+ if self .train_batch_size_per_learner is not None :
2737+ return self .train_batch_size_per_learner * (self .num_learner_workers or 1 )
2738+ else :
2739+ return self .train_batch_size
2740+
2741+ # TODO: Make rollout_fragment_length as read-only property and replace the current
2742+ # self.rollout_fragment_length a private variable.
27192743 def get_rollout_fragment_length (self , worker_index : int = 0 ) -> int :
27202744 """Automatically infers a proper rollout_fragment_length setting if "auto".
27212745
27222746 Uses the simple formula:
2723- `rollout_fragment_length` = `train_batch_size ` /
2747+ `rollout_fragment_length` = `total_train_batch_size ` /
27242748 (`num_envs_per_worker` * `num_rollout_workers`)
27252749
27262750 If result is a fraction AND `worker_index` is provided, will make
27272751 those workers add additional timesteps, such that the overall batch size (across
2728- the workers) will add up to exactly the `train_batch_size `.
2752+ the workers) will add up to exactly the `total_train_batch_size `.
27292753
27302754 Returns:
27312755 The user-provided `rollout_fragment_length` or a computed one (if user
2732- provided value is "auto"), making sure `train_batch_size ` is reached
2756+ provided value is "auto"), making sure `total_train_batch_size ` is reached
27332757 exactly in each iteration.
27342758 """
27352759 if self .rollout_fragment_length == "auto" :
@@ -2739,11 +2763,11 @@ def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
27392763 # 4 workers, 3 envs per worker, 2500 train batch size:
27402764 # -> 2500 / 12 -> 208.333 -> diff=4 (208 * 12 = 2496)
27412765 # -> worker 1: 209, workers 2-4: 208
2742- rollout_fragment_length = self .train_batch_size / (
2766+ rollout_fragment_length = self .total_train_batch_size / (
27432767 self .num_envs_per_worker * (self .num_rollout_workers or 1 )
27442768 )
27452769 if int (rollout_fragment_length ) != rollout_fragment_length :
2746- diff = self .train_batch_size - int (
2770+ diff = self .total_train_batch_size - int (
27472771 rollout_fragment_length
27482772 ) * self .num_envs_per_worker * (self .num_rollout_workers or 1 )
27492773 if (worker_index * self .num_envs_per_worker ) <= diff :
@@ -3095,36 +3119,38 @@ def validate_train_batch_size_vs_rollout_fragment_length(self) -> None:
30953119
30963120 Raises:
30973121 ValueError: If there is a mismatch between user provided
3098- `rollout_fragment_length` and `train_batch_size `.
3122+ `rollout_fragment_length` and `total_train_batch_size `.
30993123 """
31003124 if (
31013125 self .rollout_fragment_length != "auto"
31023126 and not self .in_evaluation
3103- and self .train_batch_size > 0
3127+ and self .total_train_batch_size > 0
31043128 ):
31053129 min_batch_size = (
31063130 max (self .num_rollout_workers , 1 )
31073131 * self .num_envs_per_worker
31083132 * self .rollout_fragment_length
31093133 )
31103134 batch_size = min_batch_size
3111- while batch_size < self .train_batch_size :
3135+ while batch_size < self .total_train_batch_size :
31123136 batch_size += min_batch_size
3113- if (
3114- batch_size - self . train_batch_size > 0.1 * self .train_batch_size
3115- or batch_size - min_batch_size - self .train_batch_size
3116- > ( 0.1 * self .train_batch_size )
3137+ if batch_size - self . total_train_batch_size > (
3138+ 0.1 * self .total_train_batch_size
3139+ ) or batch_size - min_batch_size - self .total_train_batch_size > (
3140+ 0.1 * self .total_train_batch_size
31173141 ):
3118- suggested_rollout_fragment_length = self .train_batch_size // (
3142+ suggested_rollout_fragment_length = self .total_train_batch_size // (
31193143 self .num_envs_per_worker * (self .num_rollout_workers or 1 )
31203144 )
31213145 raise ValueError (
3122- f"Your desired `train_batch_size` ({ self .train_batch_size } ) or a "
3123- "value 10% off of that cannot be achieved with your other "
3146+ "Your desired `total_train_batch_size` "
3147+ f"({ self .total_train_batch_size } ={ self .num_learner_workers } "
3148+ f"learners x { self .train_batch_size_per_learner } ) "
3149+ "or a value 10% off of that cannot be achieved with your other "
31243150 f"settings (num_rollout_workers={ self .num_rollout_workers } ; "
31253151 f"num_envs_per_worker={ self .num_envs_per_worker } ; "
31263152 f"rollout_fragment_length={ self .rollout_fragment_length } )! "
3127- "Try setting `rollout_fragment_length` to 'auto' OR "
3153+ "Try setting `rollout_fragment_length` to 'auto' OR to a value of "
31283154 f"{ suggested_rollout_fragment_length } ."
31293155 )
31303156
@@ -3580,8 +3606,7 @@ def _validate_evaluation_settings(self):
35803606 """Checks, whether evaluation related settings make sense."""
35813607 if (
35823608 self .evaluation_interval
3583- and self .env_runner_cls is not None
3584- and not issubclass (self .env_runner_cls , RolloutWorker )
3609+ and self .uses_new_env_runners
35853610 and not self .enable_async_evaluation
35863611 ):
35873612 raise ValueError (
0 commit comments