Skip to content

Commit 2d25c54

Browse files
Merge pull request #75 from johannes-fischer/fix_buffer
Remove special handling of SARTSA traces
2 parents de01fb3 + bfc1591 commit 2d25c54

9 files changed

+209
-202
lines changed

src/common/CircularArraySARTSATraces.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ function CircularArraySARTSATraces(;
2424
reward_eltype, reward_size = reward
2525
terminal_eltype, terminal_size = terminal
2626

27-
MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+2)) +
27+
MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+1)) +
2828
MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity+1)) +
2929
Traces(
30-
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity+1),
31-
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity+1),
30+
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
31+
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
3232
)
3333
end
3434

35-
CircularArrayBuffers.capacity(t::CircularArraySARTSATraces) = CircularArrayBuffers.capacity(minimum(map(capacity,t.traces)))
35+
CircularArrayBuffers.capacity(t::CircularArraySARTSATraces) = minimum(map(capacity,t.traces))

src/common/CircularArraySARTSTraces.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ function CircularArraySARTSTraces(;
1717
state=Int => (),
1818
action=Int => (),
1919
reward=Float32 => (),
20-
terminal=Bool => ())
21-
20+
terminal=Bool => ()
21+
)
2222
state_eltype, state_size = state
2323
action_eltype, action_size = action
2424
reward_eltype, reward_size = reward
@@ -32,4 +32,4 @@ function CircularArraySARTSTraces(;
3232
)
3333
end
3434

35-
CircularArrayBuffers.capacity(t::CircularArraySARTSTraces) = CircularArrayBuffers.capacity(minimum(map(capacity,t.traces)))
35+
CircularArrayBuffers.capacity(t::CircularArraySARTSTraces) = minimum(map(capacity,t.traces))

src/common/CircularArraySLARTTraces.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ function CircularArraySLARTTraces(;
3434
)
3535
end
3636

37-
CircularArrayBuffers.capacity(t::CircularArraySLARTTraces) = CircularArrayBuffers.capacity(minimum(map(capacity,t.traces)))
37+
CircularArrayBuffers.capacity(t::CircularArraySLARTTraces) = minimum(map(capacity,t.traces))

src/common/CircularPrioritizedTraces.jl

+1-21
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,7 @@ end
1212
function CircularPrioritizedTraces(traces::AbstractTraces{names,Ts}; default_priority) where {names,Ts}
1313
new_names = (:key, :priority, names...)
1414
new_Ts = Tuple{Int,Float32,Ts.parameters...}
15-
if traces isa CircularArraySARTSATraces
16-
c = capacity(traces) - 1
17-
else
18-
c = capacity(traces)
19-
end
15+
c = capacity(traces)
2016
CircularPrioritizedTraces{typeof(traces),new_names,new_Ts}(
2117
CircularVectorBuffer{Int}(c),
2218
SumTree(c),
@@ -38,22 +34,6 @@ function Base.push!(t::CircularPrioritizedTraces, x)
3834
end
3935
end
4036

41-
function Base.push!(t::CircularPrioritizedTraces{<:CircularArraySARTSATraces}, x)
42-
initial_length = length(t.traces)
43-
push!(t.traces, x)
44-
if length(t.traces) == 1
45-
push!(t.keys, 1)
46-
push!(t.priorities, t.default_priority)
47-
elseif length(t.traces) > 1 && (initial_length < length(t.traces) || initial_length == capacity(t.traces)-1 )
48-
# only add a key if the length changes after insertion of the tuple
49-
# or if the trace is already at capacity
50-
push!(t.keys, t.keys[end] + 1)
51-
push!(t.priorities, t.default_priority)
52-
else
53-
# may be partial inserting at the first step, ignore it
54-
end
55-
end
56-
5737
function Base.setindex!(t::CircularPrioritizedTraces, vs, k::Symbol, keys)
5838
if k === :priority
5939
@assert length(vs) == length(keys)

src/episodes.jl

+9-34
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ using ElasticArrays: ElasticArray, ElasticVector
55
"""
66
EpisodesBuffer(traces::AbstractTraces)
77
8-
Wraps an `AbstractTraces` object, usually the container of a `Trajectory`.
8+
Wraps an `AbstractTraces` object, usually the container of a `Trajectory`.
99
`EpisodesBuffer` tracks the indexes of the `traces` object that belong to the same episodes.
10-
To that end, it stores
10+
To that end, it stores
1111
1. an vector `sampleable_inds` of Booleans that determine whether an index in Traces is legally sampleable
1212
(i.e., it is not the index of a last state of an episode);
1313
2. a vector `episodes_lengths` that contains the total duration of the episode that each step belong to;
@@ -32,7 +32,7 @@ end
3232
"""
3333
PartialNamedTuple(::NamedTuple)
3434
35-
Wraps a NamedTuple to signal an EpisodesBuffer that it is pushed into that it should
35+
Wraps a NamedTuple to signal an EpisodesBuffer that it is pushed into that it should
3636
ignore the fact that this is a partial insertion. Used at the end of an episode to
3737
complete multiplex traces before moving to the next episode.
3838
"""
@@ -43,15 +43,13 @@ end
4343
# Capacity of an EpisodesBuffer is the capacity of the underlying traces + 1 for certain cases
4444
function is_capacity_plus_one(traces::AbstractTraces)
4545
if any(t->t isa MultiplexTraces, traces.traces)
46-
# MultiplexTraces buffer next_state and next_action, so we need to add one to the capacity
47-
return true
48-
elseif traces isa CircularPrioritizedTraces
49-
# CircularPrioritizedTraces buffer next_state and next_action, so we need to add one to the capacity
46+
# MultiplexTraces buffer next_state or next_action, so we need to add one to the capacity
5047
return true
5148
else
5249
false
5350
end
5451
end
52+
is_capacity_plus_one(traces::CircularPrioritizedTraces) = is_capacity_plus_one(traces.traces)
5553

5654
function EpisodesBuffer(traces::AbstractTraces)
5755
cap = is_capacity_plus_one(traces) ? capacity(traces) + 1 : capacity(traces)
@@ -70,7 +68,7 @@ function EpisodesBuffer(traces::AbstractTraces)
7068
end
7169

7270
function Base.getindex(es::EpisodesBuffer, idx::Int...)
73-
@boundscheck all(es.sampleable_inds[idx...])
71+
@boundscheck all(es.sampleable_inds[idx...]) || throw(BoundsError(es.sampleable_inds, idx))
7472
getindex(es.traces, idx...)
7573
end
7674

@@ -79,6 +77,7 @@ function Base.getindex(es::EpisodesBuffer, idx...)
7977
end
8078

8179
Base.setindex!(eb::EpisodesBuffer, idx...) = setindex!(eb.traces, idx...)
80+
capacity(eb::EpisodesBuffer) = capacity(eb.traces)
8281
Base.size(eb::EpisodesBuffer) = size(eb.traces)
8382
Base.length(eb::EpisodesBuffer) = length(eb.traces)
8483
Base.keys(eb::EpisodesBuffer) = keys(eb.traces)
@@ -118,8 +117,6 @@ pad!(vect::Vector{T}) where {T} = push!(vect, zero(T))
118117
end
119118
elseif traces_signature <: Tuple
120119
traces_signature = traces_signature.parameters
121-
122-
123120
for tr in traces_signature
124121
if !(tr <: MultiplexTraces)
125122
#push a duplicate of last element as a dummy element, should never be sampled.
@@ -148,7 +145,7 @@ function Base.push!(eb::EpisodesBuffer, xs::NamedTuple)
148145
push!(eb.episodes_lengths, 0)
149146
push!(eb.sampleable_inds, 0)
150147
elseif !partial #typical inserting
151-
if haskey(eb,:next_action) && length(eb) < max_length(eb) # if trace has next_action and lengths are mismatched
148+
if haskey(eb,:next_action) # if trace has next_action
152149
if eb.step_numbers[end] > 1 # and if there are sufficient steps in the current episode
153150
eb.sampleable_inds[end-1] = 1 # steps are indexable one step later
154151
end
@@ -171,33 +168,11 @@ function Base.push!(eb::EpisodesBuffer, xs::NamedTuple)
171168
return nothing
172169
end
173170

174-
function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTuple to push without incrementing the step number.
175-
push!(eb.traces, xs.namedtuple)
176-
eb.sampleable_inds[end-1] = 1 #completes the episode trajectory.
177-
end
178-
179-
function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularArraySARTSATraces}, xs::PartialNamedTuple)
180-
if max_length(eb) == capacity(eb.traces)
181-
popfirst!(eb)
182-
end
171+
function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTuple to push without incrementing the step number.
183172
push!(eb.traces, xs.namedtuple)
184173
eb.sampleable_inds[end-1] = 1 #completes the episode trajectory.
185174
end
186175

187-
function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces{<:CircularArraySARTSATraces}}, xs::PartialNamedTuple{@NamedTuple{action::Int64}})
188-
if max_length(eb) == capacity(eb.traces)
189-
addition = (name => zero(eltype(eb.traces[name])) for name in [:state, :reward, :terminal])
190-
xs = merge(xs.namedtuple, addition)
191-
push!(eb.traces, xs)
192-
pop!(eb.traces[:state].trace)
193-
pop!(eb.traces[:reward])
194-
pop!(eb.traces[:terminal])
195-
else
196-
push!(eb.traces, xs.namedtuple)
197-
eb.sampleable_inds[end-1] = 1
198-
end
199-
end
200-
201176
for f in (:pop!, :popfirst!)
202177
@eval function Base.$f(eb::EpisodesBuffer)
203178
$f(eb.episodes_lengths)

src/samplers.jl

+23-23
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ export MetaSampler
9393
"""
9494
MetaSampler(::NamedTuple)
9595
96-
Wraps a NamedTuple containing multiple samplers. When sampled, returns a named tuple with a
96+
Wraps a NamedTuple containing multiple samplers. When sampled, returns a named tuple with a
9797
batch from each sampler.
9898
Used internally for algorithms that sample multiple times per epoch.
99-
Note that a single "sampling" with a MetaSampler only increases the Trajectory controler
99+
Note that a single "sampling" with a MetaSampler only increases the Trajectory controler
100100
count by 1, not by the number of internal samplers. This should be taken into account when
101101
initializing an agent.
102102
@@ -131,15 +131,15 @@ export MultiBatchSampler
131131
"""
132132
MultiBatchSampler(sampler, n)
133133
134-
Wraps a sampler. When sampled, will sample n batches using sampler. Useful in combination
134+
Wraps a sampler. When sampled, will sample n batches using sampler. Useful in combination
135135
with MetaSampler to allow different sampling rates between samplers.
136-
Note that a single "sampling" with a MultiBatchSampler only increases the Trajectory
136+
Note that a single "sampling" with a MultiBatchSampler only increases the Trajectory
137137
controler count by 1, not by `n`. This should be taken into account when
138138
initializing an agent.
139139
140140
# Example
141141
```
142-
MetaSampler(policy = MultiBatchSampler(BatchSampler(10), 3),
142+
MetaSampler(policy = MultiBatchSampler(BatchSampler(10), 3),
143143
critic = MultiBatchSampler(BatchSampler(100), 5))
144144
```
145145
"""
@@ -169,13 +169,13 @@ export NStepBatchSampler
169169
NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.GLOBAL_RNG)
170170
171171
Used to sample a discounted sum of consecutive rewards in the framework of n-step TD learning.
172-
The "next" element of Multiplexed traces (such as the next_state or the next_action) will be
172+
The "next" element of Multiplexed traces (such as the next_state or the next_action) will be
173173
that in up to `n > 1` steps later in the buffer. The reward will be
174174
the discounted sum of the `n` rewards, with `γ` as the discount factor.
175175
176-
NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stacksize` is set
176+
NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stacksize` is set
177177
to an integer > 1. This samples the (stacksize - 1) previous states. This is useful in the case
178-
of partial observability, for example when the state is approximated by `stacksize` consecutive
178+
of partial observability, for example when the state is approximated by `stacksize` consecutive
179179
frames.
180180
"""
181181
mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}, R <: AbstractRNG}
@@ -187,17 +187,17 @@ mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}, R <: AbstractRN
187187
end
188188

189189
NStepBatchSampler(t::AbstractTraces; kw...) = NStepBatchSampler{keys(t)}(; kw...)
190-
function NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.default_rng()) where {names}
190+
function NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.default_rng()) where {names}
191191
@assert n >= 1 "n must be ≥ 1."
192192
ss = stacksize == 1 ? nothing : stacksize
193193
NStepBatchSampler{names, typeof(ss), typeof(rng)}(n, γ, batchsize, ss, rng)
194194
end
195195

196196
#return a boolean vector of the valid sample indices given the stacksize and the truncated n for each index.
197-
function valid_range(s::NStepBatchSampler, eb::EpisodesBuffer)
197+
function valid_range(s::NStepBatchSampler, eb::EpisodesBuffer)
198198
range = copy(eb.sampleable_inds)
199199
ns = Vector{Int}(undef, length(eb.sampleable_inds))
200-
stacksize = isnothing(s.stacksize) ? 1 : s.stacksize
200+
stacksize = isnothing(s.stacksize) ? 1 : s.stacksize
201201
for idx in eachindex(range)
202202
step_number = eb.step_numbers[idx]
203203
range[idx] = step_number >= stacksize && eb.sampleable_inds[idx]
@@ -258,9 +258,9 @@ end
258258
"""
259259
EpisodesSampler()
260260
261-
A sampler that samples all Episodes present in the Trajectory and divides them into
261+
A sampler that samples all Episodes present in the Trajectory and divides them into
262262
Episode containers. Truncated Episodes (e.g. due to the buffer capacity) are sampled as well.
263-
There will be at most one truncated episode and it will always be the first one.
263+
There will be at most one truncated episode and it will always be the first one.
264264
"""
265265
struct EpisodesSampler{names}
266266
end
@@ -295,7 +295,7 @@ function StatsBase.sample(::EpisodesSampler, t::EpisodesBuffer, names)
295295
idx += 1
296296
end
297297
end
298-
298+
299299
return [make_episode(t, r, names) for r in ranges]
300300
end
301301

@@ -304,29 +304,29 @@ end
304304
"""
305305
MultiStepSampler{names}(batchsize, n, stacksize, rng)
306306
307-
Sampler that fetches steps `[x, x+1, ..., x + n -1]` for each trace of each sampled index
308-
`x`. The samples are returned in an array of batchsize elements. For each element, n is
309-
truncated by the end of its episode. This means that the dimensions of each sample are not
310-
the same.
307+
Sampler that fetches steps `[x, x+1, ..., x + n -1]` for each trace of each sampled index
308+
`x`. The samples are returned in an array of batchsize elements. For each element, n is
309+
truncated by the end of its episode. This means that the dimensions of each sample are not
310+
the same.
311311
"""
312312
struct MultiStepSampler{names, S <: Union{Nothing,Int}, R <: AbstractRNG}
313313
n::Int
314314
batchsize::Int
315315
stacksize::S
316-
rng::R
316+
rng::R
317317
end
318318

319319
MultiStepSampler(t::AbstractTraces; kw...) = MultiStepSampler{keys(t)}(; kw...)
320-
function MultiStepSampler{names}(; n::Int, batchsize, stacksize=nothing, rng=Random.default_rng()) where {names}
320+
function MultiStepSampler{names}(; n::Int, batchsize, stacksize=nothing, rng=Random.default_rng()) where {names}
321321
@assert n >= 1 "n must be ≥ 1."
322322
ss = stacksize == 1 ? nothing : stacksize
323323
MultiStepSampler{names, typeof(ss), typeof(rng)}(n, batchsize, ss, rng)
324324
end
325325

326-
function valid_range(s::MultiStepSampler, eb::EpisodesBuffer)
326+
function valid_range(s::MultiStepSampler, eb::EpisodesBuffer)
327327
range = copy(eb.sampleable_inds)
328328
ns = Vector{Int}(undef, length(eb.sampleable_inds))
329-
stacksize = isnothing(s.stacksize) ? 1 : s.stacksize
329+
stacksize = isnothing(s.stacksize) ? 1 : s.stacksize
330330
for idx in eachindex(range)
331331
step_number = eb.step_numbers[idx]
332332
range[idx] = step_number >= stacksize && eb.sampleable_inds[idx]
@@ -353,7 +353,7 @@ function fetch(::MultiStepSampler, trace, ::Val, inds, ns)
353353
[trace[idx:(idx + ns[i] - 1)] for (i,idx) in enumerate(inds)]
354354
end
355355

356-
function fetch(s::MultiStepSampler{names, Int}, trace::AbstractTrace, ::Union{Val{:state}, Val{:next_state}}, inds, ns) where {names}
356+
function fetch(s::MultiStepSampler{names, Int}, trace::AbstractTrace, ::Union{Val{:state}, Val{:next_state}}, inds, ns) where {names}
357357
[trace[[idx + i + n - 1 for i in -s.stacksize+1:0, n in 1:ns[j]]] for (j,idx) in enumerate(inds)]
358358
end
359359

0 commit comments

Comments
 (0)