Skip to content

Commit ebebf74

Browse files
author
Alexander Ororbia
committed
Merge branch 'v3' of github.com:NACLab/ngc-learn into v3
2 parents 85d1282 + b0db87f commit ebebf74

File tree

2 files changed

+34
-67
lines changed

2 files changed

+34
-67
lines changed

ngclearn/components/synapses/modulated/REINFORCESynapse.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# %%
22

33
from jax import random, numpy as jnp, jit
4-
from ngcsimlib.logger import info
5-
from ngcsimlib.compartment import Compartment
6-
from ngcsimlib.parser import compilable
4+
from ngclearn import compilable, Compartment
5+
76
from ngclearn.utils.model_utils import clip, d_clip
87
import jax
98
import jax.numpy as jnp
@@ -194,20 +193,19 @@ def evolve(self, dt):
194193
# @transition(output_compartments=["inputs", "outputs", "objective", "rewards", "dWeights", "accumulated_gradients", "step_count", "seed"])
195194
# @staticmethod
196195
@compilable
197-
def reset(self, batch_size, shape):
198-
preVals = jnp.zeros((batch_size, shape[0]))
199-
postVals = jnp.zeros((batch_size, shape[1]))
196+
def reset(self):
197+
preVals = jnp.zeros((self.batch_size, self.shape[0]))
198+
postVals = jnp.zeros((self.batch_size, self.shape[1]))
200199
inputs = preVals
201200
outputs = postVals
202201
objective = jnp.zeros(())
203-
rewards = jnp.zeros((batch_size,))
204-
dWeights = jnp.zeros(shape)
205-
accumulated_gradients = jnp.zeros((shape[0], shape[1] * 2))
202+
rewards = jnp.zeros((self.batch_size,))
203+
dWeights = jnp.zeros(self.shape)
204+
accumulated_gradients = jnp.zeros((self.shape[0], self.shape[1] * 2))
206205
step_count = jnp.zeros(())
207206
seed = jax.random.PRNGKey(42)
208207

209-
210-
not self.inputs.targeted and self.inputs.set(inputs)
208+
hasattr(self.inputs, 'targeted') and not self.inputs.targeted and self.inputs.set(inputs)
211209
self.outputs.set(outputs)
212210
self.objective.set(objective)
213211
self.rewards.set(rewards)
@@ -260,20 +258,6 @@ def help(cls): ## component help function
260258
"hyperparameters": hyperparams}
261259
return info
262260

263-
def __repr__(self):
264-
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
265-
maxlen = max(len(c) for c in comps) + 5
266-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
267-
for c in comps:
268-
stats = tensorstats(getattr(self, c).get())
269-
if stats is not None:
270-
line = [f"{k}: {v}" for k, v in stats.items()]
271-
line = ", ".join(line)
272-
else:
273-
line = "None"
274-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
275-
return lines
276-
277261

278262
if __name__ == '__main__':
279263
from ngcsimlib.context import Context

tests/components/synapses/modulated/test_REINFORCESynapse.py

Lines changed: 25 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,12 @@
22

33
import jax
44
from jax import numpy as jnp, random, jit
5-
from ngcsimlib.context import Context
65
import numpy as np
76
np.random.seed(42)
87
from ngclearn.components.synapses.modulated.REINFORCESynapse import REINFORCESynapse, gaussian_logpdf
9-
from ngcsimlib.compilers import compile_command, wrap_command
108
from numpy.testing import assert_array_equal
119

12-
from ngcsimlib.compilers.process import Process, transition
13-
from ngcsimlib.component import Component
14-
from ngcsimlib.compartment import Compartment
15-
from ngcsimlib.context import Context
10+
from ngclearn import Context, MethodProcess
1611

1712
import jax
1813
import jax.numpy as jnp
@@ -39,22 +34,16 @@ def test_REINFORCESynapse1():
3934
scalar_stddev=-1.0
4035
)
4136

42-
evolve_process = (Process("evolve_proc") >> a.evolve)
43-
ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
37+
evolve_process = (MethodProcess("evolve_proc") >> a.evolve)
38+
reset_process = (MethodProcess("reset_proc") >> a.reset)
4439

45-
reset_process = (Process("reset_proc") >> a.reset)
46-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
47-
48-
@Context.dynamicCommand
4940
def clamp_inputs(x):
5041
a.inputs.set(x)
5142

52-
@Context.dynamicCommand
5343
def clamp_rewards(x):
5444
assert x.ndim == 1, "Rewards must be a 1D array"
5545
a.rewards.set(x)
5646

57-
@Context.dynamicCommand
5847
def clamp_weights(x):
5948
a.weights.set(x)
6049

@@ -80,7 +69,7 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
8069
expected_weights = jnp.concatenate([expected_weights_mu, expected_weights_logstd], axis=-1)
8170
initial_ngclearn_weights = jnp.concatenate([expected_weights_mu, expected_weights_logstd], axis=-1)[None]
8271
expected_gradient_list = []
83-
ctx.reset()
72+
reset_process.run()
8473

8574
# Loop through 3 steps
8675
for step in range(10):
@@ -94,12 +83,12 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
9483
clamp_weights(initial_ngclearn_weights)
9584
clamp_rewards(outputs)
9685
clamp_inputs(inputs)
97-
ctx.adapt(t=1., dt=dt)
98-
print(f"[ngclearn] objective: {a.objective.value}")
99-
print(f"[ngclearn] weights: {a.weights.value}")
100-
print(f"[ngclearn] dWeights: {a.dWeights.value}")
101-
print(f"[ngclearn] step_count: {a.step_count.value}")
102-
print(f"[ngclearn] accumulated_gradients: {a.accumulated_gradients.value}")
86+
evolve_process.run(t=1., dt=dt)
87+
print(f"[ngclearn] objective: {a.objective.get()}")
88+
print(f"[ngclearn] weights: {a.weights.get()}")
89+
print(f"[ngclearn] dWeights: {a.dWeights.get()}")
90+
print(f"[ngclearn] step_count: {a.step_count.get()}")
91+
print(f"[ngclearn] accumulated_gradients: {a.accumulated_gradients.get()}")
10392
# -------- Expectation ---------
10493
print("--------------")
10594
expected_objective, expected_grads = grad_fn(
@@ -116,12 +105,12 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
116105
print(f"[Expectation] dWeights: {expected_grads}")
117106
print(f"[Expectation] objective: {expected_objective}")
118107
np.testing.assert_allclose(
119-
a.dWeights.value[0],
108+
a.dWeights.get()[0],
120109
expected_grads,
121110
atol=1e-8
122111
)
123112
np.testing.assert_allclose(
124-
a.objective.value,
113+
a.objective.get(),
125114
expected_objective,
126115
atol=1e-8
127116
)
@@ -131,7 +120,7 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
131120
decay_list = jnp.asarray([decay**i for i in range(len(expected_gradient_list))])[::-1]
132121
expected_accumulated_gradients = jnp.mean(jnp.stack(expected_gradient_list, 0) * decay_list[:, None, None], axis=0)
133122
np.testing.assert_allclose(
134-
a.accumulated_gradients.value[0],
123+
a.accumulated_gradients.get()[0],
135124
expected_accumulated_gradients,
136125
atol=1e-9
137126
)
@@ -163,22 +152,16 @@ def test_REINFORCESynapse2():
163152
scalar_stddev=scalar_stddev
164153
)
165154

166-
evolve_process = (Process("evolve_proc") >> a.evolve)
167-
ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
168-
169-
reset_process = (Process("reset_proc") >> a.reset)
170-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
155+
evolve_process = (MethodProcess("evolve_proc") >> a.evolve)
156+
reset_process = (MethodProcess("reset_proc") >> a.reset)
171157

172-
@Context.dynamicCommand
173158
def clamp_inputs(x):
174159
a.inputs.set(x)
175160

176-
@Context.dynamicCommand
177161
def clamp_rewards(x):
178162
assert x.ndim == 1, "Rewards must be a 1D array"
179163
a.rewards.set(x)
180164

181-
@Context.dynamicCommand
182165
def clamp_weights(x):
183166
a.weights.set(x)
184167

@@ -205,7 +188,7 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
205188
expected_weights = jnp.concatenate([expected_weights_mu, expected_weights_logstd], axis=-1)
206189
initial_ngclearn_weights = jnp.concatenate([expected_weights_mu, expected_weights_logstd], axis=-1)[None]
207190
expected_gradient_list = []
208-
ctx.reset()
191+
reset_process.run()
209192

210193
# Loop through 3 steps
211194
for step in range(10):
@@ -219,12 +202,12 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
219202
clamp_weights(initial_ngclearn_weights)
220203
clamp_rewards(outputs)
221204
clamp_inputs(inputs)
222-
ctx.adapt(t=1., dt=dt)
223-
print(f"[ngclearn] objective: {a.objective.value}")
224-
print(f"[ngclearn] weights: {a.weights.value}")
225-
print(f"[ngclearn] dWeights: {a.dWeights.value}")
226-
print(f"[ngclearn] step_count: {a.step_count.value}")
227-
print(f"[ngclearn] accumulated_gradients: {a.accumulated_gradients.value}")
205+
evolve_process.run(t=1., dt=dt)
206+
print(f"[ngclearn] objective: {a.objective.get()}")
207+
print(f"[ngclearn] weights: {a.weights.get()}")
208+
print(f"[ngclearn] dWeights: {a.dWeights.get()}")
209+
print(f"[ngclearn] step_count: {a.step_count.get()}")
210+
print(f"[ngclearn] accumulated_gradients: {a.accumulated_gradients.get()}")
228211
# -------- Expectation ---------
229212
print("--------------")
230213
expected_objective, expected_grads = grad_fn(
@@ -241,12 +224,12 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
241224
print(f"[Expectation] dWeights: {expected_grads}")
242225
print(f"[Expectation] objective: {expected_objective}")
243226
np.testing.assert_allclose(
244-
a.dWeights.value[0],
227+
a.dWeights.get()[0],
245228
expected_grads,
246229
atol=1e-8
247230
)
248231
np.testing.assert_allclose(
249-
a.objective.value,
232+
a.objective.get(),
250233
expected_objective,
251234
atol=1e-8
252235
)
@@ -256,7 +239,7 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
256239
decay_list = jnp.asarray([decay**i for i in range(len(expected_gradient_list))])[::-1]
257240
expected_accumulated_gradients = jnp.mean(jnp.stack(expected_gradient_list, 0) * decay_list[:, None, None], axis=0)
258241
np.testing.assert_allclose(
259-
a.accumulated_gradients.value[0],
242+
a.accumulated_gradients.get()[0],
260243
expected_accumulated_gradients,
261244
atol=1e-9
262245
)

0 commit comments

Comments
 (0)