Description
MultiThreadEnv
has a property called terminals
which is a BitArray{1}
containing the termination information for all of the wrapped environments.
However, the is_terminated
function behaves in an unexpected way:
function RLBase.is_terminated(env::MultiThreadEnv)
env.terminals .= is_terminated.(env.envs)
env.terminals
end
The function implements the same is_terminated
required for any AbstractEnv
; however, instead of returning a Bool
as the name and typical usage of the function would imply, it returns the env.terminals
property which is a BitArray
.
The problem with this is that it breaks any code that expects an is_terminated
call to return a Bool
, such as the StopAfterEpisode
hook. I propose changing the function to the following:
function RLBase.is_terminated(env::MultiThreadEnv)
env.terminals .= is_terminated.(env.envs)
env.terminals == trues(size(env.terminals))
end
This way, the function will behave as expected in the majority of cases, and anyone who needs the termination status of individual wrapped envs can still access that with env.terminals
.