@@ -87,42 +87,42 @@ import ReinforcementLearningTrajectories.fetch
8787 push! (eb, (state = i+ 1 , action = i+ 1 , reward = i, terminal = i == 5 ))
8888 end
8989 push! (eb, (state = 7 , action = 7 ))
90- for (j,i) = enumerate (8 : 11 )
90+ for (j,i) = enumerate (8 : 12 )
9191 push! (eb, (state = i, action = i, reward = i- 1 , terminal = false ))
9292 end
9393 weights, ns = ReinforcementLearningTrajectories. valid_range (s1, eb)
94- @test weights == [0 ,1 ,1 ,1 ,1 , 0 ,0 ,1 ,1 ,1 ,0 ]
95- @test ns == [3 ,3 ,3 , 2 ,1 ,- 1 ,3 ,3 ,2 ,1 ,0 ] # the -1 is due to ep_lengths[6 ] being that of 2nd episode but step_numbers[6] being that of 1st episode
94+ @test weights == [0 ,1 ,1 ,1 ,0 ,0 ,1 ,1 ,1 , 0 ,0 ]
95+ @test ns == [3 ,3 ,2 ,1 ,- 1 ,3 ,3 ,3 , 2 ,1 ,0 ] # the -1 is due to ep_lengths[5 ] being that of 2nd episode but step_numbers[6] being that of 1st episode
9696 inds = [i for i in eachindex (weights) if weights[i] == 1 ]
9797 batch = sample (s1, eb)
9898 for key in keys (eb)
9999 @test haskey (batch, key)
100100 end
101101 # state: samples with stacksize
102102 states = ReinforcementLearningTrajectories. fetch (s1, eb[:state ], Val (:state ), inds, ns[inds])
103- @test states == [1 2 3 4 7 8 9 ;
104- 2 3 4 5 8 9 10 ]
103+ @test states == [1 2 3 6 7 8 ;
104+ 2 3 4 7 8 9 ]
105105 @test all (in (eachcol (states)), unique (eachcol (batch[:state ])))
106106 # next_state: samples with stacksize and nsteps forward
107107 next_states = ReinforcementLearningTrajectories. fetch (s1, eb[:next_state ], Val (:next_state ), inds, ns[inds])
108- @test next_states == [4 5 5 5 10 10 10 ;
109- 5 6 6 6 11 11 11 ]
108+ @test next_states == [4 4 4 9 10 10 ;
109+ 5 5 5 10 11 11 ]
110110 @test all (in (eachcol (next_states)), unique (eachcol (batch[:next_state ])))
111111 # action: samples normally
112112 actions = ReinforcementLearningTrajectories. fetch (s1, eb[:action ], Val (:action ), inds, ns[inds])
113- @test actions == inds
113+ @test actions == [ 3 , 4 , 5 , 8 , 9 , 10 ]
114114 @test all (in (actions), unique (batch[:action ]))
115115 # next_action: is a multiplex trace: should automatically sample nsteps forward
116116 next_actions = ReinforcementLearningTrajectories. fetch (s1, eb[:next_action ], Val (:next_action ), inds, ns[inds])
117- @test next_actions == [5 , 6 , 6 , 6 , 11 , 11 , 11 ]
117+ @test next_actions == [6 , 6 , 6 , 11 , 12 , 12 ]
118118 @test all (in (next_actions), unique (batch[:next_action ]))
119119 # reward: discounted sum
120120 rewards = ReinforcementLearningTrajectories. fetch (s1, eb[:reward ], Val (:reward ), inds, ns[inds])
121- @test rewards ≈ [2 + 0.99 * 3 + 0.99 ^ 2 * 4 , 3 + 0.99 * 4 + 0.99 ^ 2 * 5 , 4 + 0.99 * 5 , 5 , 8 + 0.99 * 9 + 0.99 ^ 2 * 10 ,9 + 0.99 * 10 , 10 ]
121+ @test rewards ≈ [2 + 0.99 * 3 + 0.99 ^ 2 * 4 , 3 + 0.99 * 4 , 4 , 7 + 0.99 * 8 + 0.99 ^ 2 * 9 , 8 + 0.99 * 9 + 0.99 ^ 2 * 10 ,9 + 0.99 * 10 ]
122122 @test all (in (rewards), unique (batch[:reward ]))
123123 # terminal: nsteps forward
124124 terminals = ReinforcementLearningTrajectories. fetch (s1, eb[:terminal ], Val (:terminal ), inds, ns[inds])
125- @test terminals == [0 ,1 , 1 , 1 ,0 ,0 ,0 ]
125+ @test terminals == [0 ,0 , 0 ,0 ,0 ,0 ]
126126
127127 # ## CircularPrioritizedTraces and NStepBatchSampler
128128 γ = 0.99
0 commit comments