Skip to content

Commit 6256023

Browse files
Format .jl files
1 parent 596cf10 commit 6256023

File tree

105 files changed

+1427
-1189
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

105 files changed

+1427
-1189
lines changed

docs/experiments/experiments/CFR/JuliaRL_DeepCFR_OpenSpiel.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,11 @@ function RL.Experiment(
6161
batch_size_Π = 2048,
6262
initializer = glorot_normal(CUDA.CURAND.default_rng()),
6363
)
64-
Experiment(p, env, StopAfterStep(500, is_show_progress=!haskey(ENV, "CI")), EmptyHook(), "# run DeepcCFR on leduc_poker")
65-
end
64+
Experiment(
65+
p,
66+
env,
67+
StopAfterStep(500, is_show_progress = !haskey(ENV, "CI")),
68+
EmptyHook(),
69+
"# run DeepcCFR on leduc_poker",
70+
)
71+
end

docs/experiments/experiments/CFR/JuliaRL_TabularCFR_OpenSpiel.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,14 @@ function RL.Experiment(
2323
π = TabularCFRPolicy(; rng = rng)
2424

2525
description = "# Play `$game` in OpenSpiel with TabularCFRPolicy"
26-
Experiment(π, env, StopAfterStep(300, is_show_progress=!haskey(ENV, "CI")), EmptyHook(), description)
26+
Experiment(
27+
π,
28+
env,
29+
StopAfterStep(300, is_show_progress = !haskey(ENV, "CI")),
30+
EmptyHook(),
31+
description,
32+
)
2733
end
2834

2935
ex = E`JuliaRL_TabularCFR_OpenSpiel(kuhn_poker)`
30-
run(ex)
36+
run(ex)

docs/experiments/experiments/DQN/Dopamine_DQN_Atari.jl

Lines changed: 35 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -79,39 +79,35 @@ function atari_env_factory(
7979
repeat_action_probability = 0.25,
8080
n_replica = nothing,
8181
)
82-
init(seed) =
83-
RewardOverriddenEnv(
84-
StateCachedEnv(
85-
StateTransformedEnv(
86-
AtariEnv(;
87-
name = string(name),
88-
grayscale_obs = true,
89-
noop_max = 30,
90-
frame_skip = 4,
91-
terminal_on_life_loss = false,
92-
repeat_action_probability = repeat_action_probability,
93-
max_num_frames_per_episode = n_frames * max_episode_steps,
94-
color_averaging = false,
95-
full_action_space = false,
96-
seed = seed,
97-
);
98-
state_mapping=Chain(
99-
ResizeImage(state_size...),
100-
StackFrames(state_size..., n_frames)
101-
),
102-
state_space_mapping= _ -> Space(fill(0..256, state_size..., n_frames))
103-
)
82+
init(seed) = RewardOverriddenEnv(
83+
StateCachedEnv(
84+
StateTransformedEnv(
85+
AtariEnv(;
86+
name = string(name),
87+
grayscale_obs = true,
88+
noop_max = 30,
89+
frame_skip = 4,
90+
terminal_on_life_loss = false,
91+
repeat_action_probability = repeat_action_probability,
92+
max_num_frames_per_episode = n_frames * max_episode_steps,
93+
color_averaging = false,
94+
full_action_space = false,
95+
seed = seed,
96+
);
97+
state_mapping = Chain(
98+
ResizeImage(state_size...),
99+
StackFrames(state_size..., n_frames),
100+
),
101+
state_space_mapping = _ -> Space(fill(0..256, state_size..., n_frames)),
104102
),
105-
r -> clamp(r, -1, 1)
106-
)
103+
),
104+
r -> clamp(r, -1, 1),
105+
)
107106

108107
if isnothing(n_replica)
109108
init(seed)
110109
else
111-
envs = [
112-
init(isnothing(seed) ? nothing : hash(seed + i))
113-
for i in 1:n_replica
114-
]
110+
envs = [init(isnothing(seed) ? nothing : hash(seed + i)) for i in 1:n_replica]
115111
states = Flux.batch(state.(envs))
116112
rewards = reward.(envs)
117113
terminals = is_terminated.(envs)
@@ -172,7 +168,7 @@ function RL.Experiment(
172168
::Val{:Atari},
173169
name::AbstractString;
174170
save_dir = nothing,
175-
seed = nothing
171+
seed = nothing,
176172
)
177173
rng = Random.GLOBAL_RNG
178174
Random.seed!(rng, seed)
@@ -190,7 +186,7 @@ function RL.Experiment(
190186
name,
191187
STATE_SIZE,
192188
N_FRAMES;
193-
seed = isnothing(seed) ? nothing : hash(seed + 1)
189+
seed = isnothing(seed) ? nothing : hash(seed + 1),
194190
)
195191
N_ACTIONS = length(action_space(env))
196192
init = glorot_uniform(rng)
@@ -254,17 +250,15 @@ function RL.Experiment(
254250
end,
255251
DoEveryNEpisode() do t, agent, env
256252
with_logger(lg) do
257-
@info "training" episode_length = step_per_episode.steps[end] reward = reward_per_episode.rewards[end] log_step_increment = 0
253+
@info "training" episode_length = step_per_episode.steps[end] reward =
254+
reward_per_episode.rewards[end] log_step_increment = 0
258255
end
259256
end,
260-
DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env
257+
DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env
261258
@info "evaluating agent at $t step..."
262259
p = agent.policy
263260
p = @set p.explorer = EpsilonGreedyExplorer(0.001; rng = rng) # set evaluation epsilon
264-
h = ComposedHook(
265-
TotalOriginalRewardPerEpisode(),
266-
StepsPerEpisode(),
267-
)
261+
h = ComposedHook(TotalOriginalRewardPerEpisode(), StepsPerEpisode())
268262
s = @elapsed run(
269263
p,
270264
atari_env_factory(
@@ -281,16 +275,18 @@ function RL.Experiment(
281275
avg_score = mean(h[1].rewards[1:end-1])
282276
avg_length = mean(h[2].steps[1:end-1])
283277

284-
@info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = avg_score
278+
@info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score =
279+
avg_score
285280
with_logger(lg) do
286-
@info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = 0
281+
@info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment =
282+
0
287283
end
288284
end,
289285
)
290286

291287
stop_condition = StopAfterStep(
292288
haskey(ENV, "CI") ? 1_000 : 50_000_000,
293-
is_show_progress=!haskey(ENV, "CI")
289+
is_show_progress = !haskey(ENV, "CI"),
294290
)
295291
Experiment(agent, env, stop_condition, hook, "# DQN <-> Atari($name)")
296292
end

docs/experiments/experiments/DQN/Dopamine_IQN_Atari.jl

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -84,39 +84,35 @@ function atari_env_factory(
8484
repeat_action_probability = 0.25,
8585
n_replica = nothing,
8686
)
87-
init(seed) =
88-
RewardOverriddenEnv(
89-
StateCachedEnv(
90-
StateTransformedEnv(
91-
AtariEnv(;
92-
name = string(name),
93-
grayscale_obs = true,
94-
noop_max = 30,
95-
frame_skip = 4,
96-
terminal_on_life_loss = false,
97-
repeat_action_probability = repeat_action_probability,
98-
max_num_frames_per_episode = n_frames * max_episode_steps,
99-
color_averaging = false,
100-
full_action_space = false,
101-
seed = seed,
102-
);
103-
state_mapping=Chain(
104-
ResizeImage(state_size...),
105-
StackFrames(state_size..., n_frames)
106-
),
107-
state_space_mapping= _ -> Space(fill(0..256, state_size..., n_frames))
108-
)
87+
init(seed) = RewardOverriddenEnv(
88+
StateCachedEnv(
89+
StateTransformedEnv(
90+
AtariEnv(;
91+
name = string(name),
92+
grayscale_obs = true,
93+
noop_max = 30,
94+
frame_skip = 4,
95+
terminal_on_life_loss = false,
96+
repeat_action_probability = repeat_action_probability,
97+
max_num_frames_per_episode = n_frames * max_episode_steps,
98+
color_averaging = false,
99+
full_action_space = false,
100+
seed = seed,
101+
);
102+
state_mapping = Chain(
103+
ResizeImage(state_size...),
104+
StackFrames(state_size..., n_frames),
105+
),
106+
state_space_mapping = _ -> Space(fill(0..256, state_size..., n_frames)),
109107
),
110-
r -> clamp(r, -1, 1)
111-
)
108+
),
109+
r -> clamp(r, -1, 1),
110+
)
112111

113112
if isnothing(n_replica)
114113
init(seed)
115114
else
116-
envs = [
117-
init(isnothing(seed) ? nothing : hash(seed + i))
118-
for i in 1:n_replica
119-
]
115+
envs = [init(isnothing(seed) ? nothing : hash(seed + i)) for i in 1:n_replica]
120116
states = Flux.batch(state.(envs))
121117
rewards = reward.(envs)
122118
terminals = is_terminated.(envs)
@@ -195,7 +191,12 @@ function RL.Experiment(
195191
N_FRAMES = 4
196192
STATE_SIZE = (84, 84)
197193

198-
env = atari_env_factory(name, STATE_SIZE, N_FRAMES; seed = isnothing(seed) ? nothing : hash(seed + 2))
194+
env = atari_env_factory(
195+
name,
196+
STATE_SIZE,
197+
N_FRAMES;
198+
seed = isnothing(seed) ? nothing : hash(seed + 2),
199+
)
199200
N_ACTIONS = length(action_space(env))
200201
Nₑₘ = 64
201202

@@ -250,7 +251,7 @@ function RL.Experiment(
250251
),
251252
),
252253
trajectory = CircularArraySARTTrajectory(
253-
capacity = haskey(ENV, "CI") : 1_000 : 1_000_000,
254+
capacity = haskey(ENV, "CI"):1_000:1_000_000,
254255
state = Matrix{Float32} => STATE_SIZE,
255256
),
256257
)
@@ -274,7 +275,7 @@ function RL.Experiment(
274275
steps_per_episode.steps[end] log_step_increment = 0
275276
end
276277
end,
277-
DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env
278+
DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env
278279
@info "evaluating agent at $t step..."
279280
p = agent.policy
280281
p = @set p.explorer = EpsilonGreedyExplorer(0.001; rng = rng) # set evaluation epsilon
@@ -286,7 +287,7 @@ function RL.Experiment(
286287
STATE_SIZE,
287288
N_FRAMES,
288289
MAX_EPISODE_STEPS_EVAL;
289-
seed = isnothing(seed) ? nothing : hash(seed + t)
290+
seed = isnothing(seed) ? nothing : hash(seed + t),
290291
),
291292
StopAfterStep(125_000; is_show_progress = false),
292293
h,
@@ -295,16 +296,18 @@ function RL.Experiment(
295296
avg_score = mean(h[1].rewards[1:end-1])
296297
avg_length = mean(h[2].steps[1:end-1])
297298

298-
@info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = avg_score
299+
@info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score =
300+
avg_score
299301
with_logger(lg) do
300-
@info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = 0
302+
@info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment =
303+
0
301304
end
302305
end,
303306
)
304307

305308
stop_condition = StopAfterStep(
306309
haskey(ENV, "CI") ? 10_000 : 50_000_000,
307-
is_show_progress=!haskey(ENV, "CI")
310+
is_show_progress = !haskey(ENV, "CI"),
308311
)
309312
Experiment(agent, env, stop_condition, hook, "# IQN <-> Atari($name)")
310313
end

0 commit comments

Comments
 (0)