Skip to content

[levanter] Share Pallas autotune helpers and restore compile offload#4130

Merged
dlwh merged 8 commits intomarin-community:mainfrom
dlwh:codex/pallas-autotune-helper
Mar 25, 2026
Merged

[levanter] Share Pallas autotune helpers and restore compile offload#4130
dlwh merged 8 commits intomarin-community:mainfrom
dlwh:codex/pallas-autotune-helper

Conversation

@dlwh
Copy link
Copy Markdown
Member

@dlwh dlwh commented Mar 25, 2026

Move shard-aware autotune benchmarking out of fused cross-entropy into a shared Pallas helper and restore compile offload for shard-mapped autotune sweeps. This keeps fused CE behavior intact while making the benchmark path reusable from other kernels.

Fixes #4129

dlwh added 4 commits March 20, 2026 22:10
…n35-a3b-direct-compare

# Conflicts:
#	lib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_loss/api.py
#	lib/levanter/tests/kernels/test_pallas_fused_cross_entropy_loss.py
@dlwh dlwh added the agent-generated Created by automation/agent label Mar 25, 2026
@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented Mar 25, 2026

@claude review

@claude
Copy link
Copy Markdown
Contributor

claude Bot commented Mar 25, 2026

Claude Code is working…

I'll analyze this and get back to you.

View job run

@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented Mar 25, 2026

@dlwh-golem will you review please

@dlwh-golem
Copy link
Copy Markdown
Contributor

No findings.



I reviewed the helper extraction in lib/levanter/src/levanter/kernels/pallas/autotune_utils.py and the fused CE integration in lib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_loss/api.py. The extraction preserved the prior shard-map/manual-sharding/tracer behavior, and the new generic tests cover the moved logic plus the fused CE autotune sweep path.



I also ran:

uv run --project lib/levanter --group test python -m pytest -q lib/levanter/tests/kernels/test_pallas_autotune_utils.py lib/levanter/tests/kernels/test_pallas_fused_cross_entropy_loss.py



Result: 68 passed, 12 skipped in about 9s. Residual risk is low; the only thing not really proven yet is behavior for future non-CE adopters with more complex out_specs, but nothing in this PR looks blocking.

Codex work

Codex transcript

Completed items:
- none captured

Transcript:
No findings.

I reviewed the helper extraction in [lib/levanter/src/levanter/kernels/pallas/autotune_utils.py](/home/golem/oa-jder-bot/worktrees/github.com/marin-community/marin/4130/lib/levanter/src/levanter/kernels/pallas/autotune_utils.py) and the fused CE integration in [lib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_loss/api.py](/home/golem/oa-jder-bot/worktrees/github.com/marin-community/marin/4130/lib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_loss/api.py). The extraction preserved the prior shard-map/manual-sharding/tracer behavior, and the new generic tests cover the moved logic plus the fused CE autotune sweep path.

I also ran:
`uv run --project lib/levanter --group test python -m pytest -q lib/levanter/tests/kernels/test_pallas_autotune_utils.py lib/levanter/tests/kernels/test_pallas_fused_cross_entropy_loss.py`

Result: `68 passed, 12 skipped` in about 9s. Residual risk is low; the only thing not really proven yet is behavior for future non-CE adopters with more complex `out_specs`, but nothing in this PR looks blocking.

@dlwh dlwh changed the title [levanter] Extract shared Pallas autotune helpers [levanter] Share Pallas autotune helpers and restore compile offload Mar 25, 2026
@dlwh dlwh requested a review from yonromai March 25, 2026 06:34
Copy link
Copy Markdown
Contributor

@yonromai yonromai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved. I did not find a clear regression in the helper extraction or in the restored compile-offload path. The shared helper preserves the prior fused CE behavior around manual sharding, shard-map wrapping, tracer lowering, and fallback handling, and the targeted tests cover both the new helper and the fused CE sweep path.

Generated with Codex

@dlwh dlwh enabled auto-merge (squash) March 25, 2026 18:50
@dlwh dlwh disabled auto-merge March 25, 2026 18:50
@dlwh dlwh merged commit 6553f3f into marin-community:main Mar 25, 2026
32 of 33 checks passed
Helw150 pushed a commit that referenced this pull request Apr 8, 2026
…4130)

Move shard-aware autotune benchmarking out of fused cross-entropy into a
shared Pallas helper and restore compile offload for shard-mapped
autotune sweeps. This keeps fused CE behavior intact while making the
benchmark path reusable from other kernels.

Fixes #4129
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

agent-generated Created by automation/agent

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[levanter] Extract shared Pallas autotune sharding helpers

3 participants