Skip to content

Add Grug backward-flow logging#5036

Open
dlwh wants to merge 1 commit intomainfrom
codex/grug-backward-flow-logging
Open

Add Grug backward-flow logging#5036
dlwh wants to merge 1 commit intomainfrom
codex/grug-backward-flow-logging

Conversation

@dlwh
Copy link
Copy Markdown
Member

@dlwh dlwh commented Apr 22, 2026

Add sampled backward-flow probes for Grug, a reusable Levanter analysis renderer, and W&B HTML logging. The run path logs activation and cotangent summaries, renders residual-stream DAG artifacts, and keeps the normal train-step path disabled by default. Adds design and recipe docs plus regression tests. ArrayStacked support is tracked separately in #5030.

@dlwh dlwh added the agent-generated Created by automation/agent label Apr 22, 2026
@claude
Copy link
Copy Markdown
Contributor

claude Bot commented Apr 22, 2026

Claude finished @dlwh's task in 10m 41s —— View job


Code Review

Reviewed the diff for bugs and CLAUDE.md/AGENTS.md compliance. One issue found and posted as an inline comment.

Findings

  • ⚠️ _is_literal ad-hoc compatibility hacklib/levanter/src/levanter/analysis/backward_flow.py:1160-1170 layers three fallbacks (getattr(jax.core, "Literal", None), type(...).__name__ == "Literal", hash-based TypeError probe) for what should be a direct isinstance(value, jax.core.Literal) check. Violates AGENTS.md:63 ("No ad-hoc compatibility hacks").

Considered but not flagged (reviewed and dismissed as false positives):

  • __all__ in lib/levanter/src/levanter/analysis/__init__.py — pre-existing, pervasive pattern across 13+ levanter __init__.py files; this PR only extends the existing list.
  • TYPE_CHECKING guard in lib/levanter/src/levanter/tracker/wandb.py — pre-existing, not added in this PR.
  • Tracker.log_html default no-op — deliberate optional-capability pattern (same shape as __enter__/__exit__ on Tracker); the rule against "swallowing exceptions" governs try/except: pass, not optional-method defaults. The HTML file is also still written to disk regardless of the tracker.
  • Compute/I/O mixing in _write_backward_flow_artifact — heavy compute (render_backward_flow_html) is already separated at the module level; the wrapper bundles cheap glue with the write call.
  • HTML/SVG injection — all user-controlled strings (node names from jax.named_scope, plate labels, titles) flow through html.escape() before being embedded.
  • custom_vjp correctness — fwd/bwd signatures match for both _tagged_identity and _tagged_identity_with_scale; cotangents pass through identity.
  • JIT specialization — compute_backward_flow is correctly listed in static_argnames.
  • Disabled-by-default — BackwardFlowConfig(interval=0).is_enabled is checked consistently at the trainer config, train-step entry, and per-step decision.

LGTM otherwise — the design is clean, the kill switch is real, and test coverage in test_backward_flow.py is solid.
• Branch: codex/grug-backward-flow-logging

Comment on lines +1160 to +1170
def _is_literal(value: Any) -> bool:
literal_type = getattr(jax.core, "Literal", None)
if literal_type is not None and isinstance(value, literal_type):
return True
if type(value).__name__ == "Literal":
return True
try:
hash(value)
except TypeError:
return True
return False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_is_literal layers three defensive fallbacks (getattr(jax.core, "Literal", None), type(...).__name__ == "Literal", and a hash-based TypeError probe) to detect jax.core.Literal instances. This violates AGENTS.md:63:

No ad-hoc compatibility hacks (hasattr(m, "old_attr")); update code consistently.

jax.core.Literal is a long-stable public symbol present in the pinned JAX version, and eqn.invars/eqn.outvars only ever contain jax.core.Atom (Var or Literal), so the direct isinstance check is sufficient. The hash-based fallback is also semantically backwards relative to JAX's actual type model (Literals are hashable; Vars are not — so hash raising TypeError would suggest the opposite of what the function returns).

Suggested change
def _is_literal(value: Any) -> bool:
literal_type = getattr(jax.core, "Literal", None)
if literal_type is not None and isinstance(value, literal_type):
return True
if type(value).__name__ == "Literal":
return True
try:
hash(value)
except TypeError:
return True
return False
def _is_literal(value: Any) -> bool:
return isinstance(value, jax.core.Literal)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

agent-generated Created by automation/agent

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant