@@ -225,7 +225,11 @@ def __init__(
225225 self .consts = consts
226226 self .in_tree = in_tree
227227 self .params_index = params_index
228- self .layer_tags , self .loss_tags = extract_tags (jaxpr )
228+ # clean jaxpr of layer tags at this point
229+ closed_jaxpr = jex .core .ClosedJaxpr (jaxpr = self .jaxpr , consts = self .consts )
230+ self .jaxpr , self .layer_tags = tgm .clean_layer_tags_jaxpr (closed_jaxpr )
231+ self .jaxpr = self .jaxpr .jaxpr
232+ _ , self .loss_tags = extract_tags (self .jaxpr )
229233 name_layer_tags (self .layer_tags )
230234 self .layer_tags , self .layer_indices = order_layer_tags (
231235 params_vars_flat = self .params_vars_flat ,
@@ -579,7 +583,7 @@ def _loss_tags_vjp(
579583 p_jaxpr : ProcessedJaxpr ,
580584 primal_func_args : FuncArgs ,
581585) -> LossTagsVjp :
582- """Computes a (backward-mode) vector-Jacobian product w.r.t. all loss tags.
586+ """Computes a (backward-mode) vector-Jacobian product for the vector of losses given by the loss tags.
583587
584588 The function has similar interface to :func:`jax.vjp`. It takes as inputs the
585589 concrete values of the primals at which the Jacobian will be evaluated. It
@@ -651,7 +655,7 @@ def _loss_tags_jvp(
651655 primal_func_args : FuncArgs ,
652656 params_tangents : Params ,
653657) -> LossTagsJvp :
654- """Computes a (forward-mode) Jacobian-vector product w.r.t. all loss tags.
658+ """Computes a (forward-mode) Jacobian-vector product for the losses given by the loss tags.
655659
656660 The function has similar interface to :func:`jax.jvp`. It takes as inputs the
657661 concrete values of the primals at which the Jacobian will be evaluated at and
@@ -795,7 +799,7 @@ def _layer_tag_vjp(
795799 Args:
796800 processed_jaxpr: The :class:`~ProcessedJaxpr` representing the model
797801 function. This must include at least one loss tag.
798- primal_func_args: The primals at which to evaluate the Hessian .
802+ primal_func_args: The primals at which to evaluate the Jacobian .
799803
800804 Returns:
801805 The computed ``losses`` and ``vjp_func`` pair.
@@ -824,14 +828,12 @@ def forward() -> tuple[Array, ...]:
824828 # Loop through equations and evaluate them
825829 num_losses_passed = 0
826830 for eqn in processed_jaxpr .jaxpr .eqns :
827-
828- write (eqn .outvars , tgm .eval_jaxpr_eqn (eqn , read (eqn .invars )))
829-
830831 if isinstance (eqn .primitive , tags .LossTag ):
831832 num_losses_passed += 1
832833 if num_losses_passed == len (processed_jaxpr .loss_tags ):
833834 break
834-
835+ else :
836+ write (eqn .outvars , tgm .eval_jaxpr_eqn (eqn , read (eqn .invars )))
835837 assert num_losses_passed == len (processed_jaxpr .loss_tags )
836838
837839 return tuple (read (layer_input_vars ))
@@ -884,7 +886,6 @@ def write(variables: list[jex.core.Var], values: list[Array]) -> None:
884886 for eqn in processed_jaxpr .jaxpr .eqns :
885887
886888 input_values = read (eqn .invars )
887- write (eqn .outvars , tgm .eval_jaxpr_eqn (eqn , input_values ))
888889
889890 if isinstance (eqn .primitive , tags .LossTag ):
890891 loss : LossFunction = tags .loss_eqn_construct_loss (eqn , * input_values )
@@ -896,6 +897,8 @@ def write(variables: list[jex.core.Var], values: list[Array]) -> None:
896897
897898 if num_losses_passed == len (processed_jaxpr .loss_tags ):
898899 break
900+ else :
901+ write (eqn .outvars , tgm .eval_jaxpr_eqn (eqn , input_values ))
899902
900903 assert num_losses_passed == len (processed_jaxpr .loss_tags )
901904
0 commit comments