-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
218 lines (187 loc) · 9.08 KB
/
evaluate.py
File metadata and controls
218 lines (187 loc) · 9.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
from math import ceil
import jax
from jax import numpy as jnp
import numpy as np
from flax import nnx, struct
from tqdm import tqdm
import json
from scripts.build_arc_dataset import inverse_d8_aug, inverse_colour_aug, crop, grid_hash
@struct.dataclass
class Carry:
z: jax.Array
y: jax.Array
x_input: jax.Array
aug_puzzle_idx: jax.Array
y_true: jax.Array
step: jax.Array
halted: jax.Array
def init_carry(batch, z_init, y_init, seq_len):
"""initialize the carry with the initial data"""
batch_size = batch['x'].shape[0]
hidden_dim = z_init.shape[-1]
z_init = jnp.broadcast_to(z_init, (batch_size, seq_len, hidden_dim))
y_init = jnp.broadcast_to(y_init, (batch_size, seq_len, hidden_dim))
return Carry(
z=z_init, # (batch_size, 901, hidden_dim)
y=y_init, # (batch_size, 901, hidden_dim)
x_input=batch['x'], # (batch_size, 900)
aug_puzzle_idx=batch['aug_puzzle_idx'], # (batch_size,)
y_true=batch['y'], # (batch_size, 900)
step=jnp.zeros((batch_size, ), dtype=jnp.int32), # (batch_size,)
halted=jnp.zeros((batch_size, ), dtype=jnp.bool_) # (batch_size,)
)
@nnx.jit(static_argnames=["N_supervision", "n", "T" ])
def eval_step(model, carry, N_supervision, n, T):
def latent_recursion(model, x, y, z, n):
for _ in range(n):
z = model(x, y, z)
y = model(y, z)
return y, z
x = model.input_embedding(carry.x_input, carry.aug_puzzle_idx)
y, z = carry.y, carry.z
for _ in range(N_supervision):
for _ in range(T):
y, z = latent_recursion(model, x, y, z, n)
y_logits = model.output_head(y)
y_preds = jnp.argmax(y_logits, axis=-1)
cell_correct = y_preds == carry.y_true
puzzle_correct = cell_correct.all(axis=-1, where=carry.y_true < 11)
cell_total = (carry.y_true < 11).sum()
cell_correct_sum = cell_correct.sum(where=carry.y_true < 11)
puzzle_total = (carry.aug_puzzle_idx >= 0).sum()
puzzle_correct_sum = puzzle_correct.sum()
return y_preds, cell_correct_sum, cell_total, puzzle_correct_sum, puzzle_total
def get_top_k_preds(example_preds, k):
# example_preds is a dictionary of predictions and their counts
# return the top k predictions
example_preds = dict(sorted(example_preds.items(), key=lambda x: x[1], reverse=True))
top_k = []
current_k = 0
for pred, count in example_preds.items():
top_k.append(pred)
current_k += count
if current_k >= k:
break
return top_k
def evaluate(model, data_loader_factory, y_init, z_init, N_supervision, n, T, pass_ks, shard_data, batch_size, seq_len, input_size):
# preds = {
# "abcde1g7": {
# "0": {
# "y_true": ...,
# "y_preds": {
# "pred_1": count,
# "pred_2": count,
# ...
# "pred_n": count
# }
# }
# }
preds = {}
cell_corrects = 0
cell_totals = 0
puzzle_corrects = 0
puzzle_totals = 0
data_loader = data_loader_factory()
num_batches = ceil(len(data_loader._data_source) / batch_size)
per_process_batch_size = batch_size // jax.process_count()
last_batch = False
for batch in tqdm(data_loader, desc="evaluating", total=num_batches):
if batch['x'].shape[0] < per_process_batch_size:
last_batch = True
last_batch_size = batch['x'].shape[0]
padding_size = per_process_batch_size - batch['x'].shape[0]
# Keep padding on host (NumPy). If we use `jnp.pad` here, we can end up
# creating multi-host global arrays under the mesh context, which then
# cannot be re-sharded via `make_array_from_process_local_data`.
batch['x'] = np.pad(batch['x'], ((0, padding_size), (0, 0)), mode='constant', constant_values=11)
batch['y'] = np.pad(batch['y'], ((0, padding_size), (0, 0)), mode='constant', constant_values=11)
batch['aug_puzzle_idx'] = np.pad(batch['aug_puzzle_idx'], (0, padding_size), mode='constant', constant_values=-1)
batch['puzzle_idx'] = np.pad(batch['puzzle_idx'], ((0, padding_size), (0, 0)), mode='constant', constant_values=-1)
batch['example_idx'] = np.pad(batch['example_idx'], ((0, padding_size), (0, 0)), mode='constant', constant_values=-1)
batch['d8_aug'] = np.pad(batch['d8_aug'], ((0, padding_size), (0, 0)), mode='constant', constant_values=-1)
batch['colour_aug'] = np.pad(batch['colour_aug'], ((0, padding_size), (0, 0)), mode='constant', constant_values=-1)
batch = shard_data(batch)
carry = init_carry(batch, z_init, y_init, seq_len)
y_preds, cell_correct_sum, cell_total, puzzle_correct_sum, puzzle_total = eval_step(model, carry, N_supervision, n, T)
cell_corrects += cell_correct_sum
cell_totals += cell_total
puzzle_corrects += puzzle_correct_sum
puzzle_totals += puzzle_total
y_preds = jax.experimental.multihost_utils.process_allgather(y_preds, tiled=True)
y_trues = jax.experimental.multihost_utils.process_allgather(batch['y'], tiled=True)
puzzle_idxs = jax.experimental.multihost_utils.process_allgather(batch['puzzle_idx'], tiled=True)
aug_puzzle_idxs = jax.experimental.multihost_utils.process_allgather(batch['aug_puzzle_idx'], tiled=True)
example_idxs = jax.experimental.multihost_utils.process_allgather(batch['example_idx'], tiled=True)
d8_augs = jax.experimental.multihost_utils.process_allgather(batch['d8_aug'], tiled=True)
colour_augs = jax.experimental.multihost_utils.process_allgather(batch['colour_aug'], tiled=True)
if last_batch:
y_preds = y_preds[:last_batch_size]
y_trues = y_trues[:last_batch_size]
puzzle_idxs = puzzle_idxs[:last_batch_size]
aug_puzzle_idxs = aug_puzzle_idxs[:last_batch_size]
example_idxs = example_idxs[:last_batch_size]
d8_augs = d8_augs[:last_batch_size]
colour_augs = colour_augs[:last_batch_size]
y_preds = np.array(y_preds.reshape(-1, input_size, input_size))
y_trues = np.array(y_trues.reshape(-1, input_size, input_size))
for i in range(y_preds.shape[0]):
# Unwrap scalars from batched fields
puzzle_idx = int(puzzle_idxs[i][0])
example_idx = int(example_idxs[i][0])
d8_aug = int(d8_augs[i][0])
colour_aug = colour_augs[i]
y_pred = y_preds[i]
aug_puzzle_idx = int(aug_puzzle_idxs[i])
if jax.process_index() == 0:
with open("preds.jsonl", "a") as f:
json.dump({"aug_puzzle_idx": aug_puzzle_idx, "example_idx": example_idx, "y_pred": y_pred.tolist()}, f)
f.write("\n")
y_pred = crop(y_pred)
y_pred = grid_hash(inverse_d8_aug(inverse_colour_aug(y_pred, colour_aug), d8_aug))
if puzzle_idx not in preds:
preds[puzzle_idx] = {}
if example_idx not in preds[puzzle_idx]:
preds[puzzle_idx][example_idx] = {"y_true": None, "y_preds": dict()}
y_true = y_trues[i]
y_true = crop(y_true)
y_true = grid_hash(inverse_d8_aug(inverse_colour_aug(y_true, colour_aug), d8_aug))
preds[puzzle_idx][example_idx]['y_true'] = y_true
preds[puzzle_idx][example_idx]['y_preds'][y_pred] = preds[puzzle_idx][example_idx]['y_preds'].get(y_pred, 0) + 1
# passes = {
# "abcde1g7": {
# k_1: [True, False],
# k_2: [True, False],
# ...
# k_n: [True, False]
# }
# }
passes = {}
for puzzle_idx, data in tqdm(preds.items(), desc="computing passes"):
for example_idx, example in data.items():
y_true = example['y_true']
for k in pass_ks:
top_k_preds = get_top_k_preds(example['y_preds'], k)
if puzzle_idx not in passes:
passes[puzzle_idx] = {}
if k not in passes[puzzle_idx]:
passes[puzzle_idx][k] = []
passes[puzzle_idx][k].append(y_true in top_k_preds)
# passes_reduced = {
# k_1: n_true,
# k_2: n_true,
# ...
# k_n: n_true
# }
passes_reduced = {}
for puzzle_idx, ks in tqdm(passes.items(), desc="computing passes reduced"):
for k, vs in ks.items():
passes_reduced[k] = passes_reduced.get(k, 0) + int(all(vs))
n_puzzles = len(passes)
passes_reduced = {f"pass_{k}": n_true / n_puzzles for k, n_true in passes_reduced.items()}
print(f"{cell_corrects=}, {cell_totals=}, {puzzle_corrects=}, {puzzle_totals=}, {n_puzzles=}")
metrics = {
**passes_reduced,
"cell_acc": cell_corrects / cell_totals,
"puzzle_acc": puzzle_corrects / puzzle_totals,
}
return metrics