22
33import jax
44from jax import numpy as jnp , random , jit
5- from ngcsimlib .context import Context
65import numpy as np
76np .random .seed (42 )
87from ngclearn .components .synapses .modulated .REINFORCESynapse import REINFORCESynapse , gaussian_logpdf
9- from ngcsimlib .compilers import compile_command , wrap_command
108from numpy .testing import assert_array_equal
119
12- from ngcsimlib .compilers .process import Process , transition
13- from ngcsimlib .component import Component
14- from ngcsimlib .compartment import Compartment
15- from ngcsimlib .context import Context
10+ from ngclearn import Context , MethodProcess
1611
1712import jax
1813import jax .numpy as jnp
@@ -39,22 +34,16 @@ def test_REINFORCESynapse1():
3934 scalar_stddev = - 1.0
4035 )
4136
42- evolve_process = (Process ("evolve_proc" ) >> a .evolve )
43- ctx . wrap_and_add_command ( jit ( evolve_process . pure ), name = "adapt" )
37+ evolve_process = (MethodProcess ("evolve_proc" ) >> a .evolve )
38+ reset_process = ( MethodProcess ( "reset_proc" ) >> a . reset )
4439
45- reset_process = (Process ("reset_proc" ) >> a .reset )
46- ctx .wrap_and_add_command (jit (reset_process .pure ), name = "reset" )
47-
48- @Context .dynamicCommand
4940 def clamp_inputs (x ):
5041 a .inputs .set (x )
5142
52- @Context .dynamicCommand
5343 def clamp_rewards (x ):
5444 assert x .ndim == 1 , "Rewards must be a 1D array"
5545 a .rewards .set (x )
5646
57- @Context .dynamicCommand
5847 def clamp_weights (x ):
5948 a .weights .set (x )
6049
@@ -80,7 +69,7 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
8069 expected_weights = jnp .concatenate ([expected_weights_mu , expected_weights_logstd ], axis = - 1 )
8170 initial_ngclearn_weights = jnp .concatenate ([expected_weights_mu , expected_weights_logstd ], axis = - 1 )[None ]
8271 expected_gradient_list = []
83- ctx . reset ()
72+ reset_process . run ()
8473
8574 # Loop through 3 steps
8675 for step in range (10 ):
@@ -94,12 +83,12 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
9483 clamp_weights (initial_ngclearn_weights )
9584 clamp_rewards (outputs )
9685 clamp_inputs (inputs )
97- ctx . adapt (t = 1. , dt = dt )
98- print (f"[ngclearn] objective: { a .objective .value } " )
99- print (f"[ngclearn] weights: { a .weights .value } " )
100- print (f"[ngclearn] dWeights: { a .dWeights .value } " )
101- print (f"[ngclearn] step_count: { a .step_count .value } " )
102- print (f"[ngclearn] accumulated_gradients: { a .accumulated_gradients .value } " )
86+ evolve_process . run (t = 1. , dt = dt )
87+ print (f"[ngclearn] objective: { a .objective .get () } " )
88+ print (f"[ngclearn] weights: { a .weights .get () } " )
89+ print (f"[ngclearn] dWeights: { a .dWeights .get () } " )
90+ print (f"[ngclearn] step_count: { a .step_count .get () } " )
91+ print (f"[ngclearn] accumulated_gradients: { a .accumulated_gradients .get () } " )
10392 # -------- Expectation ---------
10493 print ("--------------" )
10594 expected_objective , expected_grads = grad_fn (
@@ -116,12 +105,12 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
116105 print (f"[Expectation] dWeights: { expected_grads } " )
117106 print (f"[Expectation] objective: { expected_objective } " )
118107 np .testing .assert_allclose (
119- a .dWeights .value [0 ],
108+ a .dWeights .get () [0 ],
120109 expected_grads ,
121110 atol = 1e-8
122111 )
123112 np .testing .assert_allclose (
124- a .objective .value ,
113+ a .objective .get () ,
125114 expected_objective ,
126115 atol = 1e-8
127116 )
@@ -131,13 +120,13 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
131120 decay_list = jnp .asarray ([decay ** i for i in range (len (expected_gradient_list ))])[::- 1 ]
132121 expected_accumulated_gradients = jnp .mean (jnp .stack (expected_gradient_list , 0 ) * decay_list [:, None , None ], axis = 0 )
133122 np .testing .assert_allclose (
134- a .accumulated_gradients .value [0 ],
123+ a .accumulated_gradients .get () [0 ],
135124 expected_accumulated_gradients ,
136125 atol = 1e-9
137126 )
138127
139128
140- # test_REINFORCESynapse1()
129+ test_REINFORCESynapse1 ()
141130
142131
143132def test_REINFORCESynapse2 ():
@@ -163,22 +152,16 @@ def test_REINFORCESynapse2():
163152 scalar_stddev = scalar_stddev
164153 )
165154
166- evolve_process = (Process ("evolve_proc" ) >> a .evolve )
167- ctx .wrap_and_add_command (jit (evolve_process .pure ), name = "adapt" )
168-
169- reset_process = (Process ("reset_proc" ) >> a .reset )
170- ctx .wrap_and_add_command (jit (reset_process .pure ), name = "reset" )
155+ evolve_process = (MethodProcess ("evolve_proc" ) >> a .evolve )
156+ reset_process = (MethodProcess ("reset_proc" ) >> a .reset )
171157
172- @Context .dynamicCommand
173158 def clamp_inputs (x ):
174159 a .inputs .set (x )
175160
176- @Context .dynamicCommand
177161 def clamp_rewards (x ):
178162 assert x .ndim == 1 , "Rewards must be a 1D array"
179163 a .rewards .set (x )
180164
181- @Context .dynamicCommand
182165 def clamp_weights (x ):
183166 a .weights .set (x )
184167
@@ -205,7 +188,7 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
205188 expected_weights = jnp .concatenate ([expected_weights_mu , expected_weights_logstd ], axis = - 1 )
206189 initial_ngclearn_weights = jnp .concatenate ([expected_weights_mu , expected_weights_logstd ], axis = - 1 )[None ]
207190 expected_gradient_list = []
208- ctx . reset ()
191+ reset_process . run ()
209192
210193 # Loop through 3 steps
211194 for step in range (10 ):
@@ -219,12 +202,12 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
219202 clamp_weights (initial_ngclearn_weights )
220203 clamp_rewards (outputs )
221204 clamp_inputs (inputs )
222- ctx . adapt (t = 1. , dt = dt )
223- print (f"[ngclearn] objective: { a .objective .value } " )
224- print (f"[ngclearn] weights: { a .weights .value } " )
225- print (f"[ngclearn] dWeights: { a .dWeights .value } " )
226- print (f"[ngclearn] step_count: { a .step_count .value } " )
227- print (f"[ngclearn] accumulated_gradients: { a .accumulated_gradients .value } " )
205+ evolve_process . run (t = 1. , dt = dt )
206+ print (f"[ngclearn] objective: { a .objective .get () } " )
207+ print (f"[ngclearn] weights: { a .weights .get () } " )
208+ print (f"[ngclearn] dWeights: { a .dWeights .get () } " )
209+ print (f"[ngclearn] step_count: { a .step_count .get () } " )
210+ print (f"[ngclearn] accumulated_gradients: { a .accumulated_gradients .get () } " )
228211 # -------- Expectation ---------
229212 print ("--------------" )
230213 expected_objective , expected_grads = grad_fn (
@@ -241,12 +224,12 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
241224 print (f"[Expectation] dWeights: { expected_grads } " )
242225 print (f"[Expectation] objective: { expected_objective } " )
243226 np .testing .assert_allclose (
244- a .dWeights .value [0 ],
227+ a .dWeights .get () [0 ],
245228 expected_grads ,
246229 atol = 1e-8
247230 )
248231 np .testing .assert_allclose (
249- a .objective .value ,
232+ a .objective .get () ,
250233 expected_objective ,
251234 atol = 1e-8
252235 )
@@ -256,10 +239,10 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
256239 decay_list = jnp .asarray ([decay ** i for i in range (len (expected_gradient_list ))])[::- 1 ]
257240 expected_accumulated_gradients = jnp .mean (jnp .stack (expected_gradient_list , 0 ) * decay_list [:, None , None ], axis = 0 )
258241 np .testing .assert_allclose (
259- a .accumulated_gradients .value [0 ],
242+ a .accumulated_gradients .get () [0 ],
260243 expected_accumulated_gradients ,
261244 atol = 1e-9
262245 )
263246
264- # test_REINFORCESynapse2()
247+ test_REINFORCESynapse2 ()
265248
0 commit comments