Skip to content

Commit bd607f6

Browse files
update few spike compiler
1 parent 0ff8c34 commit bd607f6

1 file changed

Lines changed: 7 additions & 4 deletions

File tree

ml_genn/ml_genn/compilers/few_spike_compiler.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ def evaluate_batch_iter(self, inputs, outputs, data: Iterator,
111111
num_batches=num_batches)
112112
callback_list.on_test_begin()
113113

114+
# Create metric state
115+
metric_state = {o: m.create_state() for o, m in metrics.items()}
116+
114117
# Build deque to hold y
115118
y_pipe_queue = {p: deque(maxlen=d + 1)
116119
for p, d in y_pipe_depth.items()}
@@ -168,21 +171,21 @@ def evaluate_batch_iter(self, inputs, outputs, data: Iterator,
168171
batch_y_pred = self.get_readout(o)
169172

170173
# Update metrics
171-
metrics[o].update(batch_y_true,
174+
metrics[o].update(metric_state[o], batch_y_true,
172175
batch_y_pred[:len(batch_y_true)],
173176
self.communicator)
174177

175178
# End batch
176-
callback_list.on_batch_end(batch_i, metrics)
179+
callback_list.on_batch_end(batch_i, metric_state)
177180

178181
# Next batch
179182
batch_i += 1
180183

181184
# End testing
182-
callback_list.on_test_end(metrics)
185+
callback_list.on_test_end(metric_state)
183186

184187
# Return metrics
185-
return metrics, callback_list.get_data()
188+
return metric_state, callback_list.get_data()
186189

187190

188191
# Because we want the converter class to be reusable, we don't want

0 commit comments

Comments
 (0)