@@ -148,14 +148,20 @@ def __init__(
148148 action_scaling = action_scaling ,
149149 action_bound_method = action_bound_method ,
150150 )
151- if action_scaling and not np .isclose (actor .max_action , 1.0 ):
152- warnings .warn (
153- "action_scaling and action_bound_method are only intended "
154- "to deal with unbounded model action space, but find actor model "
155- f"bound action space with max_action={ actor .max_action } . "
156- "Consider using unbounded=True option of the actor model, "
157- "or set action_scaling to False and action_bound_method to None." ,
158- )
151+ if action_scaling :
152+ try :
153+ max_action = float (actor .max_action ) # type: ignore
154+ if np .isclose (max_action , 1.0 ):
155+ warnings .warn (
156+ "action_scaling and action_bound_method are only intended "
157+ "to deal with unbounded model action space, but find actor model "
158+ f"bound action space with max_action={ actor .max_action } . "
159+ "Consider using unbounded=True option of the actor model, "
160+ "or set action_scaling to False and action_bound_method to None." ,
161+ )
162+ except :
163+ pass
164+
159165 self .actor = actor
160166 self .dist_fn = dist_fn
161167 self ._eps = 1e-8
@@ -286,7 +292,7 @@ def add_discounted_returns(
286292 should be marked by done flag, unfinished (or collecting) episodes will be
287293 recognized by buffer.unfinished_index().
288294 :param buffer: the corresponding replay buffer.
289- :param numpy.ndarray indices: tell batch's location in buffer, batch is equal
295+ :param indices: tell batch's location in buffer, batch is equal
290296 to buffer[indices].
291297 """
292298 v_s_ = np .full (indices .shape , self .ret_rms .mean )
@@ -306,8 +312,7 @@ def add_discounted_returns(
306312 self .ret_rms .update (unnormalized_returns )
307313 else :
308314 batch .returns = unnormalized_returns
309- batch : BatchWithReturnsProtocol
310- return batch
315+ return cast (BatchWithReturnsProtocol , batch )
311316
312317
313318class Reinforce (OnPolicyAlgorithm [ActorPolicyProbabilistic ]):
@@ -316,7 +321,7 @@ class Reinforce(OnPolicyAlgorithm[ActorPolicyProbabilistic]):
316321 def __init__ (
317322 self ,
318323 * ,
319- policy : TActorPolicy ,
324+ policy : ActorPolicyProbabilistic ,
320325 gamma : float = 0.99 ,
321326 return_standardization : bool = False ,
322327 optim : OptimizerFactory ,
0 commit comments