Add MAGIC CLI with runtime DTensor double-backward patch#174
Add MAGIC CLI with runtime DTensor double-backward patch#174luciaquirke merged 17 commits intomagicfrom
Conversation
|
|
||
|
|
||
| @dataclass | ||
| class DoubleBackwardConfig: |
There was a problem hiding this comment.
This is the original RunConfig but it uses DataConfig for the query rather than the first item of the dataset, save_dir renamed to run_path, DataConfig for training data
| """Random seed for subset permutation.""" | ||
|
|
||
|
|
||
| def compute_query_gradients( |
There was a problem hiding this comment.
We could technically use build here with all the Trackstar bells and whistles turned off but this seems more readable. Technically not DRY. Currently lacks TRL-style tokenization/masking support
There was a problem hiding this comment.
Duplication is fine for now (and maybe forever)
77497a6 to
1be9172
Compare
…ight support - Add bergson/magic_patch.py: runtime monkey-patch for twice-differentiable DTensor redistribution (pytorch/pytorch#160509), replacing the old magic_wmdp_setup.sh that modified torch source files on disk - Add per_token mode to DataStream for [n_examples, max_length] weight tensors - Support 2D [B, T] per-token weights in weighted_causal_lm_ce - Fix backward weight_grads accumulation when autograd returns None
1be9172 to
97fe18f
Compare
for more information, see https://pre-commit.ci
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The weight gradient from autograd.grad should always be a tensor since data.weights participates in the computation graph via weighted_causal_lm_ce. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Multiple concurrent DCP async_save calls each create their own Gloo process group. With consecutive saves at steps 20-24 (last_start logic), up to 5 saves were in-flight simultaneously. Background threads from these saves may call distributed operations that conflict, causing all ranks to deadlock in fut.result() until the NCCL watchdog times out. Limit to one concurrent save at a time: wait for the previous save to complete before starting the next one. Each save still overlaps with at least one training step, so async I/O benefit is preserved. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Raises a clear ValueError at init time when the dataset doesn't have enough examples for the requested number of batches, instead of crashing with an IndexError mid-training. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
PyTorch's Future.result() waits for done callbacks to complete before returning. The destroy_process_group callback was invoked from DCP's background thread after each save, but destroy_process_group may do a barrier on the Gloo group. Since ranks complete their I/O at different times, the fast rank would deadlock waiting for the slow rank to also call destroy_process_group, while the slow rank was still in fut.result(). DCP holds its own reference to the process group, keeping it alive for the duration of the background I/O. GC will clean it up afterwards. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
for more information, see https://pre-commit.ci
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Strip per_token parameter from DataStream and 2D weight path from weighted_causal_lm_ce to keep the merge scope minimal. The per-token code is preserved on the magic-per-token branch. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
| """Runtime monkey-patch for twice-differentiable DTensor redistribution. | ||
|
|
||
| Implements pytorch/pytorch#160509 at runtime, avoiding the need to modify | ||
| torch source files on disk. Call `apply_dtensor_patch()` before any DTensor |
There was a problem hiding this comment.
Ah, good idea using monkey patching rather than actually changing the files
| """Random seed for subset permutation.""" | ||
|
|
||
|
|
||
| def compute_query_gradients( |
There was a problem hiding this comment.
Duplication is fine for now (and maybe forever)
| data: DataConfig = field(default_factory=DataConfig) | ||
| """Training dataset.""" | ||
|
|
||
| query: DataConfig = field(default_factory=lambda: DataConfig()) |
There was a problem hiding this comment.
nitpick: you should write field(default_factory=DataConfig) like you did above lol
|
@norabelrose could you please update the query handling so it works for numbers of queries that aren't divisible by the world size, without dropping data? This should enable the double backward example script to replicate. |
|
I think once we can replicate the good spearman correlations with the CLI it's basically mergeable |
|
Actually I'm going to merge now so we can get the wandb logging up, can we please add the query logic in a follow up work? |
My Changes Summary
Things which could go in this PR or a follow-up
Claude Changes Summary
bergson/magic_patch.py): Monkey-patchesRedistribute.backwardand_ToTorchTensor.backwardat runtime to make FSDP redistribution twice-differentiable (Add support for twice-differentiable DTensor redistribution pytorch/pytorch#160509). Replaces the oldmagic_wmdp_setup.shthat modified torch source files on disk. Idempotent — callapply_dtensor_patch()before any DTensor double-backward operations.Test plan