Skip to content

Commit 55e6ab5

Browse files
Format .jl files
1 parent dd17ad6 commit 55e6ab5

File tree

69 files changed

+713
-630
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+713
-630
lines changed

docs/experiments/experiments/CFR/JuliaRL_DeepCFR_OpenSpiel.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,11 @@ function RL.Experiment(
6262
initializer = glorot_normal(CUDA.CURAND.default_rng()),
6363
)
6464
# nash_conv ≈ 0.23
65-
Experiment(p, env, StopAfterStep(500, is_show_progress=!haskey(ENV, "CI")), EmptyHook(), "# run DeepcCFR on leduc_poker")
66-
end
65+
Experiment(
66+
p,
67+
env,
68+
StopAfterStep(500, is_show_progress = !haskey(ENV, "CI")),
69+
EmptyHook(),
70+
"# run DeepcCFR on leduc_poker",
71+
)
72+
end

docs/experiments/experiments/CFR/JuliaRL_TabularCFR_OpenSpiel.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,15 @@ function RL.Experiment(
2323
π = TabularCFRPolicy(; rng = rng)
2424

2525
description = "# Play `$game` in OpenSpiel with TabularCFRPolicy"
26-
Experiment(π, env, StopAfterStep(300, is_show_progress=!haskey(ENV, "CI")), EmptyHook(), description)
26+
Experiment(
27+
π,
28+
env,
29+
StopAfterStep(300, is_show_progress = !haskey(ENV, "CI")),
30+
EmptyHook(),
31+
description,
32+
)
2733
end
2834

2935
#+ tangle=false
3036
ex = E`JuliaRL_TabularCFR_OpenSpiel(kuhn_poker)`
31-
run(ex)
37+
run(ex)

docs/experiments/experiments/DQN/Dopamine_DQN_Atari.jl

Lines changed: 35 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -79,39 +79,35 @@ function atari_env_factory(
7979
repeat_action_probability = 0.25,
8080
n_replica = nothing,
8181
)
82-
init(seed) =
83-
RewardOverriddenEnv(
84-
StateCachedEnv(
85-
StateTransformedEnv(
86-
AtariEnv(;
87-
name = string(name),
88-
grayscale_obs = true,
89-
noop_max = 30,
90-
frame_skip = 4,
91-
terminal_on_life_loss = false,
92-
repeat_action_probability = repeat_action_probability,
93-
max_num_frames_per_episode = n_frames * max_episode_steps,
94-
color_averaging = false,
95-
full_action_space = false,
96-
seed = seed,
97-
);
98-
state_mapping=Chain(
99-
ResizeImage(state_size...),
100-
StackFrames(state_size..., n_frames)
101-
),
102-
state_space_mapping= _ -> Space(fill(0..256, state_size..., n_frames))
103-
)
82+
init(seed) = RewardOverriddenEnv(
83+
StateCachedEnv(
84+
StateTransformedEnv(
85+
AtariEnv(;
86+
name = string(name),
87+
grayscale_obs = true,
88+
noop_max = 30,
89+
frame_skip = 4,
90+
terminal_on_life_loss = false,
91+
repeat_action_probability = repeat_action_probability,
92+
max_num_frames_per_episode = n_frames * max_episode_steps,
93+
color_averaging = false,
94+
full_action_space = false,
95+
seed = seed,
96+
);
97+
state_mapping = Chain(
98+
ResizeImage(state_size...),
99+
StackFrames(state_size..., n_frames),
100+
),
101+
state_space_mapping = _ -> Space(fill(0..256, state_size..., n_frames)),
104102
),
105-
r -> clamp(r, -1, 1)
106-
)
103+
),
104+
r -> clamp(r, -1, 1),
105+
)
107106

108107
if isnothing(n_replica)
109108
init(seed)
110109
else
111-
envs = [
112-
init(isnothing(seed) ? nothing : hash(seed + i))
113-
for i in 1:n_replica
114-
]
110+
envs = [init(isnothing(seed) ? nothing : hash(seed + i)) for i in 1:n_replica]
115111
states = Flux.batch(state.(envs))
116112
rewards = reward.(envs)
117113
terminals = is_terminated.(envs)
@@ -172,7 +168,7 @@ function RL.Experiment(
172168
::Val{:Atari},
173169
name::AbstractString;
174170
save_dir = nothing,
175-
seed = nothing
171+
seed = nothing,
176172
)
177173
rng = Random.GLOBAL_RNG
178174
Random.seed!(rng, seed)
@@ -190,7 +186,7 @@ function RL.Experiment(
190186
name,
191187
STATE_SIZE,
192188
N_FRAMES;
193-
seed = isnothing(seed) ? nothing : hash(seed + 1)
189+
seed = isnothing(seed) ? nothing : hash(seed + 1),
194190
)
195191
N_ACTIONS = length(action_space(env))
196192
init = glorot_uniform(rng)
@@ -254,17 +250,15 @@ function RL.Experiment(
254250
end,
255251
DoEveryNEpisode() do t, agent, env
256252
with_logger(lg) do
257-
@info "training" episode_length = step_per_episode.steps[end] reward = reward_per_episode.rewards[end] log_step_increment = 0
253+
@info "training" episode_length = step_per_episode.steps[end] reward =
254+
reward_per_episode.rewards[end] log_step_increment = 0
258255
end
259256
end,
260-
DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env
257+
DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env
261258
@info "evaluating agent at $t step..."
262259
p = agent.policy
263260
p = @set p.explorer = EpsilonGreedyExplorer(0.001; rng = rng) # set evaluation epsilon
264-
h = ComposedHook(
265-
TotalOriginalRewardPerEpisode(),
266-
StepsPerEpisode(),
267-
)
261+
h = ComposedHook(TotalOriginalRewardPerEpisode(), StepsPerEpisode())
268262
s = @elapsed run(
269263
p,
270264
atari_env_factory(
@@ -281,16 +275,18 @@ function RL.Experiment(
281275
avg_score = mean(h[1].rewards[1:end-1])
282276
avg_length = mean(h[2].steps[1:end-1])
283277

284-
@info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = avg_score
278+
@info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score =
279+
avg_score
285280
with_logger(lg) do
286-
@info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = 0
281+
@info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment =
282+
0
287283
end
288284
end,
289285
)
290286

291287
stop_condition = StopAfterStep(
292288
haskey(ENV, "CI") ? 1_000 : 50_000_000,
293-
is_show_progress=!haskey(ENV, "CI")
289+
is_show_progress = !haskey(ENV, "CI"),
294290
)
295291
Experiment(agent, env, stop_condition, hook, "# DQN <-> Atari($name)")
296292
end

docs/experiments/experiments/DQN/Dopamine_IQN_Atari.jl

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -84,39 +84,35 @@ function atari_env_factory(
8484
repeat_action_probability = 0.25,
8585
n_replica = nothing,
8686
)
87-
init(seed) =
88-
RewardOverriddenEnv(
89-
StateCachedEnv(
90-
StateTransformedEnv(
91-
AtariEnv(;
92-
name = string(name),
93-
grayscale_obs = true,
94-
noop_max = 30,
95-
frame_skip = 4,
96-
terminal_on_life_loss = false,
97-
repeat_action_probability = repeat_action_probability,
98-
max_num_frames_per_episode = n_frames * max_episode_steps,
99-
color_averaging = false,
100-
full_action_space = false,
101-
seed = seed,
102-
);
103-
state_mapping=Chain(
104-
ResizeImage(state_size...),
105-
StackFrames(state_size..., n_frames)
106-
),
107-
state_space_mapping= _ -> Space(fill(0..256, state_size..., n_frames))
108-
)
87+
init(seed) = RewardOverriddenEnv(
88+
StateCachedEnv(
89+
StateTransformedEnv(
90+
AtariEnv(;
91+
name = string(name),
92+
grayscale_obs = true,
93+
noop_max = 30,
94+
frame_skip = 4,
95+
terminal_on_life_loss = false,
96+
repeat_action_probability = repeat_action_probability,
97+
max_num_frames_per_episode = n_frames * max_episode_steps,
98+
color_averaging = false,
99+
full_action_space = false,
100+
seed = seed,
101+
);
102+
state_mapping = Chain(
103+
ResizeImage(state_size...),
104+
StackFrames(state_size..., n_frames),
105+
),
106+
state_space_mapping = _ -> Space(fill(0..256, state_size..., n_frames)),
109107
),
110-
r -> clamp(r, -1, 1)
111-
)
108+
),
109+
r -> clamp(r, -1, 1),
110+
)
112111

113112
if isnothing(n_replica)
114113
init(seed)
115114
else
116-
envs = [
117-
init(isnothing(seed) ? nothing : hash(seed + i))
118-
for i in 1:n_replica
119-
]
115+
envs = [init(isnothing(seed) ? nothing : hash(seed + i)) for i in 1:n_replica]
120116
states = Flux.batch(state.(envs))
121117
rewards = reward.(envs)
122118
terminals = is_terminated.(envs)
@@ -195,7 +191,12 @@ function RL.Experiment(
195191
N_FRAMES = 4
196192
STATE_SIZE = (84, 84)
197193

198-
env = atari_env_factory(name, STATE_SIZE, N_FRAMES; seed = isnothing(seed) ? nothing : hash(seed + 2))
194+
env = atari_env_factory(
195+
name,
196+
STATE_SIZE,
197+
N_FRAMES;
198+
seed = isnothing(seed) ? nothing : hash(seed + 2),
199+
)
199200
N_ACTIONS = length(action_space(env))
200201
Nₑₘ = 64
201202

@@ -274,7 +275,7 @@ function RL.Experiment(
274275
steps_per_episode.steps[end] log_step_increment = 0
275276
end
276277
end,
277-
DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env
278+
DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env
278279
@info "evaluating agent at $t step..."
279280
p = agent.policy
280281
p = @set p.explorer = EpsilonGreedyExplorer(0.001; rng = rng) # set evaluation epsilon
@@ -286,7 +287,7 @@ function RL.Experiment(
286287
STATE_SIZE,
287288
N_FRAMES,
288289
MAX_EPISODE_STEPS_EVAL;
289-
seed = isnothing(seed) ? nothing : hash(seed + t)
290+
seed = isnothing(seed) ? nothing : hash(seed + t),
290291
),
291292
StopAfterStep(125_000; is_show_progress = false),
292293
h,
@@ -295,16 +296,18 @@ function RL.Experiment(
295296
avg_score = mean(h[1].rewards[1:end-1])
296297
avg_length = mean(h[2].steps[1:end-1])
297298

298-
@info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = avg_score
299+
@info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score =
300+
avg_score
299301
with_logger(lg) do
300-
@info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = 0
302+
@info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment =
303+
0
301304
end
302305
end,
303306
)
304307

305308
stop_condition = StopAfterStep(
306309
haskey(ENV, "CI") ? 10_000 : 50_000_000,
307-
is_show_progress=!haskey(ENV, "CI")
310+
is_show_progress = !haskey(ENV, "CI"),
308311
)
309312
Experiment(agent, env, stop_condition, hook, "# IQN <-> Atari($name)")
310313
end

docs/experiments/experiments/DQN/Dopamine_Rainbow_Atari.jl

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -83,39 +83,35 @@ function atari_env_factory(
8383
repeat_action_probability = 0.25,
8484
n_replica = nothing,
8585
)
86-
init(seed) =
87-
RewardOverriddenEnv(
88-
StateCachedEnv(
89-
StateTransformedEnv(
90-
AtariEnv(;
91-
name = string(name),
92-
grayscale_obs = true,
93-
noop_max = 30,
94-
frame_skip = 4,
95-
terminal_on_life_loss = false,
96-
repeat_action_probability = repeat_action_probability,
97-
max_num_frames_per_episode = n_frames * max_episode_steps,
98-
color_averaging = false,
99-
full_action_space = false,
100-
seed = seed,
101-
);
102-
state_mapping=Chain(
103-
ResizeImage(state_size...),
104-
StackFrames(state_size..., n_frames)
105-
),
106-
state_space_mapping= _ -> Space(fill(0..256, state_size..., n_frames))
107-
)
86+
init(seed) = RewardOverriddenEnv(
87+
StateCachedEnv(
88+
StateTransformedEnv(
89+
AtariEnv(;
90+
name = string(name),
91+
grayscale_obs = true,
92+
noop_max = 30,
93+
frame_skip = 4,
94+
terminal_on_life_loss = false,
95+
repeat_action_probability = repeat_action_probability,
96+
max_num_frames_per_episode = n_frames * max_episode_steps,
97+
color_averaging = false,
98+
full_action_space = false,
99+
seed = seed,
100+
);
101+
state_mapping = Chain(
102+
ResizeImage(state_size...),
103+
StackFrames(state_size..., n_frames),
104+
),
105+
state_space_mapping = _ -> Space(fill(0..256, state_size..., n_frames)),
108106
),
109-
r -> clamp(r, -1, 1)
110-
)
107+
),
108+
r -> clamp(r, -1, 1),
109+
)
111110

112111
if isnothing(n_replica)
113112
init(seed)
114113
else
115-
envs = [
116-
init(isnothing(seed) ? nothing : hash(seed + i))
117-
for i in 1:n_replica
118-
]
114+
envs = [init(isnothing(seed) ? nothing : hash(seed + i)) for i in 1:n_replica]
119115
states = Flux.batch(state.(envs))
120116
rewards = reward.(envs)
121117
terminals = is_terminated.(envs)
@@ -191,7 +187,12 @@ function RL.Experiment(
191187

192188
N_FRAMES = 4
193189
STATE_SIZE = (84, 84)
194-
env = atari_env_factory(name, STATE_SIZE, N_FRAMES; seed = isnothing(seed) ? nothing : hash(seed + 1))
190+
env = atari_env_factory(
191+
name,
192+
STATE_SIZE,
193+
N_FRAMES;
194+
seed = isnothing(seed) ? nothing : hash(seed + 1),
195+
)
195196
N_ACTIONS = length(action_space(env))
196197
N_ATOMS = 51
197198
init = glorot_uniform(rng)
@@ -262,7 +263,7 @@ function RL.Experiment(
262263
steps_per_episode.steps[end] log_step_increment = 0
263264
end
264265
end,
265-
DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env
266+
DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env
266267
@info "evaluating agent at $t step..."
267268
p = agent.policy
268269
p = @set p.explorer = EpsilonGreedyExplorer(0.001; rng = rng) # set evaluation epsilon
@@ -282,16 +283,18 @@ function RL.Experiment(
282283
avg_length = mean(h[2].steps[1:end-1])
283284
avg_score = mean(h[1].rewards[1:end-1])
284285

285-
@info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = avg_score
286+
@info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score =
287+
avg_score
286288
with_logger(lg) do
287-
@info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = 0
289+
@info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment =
290+
0
288291
end
289292
end,
290293
)
291294

292295
stop_condition = StopAfterStep(
293296
haskey(ENV, "CI") ? 10_000 : 50_000_000,
294-
is_show_progress=!haskey(ENV, "CI")
297+
is_show_progress = !haskey(ENV, "CI"),
295298
)
296299

297300
Experiment(agent, env, stop_condition, hook, "# Rainbow <-> Atari($name)")

docs/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ function RL.Experiment(
5151
state = Vector{Float32} => (ns,),
5252
),
5353
)
54-
stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
54+
stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
5555
hook = TotalRewardPerEpisode()
5656
Experiment(policy, env, stop_condition, hook, "# BasicDQN <-> CartPole")
5757
end

0 commit comments

Comments
 (0)