1
-
2
-
3
- from io import BytesIO
4
1
import json
5
2
import os
3
+ import pickle as pkl
6
4
import random
7
-
8
5
from glob import glob
6
+ from io import BytesIO
9
7
from pathlib import Path
10
- import pickle as pkl
11
8
from typing import Callable
12
9
10
+ import h5py as h5
13
11
import numpy as np
14
-
15
12
import tensorflow as tf
16
- from tensorflow import keras
17
-
18
- from tqdm .auto import tqdm
19
- from matplotlib import pyplot as plt
20
-
21
13
import zstd
22
- import h5py as h5
23
-
24
- from keras .src .saving .legacy import hdf5_format
25
- from keras .src .layers .convolutional .base_conv import Conv
26
- from keras .layers import Dense
27
-
28
14
from HGQ .bops import trace_minmax
15
+ from keras .layers import Dense
16
+ from keras .src .layers .convolutional .base_conv import Conv
17
+ from keras .src .saving .legacy import hdf5_format
18
+ from matplotlib import pyplot as plt
19
+ from tensorflow import keras
20
+ from tqdm .auto import tqdm
29
21
30
22
31
23
class NumpyFloatValuesEncoder (json .JSONEncoder ):
@@ -36,14 +28,15 @@ def default(self, obj):
36
28
37
29
38
30
class SaveTopN (keras .callbacks .Callback ):
39
- def __init__ (self ,
40
- metric_fn : Callable [[dict ], float ],
41
- n : int ,
42
- path : str | Path ,
43
- side : str = 'max' ,
44
- fname_format = 'epoch={epoch}-metric={metric:.4e}.h5' ,
45
- cond_fn : Callable [[dict ], bool ] = lambda x : True ,
46
- ):
31
+ def __init__ (
32
+ self ,
33
+ metric_fn : Callable [[dict ], float ],
34
+ n : int ,
35
+ path : str | Path ,
36
+ side : str = 'max' ,
37
+ fname_format = 'epoch={epoch}-metric={metric:.4e}.h5' ,
38
+ cond_fn : Callable [[dict ], bool ] = lambda x : True ,
39
+ ):
47
40
self .n = n
48
41
self .metric_fn = metric_fn
49
42
self .path = Path (path )
@@ -188,9 +181,11 @@ def absorb_batchNorm(model_target, model_original):
188
181
if layer .__class__ .__name__ == 'Functional' :
189
182
absorb_batchNorm (layer , model_original .get_layer (layer .name ))
190
183
continue
191
- if (isinstance (layer , Dense ) or isinstance (layer , Conv )) and \
192
- len (nodes := model_original .get_layer (layer .name )._outbound_nodes ) > 0 and \
193
- isinstance (nodes [0 ].outbound_layer , keras .layers .BatchNormalization ):
184
+ if (
185
+ (isinstance (layer , Dense ) or isinstance (layer , Conv ))
186
+ and len (nodes := model_original .get_layer (layer .name )._outbound_nodes ) > 0
187
+ and isinstance (nodes [0 ].outbound_layer , keras .layers .BatchNormalization )
188
+ ):
194
189
_gamma , _beta , _mu , _var = model_original .get_layer (layer .name )._outbound_nodes [0 ].outbound_layer .get_weights ()
195
190
_ratio = _gamma / np .sqrt (0.001 + _var )
196
191
_bias = - _gamma * _mu / np .sqrt (0.001 + _var ) + _beta
@@ -213,7 +208,7 @@ def absorb_batchNorm(model_target, model_original):
213
208
weights = layer .get_weights ()
214
209
new_weights = model_original .get_layer (layer .name ).get_weights ()
215
210
l = len (new_weights )
216
- layer .set_weights ([* new_weights , * weights [l :]][:len (weights )])
211
+ layer .set_weights ([* new_weights , * weights [l :]][: len (weights )])
217
212
218
213
219
214
def set_seed (seed ):
@@ -225,9 +220,10 @@ def set_seed(seed):
225
220
tf .config .experimental .enable_op_determinism ()
226
221
227
222
228
- import h5py as h5
229
223
import json
230
224
225
+ import h5py as h5
226
+
231
227
232
228
def get_best_ckpt (save_path : Path , take_min = False ):
233
229
ckpts = list (save_path .glob ('*.h5' ))
@@ -245,13 +241,14 @@ def rank(ckpt: Path):
245
241
246
242
247
243
class PeratoFront (keras .callbacks .Callback ):
248
- def __init__ (self ,
249
- path : str | Path ,
250
- fname_format : str ,
251
- metrics_names : list [str ],
252
- sides : list [int ],
253
- cond_fn : Callable [[dict ], bool ] = lambda x : True ,
254
- ):
244
+ def __init__ (
245
+ self ,
246
+ path : str | Path ,
247
+ fname_format : str ,
248
+ metrics_names : list [str ],
249
+ sides : list [int ],
250
+ cond_fn : Callable [[dict ], bool ] = lambda x : True ,
251
+ ):
255
252
self .path = Path (path )
256
253
self .fname_format = fname_format
257
254
os .makedirs (path , exist_ok = True )
0 commit comments