-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrainer_comb.pyx
33 lines (26 loc) · 1.08 KB
/
trainer_comb.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
"""
Combines stored seeds into a multi-parameter set.
Takes the number of phases as the config, splitting the level into the given
number of equal parts (if the phases count does not match the seeds count, the
parameter values will be repeated by the base bot).
"""
from cython import ccall, cclass, returns
from numpy import concatenate, linspace
from trainer_base cimport BaseTrainer
@cclass
class Trainer(BaseTrainer):
def __init__(self, level, config, *args, **kwargs):
super().__init__(level, (), *args, **kwargs)
phase_count = int(config[0]) if config else len(self.seeds)
self.phases = linspace(1 / phase_count, 1, num=phase_count, dtype='f4')
@ccall
@returns('tuple')
def train(self):
bots, histories = zip(*self.seeds)
target_keys = bots[0].shapes(0, 0, 0).keys()
combined_params = {}
for key in target_keys:
params = [b.params[key] for b in bots]
combined_params[key] = concatenate(params, axis=-1)
combined_params['_phases'] = self.phases
return combined_params, [histories]