1212from itertools import pairwise
1313import math
1414import re
15- from typing import Any , Iterable , Literal , Mapping
15+ from typing import Any , Iterable , Literal , Mapping , TypeAlias
1616
1717import jax
1818from jax ._src .core import Literal as JaxprLiteral
4141)
4242_NAME_STACK_PART_RE = re .compile (r"^(?P<wrapper>[A-Za-z_][A-Za-z0-9_]*)\((?P<inner>.*)\)$" )
4343_STAT_NAMES = ("norm" , "rms" , "rms_scaled" , "mean_abs" , "max_abs" , "max_abs_scaled" , "finite_fraction" )
44- _FLOW_SITES = ("in" , "out" )
45- _TENSOR_KINDS = ("activation" , "gradient" )
44+ BackwardFlowSite : TypeAlias = Literal ["in" , "out" ]
45+ BackwardFlowTensorKind : TypeAlias = Literal ["activation" , "gradient" ]
46+ BACKWARD_FLOW_SITE_IN : BackwardFlowSite = "in"
47+ BACKWARD_FLOW_SITE_OUT : BackwardFlowSite = "out"
48+ BACKWARD_FLOW_KIND_ACTIVATION : BackwardFlowTensorKind = "activation"
49+ BACKWARD_FLOW_KIND_GRADIENT : BackwardFlowTensorKind = "gradient"
50+ _FLOW_SITES = (BACKWARD_FLOW_SITE_IN , BACKWARD_FLOW_SITE_OUT )
51+ _TENSOR_KINDS = (BACKWARD_FLOW_KIND_ACTIVATION , BACKWARD_FLOW_KIND_GRADIENT )
4652_FLOW_DIRECTIONS = ("tb" , "lr" )
4753_DEFAULT_PREFIX = "backward_flow"
4854_DEFAULT_RESIDUAL_GAIN_HORIZON = 50
@@ -196,7 +202,7 @@ def normalize_name_stack(name_stack: str) -> str:
196202 return "/" .join (parts )
197203
198204
199- def log_backward_activation (x : jax .Array , * , site : str = "out" ) -> jax .Array :
205+ def log_backward_activation (x : jax .Array , * , site : BackwardFlowSite = BACKWARD_FLOW_SITE_OUT ) -> jax .Array :
200206 """Return ``x`` unchanged while logging activation and backward-gradient scale when enabled."""
201207 context = _ACTIVE_CONTEXT .get ()
202208 if context is None :
@@ -213,7 +219,9 @@ def log_backward_activation(x: jax.Array, *, site: str = "out") -> jax.Array:
213219 return _tagged_identity_with_scale (f"{ context .prefix } /{ name_stack } " , site , context .gradient_scale , x )
214220
215221
216- def trace_backward_activation (x : jax .Array , name : str , * , site : str = "out" ) -> jax .Array :
222+ def trace_backward_activation (
223+ x : jax .Array , name : str , * , site : BackwardFlowSite = BACKWARD_FLOW_SITE_OUT
224+ ) -> jax .Array :
217225 """Return ``x`` unchanged while logging under an extra JAX named scope."""
218226 if not name :
219227 raise ValueError ("name must be non-empty" )
@@ -223,43 +231,59 @@ def trace_backward_activation(x: jax.Array, name: str, *, site: str = "out") ->
223231
224232
225233@functools .partial (jax .custom_vjp , nondiff_argnums = (0 , 1 ))
226- def _tagged_identity (metric_prefix : str , site : str , x : jax .Array ) -> jax .Array :
234+ def _tagged_identity (metric_prefix : str , site : BackwardFlowSite , x : jax .Array ) -> jax .Array :
227235 return x
228236
229237
230- def _tagged_identity_fwd (metric_prefix : str , site : str , x : jax .Array ) -> tuple [jax .Array , None ]:
231- levanter .tracker .jit_log (_tensor_metrics (metric_prefix , x , site = site , kind = "activation" ), step = None )
238+ def _tagged_identity_fwd (metric_prefix : str , site : BackwardFlowSite , x : jax .Array ) -> tuple [jax .Array , None ]:
239+ levanter .tracker .jit_log (
240+ _tensor_metrics (metric_prefix , x , site = site , kind = BACKWARD_FLOW_KIND_ACTIVATION ), step = None
241+ )
232242 return x , None
233243
234244
235- def _tagged_identity_bwd (metric_prefix : str , site : str , _residual : None , cotangent : jax .Array ) -> tuple [jax .Array ]:
236- levanter .tracker .jit_log (_tensor_metrics (metric_prefix , cotangent , site = site , kind = "gradient" ), step = None )
245+ def _tagged_identity_bwd (
246+ metric_prefix : str , site : BackwardFlowSite , _residual : None , cotangent : jax .Array
247+ ) -> tuple [jax .Array ]:
248+ levanter .tracker .jit_log (
249+ _tensor_metrics (metric_prefix , cotangent , site = site , kind = BACKWARD_FLOW_KIND_GRADIENT ), step = None
250+ )
237251 return (cotangent ,)
238252
239253
240254_tagged_identity .defvjp (_tagged_identity_fwd , _tagged_identity_bwd )
241255
242256
243257@functools .partial (jax .custom_vjp , nondiff_argnums = (0 , 1 ))
244- def _tagged_identity_with_scale (metric_prefix : str , site : str , gradient_scale : jax .Array , x : jax .Array ) -> jax .Array :
258+ def _tagged_identity_with_scale (
259+ metric_prefix : str , site : BackwardFlowSite , gradient_scale : jax .Array , x : jax .Array
260+ ) -> jax .Array :
245261 return x
246262
247263
248264def _tagged_identity_with_scale_fwd (
249- metric_prefix : str , site : str , gradient_scale : jax .Array , x : jax .Array
265+ metric_prefix : str , site : BackwardFlowSite , gradient_scale : jax .Array , x : jax .Array
250266) -> tuple [jax .Array , jax .Array ]:
251- levanter .tracker .jit_log (_tensor_metrics (metric_prefix , x , site = site , kind = "activation" ), step = None )
267+ levanter .tracker .jit_log (
268+ _tensor_metrics (metric_prefix , x , site = site , kind = BACKWARD_FLOW_KIND_ACTIVATION ), step = None
269+ )
252270 return x , gradient_scale
253271
254272
255273def _tagged_identity_with_scale_bwd (
256274 metric_prefix : str ,
257- site : str ,
275+ site : BackwardFlowSite ,
258276 gradient_scale : jax .Array ,
259277 cotangent : jax .Array ,
260278) -> tuple [jax .Array , jax .Array ]:
261279 levanter .tracker .jit_log (
262- _tensor_metrics (metric_prefix , cotangent , site = site , kind = "gradient" , gradient_scale = gradient_scale ),
280+ _tensor_metrics (
281+ metric_prefix ,
282+ cotangent ,
283+ site = site ,
284+ kind = BACKWARD_FLOW_KIND_GRADIENT ,
285+ gradient_scale = gradient_scale ,
286+ ),
263287 step = None ,
264288 )
265289 return jnp .zeros_like (gradient_scale ), cotangent
@@ -494,13 +518,13 @@ def _tensor_metrics(
494518 metric_prefix : str ,
495519 tensor : jax .Array ,
496520 * ,
497- site : str ,
498- kind : str ,
521+ site : BackwardFlowSite ,
522+ kind : BackwardFlowTensorKind ,
499523 gradient_scale : jax .Array | None = None ,
500524) -> dict [str , jax .Array ]:
501525 summary = SummaryStats .from_tensor (tensor )
502526 metrics = summary .to_metrics (f"{ metric_prefix } /{ site } _{ kind } " )
503- if kind == "gradient" and gradient_scale is not None :
527+ if kind == BACKWARD_FLOW_KIND_GRADIENT and gradient_scale is not None :
504528 gradient_scale = jnp .asarray (gradient_scale , dtype = jnp .float32 )
505529 metrics [f"{ metric_prefix } /{ site } _{ kind } _rms_scaled" ] = summary .rms * gradient_scale
506530 metrics [f"{ metric_prefix } /{ site } _{ kind } _max_abs_scaled" ] = summary .max_abs * gradient_scale
@@ -1091,49 +1115,61 @@ def _is_supported_metric_name(metric_name: str) -> bool:
10911115 return False
10921116
10931117
1094- def _metric_value (stats : Mapping [str , float ], site : str , kind : str , metric : str ) -> float | None :
1118+ def _metric_value (
1119+ stats : Mapping [str , float ], site : BackwardFlowSite , kind : BackwardFlowTensorKind , metric : str
1120+ ) -> float | None :
10951121 return stats .get (f"{ site } _{ kind } _{ metric } " )
10961122
10971123
10981124def _preferred_gradient_rms (stats : Mapping [str , float ]) -> float | None :
1099- scaled = _preferred_metric (stats , "gradient" , "rms_scaled" , preferred_site = "in" )
1125+ scaled = _preferred_metric (stats , BACKWARD_FLOW_KIND_GRADIENT , "rms_scaled" , preferred_site = BACKWARD_FLOW_SITE_IN )
11001126 if scaled is not None :
11011127 return scaled
1102- return _preferred_metric (stats , "gradient" , "rms" , preferred_site = "in" )
1128+ return _preferred_metric (stats , BACKWARD_FLOW_KIND_GRADIENT , "rms" , preferred_site = BACKWARD_FLOW_SITE_IN )
11031129
11041130
11051131def _has_scaled_gradient_rms (stats : Mapping [str , float ]) -> bool :
1106- return any (_metric_value (stats , site , "gradient" , "rms_scaled" ) is not None for site in _FLOW_SITES )
1132+ return any (
1133+ _metric_value (stats , site , BACKWARD_FLOW_KIND_GRADIENT , "rms_scaled" ) is not None for site in _FLOW_SITES
1134+ )
11071135
11081136
11091137def _preferred_gradient_max_abs (stats : Mapping [str , float ]) -> float | None :
1110- scaled = _preferred_metric (stats , "gradient" , "max_abs_scaled" , preferred_site = "in" )
1138+ scaled = _preferred_metric (
1139+ stats , BACKWARD_FLOW_KIND_GRADIENT , "max_abs_scaled" , preferred_site = BACKWARD_FLOW_SITE_IN
1140+ )
11111141 if scaled is not None :
11121142 return scaled
1113- return _preferred_metric (stats , "gradient" , "max_abs" , preferred_site = "in" )
1143+ return _preferred_metric (stats , BACKWARD_FLOW_KIND_GRADIENT , "max_abs" , preferred_site = BACKWARD_FLOW_SITE_IN )
11141144
11151145
11161146def _has_scaled_gradient_max_abs (stats : Mapping [str , float ]) -> bool :
1117- return any (_metric_value (stats , site , "gradient" , "max_abs_scaled" ) is not None for site in _FLOW_SITES )
1147+ return any (
1148+ _metric_value (stats , site , BACKWARD_FLOW_KIND_GRADIENT , "max_abs_scaled" ) is not None for site in _FLOW_SITES
1149+ )
11181150
11191151
11201152def _preferred_activation_rms (stats : Mapping [str , float ]) -> float | None :
1121- return _preferred_metric (stats , "activation" , "rms" , preferred_site = "out" )
1153+ return _preferred_metric (stats , BACKWARD_FLOW_KIND_ACTIVATION , "rms" , preferred_site = BACKWARD_FLOW_SITE_OUT )
11221154
11231155
11241156def _preferred_finite_fraction (stats : Mapping [str , float ]) -> float | None :
1125- gradient_fraction = _preferred_metric (stats , "gradient" , "finite_fraction" , preferred_site = "in" )
1157+ gradient_fraction = _preferred_metric (
1158+ stats , BACKWARD_FLOW_KIND_GRADIENT , "finite_fraction" , preferred_site = BACKWARD_FLOW_SITE_IN
1159+ )
11261160 if gradient_fraction is not None :
11271161 return gradient_fraction
1128- return _preferred_metric (stats , "activation" , "finite_fraction" , preferred_site = "out" )
1162+ return _preferred_metric (
1163+ stats , BACKWARD_FLOW_KIND_ACTIVATION , "finite_fraction" , preferred_site = BACKWARD_FLOW_SITE_OUT
1164+ )
11291165
11301166
11311167def _preferred_metric (
11321168 stats : Mapping [str , float ],
1133- kind : str ,
1169+ kind : BackwardFlowTensorKind ,
11341170 metric : str ,
11351171 * ,
1136- preferred_site : str ,
1172+ preferred_site : BackwardFlowSite ,
11371173) -> float | None :
11381174 preferred = _metric_value (stats , preferred_site , kind , metric )
11391175 if preferred is not None :
0 commit comments