|
24 | 24 | @test length(t) == 0 |
25 | 25 | end |
26 | 26 |
|
27 | | -@testset "CircularArraySARTSTraces" begin |
| 27 | +@testset "CircularArraySARTSATraces" begin |
28 | 28 | t = CircularArraySARTSATraces(; |
29 | 29 | capacity=3, |
30 | 30 | state=Float32 => (2, 3), |
|
35 | 35 |
|
36 | 36 | @test t isa CircularArraySARTSATraces |
37 | 37 |
|
38 | | - push!(t, (state=ones(Float32, 2, 3), action=ones(Float32, 2)) |> gpu) |
| 38 | + push!(t, (state=ones(Float32, 2, 3),)) |
| 39 | + push!(t, (action=ones(Float32, 2), next_state=ones(Float32, 2, 3) * 2) |> gpu) |
39 | 40 | @test length(t) == 0 |
40 | 41 |
|
41 | 42 | push!(t, (reward=1.0f0, terminal=false) |> gpu) |
42 | | - @test length(t) == 0 # next_state and next_action is still missing |
| 43 | + @test length(t) == 0 # next_action is still missing |
43 | 44 |
|
44 | | - push!(t, (next_state=ones(Float32, 2, 3) * 2, next_action=ones(Float32, 2) * 2) |> gpu) |
| 45 | + push!(t, (state=ones(Float32, 2, 3) * 3, action=ones(Float32, 2) * 2) |> gpu) |
45 | 46 | @test length(t) == 1 |
46 | 47 |
|
47 | 48 | # this will trigger the scalar indexing of CuArray |
|
55 | 56 | ) |
56 | 57 |
|
57 | 58 | push!(t, (reward=2.0f0, terminal=false)) |
58 | | - push!(t, (state=ones(Float32, 2, 3) * 3, action=ones(Float32, 2) * 3) |> gpu) |
| 59 | + push!(t, (state=ones(Float32, 2, 3) * 4, action=ones(Float32, 2) * 3) |> gpu) |
59 | 60 |
|
60 | 61 | @test length(t) == 2 |
61 | 62 |
|
62 | 63 | push!(t, (reward=3.0f0, terminal=false)) |
63 | | - push!(t, (state=ones(Float32, 2, 3) * 4, action=ones(Float32, 2) * 4) |> gpu) |
| 64 | + push!(t, (state=ones(Float32, 2, 3) * 5, action=ones(Float32, 2) * 4) |> gpu) |
64 | 65 |
|
65 | 66 | @test length(t) == 3 |
66 | 67 |
|
67 | 68 | push!(t, (reward=4.0f0, terminal=false)) |
68 | | - push!(t, (state=ones(Float32, 2, 3) * 5, action=ones(Float32, 2) * 5) |> gpu) |
| 69 | + push!(t, (state=ones(Float32, 2, 3) * 6, action=ones(Float32, 2) * 5) |> gpu) |
| 70 | + push!(t, (reward=5.0f0, terminal=false)) |
69 | 71 |
|
70 | 72 | @test length(t) == 3 |
71 | 73 |
|
|
127 | 129 | @test t isa CircularArraySLARTTraces |
128 | 130 | end |
129 | 131 |
|
130 | | -@testset "CircularPrioritizedTraces" begin |
| 132 | +@testset "CircularPrioritizedTraces-SARTS" begin |
131 | 133 | t = CircularPrioritizedTraces( |
132 | | - CircularArraySARTSATraces(; |
| 134 | + CircularArraySARTSTraces(; |
133 | 135 | capacity=3 |
134 | 136 | ), |
135 | 137 | default_priority=1.0f0 |
|
160 | 162 |
|
161 | 163 | #EpisodesBuffer |
162 | 164 | t = CircularPrioritizedTraces( |
163 | | - CircularArraySARTSATraces(; |
| 165 | + CircularArraySARTSTraces(; |
164 | 166 | capacity=10 |
165 | 167 | ), |
166 | 168 | default_priority=1.0f0 |
|
186 | 188 | eb[:priority, [1, 2]] = [0, 0] |
187 | 189 | @test eb[:priority] == [zeros(2);ones(8)] |
188 | 190 | end |
| 191 | + |
| 192 | +@testset "CircularPrioritizedTraces-SARTSA" begin |
| 193 | + t = CircularPrioritizedTraces( |
| 194 | + CircularArraySARTSATraces(; |
| 195 | + capacity=3 |
| 196 | + ), |
| 197 | + default_priority=1.0f0 |
| 198 | + ) |
| 199 | + |
| 200 | + push!(t, (state=0, action=0)) |
| 201 | + |
| 202 | + for i in 1:5 |
| 203 | + push!(t, (reward=1.0f0, terminal=false, state=i, action=i)) |
| 204 | + end |
| 205 | + |
| 206 | + @test length(t) == 3 |
| 207 | + |
| 208 | + s = BatchSampler(5) |
| 209 | + |
| 210 | + b = sample(s, t) |
| 211 | + |
| 212 | + t[:priority, [1, 2]] = [0, 0] |
| 213 | + |
| 214 | + # shouldn't be changed since [1,2] are old keys |
| 215 | + @test t[:priority] == [1.0f0, 1.0f0, 1.0f0] |
| 216 | + |
| 217 | + t[:priority, [3, 4, 5]] = [0, 1, 0] |
| 218 | + |
| 219 | + b = sample(s, t) |
| 220 | + |
| 221 | + @test b.key == [4, 4, 4, 4, 4] # the priority of the rest transitions are set to 0 |
| 222 | + |
| 223 | + #EpisodesBuffer |
| 224 | + t = CircularPrioritizedTraces( |
| 225 | + CircularArraySARTSATraces(; |
| 226 | + capacity=10 |
| 227 | + ), |
| 228 | + default_priority=1.0f0 |
| 229 | + ) |
| 230 | + |
| 231 | + eb = EpisodesBuffer(t) |
| 232 | + push!(eb, (state = 1,)) |
| 233 | + for i = 1:5 |
| 234 | + push!(eb, (state = i+1, action =i, reward = i, terminal = false)) |
| 235 | + end |
| 236 | + push!(eb, PartialNamedTuple((action = 6,))) |
| 237 | + push!(eb, (state = 7,)) |
| 238 | + for (j,i) = enumerate(8:11) |
| 239 | + push!(eb, (state = i, action =i-1, reward = i-1, terminal = false)) |
| 240 | + end |
| 241 | + push!(eb, PartialNamedTuple((action=12,))) |
| 242 | + s = BatchSampler(1000) |
| 243 | + b = sample(s, eb) |
| 244 | + cm = counter(b[:state]) |
| 245 | + @test !haskey(cm, 6) |
| 246 | + @test !haskey(cm, 11) |
| 247 | + @test all(in(keys(cm)), [1:5;7:10]) |
| 248 | +end |
0 commit comments