copy back skyrl code to top level and delete skyrl-train and skyrl-tx#1137
Conversation
b15bbfb
into
NovaSky-AI:recover_git_history3
There was a problem hiding this comment.
Code Review
This pull request is a large refactoring that moves code from skyrl-train and skyrl-tx subdirectories into the top-level skyrl directory, and deletes the old directories. The changes primarily consist of moving files and updating import paths, which have been done consistently. There are also several improvements, such as bug fixes in shell scripts, better handling of gradient checkpointing in the JAX backend, and lazy initialization of inference engines. I've found one minor issue in the .gitignore file with duplicate entries that should be cleaned up.
| uv.lock | ||
|
|
||
| # PyInstaller | ||
| *.manifest | ||
| *.spec | ||
|
|
||
| # Installer logs | ||
| pip-log.txt | ||
| pip-delete-this-directory.txt | ||
|
|
||
| # Unit test / coverage reports | ||
| htmlcov/ | ||
| .tox/ | ||
| .nox/ | ||
| .coverage | ||
| .coverage.* | ||
| .cache | ||
| nosetests.xml | ||
| coverage.xml | ||
| *.cover | ||
| *.py,cover | ||
| .hypothesis/ | ||
| .pytest_cache/ | ||
|
|
||
| # Jupyter Notebook | ||
| .ipynb_checkpoints | ||
|
|
||
| # Environments | ||
| .env | ||
| .venv | ||
| env/ | ||
| venv/ | ||
| ENV/ | ||
| env.bak/ | ||
| venv.bak/ | ||
|
|
||
| # MkDocs build output | ||
| site/ | ||
|
|
||
| # IDEs and editors | ||
| .idea/ | ||
| .vscode/ | ||
|
|
||
| # OS generated files | ||
| .DS_Store | ||
| Thumbs.db | ||
|
|
||
| # Hydra outputs | ||
| outputs/ | ||
|
|
||
| # Local artifacts | ||
| tinker.db | ||
| uv.lock | ||
|
|
||
| # Alembic - don't track pycache | ||
| tx/tinker/alembic/__pycache__/ | ||
|
|
||
| # SQLite databases (tracked in git by default, but ignore if created locally) | ||
| *.db |
There was a problem hiding this comment.
| input_ids, | ||
| attention_mask=attention_mask, | ||
| adapter_indices=adapter_indices, | ||
| is_training=True, | ||
| ) |
There was a problem hiding this comment.
🔴 Removed is_training=True from jax backend causes unnecessary KV cache allocation during Qwen3 training
During the migration, is_training=True was removed from the _model_forward call in the jax backend. For Qwen3 models (which still use StackedDecoderLayers with the is_training flag), this causes is_training to default to False, making the scan body compute and accumulate KV cache tensors for all layers during training.
Root Cause and Impact
The old code at skyrl-tx/tx/tinker/backends/jax.py:278 passed is_training=True:
output = model(
input_ids,
attention_mask=attention_mask,
adapter_indices=adapter_indices,
is_training=True,
)The new code at skyrl/backends/jax.py:274-278 omits this flag:
output = model(
input_ids,
attention_mask=attention_mask,
adapter_indices=adapter_indices,
)Qwen3 models still accept and use is_training (see skyrl/tx/models/qwen3.py:349,419), which propagates to StackedDecoderLayers.__call__ (skyrl/tx/layers/stacked.py:269-270). When is_training=False (the default), the scan body does NOT zero out k and v, so the scan accumulates full KV cache tensors for every layer. At skyrl/tx/layers/stacked.py:279-285, a full KVCache is then constructed and returned, wasting GPU memory proportional to num_layers × batch_size × seq_len × num_heads × head_dim × 2.
Additionally, since config.gradient_checkpointing still propagates through the Qwen3 model to StackedDecoderLayers (line 274-275 of stacked.py), having the outer jax.checkpoint wrapping _model_forward at skyrl/backends/jax.py:284 creates redundant double gradient checkpointing.
Impact: Significant unnecessary GPU memory consumption during training for Qwen3 models, potentially causing OOM on memory-constrained setups. The training results are still correct since the returned KV cache is unused by the caller.
(Refers to lines 274-278)
Was this helpful? React with 👍 or 👎 to provide feedback.
Uh oh!
There was an error while loading. Please reload this page.