Skip to content

copy back skyrl code to top level and delete skyrl-train and skyrl-tx#1137

Merged
erictang000 merged 1 commit into
NovaSky-AI:recover_git_history3from
erictang000:renaming3
Feb 16, 2026
Merged

copy back skyrl code to top level and delete skyrl-train and skyrl-tx#1137
erictang000 merged 1 commit into
NovaSky-AI:recover_git_history3from
erictang000:renaming3

Conversation

@erictang000
Copy link
Copy Markdown
Collaborator

@erictang000 erictang000 commented Feb 16, 2026

@erictang000 erictang000 merged commit b15bbfb into NovaSky-AI:recover_git_history3 Feb 16, 2026
2 of 6 checks passed
@erictang000 erictang000 deleted the renaming3 branch February 16, 2026 00:48
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a large refactoring that moves code from skyrl-train and skyrl-tx subdirectories into the top-level skyrl directory, and deletes the old directories. The changes primarily consist of moving files and updating import paths, which have been done consistently. There are also several improvements, such as bug fixes in shell scripts, better handling of gradient checkpointing in the JAX backend, and lazy initialization of inference engines. I've found one minor issue in the .gitignore file with duplicate entries that should be cleaned up.

Comment thread .gitignore
Comment on lines +79 to +137
uv.lock

# PyInstaller
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Jupyter Notebook
.ipynb_checkpoints

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# MkDocs build output
site/

# IDEs and editors
.idea/
.vscode/

# OS generated files
.DS_Store
Thumbs.db

# Hydra outputs
outputs/

# Local artifacts
tinker.db
uv.lock

# Alembic - don't track pycache
tx/tinker/alembic/__pycache__/

# SQLite databases (tracked in git by default, but ignore if created locally)
*.db
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There are a few duplicate entries in the newly added lines. uv.lock is added on lines 79 and 131, and *.db is added on line 137. These entries seem to be duplicates of existing entries in the file. Please remove the redundant lines to keep the .gitignore file clean and maintainable.

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 potential issue.

View 7 additional findings in Devin Review.

Open in Devin Review

Comment thread skyrl/backends/jax.py
Comment on lines 275 to 278
input_ids,
attention_mask=attention_mask,
adapter_indices=adapter_indices,
is_training=True,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Removed is_training=True from jax backend causes unnecessary KV cache allocation during Qwen3 training

During the migration, is_training=True was removed from the _model_forward call in the jax backend. For Qwen3 models (which still use StackedDecoderLayers with the is_training flag), this causes is_training to default to False, making the scan body compute and accumulate KV cache tensors for all layers during training.

Root Cause and Impact

The old code at skyrl-tx/tx/tinker/backends/jax.py:278 passed is_training=True:

output = model(
    input_ids,
    attention_mask=attention_mask,
    adapter_indices=adapter_indices,
    is_training=True,
)

The new code at skyrl/backends/jax.py:274-278 omits this flag:

output = model(
    input_ids,
    attention_mask=attention_mask,
    adapter_indices=adapter_indices,
)

Qwen3 models still accept and use is_training (see skyrl/tx/models/qwen3.py:349,419), which propagates to StackedDecoderLayers.__call__ (skyrl/tx/layers/stacked.py:269-270). When is_training=False (the default), the scan body does NOT zero out k and v, so the scan accumulates full KV cache tensors for every layer. At skyrl/tx/layers/stacked.py:279-285, a full KVCache is then constructed and returned, wasting GPU memory proportional to num_layers × batch_size × seq_len × num_heads × head_dim × 2.

Additionally, since config.gradient_checkpointing still propagates through the Qwen3 model to StackedDecoderLayers (line 274-275 of stacked.py), having the outer jax.checkpoint wrapping _model_forward at skyrl/backends/jax.py:284 creates redundant double gradient checkpointing.

Impact: Significant unnecessary GPU memory consumption during training for Qwen3 models, potentially causing OOM on memory-constrained setups. The training results are still correct since the returned KV cache is unused by the caller.

(Refers to lines 274-278)

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant