Skip to content

Commit db9ebc0

Browse files
committed
add HGQ
1 parent 72bcb27 commit db9ebc0

File tree

2 files changed

+871
-0
lines changed

2 files changed

+871
-0
lines changed

nn_utils.py

Lines changed: 369 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,369 @@
1+
2+
3+
from io import BytesIO
4+
import json
5+
import os
6+
import random
7+
8+
from glob import glob
9+
from pathlib import Path
10+
import pickle as pkl
11+
from typing import Callable
12+
13+
import numpy as np
14+
15+
import tensorflow as tf
16+
from tensorflow import keras
17+
18+
from tqdm.auto import tqdm
19+
from matplotlib import pyplot as plt
20+
21+
import zstd
22+
import h5py as h5
23+
24+
from keras.src.saving.legacy import hdf5_format
25+
from keras.src.layers.convolutional.base_conv import Conv
26+
from keras.layers import Dense
27+
28+
from HGQ.bops import trace_minmax
29+
30+
31+
class NumpyFloatValuesEncoder(json.JSONEncoder):
32+
def default(self, obj):
33+
if isinstance(obj, np.float32): # type: ignore
34+
return float(obj)
35+
return json.JSONEncoder.default(self, obj)
36+
37+
38+
class SaveTopN(keras.callbacks.Callback):
39+
def __init__(self,
40+
metric_fn: Callable[[dict], float],
41+
n: int,
42+
path: str | Path,
43+
side: str = 'max',
44+
fname_format='epoch={epoch}-metric={metric:.4e}.h5',
45+
cond_fn: Callable[[dict], bool] = lambda x: True,
46+
):
47+
self.n = n
48+
self.metric_fn = metric_fn
49+
self.path = Path(path)
50+
self.fname_format = fname_format
51+
os.makedirs(path, exist_ok=True)
52+
self.weight_paths = np.full(n, '/dev/null', dtype=object)
53+
if side == 'max':
54+
self.best = np.full(n, -np.inf)
55+
self.side = np.greater
56+
elif side == 'min':
57+
self.best = np.full(n, np.inf)
58+
self.side = np.less
59+
self.cond = cond_fn
60+
61+
def on_epoch_end(self, epoch, logs=None):
62+
assert isinstance(logs, dict)
63+
assert isinstance(self.model, keras.models.Model)
64+
logs = logs.copy()
65+
logs['epoch'] = epoch
66+
if not self.cond(logs):
67+
return
68+
metric = self.metric_fn(logs)
69+
70+
if self.side(metric, self.best[-1]):
71+
try:
72+
os.remove(self.weight_paths[-1])
73+
except OSError:
74+
pass
75+
logs['metric'] = metric
76+
fname = self.path / self.fname_format.format(**logs)
77+
self.best[-1] = metric
78+
self.weight_paths[-1] = fname
79+
self.model.save_weights(fname)
80+
with h5.File(fname, 'r+') as f:
81+
log_str = json.dumps(logs, cls=NumpyFloatValuesEncoder)
82+
f.attrs['train_log'] = log_str
83+
idx = np.argsort(self.best)
84+
if self.side == np.greater:
85+
idx = idx[::-1]
86+
self.best = self.best[idx]
87+
self.weight_paths = self.weight_paths[idx]
88+
89+
def rename_ckpts(self, dataset, bsz=65536):
90+
assert self.weight_paths[0] != '/dev/null', 'No checkpoints to rename'
91+
assert isinstance(self.model, keras.models.Model)
92+
93+
weight_buf = BytesIO()
94+
with h5.File(weight_buf, 'w') as f:
95+
hdf5_format.save_weights_to_hdf5_group(f, self.model)
96+
weight_buf.seek(0)
97+
98+
for i, path in enumerate(tqdm(self.weight_paths, desc='Renaming checkpoints')):
99+
if path == '/dev/null':
100+
continue
101+
self.model.load_weights(path)
102+
bops = trace_minmax(self.model, dataset, bsz=bsz, verbose=False)
103+
with h5.File(path, 'r+') as f:
104+
logs = json.loads(f.attrs['train_log']) # type: ignore
105+
logs['bops'] = bops
106+
metric = self.metric_fn(logs)
107+
logs['metric'] = metric
108+
f.attrs['train_log'] = json.dumps(logs, cls=NumpyFloatValuesEncoder)
109+
self.best[i] = metric
110+
new_fname = self.path / self.fname_format.format(**logs)
111+
os.rename(path, new_fname)
112+
self.weight_paths[i] = new_fname
113+
114+
idx = np.argsort(self.best)
115+
self.best = self.best[idx]
116+
self.weight_paths = self.weight_paths[idx]
117+
with h5.File(weight_buf, 'r') as f:
118+
hdf5_format.load_weights_from_hdf5_group_by_name(f, self.model)
119+
120+
121+
class PBarCallback(tf.keras.callbacks.Callback):
122+
def __init__(self, metric='loss: {loss:.2f}/{val_loss:.2f}'):
123+
self.pbar = None
124+
self.template = metric
125+
126+
def on_epoch_begin(self, epoch, logs=None):
127+
if self.pbar is None:
128+
self.pbar = tqdm(total=self.params['epochs'], unit='epoch')
129+
130+
def on_epoch_end(self, epoch, logs=None):
131+
assert isinstance(self.pbar, tqdm)
132+
assert isinstance(logs, dict)
133+
self.pbar.update(1)
134+
string = self.template.format(**logs)
135+
if 'bops' in logs:
136+
string += f' - BOPs: {logs["bops"]:,.0f}'
137+
self.pbar.set_description(string)
138+
139+
def on_train_end(self, logs=None):
140+
if self.pbar is not None:
141+
self.pbar.close()
142+
143+
144+
def plot_history(histry: dict, metrics=('loss', 'val_loss'), ylabel='Loss', logy=False):
145+
fig, ax = plt.subplots()
146+
for metric in metrics:
147+
ax.plot(histry[metric], label=metric)
148+
ax.set_xlabel('Epoch')
149+
ax.set_ylabel(ylabel)
150+
if logy:
151+
ax.set_yscale('log')
152+
ax.legend()
153+
return fig, ax
154+
155+
156+
def save_model(model: keras.models.Model, path: str):
157+
_path = Path(path)
158+
model.save(path)
159+
if model.history is not None:
160+
history = model.history.history
161+
else:
162+
history = {}
163+
with open(_path.with_suffix('.history'), 'wb') as f:
164+
f.write(zstd.compress(pkl.dumps(history)))
165+
166+
167+
def load_model(path: str, co=None):
168+
_path = Path(path)
169+
model: keras.Model = keras.models.load_model(path, custom_objects=co) # type: ignore
170+
with open(_path.with_suffix('.history'), 'rb') as f:
171+
history: dict[str, list] = pkl.loads(zstd.decompress(f.read()))
172+
return model, history
173+
174+
175+
def save_history(history, path):
176+
with open(path, 'wb') as f:
177+
f.write(zstd.compress(pkl.dumps(history)))
178+
179+
180+
def load_history(path):
181+
with open(path, 'rb') as f:
182+
history = pkl.loads(zstd.decompress(f.read()))
183+
return history
184+
185+
186+
def absorb_batchNorm(model_target, model_original):
187+
for layer in model_target.layers:
188+
if layer.__class__.__name__ == 'Functional':
189+
absorb_batchNorm(layer, model_original.get_layer(layer.name))
190+
continue
191+
if (isinstance(layer, Dense) or isinstance(layer, Conv)) and \
192+
len(nodes := model_original.get_layer(layer.name)._outbound_nodes) > 0 and \
193+
isinstance(nodes[0].outbound_layer, keras.layers.BatchNormalization):
194+
_gamma, _beta, _mu, _var = model_original.get_layer(layer.name)._outbound_nodes[0].outbound_layer.get_weights()
195+
_ratio = _gamma / np.sqrt(0.001 + _var)
196+
_bias = -_gamma * _mu / np.sqrt(0.001 + _var) + _beta
197+
198+
k, *_b = model_original.get_layer(layer.name).get_weights()
199+
if _b:
200+
b = _b[0]
201+
else:
202+
b = np.zeros(layer.output_shape[-1])
203+
nk = np.einsum('...c, c-> ...c', k, _ratio, optimize=True)
204+
nb = np.einsum('...c, c-> ...c', b, _ratio, optimize=True) + _bias
205+
extras = layer.get_weights()[2:]
206+
layer.set_weights([nk, nb, *extras])
207+
elif hasattr(layer, 'kernel'):
208+
for w in layer.weights:
209+
if '_bw' not in w.name:
210+
break
211+
else:
212+
continue
213+
weights = layer.get_weights()
214+
new_weights = model_original.get_layer(layer.name).get_weights()
215+
l = len(new_weights)
216+
layer.set_weights([*new_weights, *weights[l:]][:len(weights)])
217+
218+
219+
def set_seed(seed):
220+
np.random.seed(seed)
221+
tf.random.set_seed(seed)
222+
os.environ['PYTHONHASHSEED'] = str(seed)
223+
random.seed(seed)
224+
225+
tf.config.experimental.enable_op_determinism()
226+
227+
228+
import h5py as h5
229+
import json
230+
231+
232+
def get_best_ckpt(save_path: Path, take_min=False):
233+
ckpts = list(save_path.glob('*.h5'))
234+
235+
def rank(ckpt: Path):
236+
with h5.File(ckpt, 'r') as f:
237+
log: dict = f.attrs['train_log'] # type: ignore
238+
log = json.loads(log) # type: ignore
239+
metric = log['metric'] # type: ignore
240+
return metric
241+
242+
ckpts = sorted(ckpts, key=rank, reverse=not take_min)
243+
ckpt = ckpts[0]
244+
return ckpt
245+
246+
247+
class PeratoFront(keras.callbacks.Callback):
248+
def __init__(self,
249+
path: str | Path,
250+
fname_format: str,
251+
metrics_names: list[str],
252+
sides: list[int],
253+
cond_fn: Callable[[dict], bool] = lambda x: True,
254+
):
255+
self.path = Path(path)
256+
self.fname_format = fname_format
257+
os.makedirs(path, exist_ok=True)
258+
self.paths = []
259+
self.metrics = []
260+
self.metric_names = metrics_names
261+
self.sides = np.array(sides)
262+
self.cond_fn = cond_fn
263+
264+
def on_epoch_end(self, epoch, logs=None):
265+
assert isinstance(self.model, keras.models.Model)
266+
assert isinstance(logs, dict)
267+
268+
logs = logs.copy()
269+
logs['epoch'] = epoch
270+
271+
if not self.cond_fn(logs):
272+
return
273+
new_metrics = np.array([logs[metric_name] for metric_name in self.metric_names])
274+
_rm_idx = []
275+
for i, old_metrics in enumerate(self.metrics):
276+
_old_metrics = self.sides * old_metrics
277+
_new_metrics = self.sides * new_metrics
278+
if np.all(_new_metrics <= _old_metrics):
279+
return
280+
if np.all(_new_metrics >= _old_metrics):
281+
_rm_idx.append(i)
282+
for i in _rm_idx[::-1]:
283+
self.metrics.pop(i)
284+
p = self.paths.pop(i)
285+
os.remove(p)
286+
287+
path = self.path / self.fname_format.format(**logs)
288+
self.metrics.append(new_metrics)
289+
self.paths.append(path)
290+
self.model.save_weights(self.paths[-1])
291+
292+
with h5.File(path, 'r+') as f:
293+
log_str = json.dumps(logs, cls=NumpyFloatValuesEncoder)
294+
f.attrs['train_log'] = log_str
295+
296+
def rename_ckpts(self, dataset, bsz=65536):
297+
assert isinstance(self.model, keras.models.Model)
298+
299+
weight_buf = BytesIO()
300+
with h5.File(weight_buf, 'w') as f:
301+
hdf5_format.save_weights_to_hdf5_group(f, self.model)
302+
weight_buf.seek(0)
303+
304+
for i, path in enumerate(tqdm(self.paths, desc='Renaming checkpoints')):
305+
self.model.load_weights(path)
306+
bops = trace_minmax(self.model, dataset, bsz=bsz, verbose=False)
307+
with h5.File(path, 'r+') as f:
308+
logs = json.loads(f.attrs['train_log']) # type: ignore
309+
logs['bops'] = bops
310+
f.attrs['train_log'] = json.dumps(logs, cls=NumpyFloatValuesEncoder)
311+
metrics = np.array([logs[metric_name] for metric_name in self.metric_names])
312+
self.metrics[i] = metrics
313+
new_fname = self.path / self.fname_format.format(**logs)
314+
os.rename(path, new_fname)
315+
self.paths[i] = new_fname
316+
317+
with h5.File(weight_buf, 'r') as f:
318+
hdf5_format.load_weights_from_hdf5_group_by_name(f, self.model)
319+
320+
321+
class BetaScheduler(keras.callbacks.Callback):
322+
def __init__(self, beta_fn: Callable[[int], float]):
323+
self.beta_fn = beta_fn
324+
325+
def on_epoch_begin(self, epoch, logs=None):
326+
assert isinstance(self.model, keras.models.Model)
327+
328+
beta = self.beta_fn(epoch)
329+
for layer in self.model.layers:
330+
if hasattr(layer, 'beta'):
331+
layer.beta.assign(keras.backend.constant(beta, dtype=keras.backend.floatx()))
332+
333+
def on_epoch_end(self, epoch, logs=None):
334+
assert isinstance(logs, dict)
335+
logs['beta'] = self.beta_fn(epoch)
336+
337+
@classmethod
338+
def from_config(cls, config):
339+
return cls(get_schedule(config.beta, config.train.epochs))
340+
341+
342+
def get_schedule(beta_conf, total_epochs):
343+
epochs = []
344+
betas = []
345+
interpolations = []
346+
for block in beta_conf.intervals:
347+
epochs.append(block.epochs)
348+
betas.append(block.betas)
349+
interpolation = block.interpolation
350+
assert interpolation in ['linear', 'log']
351+
interpolations.append(interpolation == 'log')
352+
epochs = np.array(epochs + [total_epochs])
353+
assert np.all(np.diff(epochs) >= 0)
354+
betas = np.array(betas)
355+
interpolations = np.array(interpolations)
356+
357+
def schedule(epoch):
358+
if epoch >= total_epochs:
359+
return betas[-1, -1]
360+
idx = np.searchsorted(epochs, epoch, side='right') - 1
361+
beta0, beta1 = betas[idx]
362+
epoch0, epoch1 = epochs[idx], epochs[idx + 1]
363+
if interpolations[idx]:
364+
beta = beta0 * (beta1 / beta0) ** ((epoch - epoch0) / (epoch1 - epoch0))
365+
else:
366+
beta = beta0 + (beta1 - beta0) * (epoch - epoch0) / (epoch1 - epoch0)
367+
return float(beta)
368+
369+
return schedule

0 commit comments

Comments
 (0)