Skip to content

Releases: erfanzar/Spectrax

SpectraX v0.1.2: safer MPMD stage JITs + faster TensorStore restores

07 May 20:46

Choose a tag to compare

Release Notes

SpectraX 0.1.2 is a stability and checkpointing release on top of the 0.1.x True MPMD runtime line.

Highlights

  • Added private stage-JIT cache handling for MPMD schedule executables.
  • Improved TensorStore checkpoint loading for large model restores.
  • Added index-only checkpoint restore support when structure sidecars are missing.
  • Added caller-provided checkpoint key aliases for framework/layout migrations.
  • Expanded serialization tests and GCS-auth test support.
  • Package lint/type checks are clean.

MPMD Runtime

  • Stage forward, backward, terminal, and scheduler body JITs now compile under a private cache scope.
  • Avoids unsafe global persistent cache reuse for jit_stage_body-* executables across different MPMD plans.
  • Keeps normal in-process executable reuse after first compile.
  • Preserves True MPMD behavior: forward, backward, and schedule execution are still split and dispatched per stage mesh.

Checkpointing

  • Added can_skip_structure restore path for TensorStore checkpoints that have tensorstore_index.json but no {prefix}_structure.json.
  • Added TensorStore load controls:
    • concurrent_gb
    • tensorstore_io_concurrency
    • tensorstore_copy_concurrency
    • tensorstore_cache_gb
    • tensorstore_assume_metadata
    • tensorstore_metadata_workers
    • show_progress
    • progress_every
  • Added progress reporting for large weight loads.
  • Added template-aware key aliasing so downstream frameworks can map legacy checkpoint names without baking those aliases into SpectraX.
  • Improved metadata/index handling for faster hosted checkpoint restores.

Validation

uv run ruff check spectrax
uv run basedpyright spectrax

Both pass on the package code.

Upgrade

pip install -U spectrax==0.1.2

Notes

This release intentionally does not make MPMD stage executables reusable through the global persistent disk cache. Those executables are too dependent on stage mesh, rebased shardings, schedule shape, and split jaxpr state. The safer behavior is private first-compile handling plus normal in-process reuse.

Full Changelog: v0.1.0...v0.1.2

SpectraX v0.1.0

05 May 13:33

Choose a tag to compare

SpectraX v0.1.0

This release hardens the True MPMD runtime introduced before v0.0.7: stage-local mesh rebasing, persistent-cache safety, cleaner boundary sharding, stricter typing, new CI, and expanded docs/examples.

Runtime And MPMD Fixes

  • Fixed per-rank stage JAXPR mesh rebasing for nested shard_map, pjit, concrete Mesh, and NamedSharding metadata.
  • Added cache-visible stage/rank fingerprints for MPMD stage JITs to avoid reusing executables across incompatible stage meshes.
  • Scoped persistent-cache guarding for scheduled MPMD stage executables, fixing stale cache reuse in MoE / EP paths.
  • Improved stage-local mesh handling: the pipeline axis is treated as program selection and dropped from intra-stage SPMD layouts.
  • Fixed boundary sharding/layout behavior that could cause unintended large KV/cache copies in pipeline inference.
  • Fixed scalar / zero-axis mesh rebasing edge cases during scheduled training plan construction.
  • Added clearer named scopes around MPMD stage, backward, transport, and GPipe execution paths for profiling.

Optimizers

  • Fixed MultiOptimizer slicing for nested dotted paths, including LoRA adapters under child modules.
  • Added stronger optimizer typing and optax protocol handling.
  • Added coverage for nested LoRA optimizer paths.

Docs And Examples

  • Expanded README showcase for True MPMD, stage-local meshes, and sxstage_region usage.
  • Added runnable sxstage_region multimodal example.
  • Added stage-local MPMD mesh layout inspection example.
  • Added MultiOptimizer LoRA example.
  • Added optimizer guide for Optimizer, MultiOptimizer, selector-scoped optimizer state, and mutable collections.
  • Expanded pipeline docs with stage-region and stage-local mesh rules.

Full Changelog: v0.0.7...v0.1.0

schedule-faithful

03 May 22:45

Choose a tag to compare

This release makes SpectraX’s MPMD execution path fully schedule-faithful and adds a new forward-only MPMD pipeline executor for inference wavefronts.

Highlights

  • Added MpmdPipelineExecutor and MpmdPipelineDispatchStats for forward-only MPMD inference dispatch.
  • Routed sxcall and spx.run(..., mesh=<MPMD>) through true scheduled sxjit paths.
  • Removed legacy fake-MPMD / hybrid fallback behavior so MPMD-tagged execution now uses the real MPMD scheduler.
  • Added runtime-static placement caching, cached invar assembly plans, and same-sharding device_put skip paths for lower-latency repeated dispatch.
  • Added a single-microbatch non-worker fast path for cached MPMD dispatch.
  • Improved stage-boundary handling, KV/cache sharding behavior, per-stage timing, and JAX named scopes.
  • Expanded regression and fuzz coverage for true MPMD behavior.

Breaking Changes

  • Removed the old pipeline_call helper.
  • Removed spectrax.runtime.spmd.hybrid_linear_run; callers should use spectrax.runtime.mpmd.sxjit, sxcall, or spx.run with an MPMD mesh.
  • SPMD APIs now reject MPMD-tagged meshes at the public API boundary.
  • Legacy fuse_1f1b=True / fuse_zb=True knobs are rejected on the true scheduled MPMD path.

Tests

  • Added tests/runtime/test_mpmd_pipeline_executor.py.
  • Added tests/pipeline/test_mpmd_abi_fuzz.py.
  • Updated existing pipeline, marker, hybrid, and MPMD tests for the new true-MPMD behavior.

Documentation

  • Updated README and pipeline docs around true MPMD execution.
  • Updated runtime API docs and removed obsolete hybrid SPMD documentation.

Commits

  • d6436e3 feat(mpmd): add MpmdPipelineExecutor for forward-only inference wavefronts
  • 6bb7237 Make Spectrax MPMD paths fully true scheduled MPMD
  • 9187f5d feat: optimize cached MPMD dispatch assembly
  • 1ea47ab lint: formating lines.

Full changelog: v0.0.6...v0.0.7

SpectraX MPMD Stage Regions

01 May 18:18

Choose a tag to compare

This release adds true scheduled MPMD support for sxstage_region pipelines.

Highlights

  • Enables scheduled sxjit for serial stage-region layouts such as vision encoder followed by text decoder.
  • Adds forward and backward execution across logical region stages without collapsing repeated physical ranks.
  • Supports region forward/backward paths across GPipe, lazy GPipe, KimiK2, and DualPipeV schedulers.
  • Fixes logical-stage keying so paths like V0 -> V1 -> ... then T0 -> T1 -> ... map correctly onto the same TPU ranks.
  • Improves marker handling for region-local sxstage_iter boundaries and metadata operands.
  • Adds TPU-tested regression coverage for stage-region value-and-grad correctness.
import jax.numpy as jnp
import spectrax as sx
from spectrax.runtime.mpmd import sxjit, sxvalue_and_grad

vision_region = sx.sxstage_region("vision")
text_region = sx.sxstage_region("text")

@sxjit(mesh=mesh, schedule=sx.GPipe(microbatches=4), batch_argnums=(0,))
def loss_fn(images, token_ids, params):
    def vision_path(x):
        # logical stages: V0 -> V1 -> V2 -> V3
        x = vision_block_0(params.vision[0], x)
        x = sx.sxstage_iter(x, stage=0)
        x = vision_block_1(params.vision[1], x)
        x = sx.sxstage_iter(x, stage=1)
        x = vision_block_2(params.vision[2], x)
        x = sx.sxstage_iter(x, stage=2)
        return vision_block_3(params.vision[3], x)

    def text_path(hidden):
        # logical stages: T0 -> T1 -> T2 -> T3
        hidden = text_block_0(params.text[0], hidden, token_ids)
        hidden = sx.sxstage_iter(hidden, stage=0)
        hidden = text_block_1(params.text[1], hidden)
        hidden = sx.sxstage_iter(hidden, stage=1)
        hidden = text_block_2(params.text[2], hidden)
        hidden = sx.sxstage_iter(hidden, stage=2)
        return text_block_3(params.text[3], hidden)

    vision_features = vision_region(vision_path)(images)
    logits = text_region(text_path)(vision_features)
    return cross_entropy(logits, token_ids).mean()

loss, (grads,) = sxvalue_and_grad(loss_fn, argnums=(2,))(images, token_ids, params)

Full Changelog: https://github.com/erfanzar/Spectrax/commits/v0.0.6