diff --git a/paxml/learners.py b/paxml/learners.py index 80c2aa825..e3e153905 100644 --- a/paxml/learners.py +++ b/paxml/learners.py @@ -185,9 +185,30 @@ def get_grad_tx( self._hparams.repeat_prefix_sep, ) + def get_individual_grad_norms( + self, + raw_grads, + optimizer_name): + p = self._hparams + # Compute gradient norm. + + if p.grad_norm_individual_vars: + grad_norms = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)), raw_grads) + var_keys = py_utils.extract_prefixed_keys_from_nested_map(grad_norms) + + def add_grad_norm_summary(key, value): + base_layer.add_global_summary( + f'per_var_grad_norm/{optimizer_name}{key}', + value, + SummaryType.AGGREGATE_SCALAR, + ) + + jax.tree_map(add_grad_norm_summary, var_keys, grad_norms) + def scale_gradients( self, raw_grads: NestedMap, + raw_grad_norm: JTensor, optimizer_name: Optional[str] = None, clip_gradient_norm_to_value: Optional[float] = None, clip_gradient_single_norm_to_value: Optional[float] = None, @@ -209,57 +230,20 @@ def scale_gradients( have anomaly detected (e.g. Nan or Inf, or excessively big gradient norm) and should not be skipped. """ + p = self._hparams + if optimizer_name is None: optimizer_name = '' else: optimizer_name = optimizer_name + '/' + if clip_gradient_norm_to_value is None: clip_gradient_norm_to_value = p.optimizer.clip_gradient_norm_to_value if clip_gradient_single_norm_to_value is None: clip_gradient_single_norm_to_value = ( p.optimizer.clip_gradient_single_norm_to_value ) - # Compute gradient norm. - - if p.grad_norm_individual_vars: - grad_norms = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)), raw_grads) - var_keys = py_utils.extract_prefixed_keys_from_nested_map(grad_norms) - - def add_grad_norm_summary(key, value): - base_layer.add_global_summary( - f'per_var_grad_norm/{optimizer_name}{key}', - value, - SummaryType.AGGREGATE_SCALAR, - ) - - jax.tree_map(add_grad_norm_summary, var_keys, grad_norms) - - if ( - p.grad_norm_summary - or p.check_valid_step - or clip_gradient_norm_to_value - or clip_gradient_single_norm_to_value - ): - raw_grad_norm = _compute_grad_norm(raw_grads) - if p.grad_norm_summary: - base_layer.add_global_summary( - 'learning/' + optimizer_name + 'raw_grad_norm', - raw_grad_norm, - SummaryType.AGGREGATE_SCALAR, - ) - else: - raw_grad_norm = None - - def keep_step(grad_norm): - keep_threshold = p.skip_step_gradient_norm_value - if keep_threshold: - return jnp.logical_and( - jnp.all(jnp.isfinite(grad_norm)), - jnp.all(jnp.less(grad_norm, keep_threshold)), - ) - else: - return jnp.all(jnp.isfinite(grad_norm)) def clip_grads(grads, grad_norm): if clip_gradient_norm_to_value: @@ -288,17 +272,6 @@ def scale_gradient(grad, norm): grad_scale = jnp.array(1.0) return grads, grad_scale - if p.check_valid_step: - # Mark the step as invalid if any gradient anomaly is detected (e.g. Nan - # or Inf, or excessively big gradient norm). - valid_step = keep_step(raw_grad_norm) - base_layer.add_global_summary( - 'learning/' + optimizer_name + 'is_valid_step', - valid_step.astype(jnp.float32), - SummaryType.AGGREGATE_SCALAR, - ) - else: - valid_step = True grads, grad_scale = clip_grads(raw_grads, raw_grad_norm) base_layer.add_global_summary( 'learning/' + optimizer_name + 'grad_scale', @@ -313,7 +286,71 @@ def scale_gradient(grad, norm): clipped_grad_norm, SummaryType.AGGREGATE_SCALAR, ) - return grads, valid_step # pytype: disable=bad-return-type # jax-ndarray + return grads # pytype: disable=bad-return-type # jax-ndarray + + + def get_grad_norm_valid_step( + self, + raw_grads, + optimizer_name: Optional[str] = None, + clip_gradient_norm_to_value: Optional[float] = None, + clip_gradient_single_norm_to_value: Optional[float] = None + ) -> Tuple[JTensor, JTensor]: + + p = self._hparams + + if optimizer_name is None: + optimizer_name = '' + else: + optimizer_name = optimizer_name + '/' + self.get_individual_grad_norms(raw_grads, optimizer_name) + + if clip_gradient_norm_to_value is None: + clip_gradient_norm_to_value = p.optimizer.clip_gradient_norm_to_value + if clip_gradient_single_norm_to_value is None: + clip_gradient_single_norm_to_value = ( + p.optimizer.clip_gradient_single_norm_to_value + ) + + if ( + p.grad_norm_summary + or p.check_valid_step + or clip_gradient_norm_to_value + or clip_gradient_single_norm_to_value + ): + raw_grad_norm = _compute_grad_norm(raw_grads) + if p.grad_norm_summary: + base_layer.add_global_summary( + 'learning/' + optimizer_name + 'raw_grad_norm', + raw_grad_norm, + SummaryType.AGGREGATE_SCALAR, + ) + else: + raw_grad_norm = None + + def keep_step(grad_norm): + keep_threshold = p.skip_step_gradient_norm_value + if keep_threshold: + return jnp.logical_and( + jnp.all(jnp.isfinite(grad_norm)), + jnp.all(jnp.less(grad_norm, keep_threshold)), + ) + else: + return jnp.all(jnp.isfinite(grad_norm)) + + if p.check_valid_step: + # Mark the step as invalid if any gradient anomaly is detected (e.g. Nan + # or Inf, or excessively big gradient norm). + valid_step = keep_step(raw_grad_norm) + base_layer.add_global_summary( + 'learning/' + optimizer_name + 'is_valid_step', + valid_step.astype(jnp.float32), + SummaryType.AGGREGATE_SCALAR, + ) + else: + valid_step = True + + return raw_grad_norm, valid_step def update_states( self, @@ -335,7 +372,15 @@ def update_states( """ p = self._hparams - grads, valid_step = self.scale_gradients(grads) + grad_norm, valid_step = self.get_grad_norm_valid_step(grads) + + using_grad_accum = hasattr(p.optimizer, 'num_sub_batches') + + # When using gradient accumulation, gradient scaling happens within base + # optimizer update + if not using_grad_accum: + grads = self.scale_gradients(grads, grad_norm) + transformed_grad, new_states = self.get_grad_tx(var_weight_hparams).update( grads, states, old_vars ) @@ -357,6 +402,7 @@ def _update(updated, original): new_states = jax.tree_map( _update, new_states, states, is_leaf=py_utils.is_optax_masked_node ) + # Final applied grad norm. if p.grad_norm_summary: applied_grad_norm = _compute_grad_norm(transformed_grad) @@ -588,8 +634,16 @@ def scale_gradients_by_optimizer( ) -> Tuple[NestedMap, JTensor]: optimizer_mask, default_mask = self.get_masks(var_weight_hparams) - all_grads, all_valid_step = self.scale_gradients( - jax.tree_map(lambda x, y: x * y, raw_grads, default_mask), + grads_after_default_mask = jax.tree_map(lambda x, y: x * y, raw_grads, default_mask) + + grad_norm, all_valid_step = self.get_grad_norm_valid_step( + grads_after_default_mask, + optimizer_name='main', + ) + + all_grads = self.scale_gradients( + grads_after_default_mask, + grad_norm, optimizer_name='main', ) @@ -600,8 +654,15 @@ def scale_gradients_by_optimizer( ): assert optimizer.clip_gradient_norm_to_value is not None assert optimizer.clip_gradient_single_norm_to_value is not None - grads, valid_step = self.scale_gradients( - jax.tree_map(lambda x, y: x * y, raw_grads, mask), + + grads_after_mask = jax.tree_map(lambda x, y: x * y, raw_grads, mask) + grad_norm, valid_step = self.get_grad_norm_valid_step( + grads_after_mask, + optimizer_name=name, + ) + grads = self.scale_gradients( + grads_after_mask, + grad_norm, optimizer_name=name, clip_gradient_norm_to_value=optimizer.clip_gradient_norm_to_value, clip_gradient_single_norm_to_value=optimizer.clip_gradient_single_norm_to_value, @@ -633,7 +694,8 @@ def update_states( grads, var_weight_hparams ) else: - grads, valid_step = self.scale_gradients(grads) + grad_norm, valid_step = self.get_grad_norm_valid_step(grads) + grads = self.scale_gradients(grads, grad_norm) grad_tx = self.get_grad_tx(var_weight_hparams) transformed_grad, new_states = grad_tx.update(grads, states, old_vars) if self._hparams.enable_skip_step_on_gradient_anomalies: diff --git a/paxml/learners_test.py b/paxml/learners_test.py index 92cc76bf9..45b6d50c1 100644 --- a/paxml/learners_test.py +++ b/paxml/learners_test.py @@ -64,7 +64,8 @@ def test_learner_clip_gradients(self, g1a, g1b, g2, global_clip_norm, grad2=jnp.array([g2], dtype=jnp.float32)) with base_layer.JaxContext.new_context(): - transformed_grads, _ = learner_instance.scale_gradients(grads) + grad_norm, valid_step = learner_instance.get_grad_norm_valid_step(grads) + transformed_grads = learner_instance.scale_gradients(grads, grad_norm) global_norm = np.linalg.norm([g1a, g1b, g2]) local_norm1 = np.linalg.norm([g1a, g1b])