-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathnatural_niches_fn.py
More file actions
148 lines (127 loc) · 4.77 KB
/
natural_niches_fn.py
File metadata and controls
148 lines (127 loc) · 4.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import jax
import jax.numpy as jnp
from collections import defaultdict
from tqdm import tqdm
from model import num_params
from data import load_data
from model import mlp, get_acc
from helper_fn import crossover, crossover_without_splitpoint, mutate, get_pre_trained_models
def sample_parents(
archive: jnp.ndarray,
scores: jnp.ndarray,
rand_key: jnp.ndarray,
alpha: float,
use_matchmaker: bool,
) -> tuple[jnp.ndarray, jnp.ndarray]:
k1, k2 = jax.random.split(rand_key)
z = scores.sum(axis=0)
z = jnp.where(z, z, 1) ** alpha
fitness_matrix = scores / z[None, :]
fitness = jnp.sum(fitness_matrix, axis=1)
probs = fitness / jnp.sum(fitness)
# first parent
if use_matchmaker:
parent_1_idx = jax.random.choice(k1, probs.size, shape=(1,), p=probs)[0]
# second parent
match_score = jnp.maximum(
0, fitness_matrix - fitness_matrix[parent_1_idx, :]
).sum(axis=1)
probs = match_score / jnp.sum(match_score)
parent_2_idx = jax.random.choice(k2, probs.size, shape=(1,), p=probs)[0]
else:
parent_2_idx, parent_1_idx = jax.random.choice(
k1, probs.size, shape=(2,), p=probs
)
return archive[parent_1_idx], archive[parent_2_idx]
@jax.jit
def update_archive(
score: jnp.ndarray,
param: jnp.ndarray,
archive: jnp.ndarray,
scores: jnp.ndarray,
alpha: float,
) -> tuple[jnp.ndarray, jnp.ndarray]:
ext_scores = jnp.concatenate(
[scores, score[None, ...]], axis=0
) # (pop_size + 1, num_datapoints)
z = jnp.sum(ext_scores, axis=0) ** alpha # (num_datapoints,)
# avoid div by zero
z = jnp.where(z, z, 1)
ext_scores /= z[None, :]
fitness = jnp.sum(ext_scores, axis=1) # (pop_size + 1,)
# get worst performing
worst_ix = jnp.argmin(fitness)
update_mask = worst_ix < scores.shape[0]
scores = scores.at[worst_ix].set(
jax.lax.select(update_mask, score, scores[worst_ix])
)
archive = archive.at[worst_ix].set(
jax.lax.select(update_mask, param, archive[worst_ix])
)
return archive, scores
def run_natural_niches(
runs: int,
pop_size: int,
total_forward_passes: int,
store_train_results: bool,
no_matchmaker: bool,
no_crossover: bool,
no_splitpoint: bool,
use_pre_trained: bool,
alpha: float = 1.0,
) -> list:
(x_train, y_train), (x_test, y_test) = load_data()
use_matchmaker, use_crossover, use_splitpoint = (
not no_matchmaker,
not no_crossover,
not no_splitpoint,
)
results = []
if use_pre_trained:
model_1, model_2 = get_pre_trained_models()
for run in tqdm(range(runs), desc="Runs"):
results.append(defaultdict(list))
result = results[-1]
seed = 42 + run
key = jax.random.PRNGKey(seed)
# initialization
archive = jnp.zeros([pop_size, num_params])
scores = jnp.zeros([pop_size, len(x_train)], dtype=jnp.bool)
if not use_pre_trained:
# random initialise two models and place them in the archive
key, key1, key2 = jax.random.split(key, 3)
model_1 = jax.random.normal(key1, (num_params,)) * 0.01
model_2 = jax.random.normal(key2, (num_params,)) * 0.01
for model in (model_1, model_2):
logits = mlp(model, x_train)
score = jnp.argmax(logits, axis=1) == y_train
archive, scores = update_archive(score, model, archive, scores, alpha)
for i in tqdm(range(total_forward_passes), desc="Forward passes"):
k1, k2, k3, key = jax.random.split(key, 4)
parents = sample_parents(archive, scores, k1, alpha, use_matchmaker)
if use_crossover:
if use_splitpoint:
child = crossover(parents, k2)
else:
child = crossover_without_splitpoint(parents, k2)
else:
child = parents[0]
if not use_pre_trained: # mutate only if starting from scratch
child = mutate(child, k3)
logits = mlp(child, x_train)
score = jnp.argmax(logits, axis=1) == y_train
archive, scores = update_archive(score, child, archive, scores, alpha)
# log results
result["evals"].append(i)
train_acc = scores.mean(axis=1)
if store_train_results:
# best train acc
result["train_values"].append(train_acc.max())
# test acc
best_individual = jnp.argmax(train_acc)
logits = mlp(archive[best_individual], x_test)
acc = get_acc(logits, y_test)
result["test_values"].append(acc)
if i % 1000 == 0:
print(f"Run {run}, Forward pass {i}, Test acc: {acc}")
return results