34
34
from libs .nnet3 .train .dropout_schedule import _get_dropout_proportions
35
35
from model import get_chain_model
36
36
from options import get_args
37
+ from sgd_max_change import SgdMaxChange
37
38
38
39
def get_objf (batch , model , device , criterion , opts , den_graph , training , optimizer = None , dropout = 0. ):
39
40
feature , supervision = batch
@@ -67,20 +68,20 @@ def get_objf(batch, model, device, criterion, opts, den_graph, training, optimiz
67
68
supervision , nnet_output ,
68
69
xent_output )
69
70
objf = objf_l2_term_weight [0 ]
71
+ change = 0
70
72
if training :
71
73
optimizer .zero_grad ()
72
74
objf .backward ()
73
- clip_grad_value_ (model .parameters (), 5.0 )
74
- optimizer .step ()
75
+ # clip_grad_value_(model.parameters(), 5.0)
76
+ _ , change = optimizer .step ()
75
77
76
78
objf_l2_term_weight = objf_l2_term_weight .detach ().cpu ()
77
79
78
80
total_objf = objf_l2_term_weight [0 ].item ()
79
81
total_weight = objf_l2_term_weight [2 ].item ()
80
82
total_frames = nnet_output .shape [0 ]
81
83
82
- return total_objf , total_weight , total_frames
83
-
84
+ return total_objf , total_weight , total_frames , change
84
85
85
86
def get_validation_objf (dataloader , model , device , criterion , opts , den_graph ):
86
87
total_objf = 0.
@@ -90,7 +91,7 @@ def get_validation_objf(dataloader, model, device, criterion, opts, den_graph):
90
91
model .eval ()
91
92
92
93
for batch_idx , (pseudo_epoch , batch ) in enumerate (dataloader ):
93
- objf , weight , frames = get_objf (
94
+ objf , weight , frames , _ = get_objf (
94
95
batch , model , device , criterion , opts , den_graph , False )
95
96
total_objf += objf
96
97
total_weight += weight
@@ -116,7 +117,7 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, crit
116
117
len (dataloader )) / (len (dataloader ) * num_epochs )
117
118
_ , dropout = _get_dropout_proportions (
118
119
dropout_schedule , data_fraction )[0 ]
119
- curr_batch_objf , curr_batch_weight , curr_batch_frames = get_objf (
120
+ curr_batch_objf , curr_batch_weight , curr_batch_frames , curr_batch_change = get_objf (
120
121
batch , model , device , criterion , opts , den_graph , True , optimizer , dropout = dropout )
121
122
122
123
total_objf += curr_batch_objf
@@ -127,13 +128,13 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, crit
127
128
logging .info (
128
129
'Device ({}) processing batch {}, current pseudo-epoch is {}/{}({:.6f}%), '
129
130
'global average objf: {:.6f} over {} '
130
- 'frames, current batch average objf: {:.6f} over {} frames, epoch {}'
131
+ 'frames, current batch average objf: {:.6f} over {} frames, minibatch change: {:.6f}, epoch {}'
131
132
.format (
132
133
device .index , batch_idx , pseudo_epoch , len (dataloader ),
133
134
float (pseudo_epoch ) / len (dataloader ) * 100 ,
134
135
total_objf / total_weight , total_frames ,
135
136
curr_batch_objf / curr_batch_weight ,
136
- curr_batch_frames , current_epoch ))
137
+ curr_batch_frames , curr_batch_change , current_epoch ))
137
138
138
139
if valid_dataloader and batch_idx % 1000 == 0 :
139
140
total_valid_objf , total_valid_weight , total_valid_frames = get_validation_objf (
@@ -167,6 +168,11 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, crit
167
168
dropout ,
168
169
pseudo_epoch + current_epoch * len (dataloader ))
169
170
171
+ tf_writer .add_scalar (
172
+ 'train/current_batch_change' ,
173
+ curr_batch_change ,
174
+ pseudo_epoch + current_epoch * len (dataloader ))
175
+
170
176
state_dict = model .state_dict ()
171
177
for key , value in state_dict .items ():
172
178
# skip batchnorm parameters
@@ -301,7 +307,7 @@ def process_job(learning_rate, device_id=None, local_rank=None):
301
307
else :
302
308
valid_dataloader = None
303
309
304
- optimizer = optim . Adam (model .parameters (),
310
+ optimizer = SgdMaxChange (model .parameters (),
305
311
lr = learning_rate ,
306
312
weight_decay = 5e-4 )
307
313
0 commit comments