-
Notifications
You must be signed in to change notification settings - Fork 29
Description
我们在使用该仓库代码rl_trainer/main.py进行我们模型的训练时,发现了一个代码中的bug
We have found a bug when we are using your code rl_trainer/main.py to train our own model
# …… 58-66
state_to_training = state[0] # you have defined state_to_training here
# …… 68-78
while True:
# …… 80-86
actions = logits_greedy(state_to_training, logits, height, width) # Here you use state_to_train to generate greedy policy
# …… 87-90
next_state, reward, done, _, info = env.step(env.encode(actions))
next_state_to_training = next_state[0] # create new varible next_state_to_training
next_obs = get_observations(next_state_to_training, ctrl_agent_index, obs_dim, height, width)
# …… 90-116
model.replay_buffer.push(obs, logits, step_reward, next_obs, done)
model.update()
obs = next_obs
step += 1
# …… 123-146代码里面定义了state_to_training,greedy策略也是使用state_to_training作为观测,但是后续代码并未将更新后的状态next_state_to_training赋给state_to_training,使得greedy策略一直观测的是开始时的状态。当然,对于我们自己模型的训练并没有影响,因为get_observations用的是next_state_to_training。但这个bug会使得greedy策略失效,有可能比random还差
You have define state_to_training at the beginning of the code, which is above the loop of a training episode. During the one episode training, you have usedstate_to_training as an observation for greedy policy. But, you haven't updated state_to_training using the updated state next_state_to_training, which would make the greedy policy continuously observing the state at the very beginning. Of course, it doesn't matter the training of our own model, because the argument passing to get_observations is next_state_to_training. We suppose that such a bug will make the greedy policy failed, maybe worse than random policy.
所以应该在更新obs的时候也更新state_to_training
The supposed code to fix the bug is as following:
# …… 58-66
state_to_training = state[0] # you have defined state_to_training here
# …… 68-78
while True:
# …… 80-86
actions = logits_greedy(state_to_training, logits, height, width) # Here you use state_to_train to generate greedy policy
# …… 87-90
next_state, reward, done, _, info = env.step(env.encode(actions))
next_state_to_training = next_state[0] # create new varible next_state_to_training
next_obs = get_observations(next_state_to_training, ctrl_agent_index, obs_dim, height, width)
# …… 90-116
model.replay_buffer.push(obs, logits, step_reward, next_obs, done)
model.update()
obs = next_obs
state_to_training = next_state_to_training
step += 1
# …… 123-146