Skip to content

Commit 5d297f1

Browse files
authored
Merge pull request #80 from calad0i/HGQ
Add a minimal HGQ Example
2 parents c3da9b4 + 95bf226 commit 5d297f1

File tree

2 files changed

+864
-0
lines changed

2 files changed

+864
-0
lines changed

nn_utils.py

Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
import json
2+
import os
3+
import pickle as pkl
4+
import random
5+
from io import BytesIO
6+
from pathlib import Path
7+
from typing import Callable
8+
9+
import h5py as h5
10+
import numpy as np
11+
import tensorflow as tf
12+
import zstd
13+
from HGQ.bops import trace_minmax
14+
from keras.layers import Dense
15+
from keras.src.layers.convolutional.base_conv import Conv
16+
from keras.src.saving.legacy import hdf5_format
17+
from matplotlib import pyplot as plt
18+
from tensorflow import keras
19+
from tqdm.auto import tqdm
20+
21+
22+
class NumpyFloatValuesEncoder(json.JSONEncoder):
23+
def default(self, obj):
24+
if isinstance(obj, np.float32): # type: ignore
25+
return float(obj)
26+
return json.JSONEncoder.default(self, obj)
27+
28+
29+
class SaveTopN(keras.callbacks.Callback):
30+
def __init__(
31+
self,
32+
metric_fn: Callable[[dict], float],
33+
n: int,
34+
path: str | Path,
35+
side: str = 'max',
36+
fname_format='epoch={epoch}-metric={metric:.4e}.h5',
37+
cond_fn: Callable[[dict], bool] = lambda x: True,
38+
):
39+
self.n = n
40+
self.metric_fn = metric_fn
41+
self.path = Path(path)
42+
self.fname_format = fname_format
43+
os.makedirs(path, exist_ok=True)
44+
self.weight_paths = np.full(n, '/dev/null', dtype=object)
45+
if side == 'max':
46+
self.best = np.full(n, -np.inf)
47+
self.side = np.greater
48+
elif side == 'min':
49+
self.best = np.full(n, np.inf)
50+
self.side = np.less
51+
self.cond = cond_fn
52+
53+
def on_epoch_end(self, epoch, logs=None):
54+
assert isinstance(logs, dict)
55+
assert isinstance(self.model, keras.models.Model)
56+
logs = logs.copy()
57+
logs['epoch'] = epoch
58+
if not self.cond(logs):
59+
return
60+
metric = self.metric_fn(logs)
61+
62+
if self.side(metric, self.best[-1]):
63+
try:
64+
os.remove(self.weight_paths[-1])
65+
except OSError:
66+
pass
67+
logs['metric'] = metric
68+
fname = self.path / self.fname_format.format(**logs)
69+
self.best[-1] = metric
70+
self.weight_paths[-1] = fname
71+
self.model.save_weights(fname)
72+
with h5.File(fname, 'r+') as f:
73+
log_str = json.dumps(logs, cls=NumpyFloatValuesEncoder)
74+
f.attrs['train_log'] = log_str
75+
idx = np.argsort(self.best)
76+
if self.side == np.greater:
77+
idx = idx[::-1]
78+
self.best = self.best[idx]
79+
self.weight_paths = self.weight_paths[idx]
80+
81+
def rename_ckpts(self, dataset, bsz=65536):
82+
assert self.weight_paths[0] != '/dev/null', 'No checkpoints to rename'
83+
assert isinstance(self.model, keras.models.Model)
84+
85+
weight_buf = BytesIO()
86+
with h5.File(weight_buf, 'w') as f:
87+
hdf5_format.save_weights_to_hdf5_group(f, self.model)
88+
weight_buf.seek(0)
89+
90+
for i, path in enumerate(tqdm(self.weight_paths, desc='Renaming checkpoints')):
91+
if path == '/dev/null':
92+
continue
93+
self.model.load_weights(path)
94+
bops = trace_minmax(self.model, dataset, bsz=bsz, verbose=False)
95+
with h5.File(path, 'r+') as f:
96+
logs = json.loads(f.attrs['train_log']) # type: ignore
97+
logs['bops'] = bops
98+
metric = self.metric_fn(logs)
99+
logs['metric'] = metric
100+
f.attrs['train_log'] = json.dumps(logs, cls=NumpyFloatValuesEncoder)
101+
self.best[i] = metric
102+
new_fname = self.path / self.fname_format.format(**logs)
103+
os.rename(path, new_fname)
104+
self.weight_paths[i] = new_fname
105+
106+
idx = np.argsort(self.best)
107+
self.best = self.best[idx]
108+
self.weight_paths = self.weight_paths[idx]
109+
with h5.File(weight_buf, 'r') as f:
110+
hdf5_format.load_weights_from_hdf5_group_by_name(f, self.model)
111+
112+
113+
class PBarCallback(tf.keras.callbacks.Callback):
114+
def __init__(self, metric='loss: {loss:.2f}/{val_loss:.2f}'):
115+
self.pbar = None
116+
self.template = metric
117+
118+
def on_epoch_begin(self, epoch, logs=None):
119+
if self.pbar is None:
120+
self.pbar = tqdm(total=self.params['epochs'], unit='epoch')
121+
122+
def on_epoch_end(self, epoch, logs=None):
123+
assert isinstance(self.pbar, tqdm)
124+
assert isinstance(logs, dict)
125+
self.pbar.update(1)
126+
string = self.template.format(**logs)
127+
if 'bops' in logs:
128+
string += f' - BOPs: {logs["bops"]:,.0f}'
129+
self.pbar.set_description(string)
130+
131+
def on_train_end(self, logs=None):
132+
if self.pbar is not None:
133+
self.pbar.close()
134+
135+
136+
def plot_history(histry: dict, metrics=('loss', 'val_loss'), ylabel='Loss', logy=False):
137+
fig, ax = plt.subplots()
138+
for metric in metrics:
139+
ax.plot(histry[metric], label=metric)
140+
ax.set_xlabel('Epoch')
141+
ax.set_ylabel(ylabel)
142+
if logy:
143+
ax.set_yscale('log')
144+
ax.legend()
145+
return fig, ax
146+
147+
148+
def save_model(model: keras.models.Model, path: str):
149+
_path = Path(path)
150+
model.save(path)
151+
if model.history is not None:
152+
history = model.history.history
153+
else:
154+
history = {}
155+
with open(_path.with_suffix('.history'), 'wb') as f:
156+
f.write(zstd.compress(pkl.dumps(history)))
157+
158+
159+
def load_model(path: str, co=None):
160+
_path = Path(path)
161+
model: keras.Model = keras.models.load_model(path, custom_objects=co) # type: ignore
162+
with open(_path.with_suffix('.history'), 'rb') as f:
163+
history: dict[str, list] = pkl.loads(zstd.decompress(f.read()))
164+
return model, history
165+
166+
167+
def save_history(history, path):
168+
with open(path, 'wb') as f:
169+
f.write(zstd.compress(pkl.dumps(history)))
170+
171+
172+
def load_history(path):
173+
with open(path, 'rb') as f:
174+
history = pkl.loads(zstd.decompress(f.read()))
175+
return history
176+
177+
178+
def absorb_batchNorm(model_target, model_original):
179+
for layer in model_target.layers:
180+
if layer.__class__.__name__ == 'Functional':
181+
absorb_batchNorm(layer, model_original.get_layer(layer.name))
182+
continue
183+
if (
184+
(isinstance(layer, Dense) or isinstance(layer, Conv))
185+
and len(nodes := model_original.get_layer(layer.name)._outbound_nodes) > 0
186+
and isinstance(nodes[0].outbound_layer, keras.layers.BatchNormalization)
187+
):
188+
_gamma, _beta, _mu, _var = model_original.get_layer(layer.name)._outbound_nodes[0].outbound_layer.get_weights()
189+
_ratio = _gamma / np.sqrt(0.001 + _var)
190+
_bias = -_gamma * _mu / np.sqrt(0.001 + _var) + _beta
191+
192+
k, *_b = model_original.get_layer(layer.name).get_weights()
193+
if _b:
194+
b = _b[0]
195+
else:
196+
b = np.zeros(layer.output_shape[-1])
197+
nk = np.einsum('...c, c-> ...c', k, _ratio, optimize=True)
198+
nb = np.einsum('...c, c-> ...c', b, _ratio, optimize=True) + _bias
199+
extras = layer.get_weights()[2:]
200+
layer.set_weights([nk, nb, *extras])
201+
elif hasattr(layer, 'kernel'):
202+
for w in layer.weights:
203+
if '_bw' not in w.name:
204+
break
205+
else:
206+
continue
207+
weights = layer.get_weights()
208+
new_weights = model_original.get_layer(layer.name).get_weights()
209+
l = len(new_weights) # noqa: E741 # If l looks like 1 by any chance, change your font.
210+
layer.set_weights([*new_weights, *weights[l:]][: len(weights)])
211+
212+
213+
def set_seed(seed):
214+
np.random.seed(seed)
215+
tf.random.set_seed(seed)
216+
os.environ['PYTHONHASHSEED'] = str(seed)
217+
random.seed(seed)
218+
219+
tf.config.experimental.enable_op_determinism()
220+
221+
222+
def get_best_ckpt(save_path: Path, take_min=False):
223+
ckpts = list(save_path.glob('*.h5'))
224+
225+
def rank(ckpt: Path):
226+
with h5.File(ckpt, 'r') as f:
227+
log: dict = f.attrs['train_log'] # type: ignore
228+
log = json.loads(log) # type: ignore
229+
metric = log['metric'] # type: ignore
230+
return metric
231+
232+
ckpts = sorted(ckpts, key=rank, reverse=not take_min)
233+
ckpt = ckpts[0]
234+
return ckpt
235+
236+
237+
class PeratoFront(keras.callbacks.Callback):
238+
def __init__(
239+
self,
240+
path: str | Path,
241+
fname_format: str,
242+
metrics_names: list[str],
243+
sides: list[int],
244+
cond_fn: Callable[[dict], bool] = lambda x: True,
245+
):
246+
self.path = Path(path)
247+
self.fname_format = fname_format
248+
os.makedirs(path, exist_ok=True)
249+
self.paths = []
250+
self.metrics = []
251+
self.metric_names = metrics_names
252+
self.sides = np.array(sides)
253+
self.cond_fn = cond_fn
254+
255+
def on_epoch_end(self, epoch, logs=None):
256+
assert isinstance(self.model, keras.models.Model)
257+
assert isinstance(logs, dict)
258+
259+
logs = logs.copy()
260+
logs['epoch'] = epoch
261+
262+
if not self.cond_fn(logs):
263+
return
264+
new_metrics = np.array([logs[metric_name] for metric_name in self.metric_names])
265+
_rm_idx = []
266+
for i, old_metrics in enumerate(self.metrics):
267+
_old_metrics = self.sides * old_metrics
268+
_new_metrics = self.sides * new_metrics
269+
if np.all(_new_metrics <= _old_metrics):
270+
return
271+
if np.all(_new_metrics >= _old_metrics):
272+
_rm_idx.append(i)
273+
for i in _rm_idx[::-1]:
274+
self.metrics.pop(i)
275+
p = self.paths.pop(i)
276+
os.remove(p)
277+
278+
path = self.path / self.fname_format.format(**logs)
279+
self.metrics.append(new_metrics)
280+
self.paths.append(path)
281+
self.model.save_weights(self.paths[-1])
282+
283+
with h5.File(path, 'r+') as f:
284+
log_str = json.dumps(logs, cls=NumpyFloatValuesEncoder)
285+
f.attrs['train_log'] = log_str
286+
287+
def rename_ckpts(self, dataset, bsz=65536):
288+
assert isinstance(self.model, keras.models.Model)
289+
290+
weight_buf = BytesIO()
291+
with h5.File(weight_buf, 'w') as f:
292+
hdf5_format.save_weights_to_hdf5_group(f, self.model)
293+
weight_buf.seek(0)
294+
295+
for i, path in enumerate(tqdm(self.paths, desc='Renaming checkpoints')):
296+
self.model.load_weights(path)
297+
bops = trace_minmax(self.model, dataset, bsz=bsz, verbose=False)
298+
with h5.File(path, 'r+') as f:
299+
logs = json.loads(f.attrs['train_log']) # type: ignore
300+
logs['bops'] = bops
301+
f.attrs['train_log'] = json.dumps(logs, cls=NumpyFloatValuesEncoder)
302+
metrics = np.array([logs[metric_name] for metric_name in self.metric_names])
303+
self.metrics[i] = metrics
304+
new_fname = self.path / self.fname_format.format(**logs)
305+
os.rename(path, new_fname)
306+
self.paths[i] = new_fname
307+
308+
with h5.File(weight_buf, 'r') as f:
309+
hdf5_format.load_weights_from_hdf5_group_by_name(f, self.model)
310+
311+
312+
class BetaScheduler(keras.callbacks.Callback):
313+
def __init__(self, beta_fn: Callable[[int], float]):
314+
self.beta_fn = beta_fn
315+
316+
def on_epoch_begin(self, epoch, logs=None):
317+
assert isinstance(self.model, keras.models.Model)
318+
319+
beta = self.beta_fn(epoch)
320+
for layer in self.model.layers:
321+
if hasattr(layer, 'beta'):
322+
layer.beta.assign(keras.backend.constant(beta, dtype=keras.backend.floatx()))
323+
324+
def on_epoch_end(self, epoch, logs=None):
325+
assert isinstance(logs, dict)
326+
logs['beta'] = self.beta_fn(epoch)
327+
328+
@classmethod
329+
def from_config(cls, config):
330+
return cls(get_schedule(config.beta, config.train.epochs))
331+
332+
333+
def get_schedule(beta_conf, total_epochs):
334+
epochs = []
335+
betas = []
336+
interpolations = []
337+
for block in beta_conf.intervals:
338+
epochs.append(block.epochs)
339+
betas.append(block.betas)
340+
interpolation = block.interpolation
341+
assert interpolation in ['linear', 'log']
342+
interpolations.append(interpolation == 'log')
343+
epochs = np.array(epochs + [total_epochs])
344+
assert np.all(np.diff(epochs) >= 0)
345+
betas = np.array(betas)
346+
interpolations = np.array(interpolations)
347+
348+
def schedule(epoch):
349+
if epoch >= total_epochs:
350+
return betas[-1, -1]
351+
idx = np.searchsorted(epochs, epoch, side='right') - 1
352+
beta0, beta1 = betas[idx]
353+
epoch0, epoch1 = epochs[idx], epochs[idx + 1]
354+
if interpolations[idx]:
355+
beta = beta0 * (beta1 / beta0) ** ((epoch - epoch0) / (epoch1 - epoch0))
356+
else:
357+
beta = beta0 + (beta1 - beta0) * (epoch - epoch0) / (epoch1 - epoch0)
358+
return float(beta)
359+
360+
return schedule

0 commit comments

Comments
 (0)