Skip to content

Commit 0eab12c

Browse files
committed
crude visualization for narrow act
1 parent 70880e9 commit 0eab12c

File tree

2 files changed

+179
-0
lines changed

2 files changed

+179
-0
lines changed

model.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,18 @@ def _init_layers(self):
7070
self.output.weight = nn.Parameter(output_weight)
7171
self.output.bias = nn.Parameter(output_bias)
7272

73+
def get_narrow_preactivations(self, x, ls_indices):
74+
# precompute and cache the offset for gathers
75+
if self.idx_offset == None or self.idx_offset.shape[0] != x.shape[0]:
76+
self.idx_offset = torch.arange(0,x.shape[0]*self.count,self.count, device=ls_indices.device)
77+
78+
indices = ls_indices.flatten() + self.idx_offset
79+
80+
l1s_ = self.l1(x).reshape((-1, self.count, L2))
81+
l1f_ = self.l1_fact(x)
82+
l1c_ = l1s_.view(-1, L2)[indices]
83+
return l1c_ + l1f_
84+
7385
def forward(self, x, ls_indices):
7486
# precompute and cache the offset for gathers
7587
if self.idx_offset == None or self.idx_offset.shape[0] != x.shape[0]:
@@ -241,6 +253,23 @@ def set_feature_set(self, new_feature_set):
241253
else:
242254
raise Exception('Cannot change feature set from {} to {}.'.format(self.feature_set.name, new_feature_set.name))
243255

256+
def get_narrow_preactivations(self, us, them, white_indices, white_values, black_indices, black_values, psqt_indices, layer_stack_indices):
257+
wp, bp = self.input(white_indices, white_values, black_indices, black_values)
258+
w, wpsqt = torch.split(wp, L1, dim=1)
259+
b, bpsqt = torch.split(bp, L1, dim=1)
260+
l0_ = (us * torch.cat([w, b], dim=1)) + (them * torch.cat([b, w], dim=1))
261+
# clamp here is used as a clipped relu to (0.0, 1.0)
262+
l0_ = torch.clamp(l0_, 0.0, 1.0)
263+
264+
psqt_indices_unsq = psqt_indices.unsqueeze(dim=1)
265+
wpsqt = wpsqt.gather(1, psqt_indices_unsq)
266+
bpsqt = bpsqt.gather(1, psqt_indices_unsq)
267+
preact = self.layer_stacks.get_narrow_preactivations(l0_, layer_stack_indices)
268+
bucketed_preact = []
269+
for i in range(self.num_ls_buckets):
270+
bucketed_preact.append(torch.masked_select(preact, (layer_stack_indices==i).repeat(preact.shape[1], 1).t()).reshape((-1, L2)))
271+
return bucketed_preact
272+
244273
def forward(self, us, them, white_indices, white_values, black_indices, black_values, psqt_indices, layer_stack_indices):
245274
wp, bp = self.input(white_indices, white_values, black_indices, black_values)
246275
w, wpsqt = torch.split(wp, L1, dim=1)

visualize_narrow_preactivation.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import argparse
2+
import chess
3+
import features
4+
import nnue_dataset
5+
import model as M
6+
import numpy as np
7+
import torch
8+
import matplotlib.pyplot as plt
9+
from matplotlib.gridspec import GridSpec
10+
11+
from serialize import NNUEReader
12+
13+
14+
class NNUEVisualizer():
15+
def __init__(self, model, args):
16+
self.model = model
17+
self.args = args
18+
19+
self.model.cuda()
20+
21+
import matplotlib as mpl
22+
self.dpi = 100
23+
mpl.rcParams["figure.figsize"] = (
24+
self.args.default_width//self.dpi, self.args.default_height//self.dpi)
25+
mpl.rcParams["figure.dpi"] = self.dpi
26+
27+
def _process_fig(self, name, fig=None):
28+
if self.args.save_dir:
29+
from os.path import join
30+
destname = join(
31+
self.args.save_dir, "{}{}.jpg".format("" if self.args.label is None else self.args.label + "_", name))
32+
print("Saving {}".format(destname))
33+
if fig is not None:
34+
fig.savefig(destname)
35+
else:
36+
plt.savefig(destname)
37+
38+
def get_data(self, count, batch_size):
39+
fen_batch_provider = nnue_dataset.FenBatchProvider(self.args.data, True, 1, batch_size, False, 10)
40+
41+
activations_by_bucket = [[] for i in range(self.model.num_ls_buckets)]
42+
i = 0
43+
while i < count:
44+
fens = next(fen_batch_provider)
45+
batch = nnue_dataset.make_sparse_batch_from_fens(self.model.feature_set, fens, [0] * len(fens), [1] * len(fens), [0] * len(fens))
46+
us, them, white_indices, white_values, black_indices, black_values, outcome, score, psqt_indices, layer_stack_indices = batch.contents.get_tensors('cuda')
47+
bucketed_preact = self.model.get_narrow_preactivations(us, them, white_indices, white_values, black_indices, black_values, psqt_indices, layer_stack_indices)
48+
49+
for a, b in zip(activations_by_bucket, bucketed_preact):
50+
a.append(b.cpu().detach().numpy().clip(0, 1))
51+
52+
i += batch_size
53+
print('{}/{}'.format(i, count))
54+
55+
return activations_by_bucket
56+
57+
def plot(self):
58+
bucketed_preact = self.get_data(self.args.count, self.args.batch_size)
59+
for i, d in enumerate(bucketed_preact):
60+
print('Bucket {} has {} entries.'.format(i, sum(a.shape[0] for a in d)))
61+
62+
fig, axs = plt.subplots(M.L2, self.model.num_ls_buckets, sharex=True, sharey=True, figsize=(20, 20), dpi=100)
63+
64+
for bucket_id, preact in enumerate(bucketed_preact):
65+
for i in range(M.L2):
66+
acts = np.concatenate([v[:,i] for v in preact]).flatten()
67+
68+
ax = axs[bucket_id, i]
69+
ax.hist(acts, density=True, log=True, bins=128)
70+
ax.set_xlim([0, 1])
71+
if i == 0:
72+
ax.set_ylabel('Bucket {}'.format(bucket_id))
73+
if bucket_id == 0:
74+
ax.set_xlabel('Neuron {}'.format(i))
75+
ax.xaxis.set_label_position('top')
76+
77+
fig.show()
78+
79+
def load_model(filename, feature_set):
80+
if filename.endswith(".pt") or filename.endswith(".ckpt"):
81+
if filename.endswith(".pt"):
82+
model = torch.load(filename)
83+
else:
84+
model = M.NNUE.load_from_checkpoint(
85+
filename, feature_set=feature_set)
86+
model.eval()
87+
elif filename.endswith(".nnue"):
88+
with open(filename, 'rb') as f:
89+
reader = NNUEReader(f, feature_set)
90+
model = reader.model
91+
else:
92+
raise Exception("Invalid filetype: " + str(filename))
93+
94+
return model
95+
96+
97+
def main():
98+
parser = argparse.ArgumentParser(
99+
description="Visualizes networks in ckpt, pt and nnue format.")
100+
parser.add_argument(
101+
"model", help="Source model (can be .ckpt, .pt or .nnue)")
102+
parser.add_argument(
103+
"--default-width", default=1600, type=int,
104+
help="Default width of all plots (in pixels).")
105+
parser.add_argument(
106+
"--count", default=1000000, type=int,
107+
help="")
108+
parser.add_argument(
109+
"--batch_size", default=5000, type=int,
110+
help="")
111+
parser.add_argument(
112+
"--default-height", default=900, type=int,
113+
help="Default height of all plots (in pixels).")
114+
parser.add_argument(
115+
"--save-dir", type=str, required=False,
116+
help="Save the plots in this directory.")
117+
parser.add_argument(
118+
"--dont-show", action="store_true",
119+
help="Don't show the plots.")
120+
parser.add_argument("--data", type=str, help="path to a .bin or .binpack dataset")
121+
parser.add_argument(
122+
"--label", type=str, required=False,
123+
help="Override the label used in plot titles and as prefix of saved files.")
124+
features.add_argparse_args(parser)
125+
args = parser.parse_args()
126+
127+
supported_features = ('HalfKAv2_hm', 'HalfKAv2_hm^')
128+
assert args.features in supported_features
129+
feature_set = features.get_feature_set_from_name(args.features)
130+
131+
from os.path import basename
132+
label = basename(args.model)
133+
134+
model = load_model(args.model, feature_set)
135+
136+
print("Visualizing {}".format(args.model))
137+
138+
if args.label is None:
139+
args.label = label
140+
141+
visualizer = NNUEVisualizer(model, args)
142+
143+
visualizer.plot()
144+
145+
if not args.dont_show:
146+
plt.show()
147+
148+
149+
if __name__ == '__main__':
150+
main()

0 commit comments

Comments
 (0)