Skip to content

Commit f9f8412

Browse files
Merge pull request #76 from JuliaReinforcementLearning/jpsl/format
Run code format
2 parents 2d25c54 + b5347d6 commit f9f8412

24 files changed

+928
-578
lines changed

src/common/CircularArraySARTSATraces.jl

+11-10
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,28 @@ const CircularArraySARTSATraces = Traces{
99
<:MultiplexTraces{AA′,<:Trace{<:CircularArrayBuffer}},
1010
<:Trace{<:CircularArrayBuffer},
1111
<:Trace{<:CircularArrayBuffer},
12-
}
12+
},
1313
}
1414

1515
function CircularArraySARTSATraces(;
1616
capacity::Int,
17-
state=Int => (),
18-
action=Int => (),
19-
reward=Float32 => (),
20-
terminal=Bool => ()
17+
state = Int => (),
18+
action = Int => (),
19+
reward = Float32 => (),
20+
terminal = Bool => (),
2121
)
2222
state_eltype, state_size = state
2323
action_eltype, action_size = action
2424
reward_eltype, reward_size = reward
2525
terminal_eltype, terminal_size = terminal
2626

27-
MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+1)) +
28-
MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity+1)) +
27+
MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
28+
MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
2929
Traces(
30-
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
31-
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
30+
reward = CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
31+
terminal = CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
3232
)
3333
end
3434

35-
CircularArrayBuffers.capacity(t::CircularArraySARTSATraces) = minimum(map(capacity,t.traces))
35+
CircularArrayBuffers.capacity(t::CircularArraySARTSATraces) =
36+
minimum(map(capacity, t.traces))

src/common/CircularArraySARTSTraces.jl

+10-9
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,28 @@ const CircularArraySARTSTraces = Traces{
99
<:Trace{<:CircularArrayBuffer},
1010
<:Trace{<:CircularArrayBuffer},
1111
<:Trace{<:CircularArrayBuffer},
12-
}
12+
},
1313
}
1414

1515
function CircularArraySARTSTraces(;
1616
capacity::Int,
17-
state=Int => (),
18-
action=Int => (),
19-
reward=Float32 => (),
20-
terminal=Bool => ()
17+
state = Int => (),
18+
action = Int => (),
19+
reward = Float32 => (),
20+
terminal = Bool => (),
2121
)
2222
state_eltype, state_size = state
2323
action_eltype, action_size = action
2424
reward_eltype, reward_size = reward
2525
terminal_eltype, terminal_size = terminal
2626

27-
MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+1)) +
27+
MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
2828
Traces(
2929
action = CircularArrayBuffer{action_eltype}(action_size..., capacity),
30-
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
31-
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
30+
reward = CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
31+
terminal = CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
3232
)
3333
end
3434

35-
CircularArrayBuffers.capacity(t::CircularArraySARTSTraces) = minimum(map(capacity,t.traces))
35+
CircularArrayBuffers.capacity(t::CircularArraySARTSTraces) =
36+
minimum(map(capacity, t.traces))

src/common/CircularArraySLARTTraces.jl

+16-10
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@ const CircularArraySLARTTraces = Traces{
88
<:MultiplexTraces{AA′,<:Trace{<:CircularArrayBuffer}},
99
<:Trace{<:CircularArrayBuffer},
1010
<:Trace{<:CircularArrayBuffer},
11-
}
11+
},
1212
}
1313

1414
function CircularArraySLARTTraces(;
1515
capacity::Int,
16-
state=Int => (),
17-
legal_actions_mask=Bool => (),
18-
action=Int => (),
19-
reward=Float32 => (),
20-
terminal=Bool => ()
16+
state = Int => (),
17+
legal_actions_mask = Bool => (),
18+
action = Int => (),
19+
reward = Float32 => (),
20+
terminal = Bool => (),
2121
)
2222
state_eltype, state_size = state
2323
action_eltype, action_size = action
@@ -26,12 +26,18 @@ function CircularArraySLARTTraces(;
2626
terminal_eltype, terminal_size = terminal
2727

2828
MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
29-
MultiplexTraces{LL′}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) +
29+
MultiplexTraces{LL′}(
30+
CircularArrayBuffer{legal_actions_mask_eltype}(
31+
legal_actions_mask_size...,
32+
capacity + 1,
33+
),
34+
) +
3035
MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
3136
Traces(
32-
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
33-
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
37+
reward = CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
38+
terminal = CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
3439
)
3540
end
3641

37-
CircularArrayBuffers.capacity(t::CircularArraySLARTTraces) = minimum(map(capacity,t.traces))
42+
CircularArrayBuffers.capacity(t::CircularArraySLARTTraces) =
43+
minimum(map(capacity, t.traces))

src/common/CircularPrioritizedTraces.jl

+9-4
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,18 @@ struct CircularPrioritizedTraces{T,names,Ts} <: AbstractTraces{names,Ts}
99
default_priority::Float32
1010
end
1111

12-
function CircularPrioritizedTraces(traces::AbstractTraces{names,Ts}; default_priority) where {names,Ts}
12+
function CircularPrioritizedTraces(
13+
traces::AbstractTraces{names,Ts};
14+
default_priority,
15+
) where {names,Ts}
1316
new_names = (:key, :priority, names...)
1417
new_Ts = Tuple{Int,Float32,Ts.parameters...}
1518
c = capacity(traces)
1619
CircularPrioritizedTraces{typeof(traces),new_names,new_Ts}(
1720
CircularVectorBuffer{Int}(c),
1821
SumTree(c),
1922
traces,
20-
default_priority
23+
default_priority,
2124
)
2225
end
2326

@@ -60,6 +63,8 @@ function Base.getindex(ts::CircularPrioritizedTraces, s::Symbol)
6063
end
6164
end
6265

63-
Base.getindex(t::CircularPrioritizedTraces{<:Any,names}, i) where {names} = NamedTuple{names}(map(k -> t[k][i], names))
66+
Base.getindex(t::CircularPrioritizedTraces{<:Any,names}, i) where {names} =
67+
NamedTuple{names}(map(k -> t[k][i], names))
6468

65-
capacity(t::CircularPrioritizedTraces) = ReinforcementLearningTrajectories.capacity(t.traces)
69+
capacity(t::CircularPrioritizedTraces) =
70+
ReinforcementLearningTrajectories.capacity(t.traces)

src/common/ElasticArraySARTSATraces.jl

+7-8
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ const ElasticArraySARTSATraces = Traces{
77
<:MultiplexTraces{AA′,<:Trace{<:ElasticArray}},
88
<:Trace{<:ElasticArray},
99
<:Trace{<:ElasticArray},
10-
}
10+
},
1111
}
1212

1313
function ElasticArraySARTSATraces(;
14-
state=Int => (),
15-
action=Int => (),
16-
reward=Float32 => (),
17-
terminal=Bool => ()
14+
state = Int => (),
15+
action = Int => (),
16+
reward = Float32 => (),
17+
terminal = Bool => (),
1818
)
1919
state_eltype, state_size = state
2020
action_eltype, action_size = action
@@ -24,8 +24,7 @@ function ElasticArraySARTSATraces(;
2424
MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) +
2525
MultiplexTraces{AA′}(ElasticArray{action_eltype}(undef, action_size..., 0)) +
2626
Traces(
27-
reward=ElasticArray{reward_eltype}(undef, reward_size..., 0),
28-
terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
27+
reward = ElasticArray{reward_eltype}(undef, reward_size..., 0),
28+
terminal = ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
2929
)
3030
end
31-

src/common/ElasticArraySARTSTraces.jl

+10-10
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,24 @@ const ElasticArraySARTSTraces = Traces{
77
<:Trace{<:ElasticArray},
88
<:Trace{<:ElasticArray},
99
<:Trace{<:ElasticArray},
10-
}
10+
},
1111
}
1212

1313
function ElasticArraySARTSTraces(;
14-
state=Int => (),
15-
action=Int => (),
16-
reward=Float32 => (),
17-
terminal=Bool => ())
18-
14+
state = Int => (),
15+
action = Int => (),
16+
reward = Float32 => (),
17+
terminal = Bool => (),
18+
)
19+
1920
state_eltype, state_size = state
2021
action_eltype, action_size = action
2122
reward_eltype, reward_size = reward
2223
terminal_eltype, terminal_size = terminal
2324

24-
MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) +
25-
Traces(
25+
MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) + Traces(
2626
action = ElasticArray{action_eltype}(undef, action_size..., 0),
27-
reward=ElasticArray{reward_eltype}(undef, reward_size..., 0),
28-
terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
27+
reward = ElasticArray{reward_eltype}(undef, reward_size..., 0),
28+
terminal = ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
2929
)
3030
end

src/common/ElasticArraySLARTTraces.jl

+11-9
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@ const ElasticArraySLARTTraces = Traces{
88
<:MultiplexTraces{AA′,<:Trace{<:ElasticArray}},
99
<:Trace{<:ElasticArray},
1010
<:Trace{<:ElasticArray},
11-
}
11+
},
1212
}
1313

1414
function ElasticArraySLARTTraces(;
1515
capacity::Int,
16-
state=Int => (),
17-
legal_actions_mask=Bool => (),
18-
action=Int => (),
19-
reward=Float32 => (),
20-
terminal=Bool => ()
16+
state = Int => (),
17+
legal_actions_mask = Bool => (),
18+
action = Int => (),
19+
reward = Float32 => (),
20+
terminal = Bool => (),
2121
)
2222
state_eltype, state_size = state
2323
action_eltype, action_size = action
@@ -26,10 +26,12 @@ function ElasticArraySLARTTraces(;
2626
terminal_eltype, terminal_size = terminal
2727

2828
MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) +
29-
MultiplexTraces{LL′}(ElasticArray{legal_actions_mask_eltype}(undef, legal_actions_mask_size..., 0)) +
29+
MultiplexTraces{LL′}(
30+
ElasticArray{legal_actions_mask_eltype}(undef, legal_actions_mask_size..., 0),
31+
) +
3032
MultiplexTraces{AA′}(ElasticArray{action_eltype}(undef, action_size..., 0)) +
3133
Traces(
32-
reward=ElasticArray{reward_eltype}(undef, reward_size..., 0),
33-
terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
34+
reward = ElasticArray{reward_eltype}(undef, reward_size..., 0),
35+
terminal = ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
3436
)
3537
end

src/common/sum_tree.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ function correct_sample(t::SumTree, leaf_ind)
139139
p = t.tree[leaf_ind]
140140
# walk backwards until p != 0 or until leftmost leaf reached
141141
tmp_ind = leaf_ind
142-
while iszero(p) && (tmp_ind-1)*2 > length(t.tree)
142+
while iszero(p) && (tmp_ind - 1) * 2 > length(t.tree)
143143
tmp_ind -= 1
144144
p = t.tree[tmp_ind]
145145
end
@@ -151,7 +151,7 @@ function correct_sample(t::SumTree, leaf_ind)
151151
end
152152
return p, tmp_ind
153153
end
154-
154+
155155

156156
function Base.get(t::SumTree, v)
157157
parent_ind = 1
@@ -185,7 +185,7 @@ Random.rand(t::SumTree) = rand(Random.GLOBAL_RNG, t)
185185

186186
function Random.rand(rng::AbstractRNG, t::SumTree{T}, n::Int) where {T}
187187
inds, priorities = Vector{Int}(undef, n), Vector{T}(undef, n)
188-
for i in 1:n
188+
for i = 1:n
189189
v = (i - 1 + rand(rng, T)) / n
190190
ind, p = get(t, v * t.tree[1])
191191
inds[i] = ind

src/controllers.jl

+11-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
export InsertSampleRatioController, AsyncInsertSampleRatioController, EpisodeSampleRatioController
1+
export InsertSampleRatioController,
2+
AsyncInsertSampleRatioController, EpisodeSampleRatioController
23

34
"""
45
InsertSampleRatioController(;ratio=1., threshold=1)
@@ -43,18 +44,19 @@ end
4344
function AsyncInsertSampleRatioController(
4445
ratio,
4546
threshold,
46-
; ch_in_sz=1,
47-
ch_out_sz=1,
48-
n_inserted=0,
49-
n_sampled=0
47+
;
48+
ch_in_sz = 1,
49+
ch_out_sz = 1,
50+
n_inserted = 0,
51+
n_sampled = 0,
5052
)
5153
AsyncInsertSampleRatioController(
5254
ratio,
5355
threshold,
5456
n_inserted,
5557
n_sampled,
5658
Channel(ch_in_sz),
57-
Channel(ch_out_sz)
59+
Channel(ch_out_sz),
5860
)
5961
end
6062

@@ -75,14 +77,14 @@ end
7577

7678
function on_insert!(c::EpisodeSampleRatioController, n::Int, x::NamedTuple)
7779
if n > 0
78-
c.n_episodes += sum(x.terminal)
80+
c.n_episodes += sum(x.terminal)
7981
end
8082
end
8183

8284
function on_sample!(c::EpisodeSampleRatioController)
83-
if c.n_episodes >= c.threshold && c.n_sampled <= (c.n_episodes - c.threshold) * c.ratio
85+
if c.n_episodes >= c.threshold && c.n_sampled <= (c.n_episodes - c.threshold) * c.ratio
8486
c.n_sampled += 1
8587
return true
8688
end
8789
return false
88-
end
90+
end

0 commit comments

Comments
 (0)