@@ -62,7 +62,7 @@ def update(self, current, values=None):
62
62
if k not in self .stateful_metrics :
63
63
if k not in self ._values :
64
64
self ._values [k ] = [v * (current - self ._seen_so_far ), current - self ._seen_so_far ]
65
- elif (self .smooth_interval is not None ) and (current % self .smooth_interval == 0 ):
65
+ elif (self .smooth_interval is not None ) and (current % self .smooth_interval == 1 ):
66
66
# 如果定义了累积smooth_interval,则需要重新累计
67
67
self ._values [k ] = [v , 1 ]
68
68
else :
@@ -435,7 +435,7 @@ def smooth_values(self, current, values=None):
435
435
if k not in self .stateful_metrics :
436
436
if k not in self ._values :
437
437
self ._values [k ] = [v * (current - self ._seen_so_far ), current - self ._seen_so_far ]
438
- elif (self .smooth_interval is not None ) and (current % self .smooth_interval == 0 ):
438
+ elif (self .smooth_interval is not None ) and (current % self .smooth_interval == 1 ):
439
439
# 如果定义了累积smooth_interval,则需要重新累计
440
440
self ._values [k ] = [v , 1 ]
441
441
else :
@@ -887,15 +887,13 @@ class Tensorboard(Callback):
887
887
:param method: str, 控制是按照epoch还是step来计算,默认为'epoch', 可选{'step', 'epoch'}
888
888
:param interval: int, 保存tensorboard的间隔
889
889
:param prefix: str, tensorboard分栏的前缀,默认为'train'
890
- :param on_epoch_end_scalar_epoch: bool, epoch结束后是横轴是按照epoch还是global_step来记录
891
890
'''
892
- def __init__ (self , log_dir , method = 'epoch' , interval = 10 , prefix = 'train' , on_epoch_end_scalar_epoch = True , ** kwargs ):
891
+ def __init__ (self , log_dir , method = 'epoch' , interval = 10 , prefix = 'train' , ** kwargs ):
893
892
super (Tensorboard , self ).__init__ (** kwargs )
894
893
assert method in {'step' , 'epoch' }, 'Args `method` only support `step` or `epoch`'
895
894
self .method = method
896
895
self .interval = interval
897
896
self .prefix = prefix + '/' if len (prefix .strip ()) > 0 else '' # 控制默认的前缀,用于区分栏目
898
- self .on_epoch_end_scalar_epoch = on_epoch_end_scalar_epoch # 控制on_epoch_end记录的是epoch还是global_step
899
897
900
898
from tensorboardX import SummaryWriter
901
899
os .makedirs (log_dir , exist_ok = True )
@@ -904,8 +902,7 @@ def __init__(self, log_dir, method='epoch', interval=10, prefix='train', on_epoc
904
902
def on_epoch_end (self , global_step , epoch , logs = None ):
905
903
if self .method == 'epoch' :
906
904
# 默认记录的是epoch
907
- log_step = epoch + 1 if self .on_epoch_end_scalar_epoch else global_step + 1
908
- self .process (log_step , logs )
905
+ self .process (epoch + 1 , logs )
909
906
910
907
def on_batch_end (self , global_step , local_step , logs = None ):
911
908
# 默认记录的是global_step
@@ -998,22 +995,22 @@ def __init__(self, receivers, subject='', method='epoch', interval=10, mail_host
998
995
999
996
def on_epoch_end (self , global_step , epoch , logs = None ):
1000
997
if self .method == 'epoch' :
1001
- msg = json .dumps ({k :f'{ v :.5f} ' for k ,v in logs .items () if k != 'size' }, indent = 2 , ensure_ascii = False )
998
+ msg = json .dumps ({k :f'{ v :.5f} ' for k ,v in logs .items () if k not in SKIP_METRICS }, indent = 2 , ensure_ascii = False )
1002
999
subject = f'[INFO] Epoch { epoch + 1 } performance'
1003
1000
if self .subject != '' :
1004
1001
subject = self .subject + ' | ' + subject
1005
1002
self ._email (subject , msg )
1006
1003
1007
1004
def on_batch_end (self , global_step , local_step , logs = None ):
1008
1005
if (self .method == 'step' ) and ((global_step + 1 ) % self .interval == 0 ):
1009
- msg = json .dumps ({k :f'{ v :.5f} ' for k ,v in logs .items () if k != 'size' }, indent = 2 , ensure_ascii = False )
1006
+ msg = json .dumps ({k :f'{ v :.5f} ' for k ,v in logs .items () if k not in SKIP_METRICS }, indent = 2 , ensure_ascii = False )
1010
1007
subject = f'[INFO] Step { global_step } performance'
1011
1008
if self .subject != '' :
1012
1009
subject = self .subject + ' | ' + subject
1013
1010
self ._email (subject , msg )
1014
1011
1015
1012
def on_train_end (self , logs = None ):
1016
- msg = json .dumps ({k :f'{ v :.5f} ' for k ,v in logs .items () if k != 'size' }, indent = 2 , ensure_ascii = False )
1013
+ msg = json .dumps ({k :f'{ v :.5f} ' for k ,v in logs .items () if k not in SKIP_METRICS }, indent = 2 , ensure_ascii = False )
1017
1014
subject = f'[INFO] Finish training'
1018
1015
if self .subject != '' :
1019
1016
subject = self .subject + ' | ' + subject
0 commit comments