Skip to content

feat: add fsspec streaming weight loader for fast GCS transport#1

Merged
ahmeda14960 merged 2 commits intomarinfrom
feat/fast-tpu-bootstrap-v0.13.2
Mar 22, 2026
Merged

feat: add fsspec streaming weight loader for fast GCS transport#1
ahmeda14960 merged 2 commits intomarinfrom
feat/fast-tpu-bootstrap-v0.13.2

Conversation

@ahmeda14960
Copy link
Copy Markdown
Collaborator

Summary

  • Add fsspec-based streaming weight loader (streaming_weights.py) that bypasses slow RunAI single-threaded HTTP (53 MiB/s) using fsspec byte-range downloads from GCS
  • Extend TpuBootstrapConfig with weight_loader field ("default" | "fsspec_streamer")
  • Make load_hf_weights() iterator path support both jax.Array (fsspec) and torch.Tensor (RunAI) with lazy torchax init
  • Add fsspec and gcsfs dependencies

Test plan

  • Unit tests for safetensors header parsing, chunk building, shard discovery
  • End-to-end local file iterator tests (single shard, multi-shard, BF16, CPU pinning)
  • Config parsing and validation tests for weight_loader field
  • Dispatch mock test for fsspec_streamer branch in _build_abstract_model_and_load_weights
  • Type dispatch test: np.ndarray rejected with TypeError
  • Smoke test on TPU with --engine-kwargs-json '{"additional_config": {"tpu_bootstrap": {"model_bootstrap": "abstract_load", "weight_loader": "fsspec_streamer"}}}'

🤖 Generated with Claude Code

Add an fsspec-based streaming weight loader that bypasses the slow RunAI
single-threaded HTTP path (53 MiB/s) by using fsspec byte-range downloads.
Weights are streamed shard-by-shard, chunk-by-chunk with bounded RAM (~2 GiB
peak per chunk) and all arrays materialized on CPU to avoid TPU HBM pressure.

Changes:
- New streaming_weights.py: fsspec iterator ported from Levanter's
  fsspec_safetensor.py (header parsing, chunk building, shard discovery)
- TpuBootstrapConfig: add weight_loader field ("default" | "fsspec_streamer")
- model_loader: fsspec_streamer branch in _build_abstract_model_and_load_weights
- weight_utils: lazy torchax init + jax.Array/torch.Tensor type dispatch
- requirements.txt: add fsspec, gcsfs dependencies
- Tests for config parsing, dispatch, and streaming iterator

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a Github issue, please include a link, e.g.,:
FIXES: #123456

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

…patch

The test_load_hf_weights_rejects_ndarray test used a MagicMock as model,
but load_hf_weights() calls nnx.state(model) first which fails on a
non-Module. Patch nnx.state and nnx.get_named_sharding so the test
actually exercises the iterator type-dispatch branch.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@ahmeda14960 ahmeda14960 merged commit ecc1308 into marin Mar 22, 2026
0 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant