@@ -426,7 +426,7 @@ def test_gradient_norms_on_various_models(
426
426
427
427
class ClipGradsEmbeddingLayerTest (tf .test .TestCase , parameterized .TestCase ):
428
428
429
- # TODO(wkong ): Test sparse input tensors when the GitHub CI environment
429
+ # TODO(weiweikong ): Test sparse input tensors when the GitHub CI environment
430
430
# supports them for embeddings.
431
431
@parameterized .product (
432
432
x_batch = [
@@ -541,5 +541,83 @@ def test_gradient_norms_on_various_models(
541
541
self .assertAllClose (computed_norms , true_norms , rtol = 1e-3 , atol = 1e-2 )
542
542
543
543
544
+ class ClipGradsComputeClippedGradsAndOutputsTest (
545
+ tf .test .TestCase , parameterized .TestCase
546
+ ):
547
+
548
+ def setUp (self ):
549
+ super ().setUp ()
550
+ dense_generator = lambda a , b : tf .keras .layers .Dense (b )
551
+ self ._input_dim = 2
552
+ self ._output_dim = 3
553
+ self ._model = make_two_layer_sequential_model (
554
+ dense_generator , self ._input_dim , self ._output_dim
555
+ )
556
+
557
+ @parameterized .product (
558
+ batch_size = [1 , 2 , 10 ],
559
+ l2_norm_clip = [0.1 , 1.0 , 10 ],
560
+ is_eager = [True , False ],
561
+ reduction = ['auto' , 'sum' , 'sum_over_batch_size' , 'none' ],
562
+ )
563
+ def test_clipped_gradients_on_different_losses (
564
+ self , batch_size , l2_norm_clip , is_eager , reduction
565
+ ):
566
+ loss_fn = tf .keras .losses .MeanSquaredError (reduction = reduction )
567
+ self ._model .compile (loss = loss_fn , run_eagerly = is_eager )
568
+ x_batch = tf .reshape (
569
+ tf .range (batch_size * self ._input_dim , dtype = tf .float32 ),
570
+ [batch_size , - 1 ],
571
+ )
572
+ y_batch = tf .reshape (
573
+ 1.0 + tf .range (batch_size , dtype = tf .float32 ), [batch_size , - 1 ]
574
+ )
575
+ # Stop early for efficiency.
576
+ if reduction == 'none' :
577
+ self .assertRaises (
578
+ NotImplementedError ,
579
+ # function tested
580
+ clip_grads .compute_clipped_gradients_and_outputs ,
581
+ # function args
582
+ self ._model ,
583
+ x_batch ,
584
+ y_batch ,
585
+ l2_norm_clip ,
586
+ layer_registry .make_default_layer_registry (),
587
+ )
588
+ return
589
+ # NOTE: losses from this point are scalar losses.
590
+ with tf .GradientTape () as tape :
591
+ y_pred = self ._model (x_batch )
592
+ loss_value = loss_fn (y_pred , y_batch )
593
+ true_grads = tape .gradient (loss_value , self ._model .trainable_variables )
594
+ clipped_grads , _ , _ = clip_grads .compute_clipped_gradients_and_outputs (
595
+ self ._model ,
596
+ x_batch ,
597
+ y_batch ,
598
+ l2_norm_clip ,
599
+ layer_registry .make_default_layer_registry (),
600
+ )
601
+
602
+ # Computes the L2 norm manually.
603
+ def compute_l2_norm (t ):
604
+ sqr_sum_fn = lambda x : tf .reduce_sum (tf .square (x ))
605
+ return tf .sqrt (tf .add_n (tf .nest .map_structure (sqr_sum_fn , t )))
606
+
607
+ true_norm = compute_l2_norm (true_grads )
608
+ computed_norm = compute_l2_norm (clipped_grads )
609
+ norm_bound = (
610
+ l2_norm_clip * batch_size if reduction == 'sum' else l2_norm_clip
611
+ )
612
+ if true_norm >= norm_bound :
613
+ # All of the per-example gradient norms should be less than the L2 norm
614
+ # clip value. Hence, by the triangle inequality, the gradient norm of the
615
+ # summed loss (averaged loss) should be less than the clip value times
616
+ # the batch size (just the clip value).
617
+ self .assertLessEqual (computed_norm , norm_bound )
618
+ else :
619
+ self .assertAlmostEqual (computed_norm , true_norm )
620
+
621
+
544
622
if __name__ == '__main__' :
545
623
tf .test .main ()
0 commit comments