4242_NAME_STACK_PART_RE = re .compile (r"^(?P<wrapper>[A-Za-z_][A-Za-z0-9_]*)\((?P<inner>.*)\)$" )
4343_STAT_NAMES = ("norm" , "rms" , "rms_scaled" , "mean_abs" , "max_abs" , "max_abs_scaled" , "finite_fraction" )
4444BackwardFlowSite : TypeAlias = Literal ["in" , "out" ]
45- BackwardFlowTensorKind : TypeAlias = Literal ["activation" , "gradient" ]
46- BACKWARD_FLOW_SITE_IN : BackwardFlowSite = "in"
47- BACKWARD_FLOW_SITE_OUT : BackwardFlowSite = "out"
48- BACKWARD_FLOW_KIND_ACTIVATION : BackwardFlowTensorKind = "activation"
49- BACKWARD_FLOW_KIND_GRADIENT : BackwardFlowTensorKind = "gradient"
50- _FLOW_SITES = (BACKWARD_FLOW_SITE_IN , BACKWARD_FLOW_SITE_OUT )
51- _TENSOR_KINDS = (BACKWARD_FLOW_KIND_ACTIVATION , BACKWARD_FLOW_KIND_GRADIENT )
45+ _BackwardFlowTensorKind : TypeAlias = Literal ["activation" , "gradient" ]
46+ BWD_IN : BackwardFlowSite = "in"
47+ BWD_OUT : BackwardFlowSite = "out"
48+ _BWD_KIND_ACTIVATION : _BackwardFlowTensorKind = "activation"
49+ _BWD_KIND_GRADIENT : _BackwardFlowTensorKind = "gradient"
50+ _FLOW_SITES = (BWD_IN , BWD_OUT )
51+ _TENSOR_KINDS = (_BWD_KIND_ACTIVATION , _BWD_KIND_GRADIENT )
5252_FLOW_DIRECTIONS = ("tb" , "lr" )
5353_DEFAULT_PREFIX = "backward_flow"
5454_DEFAULT_RESIDUAL_GAIN_HORIZON = 50
@@ -202,7 +202,7 @@ def normalize_name_stack(name_stack: str) -> str:
202202 return "/" .join (parts )
203203
204204
205- def log_backward_activation (x : jax .Array , * , site : BackwardFlowSite = BACKWARD_FLOW_SITE_OUT ) -> jax .Array :
205+ def log_backward_activation (x : jax .Array , * , site : BackwardFlowSite = BWD_OUT ) -> jax .Array :
206206 """Return ``x`` unchanged while logging activation and backward-gradient scale when enabled."""
207207 context = _ACTIVE_CONTEXT .get ()
208208 if context is None :
@@ -219,9 +219,7 @@ def log_backward_activation(x: jax.Array, *, site: BackwardFlowSite = BACKWARD_F
219219 return _tagged_identity_with_scale (f"{ context .prefix } /{ name_stack } " , site , context .gradient_scale , x )
220220
221221
222- def trace_backward_activation (
223- x : jax .Array , name : str , * , site : BackwardFlowSite = BACKWARD_FLOW_SITE_OUT
224- ) -> jax .Array :
222+ def trace_backward_activation (x : jax .Array , name : str , * , site : BackwardFlowSite = BWD_OUT ) -> jax .Array :
225223 """Return ``x`` unchanged while logging under an extra JAX named scope."""
226224 if not name :
227225 raise ValueError ("name must be non-empty" )
@@ -236,18 +234,14 @@ def _tagged_identity(metric_prefix: str, site: BackwardFlowSite, x: jax.Array) -
236234
237235
238236def _tagged_identity_fwd (metric_prefix : str , site : BackwardFlowSite , x : jax .Array ) -> tuple [jax .Array , None ]:
239- levanter .tracker .jit_log (
240- _tensor_metrics (metric_prefix , x , site = site , kind = BACKWARD_FLOW_KIND_ACTIVATION ), step = None
241- )
237+ levanter .tracker .jit_log (_tensor_metrics (metric_prefix , x , site = site , kind = _BWD_KIND_ACTIVATION ), step = None )
242238 return x , None
243239
244240
245241def _tagged_identity_bwd (
246242 metric_prefix : str , site : BackwardFlowSite , _residual : None , cotangent : jax .Array
247243) -> tuple [jax .Array ]:
248- levanter .tracker .jit_log (
249- _tensor_metrics (metric_prefix , cotangent , site = site , kind = BACKWARD_FLOW_KIND_GRADIENT ), step = None
250- )
244+ levanter .tracker .jit_log (_tensor_metrics (metric_prefix , cotangent , site = site , kind = _BWD_KIND_GRADIENT ), step = None )
251245 return (cotangent ,)
252246
253247
@@ -264,9 +258,7 @@ def _tagged_identity_with_scale(
264258def _tagged_identity_with_scale_fwd (
265259 metric_prefix : str , site : BackwardFlowSite , gradient_scale : jax .Array , x : jax .Array
266260) -> tuple [jax .Array , jax .Array ]:
267- levanter .tracker .jit_log (
268- _tensor_metrics (metric_prefix , x , site = site , kind = BACKWARD_FLOW_KIND_ACTIVATION ), step = None
269- )
261+ levanter .tracker .jit_log (_tensor_metrics (metric_prefix , x , site = site , kind = _BWD_KIND_ACTIVATION ), step = None )
270262 return x , gradient_scale
271263
272264
@@ -281,7 +273,7 @@ def _tagged_identity_with_scale_bwd(
281273 metric_prefix ,
282274 cotangent ,
283275 site = site ,
284- kind = BACKWARD_FLOW_KIND_GRADIENT ,
276+ kind = _BWD_KIND_GRADIENT ,
285277 gradient_scale = gradient_scale ,
286278 ),
287279 step = None ,
@@ -519,12 +511,12 @@ def _tensor_metrics(
519511 tensor : jax .Array ,
520512 * ,
521513 site : BackwardFlowSite ,
522- kind : BackwardFlowTensorKind ,
514+ kind : _BackwardFlowTensorKind ,
523515 gradient_scale : jax .Array | None = None ,
524516) -> dict [str , jax .Array ]:
525517 summary = SummaryStats .from_tensor (tensor )
526518 metrics = summary .to_metrics (f"{ metric_prefix } /{ site } _{ kind } " )
527- if kind == BACKWARD_FLOW_KIND_GRADIENT and gradient_scale is not None :
519+ if kind == _BWD_KIND_GRADIENT and gradient_scale is not None :
528520 gradient_scale = jnp .asarray (gradient_scale , dtype = jnp .float32 )
529521 metrics [f"{ metric_prefix } /{ site } _{ kind } _rms_scaled" ] = summary .rms * gradient_scale
530522 metrics [f"{ metric_prefix } /{ site } _{ kind } _max_abs_scaled" ] = summary .max_abs * gradient_scale
@@ -1116,57 +1108,47 @@ def _is_supported_metric_name(metric_name: str) -> bool:
11161108
11171109
11181110def _metric_value (
1119- stats : Mapping [str , float ], site : BackwardFlowSite , kind : BackwardFlowTensorKind , metric : str
1111+ stats : Mapping [str , float ], site : BackwardFlowSite , kind : _BackwardFlowTensorKind , metric : str
11201112) -> float | None :
11211113 return stats .get (f"{ site } _{ kind } _{ metric } " )
11221114
11231115
11241116def _preferred_gradient_rms (stats : Mapping [str , float ]) -> float | None :
1125- scaled = _preferred_metric (stats , BACKWARD_FLOW_KIND_GRADIENT , "rms_scaled" , preferred_site = BACKWARD_FLOW_SITE_IN )
1117+ scaled = _preferred_metric (stats , _BWD_KIND_GRADIENT , "rms_scaled" , preferred_site = BWD_IN )
11261118 if scaled is not None :
11271119 return scaled
1128- return _preferred_metric (stats , BACKWARD_FLOW_KIND_GRADIENT , "rms" , preferred_site = BACKWARD_FLOW_SITE_IN )
1120+ return _preferred_metric (stats , _BWD_KIND_GRADIENT , "rms" , preferred_site = BWD_IN )
11291121
11301122
11311123def _has_scaled_gradient_rms (stats : Mapping [str , float ]) -> bool :
1132- return any (
1133- _metric_value (stats , site , BACKWARD_FLOW_KIND_GRADIENT , "rms_scaled" ) is not None for site in _FLOW_SITES
1134- )
1124+ return any (_metric_value (stats , site , _BWD_KIND_GRADIENT , "rms_scaled" ) is not None for site in _FLOW_SITES )
11351125
11361126
11371127def _preferred_gradient_max_abs (stats : Mapping [str , float ]) -> float | None :
1138- scaled = _preferred_metric (
1139- stats , BACKWARD_FLOW_KIND_GRADIENT , "max_abs_scaled" , preferred_site = BACKWARD_FLOW_SITE_IN
1140- )
1128+ scaled = _preferred_metric (stats , _BWD_KIND_GRADIENT , "max_abs_scaled" , preferred_site = BWD_IN )
11411129 if scaled is not None :
11421130 return scaled
1143- return _preferred_metric (stats , BACKWARD_FLOW_KIND_GRADIENT , "max_abs" , preferred_site = BACKWARD_FLOW_SITE_IN )
1131+ return _preferred_metric (stats , _BWD_KIND_GRADIENT , "max_abs" , preferred_site = BWD_IN )
11441132
11451133
11461134def _has_scaled_gradient_max_abs (stats : Mapping [str , float ]) -> bool :
1147- return any (
1148- _metric_value (stats , site , BACKWARD_FLOW_KIND_GRADIENT , "max_abs_scaled" ) is not None for site in _FLOW_SITES
1149- )
1135+ return any (_metric_value (stats , site , _BWD_KIND_GRADIENT , "max_abs_scaled" ) is not None for site in _FLOW_SITES )
11501136
11511137
11521138def _preferred_activation_rms (stats : Mapping [str , float ]) -> float | None :
1153- return _preferred_metric (stats , BACKWARD_FLOW_KIND_ACTIVATION , "rms" , preferred_site = BACKWARD_FLOW_SITE_OUT )
1139+ return _preferred_metric (stats , _BWD_KIND_ACTIVATION , "rms" , preferred_site = BWD_OUT )
11541140
11551141
11561142def _preferred_finite_fraction (stats : Mapping [str , float ]) -> float | None :
1157- gradient_fraction = _preferred_metric (
1158- stats , BACKWARD_FLOW_KIND_GRADIENT , "finite_fraction" , preferred_site = BACKWARD_FLOW_SITE_IN
1159- )
1143+ gradient_fraction = _preferred_metric (stats , _BWD_KIND_GRADIENT , "finite_fraction" , preferred_site = BWD_IN )
11601144 if gradient_fraction is not None :
11611145 return gradient_fraction
1162- return _preferred_metric (
1163- stats , BACKWARD_FLOW_KIND_ACTIVATION , "finite_fraction" , preferred_site = BACKWARD_FLOW_SITE_OUT
1164- )
1146+ return _preferred_metric (stats , _BWD_KIND_ACTIVATION , "finite_fraction" , preferred_site = BWD_OUT )
11651147
11661148
11671149def _preferred_metric (
11681150 stats : Mapping [str , float ],
1169- kind : BackwardFlowTensorKind ,
1151+ kind : _BackwardFlowTensorKind ,
11701152 metric : str ,
11711153 * ,
11721154 preferred_site : BackwardFlowSite ,
0 commit comments