@@ -22,13 +22,13 @@ import ReinforcementLearningBase: RLBase
22
22
q_values = NN(rand(Float32, 2))
23
23
@test size(q_values) == (3,)
24
24
25
- gs = gradient(params(NN) ) do
25
+ gs = gradient(NN ) do
26
26
sum(NN(rand(Float32, 2, 5)))
27
27
end
28
28
29
- old_params = deepcopy(collect(params (NN).params))
29
+ old_params = deepcopy(collect(Flux.trainable (NN).params))
30
30
push!(NN, gs)
31
- new_params = collect(params (NN).params)
31
+ new_params = collect(Flux.trainable (NN).params)
32
32
33
33
@test old_params != new_params
34
34
end
@@ -72,42 +72,40 @@ import ReinforcementLearningBase: RLBase
72
72
end
73
73
@testset " Correctness of gradients" begin
74
74
@testset " One action per state" begin
75
- @test Flux. params (gn) == Flux. Params ([gn. pre. weight, gn. pre. bias, gn. μ. weight, gn. μ. bias, gn. σ. weight, gn. σ. bias])
75
+ @test Flux. trainable (gn). pre == gn. pre
76
+ @test Flux. trainable (gn). μ == gn. μ
77
+ @test Flux. trainable (gn). σ == gn. σ
76
78
action_saver = Matrix[]
77
- g = Flux. gradient (Flux . params ( gn)) do
78
- a, logp = gn (state, is_sampling = true , is_return_log_prob = true )
79
+ g = Flux. gradient (gn) do model
80
+ a, logp = model (state, is_sampling = true , is_return_log_prob = true )
79
81
ChainRulesCore. ignore_derivatives () do
80
82
push! (action_saver, a)
81
83
end
82
84
sum (logp)
83
85
end
84
- g2 = Flux. gradient (Flux . params ( gn)) do
85
- logp = gn (state, only (action_saver))
86
+ g2 = Flux. gradient (gn) do model
87
+ logp = model (state, only (action_saver))
86
88
sum (logp)
87
89
end
88
90
# Check that gradients are identical
89
- for (grad1, grad2) in zip (g,g2)
90
- @test grad1 ≈ grad2
91
- end
91
+ @test g == g2
92
92
end
93
93
@testset " Multiple actions per state" begin
94
94
# Same with multiple actions sampled
95
95
action_saver = []
96
96
state = unsqueeze (state, dims = 2 )
97
- g = Flux. gradient (Flux . params ( gn)) do
98
- a, logp = gn (state, 3 )
97
+ g1 = Flux. gradient (gn) do model
98
+ a, logp = model (state, 3 )
99
99
ChainRulesCore. ignore_derivatives () do
100
100
push! (action_saver, a)
101
101
end
102
102
sum (logp)
103
103
end
104
- g2 = Flux. gradient (Flux . params ( gn)) do
105
- logp = gn (state, only (action_saver))
104
+ g2 = Flux. gradient (gn) do model
105
+ logp = model (state, only (action_saver))
106
106
sum (logp)
107
107
end
108
- for (grad1, grad2) in zip (g,g2)
109
- @test grad1 ≈ grad2
110
- end
108
+ @test g1 == g2
111
109
end
112
110
end
113
111
end
@@ -117,7 +115,6 @@ import ReinforcementLearningBase: RLBase
117
115
gn = GaussianNetwork (Dense (20 ,15 ), Dense (15 ,10 ), Dense (15 ,10 , softplus)) |> gpu
118
116
state = rand (Float32, 20 ,3 ) |> gpu # batch of 3 states
119
117
@testset " Forward pass compatibility" begin
120
- @test Flux. params (gn) == Flux. Params ([gn. pre. weight, gn. pre. bias, gn. μ. weight, gn. μ. bias, gn. σ. weight, gn. σ. bias])
121
118
m, L = gn (state)
122
119
@test size (m) == size (L) == (10 ,3 )
123
120
a, logp = gn (CUDA. CURAND. RNG (), state, is_sampling = true , is_return_log_prob = true )
@@ -134,15 +131,15 @@ import ReinforcementLearningBase: RLBase
134
131
@testset " Backward pass compatibility" begin
135
132
@testset " One action sampling" begin
136
133
action_saver = CuMatrix[]
137
- g = Flux. gradient (Flux . params ( gn)) do
138
- a, logp = gn (CUDA. CURAND. RNG (), state, is_sampling = true , is_return_log_prob = true )
134
+ g = Flux. gradient (gn) do model
135
+ a, logp = model (CUDA. CURAND. RNG (), state, is_sampling = true , is_return_log_prob = true )
139
136
ChainRulesCore. ignore_derivatives () do
140
137
push! (action_saver, a)
141
138
end
142
139
sum (logp)
143
140
end
144
- g2 = Flux. gradient (Flux . params ( gn)) do
145
- logp = gn (state, only (action_saver))
141
+ g2 = Flux. gradient (gn) do model
142
+ logp = model (state, only (action_saver))
146
143
sum (logp)
147
144
end
148
145
# Check that gradients are identical
@@ -153,15 +150,15 @@ import ReinforcementLearningBase: RLBase
153
150
@testset " Multiple actions sampling" begin
154
151
action_saver = []
155
152
state = unsqueeze (state, dims = 2 )
156
- g = Flux. gradient (Flux . params (gn) ) do
153
+ g = Flux. gradient (gn ) do
157
154
a, logp = gn (CUDA. CURAND. RNG (), state, 3 )
158
155
ChainRulesCore. ignore_derivatives () do
159
156
push! (action_saver, a)
160
157
end
161
158
sum (logp)
162
159
end
163
- g2 = Flux. gradient (Flux . params ( gn)) do
164
- logp = gn (state, only (action_saver))
160
+ g2 = Flux. gradient (gn) do model
161
+ logp = model (state, only (action_saver))
165
162
sum (logp)
166
163
end
167
164
for (grad1, grad2) in zip (g,g2)
@@ -202,7 +199,10 @@ import ReinforcementLearningBase: RLBase
202
199
μ = Dense (15 ,10 )
203
200
Σ = Dense (15 ,10 * 11 ÷ 2 )
204
201
gn = CovGaussianNetwork (pre, μ, Σ)
205
- @test Flux. params (gn) == Flux. Params ([pre. weight, pre. bias, μ. weight, μ. bias, Σ. weight, Σ. bias])
202
+ @test Flux. trainable (gn). pre == pre
203
+ @test Flux. trainable (gn). μ == μ
204
+ @test Flux. trainable (gn). Σ == Σ
205
+
206
206
state = rand (Float32, 20 ,3 ) # batch of 3 states
207
207
# Check that it works in 2D
208
208
m, L = gn (state)
@@ -233,35 +233,34 @@ import ReinforcementLearningBase: RLBase
233
233
logp_truth = [logpdf (mvn, a) for (mvn, a) in zip (mvnormals, eachslice (as, dims = 3 ))]
234
234
@test stack (logp_truth; dims= 2 ) ≈ dropdims (logps,dims = 1 ) # test against ground truth
235
235
action_saver = []
236
- g = Flux. gradient (Flux . params ( gn)) do
237
- a, logp = gn (Flux. unsqueeze (state,dims = 2 ), is_sampling = true , is_return_log_prob = true )
236
+ g1 = Flux. gradient (gn) do model
237
+ a, logp = model (Flux. unsqueeze (state,dims = 2 ), is_sampling = true , is_return_log_prob = true )
238
238
ChainRulesCore. ignore_derivatives () do
239
239
push! (action_saver, a)
240
240
end
241
241
mean (logp)
242
242
end
243
- g2 = Flux. gradient (Flux . params ( gn)) do
244
- logp = gn (Flux. unsqueeze (state,dims = 2 ), only (action_saver))
243
+ g2 = Flux. gradient (gn) do model
244
+ logp = model (Flux. unsqueeze (state,dims = 2 ), only (action_saver))
245
245
mean (logp)
246
246
end
247
- for (grad1, grad2) in zip (g,g2)
248
- @test grad1 ≈ grad2
249
- end
247
+ @test g1 == g2
248
+
250
249
empty! (action_saver)
251
- g3 = Flux. gradient (Flux. params (gn)) do
252
- a, logp = gn (Flux. unsqueeze (state,dims = 2 ), 3 )
250
+
251
+ g3 = Flux. gradient (gn) do model
252
+ a, logp = model (Flux. unsqueeze (state,dims = 2 ), is_sampling = true , is_return_log_prob = true )
253
253
ChainRulesCore. ignore_derivatives () do
254
254
push! (action_saver, a)
255
255
end
256
256
mean (logp)
257
257
end
258
- g4 = Flux. gradient (Flux . params ( gn)) do
259
- logp = gn (Flux. unsqueeze (state,dims = 2 ), only (action_saver))
258
+ g4 = Flux. gradient (gn) do model
259
+ logp = model (Flux. unsqueeze (state, dims = 2 ), only (action_saver))
260
260
mean (logp)
261
261
end
262
- for (grad1, grad2) in zip (g4,g3)
263
- @test grad1 ≈ grad2
264
- end
262
+
263
+ @test g4 == g3
265
264
end
266
265
@testset " CUDA" begin
267
266
if (@isdefined CUDA) && CUDA. functional ()
@@ -271,7 +270,6 @@ import ReinforcementLearningBase: RLBase
271
270
μ = Dense (15 ,10 ) |> gpu
272
271
Σ = Dense (15 ,10 * 11 ÷ 2 ) |> gpu
273
272
gn = CovGaussianNetwork (pre, μ, Σ)
274
- @test Flux. params (gn) == Flux. Params ([pre. weight, pre. bias, μ. weight, μ. bias, Σ. weight, Σ. bias])
275
273
state = rand (Float32, 20 ,3 )|> gpu # batch of 3 states
276
274
m, L = gn (Flux. unsqueeze (state,dims = 2 ))
277
275
@test size (m) == (10 ,1 ,3 )
@@ -292,31 +290,31 @@ import ReinforcementLearningBase: RLBase
292
290
logp_truth = [logpdf (mvn, cpu (a)) for (mvn, a) in zip (mvnormals, eachslice (as, dims = 3 ))]
293
291
@test reduce (hcat, collect (logp_truth)) ≈ dropdims (cpu (logps); dims= 1 ) # test against ground truth
294
292
action_saver = []
295
- g = Flux. gradient (Flux . params ( gn)) do
296
- a, logp = gn (rng, Flux. unsqueeze (state,dims = 2 ), is_sampling = true , is_return_log_prob = true )
293
+ g = Flux. gradient (gn) do model
294
+ a, logp = model (rng, Flux. unsqueeze (state,dims = 2 ), is_sampling = true , is_return_log_prob = true )
297
295
ChainRulesCore. ignore_derivatives () do
298
296
push! (action_saver, a)
299
297
end
300
298
mean (logp)
301
299
end
302
300
303
- g2 = Flux. gradient (Flux . params ( gn)) do
304
- logp = gn (Flux. unsqueeze (state,dims = 2 ), only (action_saver))
301
+ g2 = Flux. gradient (gn) do model
302
+ logp = model (Flux. unsqueeze (state,dims = 2 ), only (action_saver))
305
303
mean (logp)
306
304
end
307
305
for (grad1, grad2) in zip (g,g2)
308
306
@test grad1 ≈ grad2
309
307
end
310
308
empty! (action_saver)
311
- g3 = Flux. gradient (Flux . params ( gn)) do
312
- a, logp = gn (rng, Flux. unsqueeze (state,dims = 2 ), 3 )
309
+ g3 = Flux. gradient (gn) do model
310
+ a, logp = model (rng, Flux. unsqueeze (state,dims = 2 ), 3 )
313
311
ChainRulesCore. ignore_derivatives () do
314
312
push! (action_saver, a)
315
313
end
316
314
mean (logp)
317
315
end
318
- g4 = Flux. gradient (Flux . params ( gn)) do
319
- logp = gn (Flux. unsqueeze (state,dims = 2 ), only (action_saver))
316
+ g4 = Flux. gradient (gn) do model
317
+ logp = model (Flux. unsqueeze (state,dims = 2 ), only (action_saver))
320
318
mean (logp)
321
319
end
322
320
for (grad1, grad2) in zip (g4,g3)
0 commit comments