diff --git a/.buildkite/generate_pipeline.py b/.buildkite/generate_pipeline.py
index 066d4427b94..5eadd4c9974 100644
--- a/.buildkite/generate_pipeline.py
+++ b/.buildkite/generate_pipeline.py
@@ -146,6 +146,7 @@ def _parse_args(args: Optional[str] = None):
parser.add_argument('--grpc', action="store_true")
parser.add_argument('--env-file')
parser.add_argument('--plugin-yaml')
+ parser.add_argument('--submodule-base-branch')
parser.add_argument('--dependency', nargs='?', const='', default='all')
parsed_args, _ = parser.parse_known_args(args_list)
@@ -190,6 +191,11 @@ def _parse_args(args: Optional[str] = None):
extra_args.append('--grpc')
if parsed_args.env_file:
extra_args.append(f'--env-file {parsed_args.env_file}')
+ if parsed_args.plugin_yaml:
+ extra_args.append(f'--plugin-yaml {parsed_args.plugin_yaml}')
+ if parsed_args.submodule_base_branch:
+ extra_args.append(
+ f'--submodule-base-branch {parsed_args.submodule_base_branch}')
if parsed_args.dependency != 'all':
space = ' ' if parsed_args.dependency else ''
extra_args.append(f'--dependency{space}{parsed_args.dependency}')
@@ -198,8 +204,9 @@ def _parse_args(args: Optional[str] = None):
def _extract_marked_tests(
- file_path: str, args: str
-) -> Dict[str, Tuple[List[str], List[str], List[Optional[str]]]]:
+ file_path: str, args: str
+) -> Dict[str, Tuple[List[str], List[str], List[Optional[str]], List[str],
+ List[bool]]]:
"""Extract test functions and filter clouds using pytest.mark
from a Python test file.
@@ -212,6 +219,10 @@ def _extract_marked_tests(
and run for hours. This makes it hard to visualize the test results and
rerun failures. Additionally, the parallelism would be controlled by pytest
instead of the buildkite job queue.
+
+ Returns:
+ Dict mapping function_name to tuple of:
+ (clouds, queues, params, extra_args, no_auto_retry_flags)
"""
# Args are already in the format pytest expects (cloud names like --lambda)
cmd = f'pytest {file_path} --collect-only {args}'
@@ -259,6 +270,7 @@ def _extract_marked_tests(
run_on_cloud_kube_backend = ('resource_heavy' in marks and
'kubernetes' in default_clouds_to_run)
benchmark_test = 'benchmark' in marks
+ no_auto_retry = 'no_auto_retry' in marks
for mark in marks:
if mark not in PYTEST_TO_CLOUD_KEYWORD:
@@ -302,20 +314,19 @@ def _extract_marked_tests(
for cloud in final_clouds_to_include
], param_list, [
extra_args for _ in range(len(final_clouds_to_include))
- ])
+ ], [no_auto_retry for _ in range(len(final_clouds_to_include))])
return function_cloud_map
-def _generate_pipeline(test_file: str,
- args: str,
- auto_retry: bool = False) -> Dict[str, Any]:
+def _generate_pipeline(test_file: str, args: str) -> Dict[str, Any]:
"""Generate a Buildkite pipeline from test files."""
steps = []
generated_steps_set = set()
function_cloud_map = _extract_marked_tests(test_file, args)
for test_function, clouds_queues_param in function_cloud_map.items():
- for cloud, queue, param, extra_args in zip(*clouds_queues_param):
+ for cloud, queue, param, extra_args, no_auto_retry in zip(
+ *clouds_queues_param):
label = f'{test_function} on {cloud}'
command = f'pytest {test_file}::{test_function} --{cloud}'
if param:
@@ -328,6 +339,7 @@ def _generate_pipeline(test_file: str,
continue
if 'PYTHON_VERSION' in os.environ:
command = f'PYTHONPATH="$PWD:$PYTHONPATH" {command}'
+
step = {
'label': label,
'command': command,
@@ -338,7 +350,15 @@ def _generate_pipeline(test_file: str,
'queue': queue
}
}
- if auto_retry:
+ if no_auto_retry:
+ # Disable automatic retries but allow manual retries.
+ step['retry'] = {
+ 'automatic': False,
+ 'manual': {
+ 'allowed': True
+ }
+ }
+ else:
step['retry'] = {
# Automatically retry 2 times on any failure by default.
'automatic': True
@@ -391,7 +411,7 @@ def _convert_release(test_files: List[str], args: str, trigger_command: str):
output_file_pipelines = []
for test_file in test_files:
print(f'Converting {test_file} to {yaml_file_path}')
- pipeline = _generate_pipeline(test_file, args, auto_retry=True)
+ pipeline = _generate_pipeline(test_file, args)
output_file_pipelines.append(pipeline)
print(f'Converted {test_file} to {yaml_file_path}\n\n')
# Enable all clouds by default for release pipeline.
@@ -462,11 +482,10 @@ def _convert_quick_tests_core(test_files: List[str], args: str,
branch != 'master'):
continue
pipeline = _generate_pipeline(test_file,
- args + f' --base-branch {branch}',
- auto_retry=True)
+ args + f' --base-branch {branch}')
output_file_pipelines.append(pipeline)
else:
- pipeline = _generate_pipeline(test_file, args, auto_retry=True)
+ pipeline = _generate_pipeline(test_file, args)
output_file_pipelines.append(pipeline)
print(f'Converted {test_file} to {yaml_file_path}\n\n')
_dump_pipeline_to_file(yaml_file_path,
diff --git a/.buildkite/test_buildkite_pipeline_generation.py b/.buildkite/test_buildkite_pipeline_generation.py
index 7c68064d18a..b5ce058c1f9 100644
--- a/.buildkite/test_buildkite_pipeline_generation.py
+++ b/.buildkite/test_buildkite_pipeline_generation.py
@@ -128,6 +128,61 @@ def _extract_test_names_from_pipeline(pipeline_path):
return test_names
+def _extract_steps_from_pipeline(pipeline_path):
+ """Extract all steps from a pipeline YAML file."""
+ with open(pipeline_path, 'r') as f:
+ pipeline = yaml.safe_load(f)
+
+ all_steps = []
+ for group in pipeline['steps']:
+ if 'steps' in group:
+ all_steps.extend(group['steps'])
+ else:
+ all_steps.append(group)
+ return all_steps
+
+
+def test_no_auto_retry_marker():
+ """Test that no_auto_retry marker works correctly.
+
+ This test uses the actual test_kubernetes_container_status_unknown_status_refresh
+ test which has the marker applied.
+ """
+ # Generate pipeline for the specific test
+ env = dict(os.environ)
+ env['PYTHONPATH'] = f"{pathlib.Path.cwd()}/tests:{env.get('PYTHONPATH', '')}"
+
+ subprocess.run([
+ 'python', '.buildkite/generate_pipeline.py', '--args', '--kubernetes',
+ '--file_pattern', 'test_cluster_job'
+ ],
+ env=env,
+ check=True)
+
+ # Check the generated pipeline
+ pipeline_path = pathlib.Path('.buildkite/pipeline_smoke_tests_release.yaml')
+ steps = _extract_steps_from_pipeline(pipeline_path)
+
+ # Find steps for test_kubernetes_container_status_unknown_status_refresh
+ target_steps = [
+ s for s in steps
+ if 'test_kubernetes_container_status_unknown_status_refresh' in s.get(
+ 'label', '')
+ ]
+
+ # Should have exactly 1 step
+ assert len(target_steps) == 1, \
+ f"Expected 1 step, got {len(target_steps)}"
+
+ # Verify no_auto_retry is applied
+ step = target_steps[0]
+ retry = step.get('retry', {})
+ assert retry.get('automatic') is False, \
+ f"no_auto_retry step should have automatic=False: {retry}"
+ assert retry.get('manual', {}).get('allowed') is True, \
+ f"no_auto_retry step should allow manual retry: {retry}"
+
+
@pytest.mark.parametrize('args', [
'',
'--aws',
diff --git a/.cursor/worktrees.json b/.cursor/worktrees.json
new file mode 100644
index 00000000000..aaf47d23c99
--- /dev/null
+++ b/.cursor/worktrees.json
@@ -0,0 +1,8 @@
+{
+ "setup-worktree": [
+ "uv venv --seed --python 3.11",
+ "uv pip install -e \".[all]\" --prerelease=allow",
+ "uv pip install -r requirements-dev.txt",
+ "npm --prefix sky/dashboard install && npm --prefix sky/dashboard run build"
+ ]
+}
diff --git a/.github/workflows/compile-protos-check.yml b/.github/workflows/compile-protos-check.yml
index 5601f995d30..2b895d1d93f 100644
--- a/.github/workflows/compile-protos-check.yml
+++ b/.github/workflows/compile-protos-check.yml
@@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ["3.8"]
+ python-version: ["3.9"]
steps:
- uses: actions/checkout@v3
- name: Install the latest version of uv
diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml
index 3c85ed81252..18b67937150 100644
--- a/.github/workflows/format.yml
+++ b/.github/workflows/format.yml
@@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ["3.8"]
+ python-version: ["3.9"]
steps:
- uses: actions/checkout@v3
- name: Install the latest version of uv
diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml
index 6df98401fcb..9e198c0890c 100644
--- a/.github/workflows/mypy.yml
+++ b/.github/workflows/mypy.yml
@@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ["3.8"]
+ python-version: ["3.9"]
steps:
- uses: actions/checkout@v3
- name: Install the latest version of uv
diff --git a/.github/workflows/nightly-build.yml b/.github/workflows/nightly-build.yml
index 3008e16ceb4..c979cb371e0 100644
--- a/.github/workflows/nightly-build.yml
+++ b/.github/workflows/nightly-build.yml
@@ -230,6 +230,20 @@ jobs:
secrets:
BUILDKITE_TOKEN: ${{ secrets.BUILDKITE_TOKEN }}
+ smoke-tests-kubernetes-jobs-consolidation:
+ needs: [gate-tests, nightly-build-pypi]
+ if: ${{ needs.gate-tests.outputs.run_tests == 'true' }}
+ uses: ./.github/workflows/buildkite-trigger-wait.yml
+ with:
+ commit: ${{ github.sha }}
+ branch: ${{ github.ref_name }}
+ message: "nightly-build-pypi --kubernetes --jobs-consolidation --no-resource-heavy"
+ pipeline: "smoke-tests"
+ build_env_vars: '{"ARGS": "--kubernetes --jobs-consolidation --no-resource-heavy"}'
+ timeout_minutes: 60
+ secrets:
+ BUILDKITE_TOKEN: ${{ secrets.BUILDKITE_TOKEN }}
+
smoke-tests-shared-gke-api-server:
needs: [gate-tests, nightly-build-pypi]
if: ${{ needs.gate-tests.outputs.run_tests == 'true' }}
@@ -273,12 +287,10 @@ jobs:
# BUILDKITE_TOKEN: ${{ secrets.BUILDKITE_TOKEN }}
publish-and-validate-both:
- # needs: [gate-tests, nightly-build-pypi, smoke-tests-aws, smoke-tests-kubernetes-resource-heavy, smoke-tests-kubernetes-no-resource-heavy, smoke-tests-kubernetes-no-resource-heavy-limit-deps, smoke-tests-remote-server-kubernetes, smoke-tests-shared-gke-api-server, smoke-tests-lambda-job-queue, smoke-tests-runpod-minimal, backward-compat-test-nightly, backward-compat-test-stable]
- needs: [gate-tests, nightly-build-pypi, smoke-tests-aws, smoke-tests-kubernetes-resource-heavy, smoke-tests-kubernetes-no-resource-heavy, smoke-tests-kubernetes-no-resource-heavy-limit-deps, smoke-tests-remote-server-kubernetes, smoke-tests-shared-gke-api-server, smoke-tests-lambda-job-queue, backward-compat-test-nightly, backward-compat-test-stable]
+ needs: [gate-tests, nightly-build-pypi, smoke-tests-aws, smoke-tests-kubernetes-resource-heavy, smoke-tests-kubernetes-no-resource-heavy, smoke-tests-kubernetes-no-resource-heavy-limit-deps, smoke-tests-remote-server-kubernetes, smoke-tests-kubernetes-jobs-consolidation, smoke-tests-shared-gke-api-server, smoke-tests-lambda-job-queue, backward-compat-test-nightly, backward-compat-test-stable]
# Allow publish/validate for manual dispatch or the original nightly cron; skip for the 5PM PT preflight
# Use always() so this job evaluates even if some test jobs were skipped when skip_buildkite is selected
- # if: ${{ always() && (github.event_name == 'workflow_dispatch' || (github.event_name == 'schedule' && github.event.schedule == '35 8 * * *')) && needs.nightly-build-pypi.result == 'success' && (needs.gate-tests.outputs.publish_without_tests == 'true' || (needs.gate-tests.outputs.run_tests == 'true' && needs.smoke-tests-aws.result == 'success' && needs.smoke-tests-kubernetes-resource-heavy.result == 'success' && needs.smoke-tests-kubernetes-no-resource-heavy.result == 'success' && needs.smoke-tests-kubernetes-no-resource-heavy-limit-deps.result == 'success' && needs.smoke-tests-remote-server-kubernetes.result == 'success' && needs.smoke-tests-shared-gke-api-server.result == 'success' && needs.smoke-tests-lambda-job-queue.result == 'success' && needs.smoke-tests-runpod-minimal.result == 'success' && needs.backward-compat-test-nightly.result == 'success' && needs.backward-compat-test-stable.result == 'success')) }}
- if: ${{ always() && (github.event_name == 'workflow_dispatch' || (github.event_name == 'schedule' && github.event.schedule == '35 8 * * *')) && needs.nightly-build-pypi.result == 'success' && (needs.gate-tests.outputs.publish_without_tests == 'true' || (needs.gate-tests.outputs.run_tests == 'true' && needs.smoke-tests-aws.result == 'success' && needs.smoke-tests-kubernetes-resource-heavy.result == 'success' && needs.smoke-tests-kubernetes-no-resource-heavy.result == 'success' && needs.smoke-tests-kubernetes-no-resource-heavy-limit-deps.result == 'success' && needs.smoke-tests-remote-server-kubernetes.result == 'success' && needs.smoke-tests-shared-gke-api-server.result == 'success' && needs.smoke-tests-lambda-job-queue.result == 'success' && needs.backward-compat-test-nightly.result == 'success' && needs.backward-compat-test-stable.result == 'success')) }}
+ if: ${{ always() && (github.event_name == 'workflow_dispatch' || (github.event_name == 'schedule' && github.event.schedule == '35 8 * * *')) && needs.nightly-build-pypi.result == 'success' && (needs.gate-tests.outputs.publish_without_tests == 'true' || (needs.gate-tests.outputs.run_tests == 'true' && needs.smoke-tests-aws.result == 'success' && needs.smoke-tests-kubernetes-resource-heavy.result == 'success' && needs.smoke-tests-kubernetes-no-resource-heavy.result == 'success' && needs.smoke-tests-kubernetes-no-resource-heavy-limit-deps.result == 'success' && needs.smoke-tests-remote-server-kubernetes.result == 'success' && needs.smoke-tests-kubernetes-jobs-consolidation.result == 'success' && needs.smoke-tests-shared-gke-api-server.result == 'success' && needs.smoke-tests-lambda-job-queue.result == 'success' && needs.backward-compat-test-nightly.result == 'success' && needs.backward-compat-test-stable.result == 'success')) }}
uses: ./.github/workflows/publish-and-validate-both.yml
with:
package_name: skypilot-nightly
@@ -297,8 +309,7 @@ jobs:
summary:
runs-on: ubuntu-latest
- needs: [check-date, nightly-build-pypi, smoke-tests-aws, smoke-tests-kubernetes-resource-heavy, smoke-tests-kubernetes-no-resource-heavy, smoke-tests-kubernetes-no-resource-heavy-limit-deps, smoke-tests-remote-server-kubernetes, smoke-tests-shared-gke-api-server, smoke-tests-lambda-job-queue, backward-compat-test-nightly, backward-compat-test-stable]
- # needs: [check-date, nightly-build-pypi, smoke-tests-aws, smoke-tests-kubernetes-resource-heavy, smoke-tests-kubernetes-no-resource-heavy, smoke-tests-kubernetes-no-resource-heavy-limit-deps, smoke-tests-remote-server-kubernetes, smoke-tests-shared-gke-api-server, smoke-tests-lambda-job-queue, smoke-tests-runpod-minimal, backward-compat-test-nightly, backward-compat-test-stable]
+ needs: [check-date, nightly-build-pypi, smoke-tests-aws, smoke-tests-kubernetes-resource-heavy, smoke-tests-kubernetes-no-resource-heavy, smoke-tests-kubernetes-no-resource-heavy-limit-deps, smoke-tests-remote-server-kubernetes, smoke-tests-kubernetes-jobs-consolidation, smoke-tests-shared-gke-api-server, smoke-tests-lambda-job-queue, backward-compat-test-nightly, backward-compat-test-stable]
if: always()
steps:
- name: Summary
@@ -333,6 +344,11 @@ jobs:
- [Smoke Tests Remote Server Kubernetes](https://buildkite.com/skypilot-1/smoke-tests/builds/${{ needs.smoke-tests-remote-server-kubernetes.outputs.build_number }}) - $([ "${{ needs.smoke-tests-remote-server-kubernetes.outputs.build_status }}" == "success" ] && echo "✅ Success" || echo "❌ Failed")
EOF
fi
+ if [ "${{ needs.smoke-tests-kubernetes-jobs-consolidation.result }}" != "skipped" ] && [ -n "${{ needs.smoke-tests-kubernetes-jobs-consolidation.outputs.build_number }}" ]; then
+ cat <> "$GITHUB_STEP_SUMMARY"
+ - [Smoke Tests Kubernetes (Jobs Consolidation)](https://buildkite.com/skypilot-1/smoke-tests/builds/${{ needs.smoke-tests-kubernetes-jobs-consolidation.outputs.build_number }}) - $([ "${{ needs.smoke-tests-kubernetes-jobs-consolidation.outputs.build_status }}" == "success" ] && echo "✅ Success" || echo "❌ Failed")
+ EOF
+ fi
if [ "${{ needs.smoke-tests-shared-gke-api-server.result }}" != "skipped" ] && [ -n "${{ needs.smoke-tests-shared-gke-api-server.outputs.build_number }}" ]; then
cat <> "$GITHUB_STEP_SUMMARY"
- [Smoke Tests Shared GKE API Server](https://buildkite.com/skypilot-1/nightly-build-shared-gke-api-server/builds/${{ needs.smoke-tests-shared-gke-api-server.outputs.build_number }}) - $([ "${{ needs.smoke-tests-shared-gke-api-server.outputs.build_status }}" == "success" ] && echo "✅ Success" || echo "❌ Failed")
@@ -361,8 +377,7 @@ jobs:
notify-slack-failure:
runs-on: ubuntu-latest
- needs: [check-date, nightly-build-pypi, smoke-tests-aws, smoke-tests-kubernetes-resource-heavy, smoke-tests-kubernetes-no-resource-heavy, smoke-tests-kubernetes-no-resource-heavy-limit-deps, smoke-tests-remote-server-kubernetes, smoke-tests-shared-gke-api-server, smoke-tests-lambda-job-queue, backward-compat-test-nightly, backward-compat-test-stable, publish-and-validate-both, trigger-docker-and-helm-release]
- # needs: [check-date, nightly-build-pypi, smoke-tests-aws, smoke-tests-kubernetes-resource-heavy, smoke-tests-kubernetes-no-resource-heavy, smoke-tests-kubernetes-no-resource-heavy-limit-deps, smoke-tests-remote-server-kubernetes, smoke-tests-shared-gke-api-server, smoke-tests-lambda-job-queue, smoke-tests-runpod-minimal, backward-compat-test-nightly, backward-compat-test-stable, publish-and-validate-both, trigger-docker-and-helm-release]
+ needs: [check-date, nightly-build-pypi, smoke-tests-aws, smoke-tests-kubernetes-resource-heavy, smoke-tests-kubernetes-no-resource-heavy, smoke-tests-kubernetes-no-resource-heavy-limit-deps, smoke-tests-remote-server-kubernetes, smoke-tests-kubernetes-jobs-consolidation, smoke-tests-shared-gke-api-server, smoke-tests-lambda-job-queue, backward-compat-test-nightly, backward-compat-test-stable, publish-and-validate-both, trigger-docker-and-helm-release]
# Only run this job if any of the previous jobs failed
if: failure()
steps:
@@ -374,8 +389,7 @@ jobs:
COMMIT_URL="${{ github.server_url }}/${{ github.repository }}/commit/${COMMIT_SHA}"
SHORT_SHA=$(echo "$COMMIT_SHA" | cut -c1-7)
BUILDKITE_MSG=""
- if [[ "${{ needs.smoke-tests-aws.result }}" == "failure" || "${{ needs.smoke-tests-kubernetes-resource-heavy.result }}" == "failure" || "${{ needs.smoke-tests-kubernetes-no-resource-heavy.result }}" == "failure" || "${{ needs.smoke-tests-kubernetes-no-resource-heavy-limit-deps.result }}" == "failure" || "${{ needs.smoke-tests-remote-server-kubernetes.result }}" == "failure" || "${{ needs.smoke-tests-shared-gke-api-server.result }}" == "failure" || "${{ needs.smoke-tests-lambda-job-queue.result }}" == "failure" || "${{ needs.backward-compat-test-nightly.result }}" == "failure" || "${{ needs.backward-compat-test-stable.result }}" == "failure" ]]; then
- # if [[ "${{ needs.smoke-tests-aws.result }}" == "failure" || "${{ needs.smoke-tests-kubernetes-resource-heavy.result }}" == "failure" || "${{ needs.smoke-tests-kubernetes-no-resource-heavy.result }}" == "failure" || "${{ needs.smoke-tests-kubernetes-no-resource-heavy-limit-deps.result }}" == "failure" || "${{ needs.smoke-tests-remote-server-kubernetes.result }}" == "failure" || "${{ needs.smoke-tests-shared-gke-api-server.result }}" == "failure" || "${{ needs.smoke-tests-lambda-job-queue.result }}" == "failure" || "${{ needs.smoke-tests-runpod-minimal.result }}" == "failure" || "${{ needs.backward-compat-test-nightly.result }}" == "failure" || "${{ needs.backward-compat-test-stable.result }}" == "failure" ]]; then
+ if [[ "${{ needs.smoke-tests-aws.result }}" == "failure" || "${{ needs.smoke-tests-kubernetes-resource-heavy.result }}" == "failure" || "${{ needs.smoke-tests-kubernetes-no-resource-heavy.result }}" == "failure" || "${{ needs.smoke-tests-kubernetes-no-resource-heavy-limit-deps.result }}" == "failure" || "${{ needs.smoke-tests-remote-server-kubernetes.result }}" == "failure" || "${{ needs.smoke-tests-kubernetes-jobs-consolidation.result }}" == "failure" || "${{ needs.smoke-tests-shared-gke-api-server.result }}" == "failure" || "${{ needs.smoke-tests-lambda-job-queue.result }}" == "failure" || "${{ needs.backward-compat-test-nightly.result }}" == "failure" || "${{ needs.backward-compat-test-stable.result }}" == "failure" ]]; then
if [[ "${{ needs.smoke-tests-aws.result }}" == "failure" ]]; then
BUILDKITE_MSG=""
fi
@@ -403,6 +417,12 @@ jobs:
fi
BUILDKITE_MSG="${BUILDKITE_MSG} "
fi
+ if [[ "${{ needs.smoke-tests-kubernetes-jobs-consolidation.result }}" == "failure" ]]; then
+ if [[ ! -z "$BUILDKITE_MSG" ]]; then
+ BUILDKITE_MSG="${BUILDKITE_MSG} and"
+ fi
+ BUILDKITE_MSG="${BUILDKITE_MSG} "
+ fi
if [[ "${{ needs.smoke-tests-shared-gke-api-server.result }}" == "failure" ]]; then
if [[ ! -z "$BUILDKITE_MSG" ]]; then
BUILDKITE_MSG="${BUILDKITE_MSG} and"
diff --git a/.github/workflows/publish-helm.yml b/.github/workflows/publish-helm.yml
index 005bbe51aa0..a6611187266 100644
--- a/.github/workflows/publish-helm.yml
+++ b/.github/workflows/publish-helm.yml
@@ -115,11 +115,6 @@ jobs:
line=$(grep -n "^-----*$" src/README.md | cut -d: -f1 | head -n 1)
tail -n +$line src/README.md >> src/charts/skypilot/README.md
- # Update the version in the external-metrics chart (prometheus server)
- # todo(rohan): update name the way we do for the main skypilot chart?
- sed -i "s/^version:.*$/version: ${semversion}/" src/charts/external-metrics/Chart.yaml
- sed -i "s/^appVersion:.*$/appVersion: ${version}/" src/charts/external-metrics/Chart.yaml
-
- name: Update docker image in charts
if: inputs.version != ''
run: |
diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml
index 7d2847ab501..43c85f77fc4 100644
--- a/.github/workflows/pylint.yml
+++ b/.github/workflows/pylint.yml
@@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ["3.8"]
+ python-version: ["3.9"]
steps:
- uses: actions/checkout@v3
- name: Install the latest version of uv
diff --git a/.github/workflows/pytest-optimizer.yml b/.github/workflows/pytest-optimizer.yml
index 3f8667b67d7..578b0cb0da0 100644
--- a/.github/workflows/pytest-optimizer.yml
+++ b/.github/workflows/pytest-optimizer.yml
@@ -19,13 +19,16 @@ jobs:
python-version: ["3.9"]
test-path:
- "tests/test_optimizer_dryruns.py -k \"partial\""
- - "tests/test_optimizer_dryruns.py -k \"not partial\""
+ - "tests/test_optimizer_dryruns.py -k \"not partial and not accelerator_memory and not accelerator_manufacturer\""
+ - "tests/test_optimizer_dryruns.py -k \"accelerator_memory or accelerator_manufacturer\""
- tests/test_optimizer_random_dag.py
include:
- test-path: "tests/test_optimizer_dryruns.py -k \"partial\""
test-name: "Optimizer Dryruns Part 1"
- - test-path: "tests/test_optimizer_dryruns.py -k \"not partial\""
+ - test-path: "tests/test_optimizer_dryruns.py -k \"not partial and not accelerator_memory and not accelerator_manufacturer\""
test-name: "Optimizer Dryruns Part 2"
+ - test-path: "tests/test_optimizer_dryruns.py -k \"accelerator_memory or accelerator_manufacturer\""
+ test-name: "Optimizer Dryruns Part 3"
- test-path: tests/test_optimizer_random_dag.py
test-name: "Optimizer Random DAG Tests"
runs-on: ubuntu-latest
@@ -38,4 +41,4 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
test-path: ${{ matrix.test-path }}
- test-name: ${{ matrix.test-name }}
\ No newline at end of file
+ test-name: ${{ matrix.test-name }}
diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml
index 271e967729a..422db1a7f9c 100644
--- a/.github/workflows/pytest.yml
+++ b/.github/workflows/pytest.yml
@@ -21,7 +21,8 @@ jobs:
# Group them based on running time to save CI time and resources
- tests/unit_tests
- tests/test_cli.py
- - tests/test_jobs_and_serve.py tests/test_yaml_parser.py tests/test_global_user_state.py tests/test_config.py tests/test_jobs.py tests/test_list_accelerators.py tests/test_wheels.py tests/test_api.py tests/test_storage.py tests/test_api_compatibility.py
+ - tests/test_jobs.py tests/test_jobs_and_serve.py tests/test_list_accelerators.py tests/test_api.py
+ - tests/test_config.py tests/test_wheels.py tests/test_yaml_parser.py tests/test_global_user_state.py tests/test_storage.py tests/test_api_compatibility.py tests/test_infra_k8s_alias.py
- tests/test_no_parellel.py
- tests/test_ssh_proxy_lag.py
include:
@@ -31,8 +32,10 @@ jobs:
test-name: "Unit Tests"
- test-path: tests/test_cli.py
test-name: "CLI Tests"
- - test-path: tests/test_jobs_and_serve.py tests/test_yaml_parser.py tests/test_global_user_state.py tests/test_config.py tests/test_jobs.py tests/test_list_accelerators.py tests/test_wheels.py tests/test_api.py tests/test_storage.py tests/test_api_compatibility.py tests/test_infra_k8s_alias.py
- test-name: "Jobs, Serve, Wheels, API, Config, Optimizer & Storage Tests"
+ - test-path: tests/test_jobs.py tests/test_jobs_and_serve.py tests/test_list_accelerators.py tests/test_api.py
+ test-name: "Jobs & API Tests"
+ - test-path: tests/test_config.py tests/test_wheels.py tests/test_yaml_parser.py tests/test_global_user_state.py tests/test_storage.py tests/test_api_compatibility.py tests/test_infra_k8s_alias.py
+ test-name: "Config, Storage & Compatibility Tests"
- test-path: tests/test_no_parellel.py
test-name: "No Parallel Tests"
- test-path: tests/test_ssh_proxy_lag.py
diff --git a/.github/workflows/release-build.yml b/.github/workflows/release-build.yml
index 80819f5c53a..9342d563a34 100644
--- a/.github/workflows/release-build.yml
+++ b/.github/workflows/release-build.yml
@@ -190,7 +190,7 @@ jobs:
smoke-tests:
needs: release-build
if: |
- always() &&
+ always() &&
needs.release-build.result == 'success' &&
github.event.inputs.skip_smoke_tests != 'true'
uses: ./.github/workflows/buildkite-trigger-wait.yml
@@ -208,7 +208,7 @@ jobs:
quicktest-core:
needs: release-build
if: |
- always() &&
+ always() &&
needs.release-build.result == 'success' &&
github.event.inputs.skip_smoke_tests != 'true'
uses: ./.github/workflows/buildkite-trigger-wait.yml
@@ -227,7 +227,7 @@ jobs:
quicktest-core-previous-minor:
needs: release-build
if: |
- always() &&
+ always() &&
needs.release-build.result == 'success' &&
github.event.inputs.skip_smoke_tests != 'true'
uses: ./.github/workflows/buildkite-trigger-wait.yml
@@ -246,7 +246,7 @@ jobs:
smoke-tests-remote-server-kubernetes:
needs: release-build
if: |
- always() &&
+ always() &&
needs.release-build.result == 'success' &&
github.event.inputs.skip_smoke_tests != 'true'
uses: ./.github/workflows/buildkite-trigger-wait.yml
@@ -263,10 +263,29 @@ jobs:
secrets:
BUILDKITE_TOKEN: ${{ secrets.BUILDKITE_TOKEN }}
+ smoke-tests-kubernetes-jobs-consolidation:
+ needs: release-build
+ if: |
+ always() &&
+ needs.release-build.result == 'success' &&
+ github.event.inputs.skip_smoke_tests != 'true'
+ uses: ./.github/workflows/buildkite-trigger-wait.yml
+ with:
+ commit: ${{ needs.release-build.outputs.new_commit_sha }}
+ branch: ${{ needs.release-build.outputs.test_branch }}
+ message: "Release ${{ needs.release-build.outputs.release_version }} --kubernetes --jobs-consolidation --no-resource-heavy"
+ pipeline: "smoke-tests"
+ build_env_vars: '{"ARGS": "--kubernetes --jobs-consolidation --no-resource-heavy"}'
+ timeout_minutes: 60
+ wait: true
+ fail_on_buildkite_failure: true
+ secrets:
+ BUILDKITE_TOKEN: ${{ secrets.BUILDKITE_TOKEN }}
+
release-tests:
needs: release-build
if: |
- always() &&
+ always() &&
needs.release-build.result == 'success' &&
github.event.inputs.skip_smoke_tests != 'true'
uses: ./.github/workflows/buildkite-trigger-wait.yml
@@ -281,7 +300,7 @@ jobs:
BUILDKITE_TOKEN: ${{ secrets.BUILDKITE_TOKEN }}
create-pr:
- needs: [release-build, smoke-tests, quicktest-core, quicktest-core-previous-minor, smoke-tests-remote-server-kubernetes, release-tests]
+ needs: [release-build, smoke-tests, quicktest-core, quicktest-core-previous-minor, smoke-tests-remote-server-kubernetes, smoke-tests-kubernetes-jobs-consolidation, release-tests]
if: always() && needs.release-build.result == 'success'
runs-on: ubuntu-latest
steps:
@@ -314,21 +333,21 @@ jobs:
if [ "$SKIP_SMOKE_TESTS" == "true" ]; then
if [ "$IS_RC_PROMOTION" == "true" ]; then
PR_BODY="## Promote RC to Stable Release ${RELEASE_VERSION}
-
- **Source:** \`$SOURCE_BRANCH\` (RC version: $RC_VERSION)
+
+ **Source:** \`$SOURCE_BRANCH\` (RC version: $RC_VERSION)
**Target:** Stable release \`${RELEASE_VERSION}\`
-
+
⚠️ **Smoke tests were SKIPPED** - This release is being promoted from a tested RC.
-
+
### Pre-release Testing
This version was previously tested as release candidate \`$RC_VERSION\` and deemed stable by early adopters.
-
+
### Changes in this PR
- Updated \`sky/__init__.py\`: \`$RC_VERSION\` → \`${RELEASE_VERSION}\`
- Updated \`charts/skypilot/values.yaml\`: Docker image tag \`$RC_VERSION\` → \`${RELEASE_VERSION}\`"
else
PR_BODY="Release ${RELEASE_VERSION}
-
+
⚠️ **Smoke tests were SKIPPED** - Please ensure manual testing was performed."
fi
else
@@ -340,6 +359,7 @@ jobs:
- [Quicktest Core](https://buildkite.com/skypilot-1/quicktest-core/builds/${{ needs.quicktest-core.outputs.build_number }}) - $([ "${{ needs.quicktest-core.outputs.build_status }}" == "success" ] && echo "✅ Success" || echo "❌ Failed")
- [Quicktest Core (vs Previous Minor)](https://buildkite.com/skypilot-1/quicktest-core/builds/${{ needs.quicktest-core-previous-minor.outputs.build_number }}) - $([ "${{ needs.quicktest-core-previous-minor.outputs.build_status }}" == "success" ] && echo "✅ Success" || echo "❌ Failed")
- [Smoke Tests Remote Server Kubernetes](https://buildkite.com/skypilot-1/smoke-tests/builds/${{ needs.smoke-tests-remote-server-kubernetes.outputs.build_number }}) - $([ "${{ needs.smoke-tests-remote-server-kubernetes.outputs.build_status }}" == "success" ] && echo "✅ Success" || echo "❌ Failed")
+ - [Smoke Tests Kubernetes (Jobs Consolidation)](https://buildkite.com/skypilot-1/smoke-tests/builds/${{ needs.smoke-tests-kubernetes-jobs-consolidation.outputs.build_number }}) - $([ "${{ needs.smoke-tests-kubernetes-jobs-consolidation.outputs.build_status }}" == "success" ] && echo "✅ Success" || echo "❌ Failed")
- [Release Tests](https://buildkite.com/skypilot-1/release/builds/${{ needs.release-tests.outputs.build_number }}) - ⏳ (not waiting for completion)
*Release Tests may take up to 24 hours to complete and might fail due to resource constraints.*"
@@ -384,6 +404,7 @@ jobs:
- [Quicktest Core](https://buildkite.com/skypilot-1/quicktest-core/builds/${{ needs.quicktest-core.outputs.build_number }}) - $([ "${{ needs.quicktest-core.outputs.build_status }}" == "success" ] && echo "✅ Success" || echo "❌ Failed")
- [Quicktest Core (vs Previous Minor)](https://buildkite.com/skypilot-1/quicktest-core/builds/${{ needs.quicktest-core-previous-minor.outputs.build_number }}) - $([ "${{ needs.quicktest-core-previous-minor.outputs.build_status }}" == "success" ] && echo "✅ Success" || echo "❌ Failed")
- [Smoke Tests Remote Server Kubernetes](https://buildkite.com/skypilot-1/smoke-tests/builds/${{ needs.smoke-tests-remote-server-kubernetes.outputs.build_number }}) - $([ "${{ needs.smoke-tests-remote-server-kubernetes.outputs.build_status }}" == "success" ] && echo "✅ Success" || echo "❌ Failed")
+ - [Smoke Tests Kubernetes (Jobs Consolidation)](https://buildkite.com/skypilot-1/smoke-tests/builds/${{ needs.smoke-tests-kubernetes-jobs-consolidation.outputs.build_number }}) - $([ "${{ needs.smoke-tests-kubernetes-jobs-consolidation.outputs.build_status }}" == "success" ] && echo "✅ Success" || echo "❌ Failed")
- [Release Tests](https://buildkite.com/skypilot-1/release/builds/${{ needs.release-tests.outputs.build_number }}) - ⏳ (not waiting for completion)
*Release Tests may take up to 24 hours to complete and might fail due to resource constraints.*
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 501eca9c287..96f4dd9fb30 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -33,16 +33,16 @@ repos:
files: "^sky/skylet/providers/ibm/.*" # Only match IBM-specific directory
- repo: https://github.com/pre-commit/mirrors-mypy
- rev: v1.14.1 # Match the version from requirements
+ rev: v1.19.1 # Match the version from requirements
hooks:
- id: mypy
- args:
- # From tests/mypy_files.txt
+ args: # Match tests/mypy_files.txt - check sky and examples/admin_policy/example_policy
- "sky"
+ - "examples/admin_policy/example_policy"
+ - "--exclude"
+ - "sky/backends/monkey_patches"
- "--exclude"
- - "sky/benchmark|sky/callbacks|sky/backends/monkey_patches"
- - "--cache-dir"
- - "/dev/null"
+ - "examples/admin_policy/example_policy/build"
- "--check-untyped-defs"
pass_filenames: false
additional_dependencies:
@@ -96,7 +96,7 @@ repos:
- id: dashboard-format
name: dashboard format
- entry: bash -c 'cd sky/dashboard && npm run format'
+ entry: bash -c 'cd sky/dashboard && npm run format -- --log-level warn'
language: node
language_version: 24.12.0
files: ^sky/dashboard/
diff --git a/AGENTS.md b/AGENTS.md
index 43b52af0c3b..bf0e90bc959 100644
--- a/AGENTS.md
+++ b/AGENTS.md
@@ -105,7 +105,7 @@ From `requirements-dev.txt`:
- yapf==0.32.0
- pylint==2.14.5
- black==22.10.0
-- mypy==1.14.1
+- mypy==1.19.1
- isort==5.12.0
- pylint-quotes==0.2.3
@@ -315,6 +315,32 @@ sky api start
sky api status
```
+### Dashboard Development
+
+**For local API server development**, rebuild the dashboard before restarting:
+
+```bash
+# Install dependencies (first time or after package.json changes)
+npm --prefix sky/dashboard install
+
+# Rebuild the dashboard
+npm --prefix sky/dashboard run build
+
+# Then restart the API server
+sky api stop
+sky api start
+```
+
+**For remote API server (Docker/Kubernetes)**, the Dockerfile automatically builds the dashboard - no manual build needed before `docker build`.
+
+The dashboard is a Next.js application. For development with hot reloading:
+
+```bash
+# Run dashboard in development mode (separate from API server)
+cd sky/dashboard
+npm run dev
+```
+
### Mocking Remote API Server Locally
To test remote API server behavior locally:
@@ -345,7 +371,7 @@ helm dependency build ./charts/skypilot
DOCKER_IMAGE=my-repo/skypilot:v1
docker buildx build --push --platform linux/amd64 -t $DOCKER_IMAGE -f Dockerfile .
-# Deploy
+# Deploy (NEW installation)
NAMESPACE=skypilot
RELEASE_NAME=skypilot
helm upgrade --install $RELEASE_NAME ./charts/skypilot --devel \
@@ -354,6 +380,33 @@ helm upgrade --install $RELEASE_NAME ./charts/skypilot --devel \
--set apiService.image=$DOCKER_IMAGE
```
+#### Upgrading Existing Deployments
+
+**CRITICAL:** Always use `--reuse-values` to preserve database/credential config:
+
+```bash
+# Upgrade existing deployment (keeps PostgreSQL, auth, etc.)
+helm upgrade skypilot ./charts/skypilot -n skypilot --reuse-values \
+ --set apiService.image=$DOCKER_IMAGE
+
+# Check current values / rollback if needed
+helm get values skypilot -n skypilot
+helm rollback skypilot -n skypilot
+```
+
+#### PostgreSQL Backend
+
+```bash
+# Create connection secret
+kubectl create secret generic db-uri -n skypilot \
+ --from-literal=uri="postgresql://user:pass@host:5432/db"
+
+# Deploy with PostgreSQL
+helm upgrade --install skypilot ./charts/skypilot -n skypilot \
+ --set apiService.dbConnectionSecretName=db-uri \
+ --set storage.enabled=false
+```
+
## Critical Code Paths (Handle with Care)
The following modules contain complex, stateful logic that requires careful review when modifying:
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index ede9764b958..ee67a6affae 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -21,7 +21,7 @@ Follow the steps below to set up a local development environment for contributin
#### Create a virtual environment
To avoid package conflicts, create and activate a clean virtual environment using [uv](https://docs.astral.sh/uv/):
```bash
-# SkyPilot requires python 3.8-3.11.
+# SkyPilot requires python 3.9-3.11.
# --seed is required to ensure pip is installed (needed for building wheels)
uv venv --seed --python 3.11
source .venv/bin/activate
@@ -100,6 +100,53 @@ py-spy top -- python -m sky.cli status # Get a live top view
py-spy -h # For more options
```
+#### Testing WSL features on a windows VM (Azure)
+
+To test features that require Windows Subsystem for Linux (WSL), such as the automatic Windows SSH config setup, you can create a Windows VM on Azure:
+
+```bash
+# Create resource group
+az group create --name wsl-test-vm --location eastus2
+
+# Create Windows 11 VM with WSL-compatible settings
+az vm create \
+ --resource-group wsl-test-vm \
+ --name win11-wsl-test \
+ --image MicrosoftWindowsDesktop:windows-11:win11-24h2-pro:latest \
+ --size Standard_D4s_v3 \
+ --admin-username skyuser \
+ --admin-password 'YourPassword123!' \
+ --public-ip-sku Standard
+
+# Enable WSL features on the VM
+az vm run-command invoke \
+ --resource-group wsl-test-vm \
+ --name win11-wsl-test \
+ --command-id RunPowerShellScript \
+ --scripts "
+ dism.exe /online /enable-feature /featurename:Microsoft-Windows-Subsystem-Linux /all /norestart
+ dism.exe /online /enable-feature /featurename:VirtualMachinePlatform /all /norestart
+ "
+
+# Restart VM to apply WSL features
+az vm restart --resource-group wsl-test-vm --name win11-wsl-test
+
+# Get VM public IP for RDP connection
+az vm show --resource-group wsl-test-vm --name win11-wsl-test --show-details --query publicIps -o tsv
+```
+
+Connect via RDP, then in PowerShell (as Admin):
+```powershell
+wsl --install -d Ubuntu-22.04
+```
+
+After restart, set up Ubuntu and install SkyPilot to test WSL-specific features.
+
+**Cleanup:**
+```bash
+az group delete --name wsl-test-vm --yes --no-wait
+```
+
#### Testing in a container
It is often useful to test your changes in a clean environment set up from scratch. Using a container is a good way to do this.
diff --git a/Dockerfile b/Dockerfile
index 73c6645c28a..1eb285efa6d 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -70,6 +70,7 @@ ARG NEXT_BASE_PATH=/dashboard
# Install system packages
RUN apt-get update -y && \
+ apt-get upgrade -y && \
apt-get install --no-install-recommends -y \
git gcc rsync sudo patch openssh-server \
pciutils nano fuse socat netcat-openbsd curl tini autossh jq logrotate && \
diff --git a/README.md b/README.md
index 981703b4494..9db0abca330 100644
--- a/README.md
+++ b/README.md
@@ -50,6 +50,7 @@ SkyPilot gives **AI teams** a simple interface to run jobs on any infra.
:fire: *News* :fire:
- [Dec 2025] **SkyPilot v0.11** released: Multi-Cloud Pools, Fast Managed Jobs, Enterprise-Readiness at Large Scale, Programmability. [**Release notes**](https://github.com/skypilot-org/skypilot/releases/tag/v0.11.0)
- [Dec 2025] **SkyPilot Pools** released: Run batch inference and other jobs on a managed pool of warm workers (across clouds or clusters). [**blog**](https://blog.skypilot.co/skypilot-pools-deepseek-ocr/), [**docs**](https://docs.skypilot.co/en/latest/examples/pools.html)
+- [Dec 2025] Train **an agent to use Google Search** as a tool with RL on your Kubernetes or clouds: [**blog**](https://blog.skypilot.co/verl-tool-calling/), [**example**](./llm/verl/)
- [Nov 2025] Serve **Kimi K2 Thinking** with reasoning capabilities on your Kubernetes or clouds: [**example**](./llm/kimi-k2-thinking/)
- [Oct 2025] Run **RL training for LLMs** with SkyRL on your Kubernetes or clouds: [**example**](./llm/skyrl/)
- [Oct 2025] Train and serve [Andrej Karpathy's](https://x.com/karpathy/status/1977755427569111362) **nanochat** - the best ChatGPT that $100 can buy: [**example**](./llm/nanochat)
diff --git a/charts/external-metrics/.gitignore b/charts/external-metrics/.gitignore
deleted file mode 100644
index 2946e34f050..00000000000
--- a/charts/external-metrics/.gitignore
+++ /dev/null
@@ -1,2 +0,0 @@
-Chart.lock
-charts/
diff --git a/charts/external-metrics/Chart.yaml b/charts/external-metrics/Chart.yaml
deleted file mode 100644
index ed5106c27bd..00000000000
--- a/charts/external-metrics/Chart.yaml
+++ /dev/null
@@ -1,11 +0,0 @@
-apiVersion: v2
-name: skypilot-prometheus-server
-description: A Helm chart for deploying Prometheus Server
-type: application
-version: 0.0.0
-appVersion: "0.0"
-dependencies:
- - name: prometheus
- version: 27.20.0
- repository: https://prometheus-community.github.io/helm-charts
- condition: prometheus.enabled
diff --git a/charts/external-metrics/values.yaml b/charts/external-metrics/values.yaml
deleted file mode 100644
index 4835025d3ff..00000000000
--- a/charts/external-metrics/values.yaml
+++ /dev/null
@@ -1,28 +0,0 @@
-# Set configuration for Prometheus helm chart
-prometheus:
- enabled: true
- # Refer to https://github.com/prometheus-community/helm-charts/blob/main/charts/prometheus/values.yaml for available values.
- # Keep the installation minimal by default. If you want to monitor more resources other than the API server,
- # it is recommended to install and manage prometheus separately.
- # SkyPilot API server will be automatically discovered by the prometheus if it runs with the default kubernetes discovery configuration.
- server:
- persistentVolume:
- enabled: true
- size: 50Gi
- retention: "1000d"
- # The Prometheus documentations recommends setting the retention size to be 80-85% of the persistent volume size.
- # ref: https://prometheus.io/docs/prometheus/latest/storage/#right-sizing-retention-size
- # 43GB is roughly 80% of the 50Gi persistent volume size. We use Gi for the PV size and GB for the retention size
- # because these are the units specified by the Prometheus chart schema for each respective field.
- retentionSize: "43GB"
- kube-state-metrics:
- enabled: true
- # TODO (kyuds): remove skypilot-cluster label in v0.12.0; deprecated in favor of skypilot-cluster-name.
- metricLabelsAllowlist:
- - pods=[skypilot-cluster,skypilot-cluster-name]
- prometheus-node-exporter:
- enabled: false
- prometheus-pushgateway:
- enabled: false
- alertmanager:
- enabled: false
diff --git a/charts/skypilot/templates/NOTES.txt b/charts/skypilot/templates/NOTES.txt
index 4e6b4e8de54..69d5e64178f 100644
--- a/charts/skypilot/templates/NOTES.txt
+++ b/charts/skypilot/templates/NOTES.txt
@@ -3,3 +3,4 @@
{{- end }}
{{- include "skypilot.checkUpgradeConfig" . }}
{{- include "skypilot.validateOAuthConfig" . }}
+{{- include "skypilot.validateExternalProxyConfig" . }}
diff --git a/charts/skypilot/templates/_helpers.tpl b/charts/skypilot/templates/_helpers.tpl
index f4339c7cf71..55a0807fa4d 100644
--- a/charts/skypilot/templates/_helpers.tpl
+++ b/charts/skypilot/templates/_helpers.tpl
@@ -175,3 +175,18 @@ false
{{- fail "Error\nauth.oauth.enabled cannot be used together with ingress OAuth2 proxy authentication (ingress.oauth2-proxy.enabled). These authentication methods are mutually exclusive. Please:\n1. Disable auth.oauth.enabled, OR\n2. Set ingress.oauth2-proxy.enabled to false\nThen try again." -}}
{{- end -}}
{{- end -}}
+
+{{/* Validate the external proxy config */}}
+{{- define "skypilot.validateExternalProxyConfig" -}}
+{{- $externalProxyEnabled := .Values.auth.externalProxy.enabled -}}
+{{- $authOAuthEnabled := .Values.auth.oauth.enabled -}}
+{{- $ingressOAuthEnabled := include "skypilot.ingressOAuthEnabled" . | trim | eq "true" -}}
+
+{{- if and $externalProxyEnabled $authOAuthEnabled -}}
+ {{- fail "Error\nauth.externalProxy.enabled cannot be used together with auth.oauth.enabled. These authentication methods are mutually exclusive. Please:\n1. Disable auth.externalProxy.enabled, OR\n2. Set auth.oauth.enabled to false\nThen try again." -}}
+{{- end -}}
+
+{{- if and $externalProxyEnabled $ingressOAuthEnabled -}}
+ {{- fail "Error\nauth.externalProxy.enabled cannot be used together with ingress.oauth2-proxy.enabled. These authentication methods are mutually exclusive. Please:\n1. Disable auth.externalProxy.enabled, OR\n2. Set ingress.oauth2-proxy.enabled to false\nThen try again." -}}
+{{- end -}}
+{{- end -}}
diff --git a/charts/skypilot/templates/api-deployment.yaml b/charts/skypilot/templates/api-deployment.yaml
index 8197ca29b6f..32987835615 100644
--- a/charts/skypilot/templates/api-deployment.yaml
+++ b/charts/skypilot/templates/api-deployment.yaml
@@ -11,8 +11,8 @@ spec:
{{- if and (not .Values.apiService.dbConnectionSecretName) (not .Values.apiService.dbConnectionString) }}
{{- fail "External database must be configured via .apiService.dbConnectionSecretName or .apiService.dbConnectionString when using RollingUpdate strategy" }}
{{- end }}
- {{- if .Values.storage.enabled }}
- {{- fail "Local storage is not supported when using RollingUpdate strategy. Use recreate upgrade strategy or set storage.enabled to false." }}
+ {{- if and .Values.storage.enabled (ne .Values.storage.accessMode "ReadWriteMany") }}
+ {{- fail "Local storage with ReadWriteOnce access mode is not supported when using RollingUpdate strategy. Either use Recreate upgrade strategy, set storage.enabled to false, or use ReadWriteMany access mode with a compatible storage class (e.g., NFS-backed storage like Google Filestore)." }}
{{- end }}
strategy:
type: RollingUpdate
@@ -82,7 +82,7 @@ spec:
value: {{ .Values.apiService.skypilotDev | quote }}
- name: SKYPILOT_RELEASE_NAME
value: {{ $fullName | quote }}
- {{- if include "skypilot.enableBasicAuthInAPIServer" . | trim | eq "true" }}
+ {{- if and (eq (include "skypilot.enableBasicAuthInAPIServer" . | trim) "true") (ne (include "skypilot.initialBasicAuthSecretName" . | trim) "") }}
- name: SKYPILOT_INITIAL_BASIC_AUTH
valueFrom:
secretKeyRef:
@@ -136,10 +136,18 @@ spec:
- name: SKYPILOT_ROLLING_UPDATE_ENABLED
value: "true"
{{- end }}
+ {{- if .Values.storage.enabled }}
+ - name: SKYPILOT_API_SERVER_STORAGE_ENABLED
+ value: "true"
+ {{- end }}
{{- if .Values.apiService.metrics.enabled }}
- name: SKY_API_SERVER_METRICS_ENABLED
value: "true"
{{- end }}
+ {{- if .Values.auth.disableBasicAuthMiddleware }}
+ - name: SKYPILOT_DISABLE_BASIC_AUTH_MIDDLEWARE
+ value: "true"
+ {{- end }}
{{- if .Values.auth.oauth.enabled }}
- name: SKYPILOT_AUTH_OAUTH2_PROXY_ENABLED
value: "true"
@@ -256,16 +264,42 @@ spec:
periodSeconds: 5
initialDelaySeconds: 5
volumeMounts:
+ {{- if and .Values.storage.enabled (eq .Values.apiService.upgradeStrategy "RollingUpdate") }}
+ # For RollingUpdate with storage enabled, use emptyDir for ~/.sky to avoid
+ # running SQLite on NFS. Only persist the clients directory for file mounts.
+ # An ephemeral volume is still required since we have to share ~/.sky between
+ # containers.
+ - name: sky-ephemeral
+ mountPath: /root/.sky
+ - name: state-volume
+ mountPath: /root/.sky/api_server/clients
+ subPath: api_server/clients
+ {{- else }}
- name: state-volume
mountPath: /root/.sky
subPath: .sky
+ {{- end }}
{{- if .Values.storage.enabled }}
- name: state-volume
mountPath: /root/.ssh # To preserve the SSH keys for the user when using the API server
subPath: .ssh
+ # Mount only the specific subdirectories needed for managed job logs, not the entire sky_logs folder
+ # This avoids persisting transient cluster logs (sky-*) and api_server logs
+ - name: state-volume
+ mountPath: /root/sky_logs/jobs_controller # Controller logs for `sky jobs logs --controller`
+ subPath: sky_logs/jobs_controller
+ - name: state-volume
+ mountPath: /root/sky_logs/managed_jobs # Task execution logs for managed jobs
+ subPath: sky_logs/managed_jobs
{{- end }}
- name: skypilot-config
mountPath: /var/skypilot/config
+ {{- if or .Values.auth.externalProxy.enabled (eq (include "skypilot.ingressOAuthEnabled" .) "true") }}
+ - name: skypilot-server-config
+ mountPath: /root/.sky/.server.yaml
+ subPath: server.yaml
+ readOnly: true
+ {{- end }}
{{- if .Values.apiService.sshNodePools }}
- name: skypilot-ssh-node-pools
mountPath: /var/skypilot/ssh_node_pool
@@ -374,9 +408,14 @@ spec:
sleep 60;
done
volumeMounts:
+ {{- if and .Values.storage.enabled (eq .Values.apiService.upgradeStrategy "RollingUpdate") }}
+ - name: sky-ephemeral
+ mountPath: /root/.sky
+ {{- else }}
- name: state-volume
mountPath: /root/.sky
subPath: .sky
+ {{- end }}
{{- end }}
{{- with .Values.apiService.sidecarContainers }}
{{- toYaml . | nindent 6 }}
@@ -660,6 +699,13 @@ spec:
- name: state-volume
persistentVolumeClaim:
claimName: {{ $fullName }}-state
+ {{- if eq .Values.apiService.upgradeStrategy "RollingUpdate" }}
+ # When using RollingUpdate with storage enabled, use a separate emptyDir
+ # for ~/.sky to avoid running SQLite on NFS. Only specific subdirectories
+ # like api_server/clients are persisted to the PVC.
+ - name: sky-ephemeral
+ emptyDir: {}
+ {{- end }}
{{- else }}
- name: state-volume
emptyDir: {}
@@ -719,6 +765,11 @@ spec:
- name: skypilot-config
configMap:
name: {{ $fullName }}-config
+ {{- if or .Values.auth.externalProxy.enabled (eq (include "skypilot.ingressOAuthEnabled" .) "true") }}
+ - name: skypilot-server-config
+ configMap:
+ name: {{ $fullName }}-server-config
+ {{- end }}
{{- if .Values.apiService.sshNodePools }}
- name: skypilot-ssh-node-pools
secret:
diff --git a/charts/skypilot/templates/server-config.yaml b/charts/skypilot/templates/server-config.yaml
new file mode 100644
index 00000000000..ab0e981ad96
--- /dev/null
+++ b/charts/skypilot/templates/server-config.yaml
@@ -0,0 +1,26 @@
+{{- $externalProxyEnabled := .Values.auth.externalProxy.enabled -}}
+{{- $ingressOAuthEnabled := include "skypilot.ingressOAuthEnabled" . | trim | eq "true" -}}
+{{- if or $externalProxyEnabled $ingressOAuthEnabled -}}
+{{- $fullName := include "skypilot.fullname" . -}}
+apiVersion: v1
+kind: ConfigMap
+metadata:
+ name: {{ $fullName }}-server-config
+ namespace: {{ .Release.Namespace }}
+data:
+ server.yaml: |-
+ auth:
+ external_proxy:
+ {{- if $externalProxyEnabled }}
+ enabled: true
+ header_name: {{ .Values.auth.externalProxy.headerName | quote }}
+ header_format: {{ .Values.auth.externalProxy.headerFormat | quote }}
+ jwt_identity_claim: {{ .Values.auth.externalProxy.jwtIdentityClaim | quote }}
+ {{- else }}
+ # Enabled for ingress.oauth2-proxy compatibility
+ enabled: true
+ header_name: "X-Auth-Request-Email"
+ header_format: "plaintext"
+ jwt_identity_claim: "sub"
+ {{- end }}
+{{- end }}
diff --git a/charts/skypilot/tests/deployment_test.yaml b/charts/skypilot/tests/deployment_test.yaml
index 4e1ab79a26e..30f8e8c609a 100644
--- a/charts/skypilot/tests/deployment_test.yaml
+++ b/charts/skypilot/tests/deployment_test.yaml
@@ -148,6 +148,25 @@ tests:
valueFrom:
fieldRef:
fieldPath: metadata.uid
+ # Verify sky-ephemeral is NOT added when storage.enabled=false (backward compatibility)
+ - notContains:
+ path: spec.template.spec.volumes
+ content:
+ name: sky-ephemeral
+ emptyDir: {}
+ # Verify state-volume uses emptyDir (not PVC)
+ - contains:
+ path: spec.template.spec.volumes
+ content:
+ name: state-volume
+ emptyDir: {}
+ # Verify ~/.sky is mounted from state-volume with subPath
+ - contains:
+ path: spec.template.spec.containers[0].volumeMounts
+ content:
+ name: state-volume
+ mountPath: /root/.sky
+ subPath: .sky
- it: should use RollingUpdate strategy when configured with external database via connection string
set:
@@ -179,14 +198,162 @@ tests:
- failedTemplate:
errorMessage: "External database must be configured via .apiService.dbConnectionSecretName or .apiService.dbConnectionString when using RollingUpdate strategy"
- - it: should fail RollingUpdate strategy with local storage enabled
+ - it: should fail RollingUpdate strategy with local storage enabled using ReadWriteOnce
set:
apiService.upgradeStrategy: RollingUpdate
apiService.dbConnectionSecretName: test-db-secret
storage.enabled: true
asserts:
- failedTemplate:
- errorMessage: "Local storage is not supported when using RollingUpdate strategy. Use recreate upgrade strategy or set storage.enabled to false."
+ errorMessage: "Local storage with ReadWriteOnce access mode is not supported when using RollingUpdate strategy. Either use Recreate upgrade strategy, set storage.enabled to false, or use ReadWriteMany access mode with a compatible storage class (e.g., NFS-backed storage like Google Filestore)."
+
+ - it: should allow RollingUpdate strategy with local storage enabled using ReadWriteMany
+ set:
+ apiService.upgradeStrategy: RollingUpdate
+ apiService.dbConnectionSecretName: test-db-secret
+ storage.enabled: true
+ storage.accessMode: ReadWriteMany
+ asserts:
+ - equal:
+ path: spec.strategy.type
+ value: RollingUpdate
+ - equal:
+ path: spec.strategy.rollingUpdate.maxSurge
+ value: 1
+ - equal:
+ path: spec.strategy.rollingUpdate.maxUnavailable
+ value: 0
+
+ - it: should use emptyDir for ~/.sky when RollingUpdate with storage enabled to avoid SQLite on NFS
+ set:
+ apiService.upgradeStrategy: RollingUpdate
+ apiService.dbConnectionSecretName: test-db-secret
+ storage.enabled: true
+ storage.accessMode: ReadWriteMany
+ asserts:
+ # Should have sky-ephemeral emptyDir volume
+ - contains:
+ path: spec.template.spec.volumes
+ content:
+ name: sky-ephemeral
+ emptyDir: {}
+ # Should mount sky-ephemeral at ~/.sky
+ - contains:
+ path: spec.template.spec.containers[0].volumeMounts
+ content:
+ name: sky-ephemeral
+ mountPath: /root/.sky
+ # Should NOT mount state-volume at ~/.sky with subPath .sky
+ - notContains:
+ path: spec.template.spec.containers[0].volumeMounts
+ content:
+ name: state-volume
+ mountPath: /root/.sky
+ subPath: .sky
+
+ - it: should persist api_server/clients directory when RollingUpdate with storage enabled
+ set:
+ apiService.upgradeStrategy: RollingUpdate
+ apiService.dbConnectionSecretName: test-db-secret
+ storage.enabled: true
+ storage.accessMode: ReadWriteMany
+ asserts:
+ # Should mount api_server/clients from PVC
+ - contains:
+ path: spec.template.spec.containers[0].volumeMounts
+ content:
+ name: state-volume
+ mountPath: /root/.sky/api_server/clients
+ subPath: api_server/clients
+
+ - it: should use state-volume for ~/.sky when Recreate with storage enabled
+ set:
+ apiService.upgradeStrategy: Recreate
+ storage.enabled: true
+ asserts:
+ # Should NOT have sky-ephemeral volume
+ - notContains:
+ path: spec.template.spec.volumes
+ content:
+ name: sky-ephemeral
+ emptyDir: {}
+ # Should mount state-volume at ~/.sky with subPath
+ - contains:
+ path: spec.template.spec.containers[0].volumeMounts
+ content:
+ name: state-volume
+ mountPath: /root/.sky
+ subPath: .sky
+ # Should NOT have separate clients mount
+ - notContains:
+ path: spec.template.spec.containers[0].volumeMounts
+ content:
+ name: state-volume
+ mountPath: /root/.sky/api_server/clients
+ subPath: api_server/clients
+
+ - it: should use sky-ephemeral for logrotate sidecar when RollingUpdate with storage enabled
+ set:
+ apiService.upgradeStrategy: RollingUpdate
+ apiService.dbConnectionSecretName: test-db-secret
+ storage.enabled: true
+ storage.accessMode: ReadWriteMany
+ apiService.logs.retention.enabled: true
+ asserts:
+ # Logrotate sidecar should mount sky-ephemeral at ~/.sky
+ - contains:
+ path: spec.template.spec.containers[1].volumeMounts
+ content:
+ name: sky-ephemeral
+ mountPath: /root/.sky
+
+ - it: should use state-volume for logrotate sidecar when Recreate with storage enabled
+ set:
+ apiService.upgradeStrategy: Recreate
+ storage.enabled: true
+ apiService.logs.retention.enabled: true
+ asserts:
+ # Logrotate sidecar should mount state-volume at ~/.sky with subPath
+ - contains:
+ path: spec.template.spec.containers[1].volumeMounts
+ content:
+ name: state-volume
+ mountPath: /root/.sky
+ subPath: .sky
+
+ - it: should mount managed job log directories when storage is enabled
+ set:
+ storage.enabled: true
+ asserts:
+ - contains:
+ path: spec.template.spec.containers[0].volumeMounts
+ content:
+ name: state-volume
+ mountPath: /root/sky_logs/jobs_controller
+ subPath: sky_logs/jobs_controller
+ - contains:
+ path: spec.template.spec.containers[0].volumeMounts
+ content:
+ name: state-volume
+ mountPath: /root/sky_logs/managed_jobs
+ subPath: sky_logs/managed_jobs
+
+ - it: should not mount managed job log directories when storage is disabled
+ set:
+ storage.enabled: false
+ asserts:
+ - notContains:
+ path: spec.template.spec.containers[0].volumeMounts
+ content:
+ name: state-volume
+ mountPath: /root/sky_logs/jobs_controller
+ subPath: sky_logs/jobs_controller
+ - notContains:
+ path: spec.template.spec.containers[0].volumeMounts
+ content:
+ name: state-volume
+ mountPath: /root/sky_logs/managed_jobs
+ subPath: sky_logs/managed_jobs
- it: should honor fullnameOverride for deployment names and labels
set:
@@ -606,3 +773,90 @@ tests:
content:
name: setup-coreweave-credentials
image: registry.example.com/custom/berkeleyskypilot/skypilot-nightly:latest
+
+ # Test cases for SKYPILOT_INITIAL_BASIC_AUTH environment variable
+ - it: should set SKYPILOT_INITIAL_BASIC_AUTH when basic auth enabled with initialBasicAuthSecret
+ set:
+ apiService.enableUserManagement: true
+ apiService.initialBasicAuthSecret: my-auth-secret
+ asserts:
+ - contains:
+ path: spec.template.spec.containers[0].env
+ content:
+ name: SKYPILOT_INITIAL_BASIC_AUTH
+ valueFrom:
+ secretKeyRef:
+ name: my-auth-secret
+ key: auth
+
+ - it: should set SKYPILOT_INITIAL_BASIC_AUTH when basic auth enabled with initialBasicAuthCredentials
+ set:
+ apiService.enableUserManagement: true
+ apiService.initialBasicAuthCredentials: "user:password"
+ asserts:
+ - contains:
+ path: spec.template.spec.containers[0].env
+ content:
+ name: SKYPILOT_INITIAL_BASIC_AUTH
+ valueFrom:
+ secretKeyRef:
+ name: RELEASE-NAME-initial-basic-auth
+ key: auth
+
+ - it: should not set SKYPILOT_INITIAL_BASIC_AUTH when basic auth enabled but no initial credentials
+ set:
+ apiService.enableUserManagement: true
+ # Neither initialBasicAuthSecret nor initialBasicAuthCredentials set
+ asserts:
+ - notContains:
+ path: spec.template.spec.containers[0].env
+ content:
+ name: SKYPILOT_INITIAL_BASIC_AUTH
+
+ - it: should not set SKYPILOT_INITIAL_BASIC_AUTH when basic auth disabled
+ set:
+ apiService.enableUserManagement: false
+ apiService.initialBasicAuthSecret: my-auth-secret
+ asserts:
+ - notContains:
+ path: spec.template.spec.containers[0].env
+ content:
+ name: SKYPILOT_INITIAL_BASIC_AUTH
+
+ - it: should not set SKYPILOT_INITIAL_BASIC_AUTH when oauth2-proxy is enabled
+ set:
+ ingress.oauth2-proxy.enabled: true
+ apiService.enableUserManagement: true
+ apiService.initialBasicAuthSecret: my-auth-secret
+ asserts:
+ - notContains:
+ path: spec.template.spec.containers[0].env
+ content:
+ name: SKYPILOT_INITIAL_BASIC_AUTH
+
+ # Test cases for SKYPILOT_DISABLE_BASIC_AUTH_MIDDLEWARE environment variable
+ - it: should set SKYPILOT_DISABLE_BASIC_AUTH_MIDDLEWARE when auth.disableBasicAuthMiddleware is true
+ set:
+ auth.disableBasicAuthMiddleware: true
+ asserts:
+ - contains:
+ path: spec.template.spec.containers[0].env
+ content:
+ name: SKYPILOT_DISABLE_BASIC_AUTH_MIDDLEWARE
+ value: "true"
+
+ - it: should not set SKYPILOT_DISABLE_BASIC_AUTH_MIDDLEWARE when auth.disableBasicAuthMiddleware is false
+ set:
+ auth.disableBasicAuthMiddleware: false
+ asserts:
+ - notContains:
+ path: spec.template.spec.containers[0].env
+ content:
+ name: SKYPILOT_DISABLE_BASIC_AUTH_MIDDLEWARE
+
+ - it: should not set SKYPILOT_DISABLE_BASIC_AUTH_MIDDLEWARE by default
+ asserts:
+ - notContains:
+ path: spec.template.spec.containers[0].env
+ content:
+ name: SKYPILOT_DISABLE_BASIC_AUTH_MIDDLEWARE
diff --git a/charts/skypilot/values.schema.json b/charts/skypilot/values.schema.json
index 6ae7a2685f4..2d7e9fbaffb 100644
--- a/charts/skypilot/values.schema.json
+++ b/charts/skypilot/values.schema.json
@@ -172,6 +172,26 @@
"null"
],
"properties": {
+ "externalProxy": {
+ "type": [
+ "object",
+ "null"
+ ],
+ "properties": {
+ "enabled": {
+ "type": "boolean"
+ },
+ "headerFormat": {
+ "type": "string"
+ },
+ "headerName": {
+ "type": "string"
+ },
+ "jwtIdentityClaim": {
+ "type": "string"
+ }
+ }
+ },
"oauth": {
"type": [
"object",
diff --git a/charts/skypilot/values.yaml b/charts/skypilot/values.yaml
index 2cf87dbf1f4..9d4a4a2b980 100644
--- a/charts/skypilot/values.yaml
+++ b/charts/skypilot/values.yaml
@@ -36,6 +36,9 @@ apiService:
# - Recreate: delete the old pod first and create a new one (has downtime).
# - RollingUpdate: [EXPERIMENTAL] create a new pod first, wait for it to be ready, then delete the old one (zero downtime).
# Default to Recreate. When set to RollingUpdate, an external database must be configured via .apiService.dbConnectionSecretName or .apiService.dbConnectionString.
+ # For persistent storage with RollingUpdate, use storage.accessMode=ReadWriteMany with an RWX-capable storage class.
+ # If storage.enabled=false with RollingUpdate, file mounts and logs will be lost on pod restart; consider configuring
+ # 'jobs.bucket' in the SkyPilot config to persist file mounts to cloud storage.
upgradeStrategy: Recreate
# Deprecated: use other auth methods instead.
# Refer to https://docs.skypilot.co/en/latest/reference/auth.html for more details.
@@ -257,14 +260,53 @@ auth:
# @schema type: [boolean, null]
enabled: null
+ # Proxy authentication configuration.
+ # Use this when deploying behind an external authentication proxy
+ # (e.g., AWS ALB with Cognito, Azure Front Door, custom ingress auth).
+ # When enabled, the API server trusts the identity header from the proxy.
+ # This is mutually exclusive with auth.oauth and ingress.oauth2-proxy.
+ # @schema type: [object, null]
+ externalProxy:
+ # Enable proxy authentication.
+ # @schema type: [boolean]
+ enabled: false
+ # Header name containing the user identity.
+ # @schema type: [string]
+ headerName: 'X-Auth-Request-Email'
+ # Header format: 'plaintext' or 'jwt'.
+ # Use 'jwt' for headers that contain JWT tokens.
+ # Use 'plaintext' for headers that contain plain identity strings.
+ # @schema type: [string]
+ headerFormat: 'plaintext'
+ # JWT claim to extract identity from (only used when headerFormat is 'jwt').
+ # @schema type: [string]
+ jwtIdentityClaim: 'sub'
+
storage:
# Enable/disable persistent storage
# With this enabled, SkyPilot will use a PV to persist the internal data like states, logs, lock files, catalog, etc.
+ # Persisted data includes:
+ # - Managed job logs (accessible via `sky jobs logs ` and `sky jobs logs --controller `)
+ # - File mounts uploaded during managed job submission
+ # - API server state and configuration
+ # Note: Transient cluster logs (sky-*) and api_server logs are NOT persisted to minimize storage usage.
# Refer to https://docs.skypilot.co/en/latest/reference/architecture/state.html for more details.
+ #
+ # IMPORTANT: When using RollingUpdate upgrade strategy:
+ # - ReadWriteOnce (RWO): NOT supported - the PVC cannot be mounted by both old and new pods during rolling update.
+ # - ReadWriteMany (RWX): Supported - requires an RWX-capable storage class (e.g., NFS-backed storage like Google Filestore,
+ # AWS EFS, Azure Files, or an NFS provisioner). Both pods can mount the same PVC during the rolling update.
+ # - storage.enabled=false: Supported - but file mounts and logs will be lost on pod restart. Consider configuring
+ # 'jobs.bucket' in the SkyPilot config to use cloud storage for file mounts.
enabled: true
# Storage class name - leave empty to use cluster default
+ # For RWX storage with RollingUpdate, use a storage class that supports ReadWriteMany access mode:
+ # - GKE: Create a Filestore-backed storage class (https://cloud.google.com/filestore/docs/accessing-fileshares)
+ # - EKS: Use EFS CSI driver (https://docs.aws.amazon.com/eks/latest/userguide/efs-csi.html)
+ # - AKS: Use Azure Files (https://docs.microsoft.com/azure/aks/azure-files-dynamic-pv)
storageClassName: ""
# Access modes - ReadWriteOnce or ReadWriteMany depending on what is supported by the storage class
+ # When using RollingUpdate upgrade strategy, ReadWriteMany is required for persistent storage.
accessMode: ReadWriteOnce
# Storage size
size: 10Gi
diff --git a/docs/source/_static/custom.js b/docs/source/_static/custom.js
index 89d53b38256..9f7c75d3d64 100644
--- a/docs/source/_static/custom.js
+++ b/docs/source/_static/custom.js
@@ -35,16 +35,14 @@ document.addEventListener('DOMContentLoaded', () => {
// New items:
const newItems = [
{ selector: '.toctree-l2 > a', text: 'HTTPS Encryption' },
- { selector: '.toctree-l1 > a', text: 'Examples' },
- { selector: '.toctree-l1 > a', text: 'Workspaces: Isolating Teams' },
{ selector: '.toctree-l1 > a', text: 'External Logging Storage' },
- { selector: '.toctree-l1 > a', text: 'Authentication and RBAC' },
{ selector: '.toctree-l1 > a', text: 'Volumes' },
{ selector: '.toctree-l2 > a', text: 'Upgrading API Server' },
{ selector: '.toctree-l1 > a', text: 'High Availability Controller' },
{ selector: '.toctree-l2 > a', text: 'High Availability Controller' },
{ selector: '.toctree-l3 > a', text: 'Advanced: High Availability Controller' },
{ selector: '.toctree-l1 > a', text: 'Using a Pool of Workers' },
+ { selector: '.toctree-l1 > a', text: 'Job Groups' },
{ selector: '.toctree-l1 > a', text: 'Using Slurm' },
];
newItems.forEach(({ selector, text }) => {
diff --git a/docs/source/cloud-setup/cloud-permissions/aws.rst b/docs/source/cloud-setup/cloud-permissions/aws.rst
index 9f268b41e8f..3b5d3071774 100644
--- a/docs/source/cloud-setup/cloud-permissions/aws.rst
+++ b/docs/source/cloud-setup/cloud-permissions/aws.rst
@@ -463,11 +463,44 @@ These are the minimal policy rules required by SkyPilot:
{
"Effect": "Allow",
"Action": [
- "s3:*"
+ "s3:GetObject",
+ "s3:PutObject",
+ "s3:DeleteObject"
],
+ "Resource": "arn:aws:s3:::*/*"
+ },
+ {
+ "Effect": "Allow",
+ "Action": [
+ "s3:ListBucket",
+ "s3:GetBucketLocation"
+ ],
+ "Resource": "arn:aws:s3:::*"
+ },
+ {
+ "Effect": "Allow",
+ "Action": "s3:ListAllMyBuckets",
"Resource": "*"
}
+**Optional**: If you also want to allow SkyPilot to create and delete S3 buckets (for ``sky storage`` commands), add these additional permissions:
+
+.. code-block:: json
+
+ {
+ "Effect": "Allow",
+ "Action": [
+ "s3:CreateBucket",
+ "s3:DeleteBucket",
+ "s3:PutBucketTagging"
+ ],
+ "Resource": "arn:aws:s3:::*"
+ }
+
+.. tip::
+
+ If you are using EKS and want to set up S3 access with IAM roles, see :ref:`aws-eks-iam-roles`.
+
**Once you have added all needed policies, click Next** and follow the instructions to finish creating the policy. You can give the policy a descriptive name, such as ``minimal-skypilot-policy``.
diff --git a/docs/source/cloud-setup/cloud-permissions/kubernetes.rst b/docs/source/cloud-setup/cloud-permissions/kubernetes.rst
index c58bd28a5d2..31fa303d830 100644
--- a/docs/source/cloud-setup/cloud-permissions/kubernetes.rst
+++ b/docs/source/cloud-setup/cloud-permissions/kubernetes.rst
@@ -266,7 +266,6 @@ To create a service account that has all necessary permissions for SkyPilot (inc
apiVersion: rbac.authorization.k8s.io/v1
metadata:
name: sky-sa-cluster-role # Can be changed if needed
- namespace: default # Change to your namespace if using a different one.
labels:
parent: skypilot
rules:
@@ -291,7 +290,6 @@ To create a service account that has all necessary permissions for SkyPilot (inc
kind: ClusterRoleBinding
metadata:
name: sky-sa-cluster-role-binding # Can be changed if needed
- namespace: default # Change to your namespace if using a different one.
labels:
parent: skypilot
subjects:
@@ -300,7 +298,7 @@ To create a service account that has all necessary permissions for SkyPilot (inc
namespace: default # Change to your namespace if using a different one.
roleRef:
kind: ClusterRole
- name: sky-sa-cluster-role # Use the same name as the cluster role at line 43
+ name: sky-sa-cluster-role # Use the same name as the cluster role at line 56
apiGroup: rbac.authorization.k8s.io
---
# Optional: If using object store mounting, create the skypilot-system namespace
diff --git a/docs/source/docs/index.rst b/docs/source/docs/index.rst
index 63cac22835a..f0fe1a8d7d4 100644
--- a/docs/source/docs/index.rst
+++ b/docs/source/docs/index.rst
@@ -306,6 +306,7 @@ Read the research:
Many Parallel Jobs <../running-jobs/many-jobs>
Model Training Guide <../reference/training-guide>
Using a Pool of Workers <../examples/pools>
+ Job Groups <../examples/job-groups>
.. toctree::
:hidden:
diff --git a/docs/source/examples/interactive-development.rst b/docs/source/examples/interactive-development.rst
index 316d601f71f..1015c553816 100644
--- a/docs/source/examples/interactive-development.rst
+++ b/docs/source/examples/interactive-development.rst
@@ -115,6 +115,13 @@ For more details, please refer to the `VSCode documentation ` which run tasks sequentially (pipelines),
+Job Groups launch all tasks simultaneously, enabling complex distributed architectures.
+
+.. figure:: ../images/job-groups-dashboard.png
+ :width: 100%
+ :align: center
+ :alt: Job Groups in SkyPilot Dashboard
+
+ A Job Group with 4 tasks (data-server, rollout-server, reward-server, ppo-trainer)
+ running in parallel on Kubernetes. Each task has different resource requirements
+ and can be monitored independently through the dashboard.
+
+Overview
+--------
+
+**Key Features:**
+
+- **Parallel execution**: Launch multiple tasks simultaneously, each running independently
+- **Heterogeneous resources**: Different resource requirements per task (e.g., GPUs for training, CPUs for data serving)
+- **Automatic service discovery**: Tasks discover each other and communicate via hostnames
+- **Independent recovery**: Each task recovers from preemptions without affecting other tasks
+
+**When to Use Job Groups:**
+
+Job Groups are ideal for workloads where multiple components with different requirements need to run together and communicate. Common use cases include:
+
+- **RL post-training**: Separate tasks for trainer, reward modeling, rollout server, and data serving
+- **Parallel train-eval**: Training and evaluation running in parallel with shared storage
+
+.. tip::
+
+ Use Job Groups when your workload has **heterogeneous tasks** that need to run
+ **in parallel** and **communicate with each other**. For homogeneous multi-node
+ training within a single task, use :ref:`distributed jobs ` instead.
+ For sequential task execution, use :ref:`managed job pipelines `.
+
+.. contents:: Contents
+ :local:
+ :backlinks: none
+
+
+Creating a job group
+--------------------
+
+A Job Group is defined using a multi-document YAML file. The first document is the
+**header** that defines the group's properties, followed by individual task definitions:
+
+.. code-block:: yaml
+
+ # job-group.yaml
+ ---
+ # Header: Job Group configuration
+ name: my-job-group
+ execution: parallel # Required: indicates this is a Job Group
+ ---
+ # Task 1: Trainer
+ name: trainer
+ resources:
+ accelerators: A100:1
+ run: |
+ python train.py
+ ---
+ # Task 2: Evaluator
+ name: evaluator
+ resources:
+ accelerators: A100:1
+ run: |
+ python evaluate.py
+
+Launch the Job Group with:
+
+.. code-block:: console
+
+ $ sky jobs launch job-group.yaml
+
+Header fields
+~~~~~~~~~~~~~
+
+The header document supports the following fields:
+
+.. list-table::
+ :widths: 20 20 60
+ :header-rows: 1
+
+ * - Field
+ - Default
+ - Description
+ * - ``name``
+ - Required
+ - Name of the Job Group
+ * - ``execution``
+ - Required
+ - Must be ``parallel`` to indicate this is a Job Group
+ * - ``primary_tasks``
+ - None
+ - List of task names that are "primary". Tasks not in this list are
+ "auxiliary" - long-running services (e.g., data servers, replay buffers)
+ that wait for a signal to terminate. When all primary tasks complete,
+ auxiliary tasks are terminated. If not set, all tasks are primary.
+ * - ``termination_delay``
+ - None
+ - Delay before terminating auxiliary tasks when primary tasks complete,
+ allowing them to finish pending work (e.g., flushing data). Can be a
+ string (e.g., ``"30s"``, ``"5m"``) or a dict with per-task delays
+ (e.g., ``{"default": "30s", "replay-buffer": "1m"}``).
+
+Each task document after the header follows the standard :ref:`SkyPilot task YAML format `.
+
+.. note::
+
+ Every task in a Job Group **must have a unique name**. The name is used for
+ service discovery and log viewing.
+
+
+Service discovery
+-----------------
+
+Tasks in a Job Group can discover each other using hostnames. SkyPilot automatically
+configures networking so that tasks can communicate.
+
+Hostname format
+~~~~~~~~~~~~~~~
+
+Each task's head node is accessible via the hostname:
+
+.. code-block:: text
+
+ {task_name}-0.{job_group_name}
+
+For multi-node tasks, worker nodes use:
+
+.. code-block:: text
+
+ {task_name}-{node_index}.{job_group_name}
+
+For example, in a Job Group named ``rlhf-experiment`` with a 2-node ``trainer`` task:
+
+- ``trainer-0.rlhf-experiment`` - Head node (rank 0)
+- ``trainer-1.rlhf-experiment`` - Worker node (rank 1)
+
+Environment variables
+~~~~~~~~~~~~~~~~~~~~~
+
+SkyPilot injects the following environment variables into all tasks:
+
+.. list-table::
+ :widths: 40 60
+ :header-rows: 1
+
+ * - Variable
+ - Description
+ * - ``SKYPILOT_JOBGROUP_NAME``
+ - Name of the Job Group
+
+Example usage in a task:
+
+.. code-block:: bash
+
+ # Access the trainer task from the evaluator using the hostname
+ curl http://trainer-0.${SKYPILOT_JOBGROUP_NAME}:8000/status
+
+
+Viewing logs
+------------
+
+View logs for a specific task within a Job Group:
+
+.. code-block:: console
+
+ # View logs for a specific task by name
+ $ sky jobs logs trainer
+
+ # View logs for a specific task by task ID
+ $ sky jobs logs 0
+
+ # View all task logs (default)
+ $ sky jobs logs
+
+When viewing logs for a multi-task job, SkyPilot displays a hint:
+
+.. code-block:: console
+
+ Hint: This job has 3 tasks. Use 'sky jobs logs 42 TASK' to view logs
+ for a specific task (TASK can be task ID or name).
+
+
+Examples
+--------
+
+Parallel train-eval with shared storage
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+This example runs training and evaluation in parallel, sharing checkpoints via
+a Kubernetes PVC volume:
+
+.. figure:: ../images/job-groups-train-eval-architecture.png
+ :width: 80%
+ :align: center
+ :alt: Parallel Train-Eval Architecture with Job Groups
+
+ Parallel training and evaluation with shared storage. The trainer saves checkpoints
+ to a shared volume while the evaluator monitors and evaluates new checkpoints on-the-fly.
+
+.. code-block:: yaml
+
+ ---
+ name: train-eval
+ execution: parallel
+ ---
+ name: trainer
+ resources:
+ accelerators: A100:1
+ volumes:
+ /checkpoints: my-checkpoint-volume
+ run: |
+ python train.py --checkpoint-dir /checkpoints
+ ---
+ name: evaluator
+ resources:
+ accelerators: A100:1
+ volumes:
+ /checkpoints: my-checkpoint-volume
+ run: |
+ python evaluate.py --checkpoint-dir /checkpoints
+
+See the full example at ``llm/train-eval-jobgroup/`` in the SkyPilot repository.
+
+RL post-training architecture
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+This example demonstrates a distributed RL post-training architecture with 5 tasks:
+
+.. code-block:: yaml
+
+ ---
+ name: rlhf-training
+ execution: parallel
+ ---
+ name: data-server
+ resources:
+ cpus: 4+
+ run: |
+ python data_server.py
+ ---
+ name: rollout-server
+ num_nodes: 2
+ resources:
+ accelerators: A100:1
+ run: |
+ python rollout_server.py
+ ---
+ name: reward-server
+ resources:
+ cpus: 8+
+ run: |
+ python reward_server.py
+ ---
+ name: replay-buffer
+ resources:
+ cpus: 4+
+ memory: 32+
+ run: |
+ python replay_buffer.py
+ ---
+ name: ppo-trainer
+ num_nodes: 2
+ resources:
+ accelerators: A100:1
+ run: |
+ python ppo_trainer.py \
+ --data-server data-server-0.${SKYPILOT_JOBGROUP_NAME}:8000 \
+ --rollout-server rollout-server-0.${SKYPILOT_JOBGROUP_NAME}:8001 \
+ --reward-server reward-server-0.${SKYPILOT_JOBGROUP_NAME}:8002
+
+See the full RL post-training example at ``llm/rl-post-training-jobgroup/`` in the SkyPilot repository.
+
+Primary and auxiliary tasks
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+In many distributed workloads, you have a main task (e.g., trainer) and supporting
+services (e.g., data servers, replay buffers) that run indefinitely until the main
+task signals completion. These supporting services are "auxiliary tasks" - they
+don't have a natural termination point and need to be told when to shut down.
+
+Use ``primary_tasks`` to designate which tasks drive the job's lifecycle. Auxiliary
+tasks (those not listed) will be automatically terminated when all primary tasks
+complete:
+
+.. code-block:: yaml
+
+ ---
+ name: train-with-services
+ execution: parallel
+ primary_tasks: [trainer] # Only trainer is primary
+ termination_delay: 30s # Give services 30s to finish after trainer completes
+ ---
+ name: trainer
+ resources:
+ accelerators: A100:1
+ run: |
+ python train.py # Primary task: job completes when this finishes
+ ---
+ name: data-server
+ resources:
+ cpus: 4+
+ run: |
+ python data_server.py # Auxiliary: terminated 30s after trainer completes
+
+When the trainer task finishes, the data-server (auxiliary) task will receive a
+termination signal after the 30-second delay, allowing it to flush pending data
+or perform cleanup.
+
+
+Current limitations
+-------------------
+
+- **Co-location**: All tasks in a Job Group run on the same infrastructure
+ (same Kubernetes cluster or cloud zone).
+
+- **Networking**: Service discovery (hostname-based communication between tasks)
+ currently only works on Kubernetes. On other clouds, tasks can run in parallel
+ but cannot communicate with each other using the hostname format.
+
+.. note::
+
+ Job Groups require ``execution: parallel`` in the header. For sequential task
+ execution, use :ref:`managed job pipelines ` instead (omit the
+ ``execution`` field or set it to ``serial``).
+
+
+.. seealso::
+
+ :ref:`managed-jobs` for single tasks or sequential pipelines.
+
+ :ref:`dist-jobs` for multi-node distributed training within a single task.
diff --git a/docs/source/examples/managed-jobs.rst b/docs/source/examples/managed-jobs.rst
index 61e3d71267e..82e70db2b85 100644
--- a/docs/source/examples/managed-jobs.rst
+++ b/docs/source/examples/managed-jobs.rst
@@ -7,6 +7,13 @@ Managed Jobs
This feature is great for scaling out: running a single job for long durations, or running many jobs in parallel.
+.. seealso::
+
+ :doc:`pools` for running batch inference workloads across multiple infrastructure.
+
+ :ref:`job-groups` for running multiple heterogeneous tasks in parallel that
+ can communicate with each other.
+
SkyPilot supports **managed jobs** (:code:`sky jobs`), which can automatically retry failures, recover from spot instance preemptions, and clean up when done.
To start a managed job, use :code:`sky jobs launch`:
@@ -403,8 +410,12 @@ A pipeline is a managed job that contains a sequence of tasks running one after
This is useful for running a sequence of tasks that depend on each other, e.g., training a model and then running inference on it.
Different tasks can have different resource requirements to use appropriate per-task resources, which saves costs, while keeping the burden of managing the tasks off the user.
+.. seealso::
+
+ :ref:`job-groups` for running multiple tasks **in parallel** instead of sequentially.
+
.. note::
- In other words, a managed job is either a single task or a pipeline of tasks. All managed jobs are submitted by :code:`sky jobs launch`.
+ In other words, a managed job is either a single task, a pipeline (sequential tasks), or a :ref:`job group ` (parallel tasks). All managed jobs are submitted by :code:`sky jobs launch`.
To run a pipeline, specify the sequence of tasks in a YAML file. Here is an example:
@@ -461,6 +472,13 @@ second task has name :code:`eval`. The tasks are separated by a line with three
dashes :code:`---`. Each task has its own :code:`resources`, :code:`setup`, and
:code:`run` sections. Tasks are executed sequentially. If a task fails, later tasks are skipped.
+.. tip::
+
+ To explicitly indicate a pipeline (sequential execution), you can add
+ :code:`execution: serial` to the header. This is optional since pipelines
+ are the default when :code:`execution` is omitted. Use :code:`execution: parallel`
+ for :ref:`job groups ` instead.
+
To pass data between the tasks, use a shared file mount. In this example, the :code:`train` task writes its output to the :code:`/checkpoint` file mount, which the :code:`eval` task is then able to read from.
To submit the pipeline, the same command :code:`sky jobs launch` is used. The pipeline will be automatically launched and monitored by SkyPilot. You can check the status of the pipeline with :code:`sky jobs queue` or :code:`sky dashboard`.
diff --git a/docs/source/examples/models/index.rst b/docs/source/examples/models/index.rst
index b318907cfe9..ab837535265 100644
--- a/docs/source/examples/models/index.rst
+++ b/docs/source/examples/models/index.rst
@@ -17,7 +17,7 @@ Models
CodeLlama
Pixtral
Mixtral
- Mistral 7B
+ Mistral 7B
Qwen 3
Kimi K2
Kimi K2 Thinking
diff --git a/docs/source/examples/performance/index.rst b/docs/source/examples/performance/index.rst
index 81cc46c9ee6..baa805389e4 100644
--- a/docs/source/examples/performance/index.rst
+++ b/docs/source/examples/performance/index.rst
@@ -8,3 +8,4 @@ AI Performance
GCP/GKE GPUDirect
Coreweave with InfiniBand
Nebius with InfiniBand
+ Together AI with InfiniBand
diff --git a/docs/source/examples/performance/together_infiniband.md b/docs/source/examples/performance/together_infiniband.md
new file mode 120000
index 00000000000..606573580ae
--- /dev/null
+++ b/docs/source/examples/performance/together_infiniband.md
@@ -0,0 +1 @@
+../../generated-examples/together_infiniband.md
\ No newline at end of file
diff --git a/docs/source/getting-started/installation.rst b/docs/source/getting-started/installation.rst
index 9a54228311b..1ca456d06b2 100644
--- a/docs/source/getting-started/installation.rst
+++ b/docs/source/getting-started/installation.rst
@@ -21,10 +21,9 @@ SkyPilot supports installation with ``uv`` or ``pip``.
.. code-block:: shell
# Create a virtual environment with pip pre-installed (required for SkyPilot)
- # SkyPilot requires 3.7 <= python <= 3.13.
+ # SkyPilot requires 3.9 <= python <= 3.13.
uv venv --seed --python 3.10
source .venv/bin/activate # Use WSL on Windows
-
uv pip install skypilot
# install dependencies for the clouds you want to use
@@ -34,14 +33,14 @@ SkyPilot supports installation with ``uv`` or ``pip``.
The ``--seed`` flag is **required** as it ensures ``pip`` is installed in the virtual environment.
SkyPilot needs ``pip`` to build wheels for remote cluster setup.
-
+
.. tab-item:: uv tool
:sync: uv-tool-tab
.. code-block:: shell
# Install as a globally available tool with pip included
- # SkyPilot requires 3.7 <= python <= 3.13.
+ # SkyPilot requires 3.9 <= python <= 3.13.
uv tool install --with pip skypilot
# install dependencies for the clouds you want to use
@@ -67,6 +66,7 @@ SkyPilot supports installation with ``uv`` or ``pip``.
# install dependencies for the clouds you want to use
pip install "skypilot[kubernetes,aws,gcp]"
+
.. dropdown:: Install SkyPilot from nightly build or source
SkyPilot provides nightly builds and source code for the latest features and for development.
@@ -131,7 +131,7 @@ SkyPilot supports installation with ``uv`` or ``pip``.
git clone https://github.com/skypilot-org/skypilot.git
cd skypilot
- pip install -e .
+ pip install -e .
Alternatively, we also provide a :ref:`Docker image ` as a quick way to try out SkyPilot.
@@ -431,7 +431,7 @@ Install the necessary dependencies for AWS.
:sync: pip-tab
.. code-block:: shell
-
+
# SkyPilot requires 3.7 <= python <= 3.13.
# From stable release
pip install "skypilot[aws]"
@@ -586,7 +586,7 @@ Install the necessary dependencies for Azure.
:sync: pip-tab
.. code-block:: shell
-
+
# SkyPilot requires 3.7 <= python <= 3.13.
# From stable release
pip install "skypilot[azure]"
@@ -641,7 +641,7 @@ CoreWeave
:sync: pip-tab
.. code-block:: shell
-
+
# SkyPilot requires 3.7 <= python <= 3.13.
# From stable release
pip install "skypilot[coreweave]"
@@ -737,7 +737,7 @@ Install the necessary dependencies for Nebius.
:sync: pip-tab
.. code-block:: shell
-
+
# Nebius requires 3.10 <= python <= 3.13.
# From stable release
pip install "skypilot[nebius]"
@@ -818,7 +818,7 @@ Install the necessary dependencies for RunPod
:sync: pip-tab
.. code-block:: shell
-
+
# SkyPilot requires 3.7 <= python <= 3.13.
# From stable release
pip install "skypilot[runpod]"
@@ -868,7 +868,7 @@ Install the necessary dependencies for OCI.
:sync: pip-tab
.. code-block:: shell
-
+
# SkyPilot requires 3.7 <= python <= 3.13.
# From stable release
pip install "skypilot[oci]"
@@ -939,7 +939,7 @@ Install the necessary dependencies for Lambda Cloud.
:sync: pip-tab
.. code-block:: shell
-
+
# SkyPilot requires 3.7 <= python <= 3.13.
# From stable release
pip install "skypilot[lambda]"
@@ -989,7 +989,7 @@ Together AI |community-badge|
:sync: pip-tab
.. code-block:: shell
-
+
# SkyPilot requires 3.7 <= python <= 3.13.
# From stable release
pip install "skypilot[kubernetes]"
@@ -1042,7 +1042,7 @@ Install the necessary dependencies for Paperspace.
:sync: pip-tab
.. code-block:: shell
-
+
# SkyPilot requires 3.7 <= python <= 3.13.
# From stable release
pip install "skypilot[paperspace]"
@@ -1092,7 +1092,7 @@ Install the necessary dependencies for Vast.
:sync: pip-tab
.. code-block:: shell
-
+
# SkyPilot requires 3.7 <= python <= 3.13.
# From stable release
pip install "skypilot[vast]"
@@ -1143,7 +1143,7 @@ Install the necessary dependencies for Fluidstack.
:sync: pip-tab
.. code-block:: shell
-
+
# SkyPilot requires 3.7 <= python <= 3.13.
# From stable release
pip install "skypilot[fluidstack]"
@@ -1193,7 +1193,7 @@ Cudo Compute |community-badge|
:sync: pip-tab
.. code-block:: shell
-
+
# SkyPilot requires 3.7 <= python <= 3.13.
# From stable release
pip install "skypilot[cudo]"
@@ -1255,7 +1255,7 @@ Install the necessary dependencies for Shadeform.
:sync: pip-tab
.. code-block:: shell
-
+
# SkyPilot requires 3.7 <= python <= 3.13.
# From stable release
pip install "skypilot[shadeform]"
@@ -1308,7 +1308,7 @@ Install the necessary dependencies for IBM.
:sync: pip-tab
.. code-block:: shell
-
+
# IBM requires 3.7 <= python <= 3.11.
# From stable release
pip install "skypilot[ibm]"
@@ -1388,7 +1388,7 @@ Install the necessary dependencies for SCP.
:sync: pip-tab
.. code-block:: shell
-
+
# SCP requires 3.7 <= python <= 3.11.
# From stable release
pip install "skypilot[scp]"
@@ -1446,7 +1446,7 @@ Install the necessary dependencies for VMware vSphere.
:sync: pip-tab
.. code-block:: shell
-
+
# SkyPilot requires 3.7 <= python <= 3.13.
# From stable release
pip install "skypilot[vsphere]"
@@ -1524,7 +1524,7 @@ Install the necessary dependencies for Cloudflare R2.
:sync: pip-tab
.. code-block:: shell
-
+
# SkyPilot requires 3.7 <= python <= 3.13.
# From stable release
pip install "skypilot[cloudflare]"
@@ -1566,7 +1566,7 @@ Next, get your `Account ID `__ makes it easy to find global compute resources and train state-of-the-art models through distributed training across clusters. To configure Prime Intellect access:
+`Prime Intellect `__ makes it easy to find global compute resources and train state-of-the-art models through distributed training across clusters. To configure Prime Intellect access:
Install the necessary dependencies for Prime Intellect.
@@ -1597,7 +1597,7 @@ Install the necessary dependencies for Prime Intellect.
:sync: pip-tab
.. code-block:: shell
-
+
# SkyPilot requires 3.7 <= python <= 3.13.
# From stable release
pip install "skypilot[primeintellect]"
@@ -1653,7 +1653,7 @@ Seeweb |community-badge|
:sync: pip-tab
.. code-block:: shell
-
+
# Seeweb requires 3.10 <= python <= 3.13.
# From stable release
pip install "skypilot[seeweb]"
@@ -1715,4 +1715,4 @@ Finally, you can stop the container with:
See more details about the dev container image
``berkeleyskypilot/skypilot-nightly`` `here
-`_.
\ No newline at end of file
+`_.
diff --git a/docs/source/images/dashboard-clusters.png b/docs/source/images/dashboard-clusters.png
index 7de83651f30..d3708661261 100644
Binary files a/docs/source/images/dashboard-clusters.png and b/docs/source/images/dashboard-clusters.png differ
diff --git a/docs/source/images/dashboard-managed-jobs.png b/docs/source/images/dashboard-managed-jobs.png
index 9513f786479..52aff598bcb 100644
Binary files a/docs/source/images/dashboard-managed-jobs.png and b/docs/source/images/dashboard-managed-jobs.png differ
diff --git a/docs/source/images/job-groups-dashboard.png b/docs/source/images/job-groups-dashboard.png
new file mode 100644
index 00000000000..759b4b7aa6a
Binary files /dev/null and b/docs/source/images/job-groups-dashboard.png differ
diff --git a/docs/source/images/job-groups-rl-architecture.jpg b/docs/source/images/job-groups-rl-architecture.jpg
new file mode 100644
index 00000000000..f608e172ddb
Binary files /dev/null and b/docs/source/images/job-groups-rl-architecture.jpg differ
diff --git a/docs/source/images/job-groups-train-eval-architecture.png b/docs/source/images/job-groups-train-eval-architecture.png
new file mode 100644
index 00000000000..206a916fd1a
Binary files /dev/null and b/docs/source/images/job-groups-train-eval-architecture.png differ
diff --git a/docs/source/images/metrics/deploy-prom-operator.png b/docs/source/images/metrics/deploy-prom-operator.png
deleted file mode 100644
index 2be5229b386..00000000000
Binary files a/docs/source/images/metrics/deploy-prom-operator.png and /dev/null differ
diff --git a/docs/source/images/metrics/search-prom-operator.png b/docs/source/images/metrics/search-prom-operator.png
deleted file mode 100644
index 3260ab3b13b..00000000000
Binary files a/docs/source/images/metrics/search-prom-operator.png and /dev/null differ
diff --git a/docs/source/images/metrics/status-prom-operator.png b/docs/source/images/metrics/status-prom-operator.png
deleted file mode 100644
index 688b6e33da6..00000000000
Binary files a/docs/source/images/metrics/status-prom-operator.png and /dev/null differ
diff --git a/docs/source/images/slurm-cluster-details-page.png b/docs/source/images/slurm-cluster-details-page.png
new file mode 100644
index 00000000000..fbed25568af
Binary files /dev/null and b/docs/source/images/slurm-cluster-details-page.png differ
diff --git a/docs/source/images/slurm-infra-page.png b/docs/source/images/slurm-infra-page.png
new file mode 100644
index 00000000000..17114ab9a89
Binary files /dev/null and b/docs/source/images/slurm-infra-page.png differ
diff --git a/docs/source/reference/api-server/examples/api-server-gpu-metrics-setup.rst b/docs/source/reference/api-server/examples/api-server-gpu-metrics-setup.rst
index e8fe887ac37..2050d25ee9d 100644
--- a/docs/source/reference/api-server/examples/api-server-gpu-metrics-setup.rst
+++ b/docs/source/reference/api-server/examples/api-server-gpu-metrics-setup.rst
@@ -146,21 +146,37 @@ Prometheus setup
In the cluster where you deploy the API server, Prometheus is installed automatically as part of :ref:`api-server-setup-dcgm-metrics-scraping`.
-For other Kubernetes clusters (external clusters), deploy Prometheus manually. SkyPilot also requires a Service ``skypilot-prometheus-server`` in the ``skypilot`` namespace to scrape metrics from external clusters.
+For other Kubernetes clusters (external clusters), deploy Prometheus manually. SkyPilot requires a Service named ``skypilot-prometheus-server`` in the ``skypilot`` namespace to scrape metrics from external clusters.
-If you use the `Prometheus operator `_, e.g., the `kube-prometheus-stack `_, install it in the ``skypilot`` namespace, then create the ``skypilot-prometheus-server`` Service in the same namespace.
+First, create a ``prometheus-values.yaml`` file with the following configuration:
+
+.. literalinclude:: ../../../../../examples/metrics/prometheus-values.yaml
+ :language: yaml
+
+Then install Prometheus using ``skypilot-prometheus`` as the release name (this creates the required ``skypilot-prometheus-server`` service):
.. code-block:: bash
- kubectl create -f https://raw.githubusercontent.com/skypilot-org/skypilot/refs/heads/master/examples/metrics/skypilot_prometheus_server_service.yaml -n skypilot
+ helm repo add prometheus-community https://prometheus-community.github.io/helm-charts
+ helm repo update
+ helm upgrade --install skypilot-prometheus prometheus-community/prometheus \
+ --namespace skypilot \
+ --create-namespace \
+ -f prometheus-values.yaml
-Alternatively, install the SkyPilot Prometheus server chart; it will create the ``skypilot-prometheus-server`` Service automatically:
+Verify the service was created:
.. code-block:: bash
- helm upgrade --install skypilot skypilot/skypilot-prometheus-server --devel \
- --namespace skypilot \
- --create-namespace
+ kubectl get svc skypilot-prometheus-server -n skypilot
+
+Refer to the `Prometheus helm chart values `_ for additional configuration options.
+
+.. note::
+
+ Do not use the Prometheus Operator (kube-prometheus-stack) for GPU metrics.
+ The Prometheus Operator adds an ``exported_`` prefix to pod and namespace labels,
+ which breaks the PromQL queries used by SkyPilot.
If you are using the Nebius Kubernetes cluster, refer to :ref:`api-server-gpu-metrics-setup-nebius` for how to setup the GPU metrics.
diff --git a/docs/source/reference/api-server/examples/example-deploy-gke-nebius-okta.rst b/docs/source/reference/api-server/examples/example-deploy-gke-nebius-okta.rst
index bb3c023c709..1c75e202b6e 100644
--- a/docs/source/reference/api-server/examples/example-deploy-gke-nebius-okta.rst
+++ b/docs/source/reference/api-server/examples/example-deploy-gke-nebius-okta.rst
@@ -428,38 +428,29 @@ Setup GPU metrics in Nebius Kubernetes cluster
If you are using Nebius Kubernetes cluster, you can setup GPU metrics in the cluster to get real-time GPU metrics in the SkyPilot dashboard.
-1. Install the Prometheus operator.
+1. Install Prometheus.
-On Nebius console, in the detail page of the Nebius Kubernetes cluster, go to ``Applications`` -> Search for ``Prometheus Operator`` -> ``Deploy`` -> Enter ``skypilot`` for the ``Namespace`` -> ``Deploy application``.
+First, create a ``prometheus-values.yaml`` file with the following configuration:
-.. image:: ../../../images/metrics/search-prom-operator.png
- :alt: Search for Prometheus Operator
- :align: center
- :width: 60%
-
-.. image:: ../../../images/metrics/deploy-prom-operator.png
- :alt: Deploy Prometheus Operator
- :align: center
- :width: 60%
+.. literalinclude:: ../../../../../examples/metrics/prometheus-values.yaml
+ :language: yaml
-Wait for the Prometheus operator to be installed, the status badge will become ``Deployed``.
-
-.. image:: ../../../images/metrics/status-prom-operator.png
- :alt: Status of Prometheus Operator
- :align: center
- :width: 60%
-
-You can also check the Pod status to verify the installation.
+Then install Prometheus using ``skypilot-prometheus`` as the release name:
.. code-block:: bash
- kubectl get pods -n skypilot
+ helm repo add prometheus-community https://prometheus-community.github.io/helm-charts
+ helm repo update
+ helm upgrade --install skypilot-prometheus prometheus-community/prometheus \
+ --namespace skypilot \
+ --create-namespace \
+ -f prometheus-values.yaml
-By default, the CPU and memory metrics exported by node exporter do not include the ``node`` label, which is required for the SkyPilot dashboard to display the metrics. You can add the ``node`` label to the metrics by applying the following config to the node exporter service monitor resource:
+Verify the ``skypilot-prometheus-server`` service was created:
.. code-block:: bash
- kubectl apply -f https://raw.githubusercontent.com/skypilot-org/skypilot/refs/heads/master/examples/metrics/kube_prometheus_node_exporter_service_monitor.yaml -n skypilot
+ kubectl get svc skypilot-prometheus-server -n skypilot
2. Install the Nvidia Device Plugin.
@@ -490,21 +481,7 @@ You can also check the Pod status to verify the installation.
The dcgm exporter will be installed automatically.
-3. Create the Prometheus service for SkyPilot API server to retrieve the GPU metrics:
-
- .. code-block:: bash
-
- kubectl create -f https://raw.githubusercontent.com/skypilot-org/skypilot/refs/heads/master/examples/metrics/skypilot_prometheus_server_service.yaml -n skypilot
-
-Confirm that the service endpoint is created by running the following command:
-
-.. code-block:: bash
-
- kubectl get endpoints skypilot-prometheus-server -n skypilot
- NAME ENDPOINTS AGE
- skypilot-prometheus-server 10.24.20.128:9090 62s
-
-4. If you are using multiple Kubernetes clusters, you will need to add the context names to ``allowed_contexts`` in the SkyPilot config.
+3. If you are using multiple Kubernetes clusters, you will need to add the context names to ``allowed_contexts`` in the SkyPilot config.
An example config file that allows using the hosting Kubernetes cluster and two additional Kubernetes clusters is shown below:
diff --git a/docs/source/reference/api-server/helm-values-spec.rst b/docs/source/reference/api-server/helm-values-spec.rst
index b87246f76d9..4c9daa9604e 100644
--- a/docs/source/reference/api-server/helm-values-spec.rst
+++ b/docs/source/reference/api-server/helm-values-spec.rst
@@ -27,10 +27,6 @@ Values
Below is the available helm value keys and the default value of each key:
-..
- Omitted values:
- * storage.accessMode: accessMode other than ReadWriteOnce is not tested yet.
-
.. parsed-literal::
:ref:`global `:
@@ -100,6 +96,11 @@ Below is the available helm value keys and the default value of each key:
:ref:`cookie-expire `: null
:ref:`serviceAccount `:
:ref:`enabled `: null
+ :ref:`externalProxy `:
+ :ref:`enabled `: false
+ :ref:`headerName `: 'X-Auth-Request-Email'
+ :ref:`headerFormat `: 'plaintext'
+ :ref:`jwtIdentityClaim `: 'sub'
:ref:`storage `:
:ref:`enabled `: true
@@ -423,6 +424,11 @@ Upgrade strategy for the API server deployment. Available options are:
When set to ``RollingUpdate``, an external database must be configured via :ref:`apiService.dbConnectionSecretName ` or :ref:`apiService.dbConnectionString `.
+For persistent storage with RollingUpdate:
+
+- If :ref:`storage.enabled=true `, use :ref:`storage.accessMode ` =ReadWriteMany with an RWX-capable storage class (e.g., NFS-backed storage). This sets the ``SKYPILOT_API_SERVER_STORAGE_ENABLED`` environment variable, ensuring managed job logs and file mounts persist across rolling updates.
+- If ``storage.enabled=false``, file mounts and logs will be lost on pod restart. Consider configuring ``jobs.bucket`` in the SkyPilot config to persist file mounts to cloud storage.
+
Default: ``"Recreate"``
.. code-block:: yaml
@@ -1130,6 +1136,95 @@ Default: ``null``
serviceAccount:
enabled: true
+.. _helm-values-auth-externalProxy:
+
+``auth.externalProxy``
+^^^^^^^^^^^^^^^^^^^^^^
+
+Configuration for trusting an external authentication proxy in front of the API server. Use this when your infrastructure has a reverse proxy or load balancer that handles authentication (e.g., AWS ALB with Cognito, Azure Front Door with Azure AD, or a custom ingress controller with authentication middleware).
+
+When enabled, the API server extracts user identity from the HTTP header set by the proxy. The proxy is trusted to have already authenticated the user.
+
+This is mutually exclusive with :ref:`auth.oauth ` and :ref:`ingress.oauth2-proxy `.
+
+Default: see the yaml below.
+
+.. code-block:: yaml
+
+ auth:
+ externalProxy:
+ enabled: false
+ headerName: 'X-Auth-Request-Email'
+ headerFormat: 'plaintext'
+
+.. _helm-values-auth-externalProxy-enabled:
+
+``auth.externalProxy.enabled``
+''''''''''''''''''''''''''''''
+
+Enable external proxy authentication. When enabled, the API server will extract user identity from the header specified by ``headerName``.
+
+Default: ``false``
+
+.. code-block:: yaml
+
+ auth:
+ externalProxy:
+ enabled: true
+
+.. _helm-values-auth-externalProxy-headerName:
+
+``auth.externalProxy.headerName``
+'''''''''''''''''''''''''''''''''
+
+The HTTP header name containing the user identity.
+
+Default: ``'X-Auth-Request-Email'``
+
+.. code-block:: yaml
+
+ auth:
+ externalProxy:
+ headerName: 'X-WEBAUTH-USER'
+
+.. _helm-values-auth-externalProxy-headerFormat:
+
+``auth.externalProxy.headerFormat``
+'''''''''''''''''''''''''''''''''''
+
+The format of the header value. Available options:
+
+- ``plaintext``: The header value is the user identity directly (e.g., ``user@example.com``)
+- ``jwt``: The header value is a JWT token from which the identity should be extracted using ``jwtIdentityClaim``
+
+Use ``jwt`` when integrating with load balancers that pass JWT tokens.
+
+Default: ``'plaintext'``
+
+.. code-block:: yaml
+
+ auth:
+ externalProxy:
+ headerFormat: 'jwt'
+
+.. _helm-values-auth-externalProxy-jwtIdentityClaim:
+
+``auth.externalProxy.jwtIdentityClaim``
+'''''''''''''''''''''''''''''''''''''''
+
+The JWT claim to extract the user identity from when ``headerFormat`` is ``jwt``.
+
+Only used when ``headerFormat`` is ``jwt``.
+
+Default: ``'sub'``
+
+.. code-block:: yaml
+
+ auth:
+ externalProxy:
+ headerFormat: 'jwt'
+ jwtIdentityClaim: 'email'
+
.. _helm-values-storage:
@@ -1143,6 +1238,19 @@ Default: ``null``
Enable persistent storage for the API server, setting this to ``false`` is prone to data loss and should only be used for testing.
+When enabled, SkyPilot creates a PersistentVolumeClaim (PVC) to persist:
+
+- **Managed job logs**: Accessible via ``sky jobs logs `` and ``sky jobs logs --controller ``
+- **File mounts**: Local files uploaded during managed job submission
+
+.. note::
+
+ Setting ``storage.enabled=true`` sets the environment variable ``SKYPILOT_API_SERVER_STORAGE_ENABLED=true`` on the API server pod. This ensures that managed job logs and file mounts persist across API server restarts and rolling updates.
+
+ Transient logs (api_server logs, sky-* cluster logs) are NOT persisted to minimize storage usage.
+
+For RollingUpdate upgrade strategy, see :ref:`apiService.upgradeStrategy ` for storage access mode requirements.
+
Default: ``true``
.. code-block:: yaml
@@ -1169,15 +1277,32 @@ Default: ``""``
``storage.accessMode``
^^^^^^^^^^^^^^^^^^^^^^
-Access mode for the persistent storage volume. Can be set to ``ReadWriteOnce`` or ``ReadWriteMany`` depending on what is supported by the storage class.
+Access mode for the persistent storage volume. Available options:
+
+- ``ReadWriteOnce`` (RWO): The volume can be mounted as read-write by a single node. This is the default and works with most storage classes. Compatible with ``Recreate`` upgrade strategy. **Not compatible with RollingUpdate upgrade strategy** since the PVC cannot be mounted by both old and new pods simultaneously during rolling updates.
+
+- ``ReadWriteMany`` (RWX): The volume can be mounted as read-write by multiple nodes. Compatible with both ``Recreate`` and ``RollingUpdate`` upgrade strategies. Requires an RWX-capable storage class such as:
+
+ - GKE: Filestore-backed storage class
+ - EKS: EFS CSI driver
+ - AKS: Azure Files
+ - On-prem: NFS provisioner
+
+For more details on upgrade strategies, see :ref:`apiService.upgradeStrategy `.
Default: ``ReadWriteOnce``
.. code-block:: yaml
+ # For Recreate upgrade strategy (default), ReadWriteOnce is sufficient
storage:
accessMode: ReadWriteOnce
+ # For RollingUpdate upgrade strategy with persistent storage, use ReadWriteMany
+ storage:
+ accessMode: ReadWriteMany
+ storageClassName:
+
.. _helm-values-storage-size:
``storage.size``
diff --git a/docs/source/reference/auto-stop.rst b/docs/source/reference/auto-stop.rst
index 0c27f05927b..dec39e47209 100644
--- a/docs/source/reference/auto-stop.rst
+++ b/docs/source/reference/auto-stop.rst
@@ -139,3 +139,146 @@ Alternatively, pass the ``--wait-for`` flag to either ``sky autostop`` or ``sky
# Hard time limit: Stop after 10 minutes, regardless of running jobs or SSH sessions.
sky autostop mycluster -i 10 --wait-for none
+
+.. _auto-stop-hooks:
+
+Autostop hooks
+~~~~~~~~~~~~~~
+
+To execute a script before autostopping, specify a hook in the autostop configuration.
+The hook script runs on the remote cluster before the cluster is stopped or torn down.
+This is useful for tasks like committing code, saving checkpoints, or performing cleanup operations.
+
+.. code-block:: yaml
+
+ resources:
+ autostop:
+ idle_minutes: 10
+ hook: |
+ cd my-code-base
+ git add .
+ git commit -m "Commit my code"
+ git push
+ hook_timeout: 300
+
+The hook script runs on the cluster and has access to the cluster's filesystem and environment variables.
+If the hook script fails (non-zero exit code), the autostop process will still continue,
+but a warning will be logged.
+
+**Hook Timeout**
+
+By default, autostop hooks have a **1-hour (3600 seconds) timeout**. If your hook
+takes longer than this, it will be killed and autostop will proceed. To
+customize the timeout in your YAML configuration:
+
+.. code-block:: yaml
+
+ resources:
+ autostop:
+ idle_minutes: 10
+ hook: |
+ # Long-running backup operation
+ tar -czf backup.tar.gz /large/dataset
+ aws s3 cp backup.tar.gz s3://my-bucket/
+ hook_timeout: 7200 # 2 hours in seconds
+
+**Important Notes:**
+
+- If the hook times out, autostop will proceed after logging a warning
+- The minimum timeout is 1 second
+- Hook execution will keep the cluster from terminating while it runs, occupying the resources. Be aware of that when setting ``idle_minutes``
+
+Common use cases for autostop hooks:
+
+.. dropdown:: Committing and pushing code changes
+
+ .. code-block:: yaml
+
+ resources:
+ autostop:
+ idle_minutes: 10
+ hook: |
+ cd my-code-base
+ git add .
+ git commit -m "Auto-commit before shutdown"
+ git push
+
+.. dropdown:: Saving model checkpoints to persistent storage
+
+ .. code-block:: yaml
+
+ resources:
+ autostop:
+ idle_minutes: 10
+ hook: |
+ # Save checkpoints to a mounted volume or cloud storage
+ cp -r /workspace/checkpoints/* /mnt/persistent-storage/checkpoints/
+ # Or upload to S3
+ aws s3 sync /workspace/checkpoints/ s3://my-bucket/checkpoints/
+
+.. dropdown:: Uploading logs or results to cloud storage
+
+ .. code-block:: yaml
+
+ resources:
+ autostop:
+ idle_minutes: 10
+ hook: |
+ # Upload logs to S3
+ aws s3 sync /workspace/logs/ s3://my-bucket/logs/$(date +%Y%m%d)/
+ # Or upload to GCS
+ gcloud storage cp -r /workspace/results/ gs://my-bucket/results/$(date +%Y%m%d)/
+
+.. dropdown:: Syncing W&B runs before shutdown
+
+ .. code-block:: yaml
+
+ resources:
+ autostop:
+ idle_minutes: 10
+ hook: |
+ # Sync W&B runs to the cloud before shutdown
+ # Sync all runs in the wandb directory
+ wandb sync ./wandb
+ # Or sync a specific run
+ # wandb sync ./wandb/run-20250813_124246-n67z9ude
+
+.. dropdown:: Sending notifications about the cluster shutdown
+
+ .. code-block:: yaml
+
+ resources:
+ autostop:
+ idle_minutes: 10
+ hook: |
+ # Send email notification
+ echo "Cluster shutting down after idle period" | \
+ mail -s "Cluster Autostop" user@example.com
+ # Or send Slack notification via webhook
+ curl -X POST -H 'Content-type: application/json' \
+ --data '{"text":"Cluster shutting down after idle period"}' \
+ https://hooks.slack.com/services/YOUR/WEBHOOK/URL
+
+.. dropdown:: Triggering downstream workflows
+
+ .. code-block:: yaml
+
+ resources:
+ autostop:
+ idle_minutes: 10
+ hook: |
+ # Trigger an evaluation pipeline in Airflow
+ curl -X POST https://airflow.example.com/api/v1/dags/model_eval/dag_runs \
+ -H "Content-Type: application/json" \
+ -d '{"conf": {"model_path": "s3://my-bucket/models/v1"}}'
+
+.. dropdown:: Pushing model to Hugging Face Hub
+
+ .. code-block:: yaml
+
+ resources:
+ autostop:
+ idle_minutes: 10
+ hook: |
+ # Upload the trained model to Hugging Face Hub
+ huggingface-cli upload my-org/my-model /workspace/model-output .
diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst
index c182d8427e0..eafde9ced37 100644
--- a/docs/source/reference/config.rst
+++ b/docs/source/reference/config.rst
@@ -63,6 +63,7 @@ Below is the configuration syntax and some example values. See detailed explanat
:ref:`provision `:
:ref:`ssh_timeout `: 10
+ :ref:`install_conda `: false
:ref:`kubernetes `:
:ref:`ports `: loadbalancer
@@ -122,6 +123,9 @@ Below is the configuration syntax and some example values. See detailed explanat
map-migrated: my-value
Owner: user-unique-name
:ref:`vpc_name `: skypilot-vpc
+ :ref:`vpc_names `:
+ - skypilot-vpc-1
+ - skypilot-vpc-2
:ref:`use_internal_ips `: true
:ref:`use_ssm `: true
:ref:`ssh_proxy_command `: ssh -W %h:%p user@host
@@ -615,6 +619,31 @@ determines how long to wait for the connection to be established.
Default: ``10``.
+.. _config-yaml-provision-install-conda:
+
+``provision.install_conda``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Whether to install conda on the remote cluster (optional).
+
+Skypilot clusters come with conda preinstalled for convenience.
+When set to ``false``, SkyPilot will not install conda on the cluster.
+
+Default: ``true``.
+
+Example:
+
+.. code-block:: yaml
+
+ provision:
+ install_conda: false
+
+.. note::
+
+ Default SkyPilot images often come with conda preinstalled.
+ To fully avoid installing conda, use a custom Docker image that does not have conda preinstalled
+ along with ``install_conda: false``.
+
.. _config-yaml-aws:
``aws``
@@ -670,6 +699,24 @@ Regions without a VPC with this name will not be used to launch nodes.
Default: ``null`` (use the default VPC in each region).
+Deprecated: use ``aws.vpc_names`` instead.
+
+.. _config-yaml-aws-vpc-names:
+
+``aws.vpc_names``
+~~~~~~~~~~~~~~~~~
+
+VPCs to use in each region (optional).
+
+If this is set, SkyPilot will attempt each VPC for failover in regions
+that contain the attempted VPCs (provisioner automatically looks for such
+regions). Regions without any matching VPCs will not be used to launch nodes.
+
+It is possible to set either a ``string`` (one VPC), or a ``list`` (multiple
+target VPCs).
+
+Default: ``null`` (use the default VPC in each region).
+
.. _config-yaml-aws-use-internal-ips:
``aws.use_internal_ips``
@@ -1364,6 +1411,7 @@ Example:
myannotation: myvalue
provision_timeout: 10
autoscaler: gke
+ set_pod_resource_limits: true # or a multiplier like 1.5
pod_config:
metadata:
labels:
@@ -1471,6 +1519,47 @@ Example:
post_provision_runcmd:
- echo "hello world!"
+.. _config-yaml-kubernetes-set-pod-resource-limits:
+
+``kubernetes.set_pod_resource_limits``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Set pod CPU/memory limits relative to requests (optional).
+
+This is useful for Kubernetes clusters that require resource limits to be set
+(e.g., for LimitRange enforcement, resource quotas, or cluster policies).
+
+Can be one of:
+
+- ``false`` (default): Do not set CPU/memory limits (only requests are set).
+- ``true``: Set limits equal to requests (multiplier of 1).
+- A positive number: Set limits to requests multiplied by this value (e.g., ``1.5`` for 50% headroom).
+
+Default: ``false``.
+
+Example:
+
+.. code-block:: yaml
+
+ kubernetes:
+ # Set limits equal to requests
+ set_pod_resource_limits: true
+
+.. code-block:: yaml
+
+ kubernetes:
+ # Set limits to 1.5x requests (50% headroom)
+ set_pod_resource_limits: 1.5
+
+This can also be configured per-context using ``context_configs``:
+
+.. code-block:: yaml
+
+ kubernetes:
+ context_configs:
+ prod-cluster:
+ set_pod_resource_limits: 2.0
+
.. _config-yaml-kubernetes-context-configs:
``kubernetes.context_configs``
diff --git a/docs/source/reference/kubernetes/kubernetes-getting-started.rst b/docs/source/reference/kubernetes/kubernetes-getting-started.rst
index 6083119d6a3..f64b79c9acd 100644
--- a/docs/source/reference/kubernetes/kubernetes-getting-started.rst
+++ b/docs/source/reference/kubernetes/kubernetes-getting-started.rst
@@ -142,32 +142,26 @@ Once your cluster administrator has :ref:`setup a Kubernetes cluster `, you can view resources from all users with :code:`sky status -u`:
.. code-block:: console
- $ sky status --k8s
- Kubernetes cluster state (context: mycluster)
- SkyPilot clusters
- USER NAME LAUNCHED INFRA RESOURCES STATUS
- alice infer-svc-1 23 hrs ago Kubernetes 1x(gpus=L4:1, ...) UP
- alice sky-jobs-controller-80b50983 2 days ago Kubernetes 1x(cpus=4, mem=4, ...) UP
- alice sky-serve-controller-80b50983 23 hrs ago Kubernetes 1x(cpus=4, mem=4, ...) UP
- bob dev 1 day ago Kubernetes 1x(gpus=H100:1, ...) UP
- bob multinode-dev 1 day ago Kubernetes 2x(cpus=2, mem=2, ...) UP
- bob sky-jobs-controller-2ea485ea 2 days ago Kubernetes 1x(cpus=4, mem=4, ...) UP
-
- Managed jobs
- In progress tasks: 1 STARTING
- USER ID TASK NAME REQUESTED SUBMITTED TOT. DURATION JOB DURATION #RECOVERIES STATUS
- alice 1 - eval 1x[CPU:1+] 2 days ago 49s 8s 0 SUCCEEDED
- bob 4 - pretrain 1x[H100:4] 1 day ago 1h 1m 11s 1h 14s 0 SUCCEEDED
- bob 3 - bigjob 1x[CPU:16] 1 day ago 1d 21h 11m 4s - 0 STARTING
- bob 2 - failjob 1x[CPU:1+] 1 day ago 54s 9s 0 FAILED
- bob 1 - shortjob 1x[CPU:1+] 2 days ago 1h 1m 19s 1h 16s 0 SUCCEEDED
+ $ sky status -u
+ Clusters
+ NAME USER WORKSPACE INFRA RESOURCES STATUS AUTOSTOP LAUNCHED
+ mycluster alice@example.com prod Kubernetes (k8s-context1) 1x(cpus=2, mem=4, ...) UP - 10 mins ago
+ dev alice@example.com ml-team Kubernetes (k8s-context2) 1x(gpus=H100:1, cpus=4, mem=16, ...) UP 10m 1 hr ago
+ training bob@example.com ml-team Kubernetes (k8s-context1) 1x(gpus=L4:4, cpus=8, mem=32, ...) UP - 2 hrs ago
You can also inspect the real-time GPU usage on the cluster with :code:`sky show-gpus --infra k8s`.
@@ -298,6 +292,23 @@ To use images from private repositories (e.g., Private DockerHub, Amazon ECR, Go
--docker-server=nvcr.io
+
+
+.. _kubernetes-using-volumes:
+
+Mounting NFS and other volumes
+------------------------------
+
+SkyPilot supports mounting various types of volumes to your pods on Kubernetes:
+
+* :ref:`Persistent volumes `: Independently managed volumes with lifecycle separate from clusters, ideal for long-term data storage and sharing datasets across clusters. These are backed by Kubernetes PVCs on block storage (e.g., AWS EBS, GCP Persistent Disk) or distributed file systems (e.g., JuiceFS, Nebius shared file system, AWS EFS, GCP Filestore).
+
+* :ref:`Ephemeral volumes `: Automatically created and deleted with your cluster, suitable for temporary storage and caches that are cluster-specific. Also backed by Kubernetes PVCs.
+
+* :ref:`Other volume types `: Mount hostPath, NFS, and other Kubernetes volume types by overriding SkyPilot's ``pod_config``.
+
+For detailed information on configuring and using volumes, see :ref:`Volumes on Kubernetes `.
+
Opening ports
-------------
@@ -382,21 +393,6 @@ For example, to set custom environment variables and use GPUDirect RDMA, you can
pod_config:
...
-.. _kubernetes-using-volumes:
-
-Mounting volumes
-------------------------------
-
-SkyPilot supports mounting various types of volumes to your pods on Kubernetes:
-
-* **Persistent volumes**: Independently managed volumes with lifecycle separate from clusters, ideal for long-term data storage and sharing datasets across clusters. These can be backed by block storage (e.g., AWS EBS, GCP Persistent Disk) or distributed file systems (e.g., JuiceFS, Nebius shared file system, AWS EFS, GCP Filestore).
-
-* **Ephemeral volumes**: Automatically created and deleted with your cluster, suitable for temporary storage and caches that are cluster-specific.
-
-* **Other volume types**: You can also mount hostPath, NFS, etc. as needed.
-
-For detailed information on configuring and using volumes, see :ref:`volumes-on-kubernetes`.
-
FAQs
----
diff --git a/docs/source/reference/kubernetes/kubernetes-setup.rst b/docs/source/reference/kubernetes/kubernetes-setup.rst
index 9ddd52bae28..d8b5cbad687 100644
--- a/docs/source/reference/kubernetes/kubernetes-setup.rst
+++ b/docs/source/reference/kubernetes/kubernetes-setup.rst
@@ -257,18 +257,18 @@ The following setup steps are optional and can be performed based on your specif
.. _kubernetes-setup-volumes:
-Set up volumes
-^^^^^^^^^^^^^^^
+Set up NFS and other volumes
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^
SkyPilot supports mounting various types of volumes to your pods on Kubernetes:
-* **Persistent volumes**: Independently managed volumes with lifecycle separate from clusters, ideal for long-term data storage and sharing datasets across clusters. These can be backed by block storage (e.g., AWS EBS, GCP Persistent Disk) or distributed file systems (e.g., JuiceFS, Nebius shared file system, AWS EFS, GCP Filestore).
+* :ref:`Persistent volumes `: Independently managed volumes with lifecycle separate from clusters, ideal for long-term data storage and sharing datasets across clusters. These are backed by Kubernetes PVCs on block storage (e.g., AWS EBS, GCP Persistent Disk) or distributed file systems (e.g., JuiceFS, Nebius shared file system, AWS EFS, GCP Filestore).
-* **Ephemeral volumes**: Automatically created and deleted with your cluster, suitable for temporary storage and caches that are cluster-specific.
+* :ref:`Ephemeral volumes `: Automatically created and deleted with your cluster, suitable for temporary storage and caches that are cluster-specific. Also backed by Kubernetes PVCs.
-* **Other volume types**: You can also mount hostPath, NFS, etc. as needed.
+* :ref:`Other volume types `: Mount hostPath, NFS, and other Kubernetes volume types by overriding SkyPilot's ``pod_config``.
-For detailed information on configuring and using volumes, see :ref:`volumes-on-kubernetes`.
+For detailed information on configuring and using volumes, see :ref:`Volumes on Kubernetes `.
.. _kubernetes-setup-priority:
@@ -394,33 +394,27 @@ Below, we provide tips on how to monitor SkyPilot resources on your Kubernetes c
List SkyPilot resources across all users
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-We provide a convenience command, :code:`sky status --k8s`, to view the status of all SkyPilot resources in the cluster.
+When using the :ref:`SkyPilot API server `, you can use the :ref:`SkyPilot dashboard ` to view SkyPilot resources across all users:
-Unlike :code:`sky status` which lists only the SkyPilot resources launched by the current user,
-:code:`sky status --k8s` lists all SkyPilot resources in the cluster across all users.
-.. code-block:: console
+.. image:: ../../images/dashboard-clusters.png
+ :width: 100%
+ :align: center
+ :alt: SkyPilot Dashboard
- $ sky status --k8s
- Kubernetes cluster state (context: mycluster)
- SkyPilot clusters
- USER NAME LAUNCHED RESOURCES STATUS
- alice infer-svc-1 23 hrs ago 1x Kubernetes(cpus=1, mem=1, {'L4': 1}) UP
- alice sky-jobs-controller-80b50983 2 days ago 1x Kubernetes(cpus=4, mem=4) UP
- alice sky-serve-controller-80b50983 23 hrs ago 1x Kubernetes(cpus=4, mem=4) UP
- bob dev 1 day ago 1x Kubernetes(cpus=2, mem=8, {'H100': 1}) UP
- bob multinode-dev 1 day ago 2x Kubernetes(cpus=2, mem=2) UP
- bob sky-jobs-controller-2ea485ea 2 days ago 1x Kubernetes(cpus=4, mem=4) UP
-
- Managed jobs
- In progress tasks: 1 STARTING
- USER ID TASK NAME RESOURCES SUBMITTED TOT. DURATION JOB DURATION #RECOVERIES STATUS
- alice 1 - eval 1x[CPU:1+] 2 days ago 49s 8s 0 SUCCEEDED
- bob 4 - pretrain 1x[H100:4] 1 day ago 1h 1m 11s 1h 14s 0 SUCCEEDED
- bob 3 - bigjob 1x[CPU:16] 1 day ago 1d 21h 11m 4s - 0 STARTING
- bob 2 - failjob 1x[CPU:1+] 1 day ago 54s 9s 0 FAILED
- bob 1 - shortjob 1x[CPU:1+] 2 days ago 1h 1m 19s 1h 16s 0 SUCCEEDED
+|
+
+Or run :code:`sky status -u`:
+
+.. code-block:: console
+ $ sky status -u
+ Clusters
+ NAME USER WORKSPACE INFRA RESOURCES STATUS AUTOSTOP LAUNCHED
+ training-multinode alice@skypilot.co ml-team Kubernetes (nebius) 2x(gpus=H100:8, cpus=200, mem=800, ...) RUNNING 60m 5d ago
+ dev-alice alice@skypilot.co research-private Kubernetes (coreweave) 1x(gpus=H200:1, cpus=8, mem=32, ...) RUNNING - 6d ago
+ inference mike@skypilot.co default AWS (us-west-2) 1x(gpus=L4:1, g6.2xlarge, ...) RUNNING 30m 4d ago
+ dev-bob bob@skypilot.co default GCP (us-west1) 1x(cpus=4, mem=15, n1-standard-4, ...) STOPPED - 6d ago
.. _kubernetes-observability-dashboard:
@@ -434,6 +428,7 @@ SkyPilot resources on your cluster.
:align: center
:alt: Kubernetes dashboard
+|
As a demo, we provide a sample Kubernetes dashboard deployment manifest that you can deploy with:
diff --git a/docs/source/reference/kubernetes/kubernetes-troubleshooting.rst b/docs/source/reference/kubernetes/kubernetes-troubleshooting.rst
index d7e2aeb62c5..54b5761c194 100644
--- a/docs/source/reference/kubernetes/kubernetes-troubleshooting.rst
+++ b/docs/source/reference/kubernetes/kubernetes-troubleshooting.rst
@@ -5,7 +5,7 @@ Kubernetes Troubleshooting
If you're unable to run SkyPilot tasks on your Kubernetes cluster, this guide will help you debug common issues.
-If this guide does not help resolve your issue, please reach out to us on `Slack `_ or `GitHub `_.
+If this guide does not help resolve your issue, please reach out to us on `Slack `_ or `GitHub `_.
.. _kubernetes-troubleshooting-basic:
diff --git a/docs/source/reference/slurm/slurm-getting-started.rst b/docs/source/reference/slurm/slurm-getting-started.rst
index 2e72503301a..b616bfdbbac 100644
--- a/docs/source/reference/slurm/slurm-getting-started.rst
+++ b/docs/source/reference/slurm/slurm-getting-started.rst
@@ -74,7 +74,9 @@ Create the configuration file:
.. note::
- ``HostName``, ``User``, and ``IdentityFile`` are required fields.
+ ``HostName`` and ``User`` are required fields. ``IdentityFile`` is optional;
+ if not specified, SSH will use keys from ssh-agent or default key locations
+ (e.g., ``~/.ssh/id_rsa``, ``~/.ssh/id_ed25519``).
Verify your SSH connection works by running:
@@ -178,6 +180,63 @@ SkyPilot will translate this to the appropriate ``--gres=gpu:`` directive for Sl
Common names include ``H100``, ``H200``, ``L4`` etc.
+Viewing GPU availability
+------------------------
+
+SkyPilot provides a unified dashboard to monitor GPU availability and utilization across **all** your Slurm clusters.
+
+To open the dashboard:
+
+.. code-block:: bash
+
+ $ sky dashboard
+
+Navigate to the **Infra** tab to see the real-time GPU availability across all your Slurm clusters:
+
+.. image:: /images/slurm-infra-page.png
+ :alt: SkyPilot Dashboard showing Slurm GPU availability overview
+ :width: 100%
+
+|
+
+Click on a cluster name to see detailed GPU availability per node:
+
+.. image:: /images/slurm-cluster-details-page.png
+ :alt: SkyPilot Dashboard showing Slurm cluster GPU details
+ :width: 100%
+
+|
+
+You can also view GPU availability from the CLI:
+
+.. code-block:: console
+
+ $ sky show-gpus --infra slurm
+ Slurm GPUs
+ GPU UTILIZATION
+ L40S 3 of 8 free
+ GH200 1 of 2 free
+ H100 8 of 8 free
+
+ Slurm Cluster: mycluster1
+ GPU REQUESTABLE_QTY_PER_NODE UTILIZATION
+ L40S 1, 2, 4 3 of 8 free
+
+ Slurm Cluster: mycluster2
+ GPU REQUESTABLE_QTY_PER_NODE UTILIZATION
+ GH200 1 1 of 2 free
+
+ Slurm Cluster: mycluster3
+ GPU REQUESTABLE_QTY_PER_NODE UTILIZATION
+ H100 1, 2, 4, 8 8 of 8 free
+
+ Slurm per node GPU availability
+ CLUSTER NODE PARTITION STATE GPU UTILIZATION
+ mycluster1 ip-10-3-132-97 dev*,gpus mix L40S 1 of 4 free
+ mycluster1 ip-10-3-168-59 dev*,gpus mix L40S 2 of 4 free
+ ...
+
+
Shared filesystem (NFS)
-----------------------
diff --git a/docs/source/reference/volumes.rst b/docs/source/reference/volumes.rst
index e9c50b6b811..6ddb9257ca7 100644
--- a/docs/source/reference/volumes.rst
+++ b/docs/source/reference/volumes.rst
@@ -11,70 +11,13 @@ Benefits of using volumes:
* **Data persistence**: Volumes can persist data independently of task life cycles, making it easy to share data between different tasks (e.g., datasets, caches) or preserve results.
* **Size control**: You can set volume size limits to manage costs and limit storage usage.
-SkyPilot supports creating and managing volumes directly through the ``sky`` CLI and the web dashboard.
+Volumes are currently supported on Kubernetes clusters and RunPod.
-Supported volume types:
-- Kubernetes: `Persistent Volume Claims (PVCs) `_
-
- - Tested storage backends: AWS EBS, GCP Persistent Disk, Nebius network SSD, JuiceFS, Nebius shared file system, GCP Filestore
-
-- RunPod: `Network Volumes `_
-
-With SSH node pools, you can mount host volumes or directories into SkyPilot clusters and managed jobs. See :ref:`SSH node pools ` for details.
-
-.. _volumes-on-kubernetes:
-
-Volumes on Kubernetes
----------------------
-
-In Kubernetes clusters, PVCs (Persistent Volume Claims) request and bind to PV (Persistent Volume) resources. These persistent volumes can be backed by various storage backends, including **block storage** solutions (AWS EBS, GCP Persistent Disk) and **distributed file systems** (JuiceFS, Nebius shared file system, AWS EFS, GCP Filestore), etc.
-
-SkyPilot supports two types of volumes on Kubernetes:
-
-1. **Persistent volumes**: Managed independently through CLI commands with lifecycle separate from clusters
-2. **Ephemeral volumes**: Bound to cluster lifecycle, automatically created and deleted with the cluster
-
-.. list-table::
- :widths: 30 35 35
- :header-rows: 1
-
- * - Feature
- - Persistent Volumes
- - Ephemeral Volumes
- * - Lifecycle
- - Independent (manually managed)
- - Bound to cluster
- * - Creation
- - ``sky volumes apply``
- - Automatic (in task YAML)
- * - Deletion
- - ``sky volumes delete``
- - Automatic (with cluster)
- * - Sharing across clusters
- - Yes
- - No (cluster-specific)
- * - Use case
- - Long-term data, shared datasets
- - Temporary storage, caches
-
-In addition to the above, you can also mount PVCs, NFS or hostPath with Kubernetes configs. See :ref:`advanced-mount-pvc-with-kubernetes-configs` and :ref:`advanced-mount-nfs-hostpath-with-kubernetes-configs` for details.
-
-Persistent volumes
-~~~~~~~~~~~~~~~~~~
-
-Persistent volumes are created and managed independently using the following commands:
-
-- ``sky volumes apply``: Create a new volume
-- ``sky volumes ls``: List all volumes
-- ``sky volumes delete``: Delete a volume
-
-.. note::
-
- Volumes are shared across users on a SkyPilot API server. A user can mount volumes created by other users. This is useful for sharing caches and data across users.
+.. _volumes-quickstart:
Quickstart
-^^^^^^^^^^
+----------
1. Prepare a volume YAML file:
@@ -83,17 +26,12 @@ Quickstart
# volume.yaml
name: new-pvc
type: k8s-pvc
- infra: kubernetes # or k8s or k8s/context
+ infra: k8s # or `k8s/context` or `runpod`
size: 10Gi
- # If the PVC already exists, set `use_existing` to true and
- # set the `name` to the existing PVC name
+
+ # Optional: To use an existing PVC on k8s instead of creating a new one, set to `true` and set `name` to the existing PVC name.
# use_existing: true
- labels:
- key: value
- config:
- namespace: default # optional
- storage_class_name: csi-mounted-fs-path-sc # optional
- access_mode: ReadWriteMany # optional
+
2. Create the volume with ``sky volumes apply volume.yaml``:
@@ -114,34 +52,23 @@ Quickstart
run: |
echo "Hello, World!" > /mnt/data/hello.txt
-.. note::
+.. tip::
- - For multi-node clusters, volumes are mounted to all nodes. You must configure ``config.access_mode`` to ``ReadWriteMany`` and use a ``storage_class_name`` that supports the ``ReadWriteMany`` access mode. Otherwise, SkyPilot will fail to launch the cluster.
- - If you want to mount a volume to all the cluster or jobs by default, you can use the admin policy to inject the volume path into the task YAML. See :ref:`add-volumes-policy` for details.
+ For temporary or cache data that should only last for the lifetime of a SkyPilot cluster, use :ref:`ephemeral volumes `.
.. _volumes-on-kubernetes-manage:
Managing volumes
-^^^^^^^^^^^^^^^^
+----------------
List all volumes with ``sky volumes ls``:
.. code-block:: console
- $ sky volumes ls
- NAME TYPE INFRA SIZE USER WORKSPACE AGE STATUS LAST_USE USED_BY
- new-pvc k8s-pvc Kubernetes/nebius-mk8s-vol 1Gi alice default 8m IN_USE
-
+ $ sky volumes ls -v
+ NAME TYPE INFRA SIZE USER WORKSPACE AGE STATUS LAST_USE USED_BY NAME_ON_CLOUD STORAGE_CLASS ACCESS_MODE
+ new-pvc k8s-pvc Kubernetes/nebius-mk8s-vol 1Gi alice default 8m IN_USE 2025-06-24 10:18:32 training new-pvc-73ec42f2-5c6c4e csi-mounted-fs-path-sc ReadWriteMany
-.. tip::
-
- Use ``-v`` to view detailed information about a volume.
-
- .. code-block:: console
-
- $ sky volumes ls -v
- NAME TYPE INFRA SIZE USER WORKSPACE AGE STATUS LAST_USE USED_BY NAME_ON_CLOUD STORAGE_CLASS ACCESS_MODE
- new-pvc k8s-pvc Kubernetes/nebius-mk8s-vol 1Gi alice default 8m IN_USE 2025-06-24 10:18:32 training new-pvc-73ec42f2-5c6c4e csi-mounted-fs-path-sc ReadWriteMany
Delete a volume with ``sky volumes delete``:
@@ -161,183 +88,100 @@ You can also check the volumes in the SkyPilot dashboard.
:align: center
:width: 80%
-Filesystem volume examples
-^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-This section demonstrates how to configure and use distributed filesystems as SkyPilot volumes. We'll cover options like `JuiceFS `_ (a cloud-native distributed filesystem) and `Nebius shared file system `_ (a high-performance shared storage solution).
-
-
-.. tab-set::
-
- .. tab-item:: JuiceFS
- :sync: juicefs-tab
-
- To use JuiceFS as a SkyPilot volume:
-
- 1. **Install the JuiceFS CSI driver** on your Kubernetes cluster. Follow the official `installation guide `_ for detailed instructions.
-
- 2. **Verify the driver installation** - Confirm that the JuiceFS CSI Driver pods are running:
-
- .. code-block:: console
-
- $ kubectl -n kube-system get pod -l app.kubernetes.io/name=juicefs-csi-driver
- NAME READY STATUS RESTARTS AGE
- juicefs-csi-controller-0 2/2 Running 0 10m
- juicefs-csi-node-8rd96 3/3 Running 0 10m
-
- 3. **Set up JuiceFS storage and create a SkyPilot volume** - You can use either dynamic provisioning (with a StorageClass) or static provisioning (with a pre-created PV):
- .. tab-set::
-
- .. tab-item:: Dynamic Provisioning (StorageClass)
- :sync: dynamic-tab
-
- Create a StorageClass for dynamic provisioning. Refer to the `JuiceFS StorageClass guide `_ for details.
-
- .. code-block:: console
-
- $ kubectl get storageclass juicefs-sc
- NAME PROVISIONER RECLAIMPOLICY VOLUMEBINDINGMODE ALLOWVOLUMEEXPANSION AGE
- juicefs-sc csi.juicefs.com Retain Immediate false 10m
-
- Create a SkyPilot volume YAML referencing the StorageClass:
-
- .. code-block:: yaml
-
- # juicefs-volume.yaml
- name: juicefs-volume
- type: k8s-pvc
- infra: k8s
- size: 100Gi
- config:
- storage_class_name: juicefs-sc
- access_mode: ReadWriteMany
-
- .. code-block:: console
-
- $ sky volumes apply juicefs-volume.yaml
-
- .. tab-item:: Static Provisioning (PV)
- :sync: static-tab
-
- Create a PersistentVolume and PVC manually. Refer to the `JuiceFS static provisioning guide `_ for details.
-
- .. code-block:: console
-
- $ kubectl get pv juicefs-pv
- NAME CAPACITY ACCESS MODES RECLAIM POLICY STATUS CLAIM STORAGECLASS AGE
- juicefs-pv 100Gi RWX Retain Bound default/juicefs-pvc 10m
-
- $ kubectl get pvc juicefs-pvc
- NAME STATUS VOLUME CAPACITY ACCESS MODES STORAGECLASS AGE
- juicefs-pvc Bound juicefs-pv 100Gi RWX 10m
-
- Create a SkyPilot volume YAML with ``use_existing: true`` to reference the existing PVC:
-
- .. code-block:: yaml
-
- # juicefs-volume.yaml
- name: juicefs-volume
- type: k8s-pvc
- infra: k8s
- use_existing: true
- config:
- access_mode: ReadWriteMany
-
- .. code-block:: console
-
- $ sky volumes apply juicefs-volume.yaml
-
- 4. **Mount the volume to SkyPilot task** in your SkyPilot YAML:
-
- .. code-block:: yaml
-
- # task.yaml
- num_nodes: 2
-
- volumes:
- # Mount the JuiceFS volume to /mnt/data across all nodes
- /mnt/data: juicefs-volume
+.. _volumes-on-kubernetes:
- run: |
- # Verify the volume is mounted and accessible
- df -h /mnt/data
- ls -la /mnt/data
+Volumes on Kubernetes
+---------------------
- .. code-block:: console
+In Kubernetes clusters, SkyPilot Volumes map to `PVCs (Persistent Volume Claims) `_.
- # Launch the cluster with the JuiceFS volume
- $ sky launch -c juicefs-cluster task.yaml
+PVCs can be backed by various storage backends, including **block storage** solutions (AWS EBS, GCP Persistent Disk) and **distributed file systems** (JuiceFS, Nebius shared file system, AWS EFS, GCP Filestore).
- .. tab-item:: Nebius shared file system
- :sync: nebius-tab
+SkyPilot Volumes can be of two types:
- To use Nebius shared file system as a SkyPilot volume:
+1. :ref:`Persistent volumes `: Managed through ``sky volumes`` CLI commands with lifecycle separate from SkyPilot clusters.
+2. :ref:`Ephemeral volumes `: Bound to SkyPilot cluster lifecycle, automatically created and deleted when ``sky launch`` or ``sky down`` is run.
- 1. **Set up the Nebius filesystem infrastructure** by following the official documentation:
+.. list-table::
+ :widths: 35 35 30
+ :header-rows: 1
- - `Create a shared filesystem `_
- - `Create a node group and mount the filesystem `_
- - `Install the CSI driver `_
+ * - Feature
+ - :ref:`Persistent volumes `
+ - :ref:`Ephemeral volumes `
+ * - Lifecycle
+ - Independent (managed via ``sky volumes``)
+ - Bound to SkyPilot cluster
+ * - Creation
+ - ``sky volumes apply``
+ - Automatic (in task YAML)
+ * - Deletion
+ - ``sky volumes delete``
+ - Automatic (with cluster)
+ * - Sharing across SkyPilot clusters
+ - Yes
+ - No (cluster-specific)
+ * - Use case
+ - Long-term data, code, shared datasets
+ - Temporary storage, caches
- 2. **Verify the storage class** - Confirm that the ``csi-mounted-fs-path-sc`` storage class has been created:
+.. tip::
- .. code-block:: console
+ For advanced use cases, you can also mount PVCs, NFS, or hostPath volumes by overriding SkyPilot's pod configs.
+ See :ref:`advanced-mount-pvc-with-kubernetes-configs` for details.
- $ kubectl get storageclass
- NAME PROVISIONER RECLAIMPOLICY VOLUMEBINDINGMODE ALLOWVOLUMEEXPANSION AGE
- csi-mounted-fs-path-sc mounted-fs-path.csi.nebius.ai Delete WaitForFirstConsumer false 10m
+.. _persistent-volumes:
- 3. **Create a SkyPilot volume for Nebius file system** with a volume YAML:
+Persistent volumes
+~~~~~~~~~~~~~~~~~~
- .. code-block:: yaml
+Persistent volumes are created and managed independently using the ``sky volumes`` CLI commands described in the :ref:`Quickstart ` and :ref:`Managing volumes ` sections above.
- # nebius-volume.yaml
- name: nebius-pvc
- type: k8s-pvc
- infra: k8s
- size: 100Gi
- config:
- storage_class_name: csi-mounted-fs-path-sc
- access_mode: ReadWriteMany
+.. note::
- .. code-block:: console
+ Persistent volumes are shared across users on a SkyPilot API server. A user can mount volumes created by other users. This is useful for sharing caches and data across users.
- $ sky volumes apply nebius-volume.yaml
+**Volume YAML configuration options:**
- 4. **Mount the volume to SkyPilot task** in your SkyPilot YAML:
+.. code-block:: yaml
- .. code-block:: yaml
+ # volume.yaml
+ name: my-volume
+ type: k8s-pvc
+ infra: k8s # or k8s/
+ size: 10Gi
- # task.yaml
- num_nodes: 2
+ # Optional: To use an existing PVC instead of creating a new one, set to `true` and set `name` to the existing PVC name.
+ use_existing: true
- volumes:
- # Mount the Nebius shared filesystem to /mnt/data across all nodes
- /mnt/data: nebius-pvc
+ # Optional: add labels to the PVC
+ labels:
+ key: value
- run: |
- # Verify the volume is mounted and accessible
- df -h /mnt/data
- ls -la /mnt/data
+ # Optional: additional configuration
+ config:
+ namespace: default
+ storage_class_name: csi-mounted-fs-path-sc
+ access_mode: ReadWriteMany # Required for multi-node clusters
- .. code-block:: console
+.. note::
- # Launch the cluster with the Nebius volume
- $ sky launch -c nebius-cluster task.yaml
+ - For multi-node clusters, volumes are mounted to all nodes. You must set ``config.access_mode`` to ``ReadWriteMany`` and use a ``storage_class_name`` that supports this access mode. Otherwise, SkyPilot will fail to launch the cluster.
+ - To mount a volume to all clusters or jobs by default, use the admin policy to inject the volume path into the task YAML. See :ref:`add-volumes-policy` for details.
+.. _ephemeral-volumes:
Ephemeral volumes
~~~~~~~~~~~~~~~~~
-Unlike persistent volumes that are managed independently, ephemeral volumes are automatically created when a cluster is launched and deleted when the cluster is terminated. This makes them ideal for temporary storage needs such as caches, intermediate results, or any data that should only exist for the duration of a cluster's lifetime.
-
-**Key characteristics:**
+Unlike persistent volumes, which must be managed independently via ``sky volumes`` CLI commands, ephemeral volumes are automatically created when a cluster is launched via ``sky launch`` and deleted when the cluster is terminated via ``sky down`` or autodowned.
- **Automatic lifecycle management**: No need to manually create or delete volumes
- **Cluster-bound**: Created with the cluster and deleted when the cluster is terminated
- **Simplified usage**: Defined directly in the task YAML with the cluster configuration
-- **Currently Kubernetes-only**: Only supported on Kubernetes clusters
+- **Ideal for temporary storage**: caches, intermediate results, or any data that should only exist for the duration of a cluster's lifetime
+
To use an ephemeral volume, simply specify the ``size`` field in the volumes section of your task YAML:
@@ -385,93 +229,18 @@ When you terminate the cluster, the ephemeral volumes are automatically deleted:
# Cluster and its ephemeral volumes are deleted
.. _advanced-mount-pvc-with-kubernetes-configs:
-
-Advanced: Mount PVCs with Kubernetes configs
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-Using SkyPilot volumes allows you to mount different volumes to different tasks. SkyPilot also offers an advanced way to mount a Kubernetes PVC with the detailed Kubernetes configs. This allows you to:
-
-1. Mount a PVC with additional configurations that is not supported by SkyPilot volumes.
-
-2. Specify a global (per Kubernetes context) PVC to be mounted on all SkyPilot clusters.
-
-Mount a PVC with additional configuration
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-To mount a PVC with additional configuration, you can set the ``kubernetes.pod_config`` in the :ref:`advanced config `:
-
-.. code-block:: yaml
-
- kubernetes:
- pod_config:
- spec:
- securityContext:
- fsGroup: 1000
- fsGroupChangePolicy: OnRootMismatch
- containers:
- - volumeMounts:
- - mountPath: /mnt/data
- name: my-pvc
- volumes:
- - name: my-pvc
- persistentVolumeClaim:
- claimName: my-pvc
-
-.. note::
-
- The ``kubernetes.pod_config`` in the advanced config applies to every cluster launched on Kubernetes. To mount different PVCs per cluster, set the ``kubernetes.pod_config`` in the task YAML file as described in the :ref:`per-task configuration `. Refer to Kubernetes `volume mounts `_ and `volumes `_ documentation for more details.
-
-Mount a PVC to all clusters in each context
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-If you want to mount different PVCs for different Kubernetes contexts, you can set the ``allowed_contexts`` and ``context_configs`` in the :ref:`advanced config `.
-
-.. code-block:: yaml
-
- kubernetes:
- allowed_contexts:
- - context1
- - context2
- context_configs:
- context1:
- pod_config:
- spec:
- securityContext:
- fsGroup: 1000
- fsGroupChangePolicy: OnRootMismatch
- containers:
- - volumeMounts:
- - mountPath: /mnt/data
- name: my-pvc
- volumes:
- - name: my-pvc
- persistentVolumeClaim:
- claimName: pvc1
- context2:
- pod_config:
- spec:
- securityContext:
- fsGroup: 1000
- fsGroupChangePolicy: OnRootMismatch
- containers:
- - volumeMounts:
- - mountPath: /mnt/data
- name: my-pvc
- volumes:
- - name: my-pvc
- persistentVolumeClaim:
- claimName: pvc2
-
.. _advanced-mount-nfs-hostpath-with-kubernetes-configs:
-Advanced: Mount NFS or hostPath with Kubernetes configs
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Advanced: Use Kubernetes configs to mount PVCs, NFS, or hostPath
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-`Kubernetes volumes `_ can be attached to your SkyPilot pods using the :ref:`pod_config ` field. This is useful for accessing shared storage such as NFS or local high-performance storage like NVMe drives.
+In addition to using SkyPilot volumes, you can also mount `Kubernetes volumes `_ (PVCs, NFS, hostPath) by overriding SkyPilot's :ref:`pod_config `. This is useful for:
-Volume mounting can be done directly in the task YAML on a per-task basis, or globally for all tasks in `SkyPilot config `_.
+1. Mounting a PVC with additional configurations not supported by SkyPilot volumes (e.g., ``fsGroup``, ``fsGroupChangePolicy``).
+2. Specifying a global (per Kubernetes context) volume to be mounted on all SkyPilot clusters.
+3. Accessing shared storage such as NFS or local high-performance storage like NVMe drives.
-Examples:
+Volume mounting can be done directly in the task YAML on a per-task basis, or globally for all tasks in `SkyPilot config `_.
.. tab-set::
@@ -612,8 +381,102 @@ Examples:
path: /path/on/host/nvme
type: Directory
+ .. tab-item:: PVC
+ :name: kubernetes-volumes-pvc
+
+ Mount a PVC with additional configurations like ``fsGroup`` and ``fsGroupChangePolicy``.
+
+ **Per-task configuration:**
+
+ .. code-block:: yaml
+
+ # task.yaml
+ run: |
+ echo "Hello, world!" > /mnt/data/hello.txt
+ ls -la /mnt/data
+ config:
+ kubernetes:
+ pod_config:
+ spec:
+ securityContext:
+ fsGroup: 1000
+ fsGroupChangePolicy: OnRootMismatch
+ containers:
+ - volumeMounts:
+ - mountPath: /mnt/data
+ name: my-pvc
+ volumes:
+ - name: my-pvc
+ persistentVolumeClaim:
+ claimName: my-pvc
+
+ **Global configuration:**
+
+ .. code-block:: yaml
+
+ # SkyPilot config
+ kubernetes:
+ pod_config:
+ spec:
+ securityContext:
+ fsGroup: 1000
+ fsGroupChangePolicy: OnRootMismatch
+ containers:
+ - volumeMounts:
+ - mountPath: /mnt/data
+ name: my-pvc
+ volumes:
+ - name: my-pvc
+ persistentVolumeClaim:
+ claimName: my-pvc
+
+ **Mount different PVCs per context:**
+
+ If you want to mount different PVCs for different Kubernetes contexts, you can set the ``allowed_contexts`` and ``context_configs`` in the :ref:`advanced config `.
+
+ .. code-block:: yaml
+
+ # SkyPilot config
+ kubernetes:
+ allowed_contexts:
+ - context1
+ - context2
+ context_configs:
+ context1:
+ pod_config:
+ spec:
+ securityContext:
+ fsGroup: 1000
+ fsGroupChangePolicy: OnRootMismatch
+ containers:
+ - volumeMounts:
+ - mountPath: /mnt/data
+ name: my-pvc
+ volumes:
+ - name: my-pvc
+ persistentVolumeClaim:
+ claimName: pvc1
+ context2:
+ pod_config:
+ spec:
+ securityContext:
+ fsGroup: 1000
+ fsGroupChangePolicy: OnRootMismatch
+ containers:
+ - volumeMounts:
+ - mountPath: /mnt/data
+ name: my-pvc
+ volumes:
+ - name: my-pvc
+ persistentVolumeClaim:
+ claimName: pvc2
+
+ .. note::
+
+ The ``kubernetes.pod_config`` in the advanced config applies to every cluster launched on Kubernetes. To mount different PVCs per cluster, set the ``kubernetes.pod_config`` in the task YAML file as described in the :ref:`per-task configuration `. Refer to Kubernetes `volume mounts `_ and `volumes `_ documentation for more details.
+
.. tab-item:: Nebius shared filesystem
- :name: kubernetes-volumes-nebius-shared-filesystem
+ :name: primitives-volumes-nebius-vm-hostpath
When creating a node group on the Nebius console, attach your desired shared file system to the node group (``Create Node Group`` -> ``Attach shared filesystem``):
@@ -666,12 +529,183 @@ Examples:
path: /mnt/ # e.g. /mnt/filesystem-d0
type: Directory
+
.. note::
When using `hostPath volumes `_, the specified paths must already exist on the Kubernetes node where the pod is scheduled.
For NFS mounts using hostPath, ensure the NFS mount is already configured on all Kubernetes nodes.
+Advanced: Installing additional storage backends
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+SkyPilot volumes work with any Kubernetes StorageClass already available in your cluster. If your cluster doesn't have a StorageClass that meets your needs, you can optionally install one.
+
+Below are example configurations for setting up shared filesystems like JuiceFS or Nebius Shared Filesystem as SkyPilot volumes. Any storage backend that provides a Kubernetes StorageClass will work.
+
+.. dropdown:: Installing additional storage backends - JuiceFS, Nebius Shared Filesystem
+ :animate: fade-in
+
+ .. tab-set::
+
+ .. tab-item:: JuiceFS
+ :sync: juicefs-tab
+
+ To use `JuiceFS `_ as a SkyPilot volume:
+
+ 1. **Install the JuiceFS CSI driver** on your Kubernetes cluster. Follow the official `installation guide `_ for detailed instructions.
+
+ 2. **Verify the driver installation** - Confirm that the JuiceFS CSI Driver pods are running:
+
+ .. code-block:: console
+
+ $ kubectl -n kube-system get pod -l app.kubernetes.io/name=juicefs-csi-driver
+ NAME READY STATUS RESTARTS AGE
+ juicefs-csi-controller-0 2/2 Running 0 10m
+ juicefs-csi-node-8rd96 3/3 Running 0 10m
+
+ 3. **Set up JuiceFS storage and create a SkyPilot volume** - You can use either dynamic provisioning (with a StorageClass) or static provisioning (with a pre-created PV):
+
+ .. tab-set::
+
+ .. tab-item:: Dynamic Provisioning (StorageClass)
+ :sync: dynamic-tab
+
+ Create a StorageClass for dynamic provisioning. Refer to the `JuiceFS StorageClass guide `_ for details.
+
+ .. code-block:: console
+
+ $ kubectl get storageclass juicefs-sc
+ NAME PROVISIONER RECLAIMPOLICY VOLUMEBINDINGMODE ALLOWVOLUMEEXPANSION AGE
+ juicefs-sc csi.juicefs.com Retain Immediate false 10m
+
+ Create a SkyPilot volume YAML referencing the StorageClass:
+
+ .. code-block:: yaml
+
+ # juicefs-volume.yaml
+ name: juicefs-volume
+ type: k8s-pvc
+ infra: k8s
+ size: 100Gi
+ config:
+ storage_class_name: juicefs-sc
+ access_mode: ReadWriteMany
+
+ .. code-block:: console
+
+ $ sky volumes apply juicefs-volume.yaml
+
+ .. tab-item:: Static Provisioning (PV)
+ :sync: static-tab
+
+ Create a PersistentVolume and PVC manually. Refer to the `JuiceFS static provisioning guide `_ for details.
+
+ .. code-block:: console
+
+ $ kubectl get pv juicefs-pv
+ NAME CAPACITY ACCESS MODES RECLAIM POLICY STATUS CLAIM STORAGECLASS AGE
+ juicefs-pv 100Gi RWX Retain Bound default/juicefs-pvc 10m
+
+ $ kubectl get pvc juicefs-pvc
+ NAME STATUS VOLUME CAPACITY ACCESS MODES STORAGECLASS AGE
+ juicefs-pvc Bound juicefs-pv 100Gi RWX 10m
+
+ Create a SkyPilot volume YAML with ``use_existing: true`` to reference the existing PVC:
+
+ .. code-block:: yaml
+
+ # juicefs-volume.yaml
+ name: juicefs-volume
+ type: k8s-pvc
+ infra: k8s
+ use_existing: true
+ config:
+ access_mode: ReadWriteMany
+
+ .. code-block:: console
+
+ $ sky volumes apply juicefs-volume.yaml
+
+ 4. **Mount the volume to SkyPilot task** in your SkyPilot YAML:
+
+ .. code-block:: yaml
+
+ # task.yaml
+ num_nodes: 2
+
+ volumes:
+ # Mount the JuiceFS volume to /mnt/data across all nodes
+ /mnt/data: juicefs-volume
+
+ run: |
+ # Verify the volume is mounted and accessible
+ df -h /mnt/data
+ ls -la /mnt/data
+
+ .. code-block:: console
+
+ # Launch the cluster with the JuiceFS volume
+ $ sky launch -c juicefs-cluster task.yaml
+
+ .. tab-item:: Nebius shared file system
+ :sync: nebius-tab
+
+ To use `Nebius shared file system `_ as a SkyPilot volume using the CSI driver. For a simpler setup, we recommend using the :ref:`hostPath-based method ` described above, which mounts the filesystem directly from the host without requiring a CSI driver.
+
+ 1. **Set up the Nebius filesystem infrastructure** by following the official documentation:
+
+ - `Create a shared filesystem `_
+ - `Create a node group and mount the filesystem `_
+ - `Install the CSI driver `_
+
+ 2. **Verify the storage class** - Confirm that the ``csi-mounted-fs-path-sc`` storage class has been created:
+
+ .. code-block:: console
+
+ $ kubectl get storageclass
+ NAME PROVISIONER RECLAIMPOLICY VOLUMEBINDINGMODE ALLOWVOLUMEEXPANSION AGE
+ csi-mounted-fs-path-sc mounted-fs-path.csi.nebius.ai Delete WaitForFirstConsumer false 10m
+
+ 3. **Create a SkyPilot volume for Nebius file system** with a volume YAML:
+
+ .. code-block:: yaml
+
+ # nebius-volume.yaml
+ name: nebius-pvc
+ type: k8s-pvc
+ infra: k8s
+ size: 100Gi
+ config:
+ storage_class_name: csi-mounted-fs-path-sc
+ access_mode: ReadWriteMany
+
+ .. code-block:: console
+
+ $ sky volumes apply nebius-volume.yaml
+
+ 4. **Mount the volume to SkyPilot task** in your SkyPilot YAML:
+
+ .. code-block:: yaml
+
+ # task.yaml
+ num_nodes: 2
+
+ volumes:
+ # Mount the Nebius shared filesystem to /mnt/data across all nodes
+ /mnt/data: nebius-pvc
+
+ run: |
+ # Verify the volume is mounted and accessible
+ df -h /mnt/data
+ ls -la /mnt/data
+
+ .. code-block:: console
+
+ # Launch the cluster with the Nebius volume
+ $ sky launch -c nebius-cluster task.yaml
+
+
.. _volumes-on-runpod:
Volumes on RunPod
@@ -728,3 +762,10 @@ Managing volumes
~~~~~~~~~~~~~~~~
Same as Kubernetes volumes, refer to :ref:`volumes-on-kubernetes-manage` for more details.
+
+.. _ssh-node-pool-volumes:
+
+Volumes on SSH node pools
+-------------------------
+
+With SSH node pools, you can mount host volumes or directories into SkyPilot clusters and managed jobs. See :ref:`Volumes on SSH node pools ` for details.
diff --git a/docs/source/reference/yaml-spec.rst b/docs/source/reference/yaml-spec.rst
index 181f1c0d6a9..9ab1e9132d9 100644
--- a/docs/source/reference/yaml-spec.rst
+++ b/docs/source/reference/yaml-spec.rst
@@ -49,6 +49,12 @@ Below is the configuration syntax and some example values. See details under ea
:ref:`autostop `:
idle_minutes: 10
wait_for: none
+ :ref:`hook `: |
+ cd my-code-base
+ git add .
+ git commit -m "Auto-commit before shutdown"
+ git push
+ hook_timeout: 300
:ref:`any_of `:
- infra: aws/us-west-2
@@ -270,6 +276,12 @@ Format:
- ``jobs_and_ssh`` (default): Wait for in‑progress jobs and SSH connections to finish
- ``jobs``: Only wait for in‑progress jobs
- ``none``: Wait for nothing; autostop right after ``idle_minutes``
+ - ``hook``: Optional script to execute before autostop. The script runs on the remote cluster before stopping or tearing down. If the hook fails, autostop will still proceed but a warning will be logged.
+
+ See :ref:`Autostop hooks ` for detailed explanation and examples.
+
+ - ``hook_timeout``: Timeout in seconds for hook execution (default: 3600 = 1 hour, minimum: 1).
+ If the hook exceeds this timeout, it will be terminated and autostop continues.
```` can be one of:
- ``m``: minutes (default if not specified)
@@ -317,6 +329,20 @@ OR
idle_minutes: 10
wait_for: none # Stop after 10 minutes, regardless of running jobs or SSH connections
+OR
+
+.. code-block:: yaml
+
+ resources:
+ autostop:
+ idle_minutes: 10
+ hook: |
+ cd my-code-base
+ git add .
+ git commit -m "Auto-commit before shutdown"
+ git push
+ hook_timeout: 300
+
.. _yaml-spec-resources-accelerators:
@@ -912,7 +938,7 @@ We can also specify the exit codes that should always trigger recovery, regardle
We can specify multiple exit codes:
-.. code-block:: yaml
+.. code-block:: yaml
resources:
job_recovery:
diff --git a/docs/source/running-jobs/environment-variables.rst b/docs/source/running-jobs/environment-variables.rst
index 74433cc50b9..cc3d217c01b 100644
--- a/docs/source/running-jobs/environment-variables.rst
+++ b/docs/source/running-jobs/environment-variables.rst
@@ -188,8 +188,8 @@ Environment variables for ``setup``
3.4.5.6
* - ``SKYPILOT_SETUP_NUM_GPUS_PER_NODE``
- Number of GPUs per node in the cluster.
-
- Note that GPUs may not be available at this stage. Do not assume
+
+ Note that GPUs may not be available at this stage. Do not assume
GPUs are available during setup.
- 1
@@ -214,6 +214,9 @@ Environment variables for ``setup``
)['cloud']
- {"cluster_name": "my-cluster-name", "cloud": "GCP", "region": "us-central1", "zone": "us-central1-a"}
+ * - ``SKYPILOT_USER``
+ - The username of the user who launched the job.
+ - alice
* - ``SKYPILOT_SERVE_REPLICA_ID``
- The ID of a replica within the service (starting from 1). Available only for a :ref:`service `'s replica task.
- 1
@@ -270,6 +273,9 @@ Environment variables for ``run``
os.environ['SKYPILOT_CLUSTER_INFO']
)['cloud']
- {"cluster_name": "my-cluster-name", "cloud": "GCP", "region": "us-central1", "zone": "us-central1-a"}
+ * - ``SKYPILOT_USER``
+ - The username of the user who launched the job.
+ - alice
* - ``SKYPILOT_SERVE_REPLICA_ID``
- The ID of a replica within the service (starting from 1). Available only for a :ref:`service `'s replica task.
- - 1
\ No newline at end of file
+ - 1
diff --git a/examples/airflow/README.md b/examples/airflow/README.md
index 69faf5728fc..92998eb5c5b 100644
--- a/examples/airflow/README.md
+++ b/examples/airflow/README.md
@@ -11,7 +11,7 @@ This example uses a remote SkyPilot API Server to manage shared state across inv
-**💡 Tip:** SkyPilot also supports defining and running pipelines without Airflow. Check out [Jobs Pipelines](https://skypilot.readthedocs.io/en/latest/examples/managed-jobs.html#job-pipelines) for more information.
+**💡 Tip:** SkyPilot also supports defining and running pipelines without Airflow. Check out [Jobs Pipelines](https://docs.skypilot.co/en/latest/examples/managed-jobs.html#job-pipelines) for more information.
## Why use SkyPilot with Airflow?
In AI workflows, **the transition from development to production is hard**.
@@ -28,7 +28,7 @@ production Airflow cluster. Behind the scenes, SkyPilot handles environment setu
Here's how you can use SkyPilot to take your dev workflows to production in Airflow:
1. **Define and test your workflow as SkyPilot tasks**.
- - Use `sky launch` and [Sky VSCode integration](https://skypilot.readthedocs.io/en/latest/examples/interactive-development.html#dev-vscode) to run, debug and iterate on your code.
+ - Use `sky launch` and [Sky VSCode integration](https://docs.skypilot.co/en/latest/examples/interactive-development.html#dev-vscode) to run, debug and iterate on your code.
2. **Orchestrate SkyPilot tasks in Airflow** by invoking `sky launch` on their YAMLs as a task in the Airflow DAG.
- Airflow does the scheduling, logging, and monitoring, while SkyPilot handles the infra setup and task execution.
@@ -78,7 +78,7 @@ The train and eval step can be run in a similar way:
sky launch -c train --env DATA_BUCKET_NAME= --env DATA_BUCKET_STORE_TYPE=s3 train.yaml
```
-Hint: You can use `ssh` and VSCode to [interactively develop](https://skypilot.readthedocs.io/en/latest/examples/interactive-development.html) and debug the tasks.
+Hint: You can use `ssh` and VSCode to [interactively develop](https://docs.skypilot.co/en/latest/examples/interactive-development.html) and debug the tasks.
Note: `eval` can be optionally run on the same cluster as `train` with `sky exec`.
diff --git a/examples/aws_efa/README.md b/examples/aws_efa/README.md
index 1a23e306a2d..747c7f79749 100644
--- a/examples/aws_efa/README.md
+++ b/examples/aws_efa/README.md
@@ -6,23 +6,15 @@ Elastic Fabric Adapter (EFA) is an AWS alternative to Nvidia infiniband that ena
### TL;DR: enable EFA with SkyPilot
-You can enable EFA on AWS HyperPod/EKS clusters with an simple additional setting in your SkyPilot YAML:
+You can enable EFA on AWS HyperPod/EKS clusters by simply adding ``network_tier: best`` to your resources specification:
```yaml
-config:
- kubernetes:
- pod_config:
- spec:
- containers:
- - resources:
- limits:
- vpc.amazonaws.com/efa: 4
- requests:
- vpc.amazonaws.com/efa: 4
+resources:
+ infra: k8s
+ accelerators: A100:8
+ network_tier: best
```
-
-
### Enable EFA with HyperPod/EKS
* On HyperPod (backed by EKS), EFA is enabled by default, and you don't need to do anything.
@@ -40,42 +32,15 @@ hyperpod-i-0da69b9076c7ff6a4 ml.p4d.24xlarge 8 4
...
```
-### Access HyperPod and run distributed job with SkyPilot
-
-To access HyperPod and run distributed job with SkyPilot, see the SkyPilot [HyperPod example](https://github.com/skypilot-org/skypilot/blob/master/examples/hyperpod-eks).
-
-#### Adding EFA configurations in SkyPilot YAML
-
-To enable EFA in SkyPilot YAML, you can specify the following section in the SkyPilot YAML:
-
-```yaml
-config:
- kubernetes:
- pod_config:
- spec:
- containers:
- - resources:
- limits:
- vpc.amazonaws.com/efa: 4
- requests:
- vpc.amazonaws.com/efa: 4
-```
-
-This section is important for EFA integration:
-
-- `config.kubernetes.pod_config`: Provides Kubernetes-specific pod configuration
-- `spec.containers[0].resources`: Defines resource requirements
- - `limits.vpc.amazonaws.com/efa: 4`: Limits the Pod to use 4 EFA devices
- - `requests.vpc.amazonaws.com/efa: 4`: Requests 4 EFA devices for the Pod
-
-
-The `vpc.amazonaws.com/efa` resource type is exposed by the AWS EFA device plugin in Kubernetes.
+The `vpc.amazonaws.com/efa` resource is exposed by the AWS EFA device plugin in Kubernetes.
To see how many EFA are available for each instance types that have EFA, see the [Network cards](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/using-eni.html#network-cards) list in the Amazon EC2 User Guide.
Check the following table for the GPU and EFA count mapping for AWS instance types:
| Instance Type | GPU Type | #EFA |
|---------------|----------|------|
+| p6-b300.48xlarge | B300:8 | 16 |
+| p6-b200.48xlarge | B200:8 | 8 |
| p4d.24xlarge | A100:8 | 4 |
| p4de.24xlarge | A100:8 | 4 |
| p5.48xlarge | H100:8 | 32 |
@@ -100,15 +65,13 @@ Check the following table for the GPU and EFA count mapping for AWS instance typ
| g6e.16xlarge | L40S:1 | 1 |
| g6e.24xlarge | L40S:4 | 2 |
| g6e.48xlarge | L40S:8 | 4 |
-
-
-Update the EFA number in the [`nccl_efa.yaml`](https://github.com/skypilot-org/skypilot/blob/master/examples/aws_efa/nccl_efa.yaml) for the GPUs you use.
+| gr6.8xlarge | L4:1 | 1 |
### Running NCCL test with EFA using SkyPilot
Check the [`nccl_efa.yaml`](https://github.com/skypilot-org/skypilot/blob/master/examples/aws_efa/nccl_efa.yaml) for the complete SkyPilot cluster yaml configurations.
-The `image_id` provides the environment setup for [NCCL](https://developer.nvidia.com/nccl) (NVIDIA Collective Communications Library) and EFA (Elastic Fabric Adapter).
+The image [public.ecr.aws/hpc-cloud/nccl-tests:latest](https://github.com/aws-samples/awsome-distributed-training/blob/main/micro-benchmarks/nccl-tests/nccl-tests.Dockerfile) provides the environment setup for [NCCL](https://developer.nvidia.com/nccl) (NVIDIA Collective Communications Library) and EFA (Elastic Fabric Adapter).
To run the NCCL test with EFA support:
@@ -123,10 +86,7 @@ SkyPilot will:
4. Output performance metrics showing the benefits of EFA for distributed training
> **NOTE:**
-> We can turn off EFA with `nccl_efa.yaml` by passing an env:
-> ```bash
-> sky launch -c efa --env USE_EFA=false nccl_efa.yaml
-> ```
+> We can turn off EFA with `nccl_efa.yaml` by commenting out `network_tier: best`.
#### Benchmark results
@@ -178,7 +138,7 @@ EFA provides much higher throughput than the traditional TCP transport. Enabling
## Using EFA on AWS VM
-For the instance types listed in the GPU and EFA count mapping table in the [Adding EFA configurations in SkyPilot YAML](#adding-efa-configurations-in-skypilot-yaml) section, the EFA can be enabled by setting `resources.network_tier: best` in the task YAML.
+For the instance types listed in the GPU and EFA count mapping table in the [Enable EFA with HyperPod/EKS](#enable-efa-with-hyperpodeks) section, the EFA can be enabled by setting `resources.network_tier: best` in the task YAML.
```yaml
resources:
diff --git a/examples/aws_efa/nccl_efa.yaml b/examples/aws_efa/nccl_efa.yaml
index 809bd0c4b9b..ab2937e597c 100644
--- a/examples/aws_efa/nccl_efa.yaml
+++ b/examples/aws_efa/nccl_efa.yaml
@@ -5,14 +5,11 @@ name: nccl-efa-eks
resources:
infra: k8s
accelerators: A100:8
- cpus: 90+
image_id: docker:public.ecr.aws/hpc-cloud/nccl-tests:latest
+ network_tier: best
num_nodes: 2
-envs:
- USE_EFA: "true"
-
run: |
if [ "${SKYPILOT_NODE_RANK}" == "0" ]; then
echo "Head node"
@@ -28,22 +25,6 @@ run: |
nodes=${nodes::-1}
echo "All nodes: ${nodes}"
- # Set environment variables
- export PATH=$PATH:/usr/local/cuda-12.2/bin:/opt/amazon/efa/bin:/usr/bin
- export LD_LIBRARY_PATH=/usr/local/cuda-12.2/lib64:/opt/amazon/openmpi/lib:/opt/nccl/build/lib:/opt/amazon/efa/lib:/opt/aws-ofi-nccl/install/lib:/usr/local/nvidia/lib:$LD_LIBRARY_PATH
- export NCCL_HOME=/opt/nccl
- export CUDA_HOME=/usr/local/cuda-12.2
- export NCCL_DEBUG=INFO
- export NCCL_BUFFSIZE=8388608
- export NCCL_P2P_NET_CHUNKSIZE=524288
- export NCCL_TUNER_PLUGIN=/opt/aws-ofi-nccl/install/lib/libnccl-ofi-tuner.so
-
- if [ "${USE_EFA}" == "true" ]; then
- export FI_PROVIDER="efa"
- else
- export FI_PROVIDER=""
- fi
-
/opt/amazon/openmpi/bin/mpirun \
--allow-run-as-root \
--tag-output \
@@ -51,13 +32,9 @@ run: |
-np $NP \
-N $SKYPILOT_NUM_GPUS_PER_NODE \
--bind-to none \
- -x FI_PROVIDER \
-x PATH \
-x LD_LIBRARY_PATH \
-x NCCL_DEBUG=INFO \
- -x NCCL_BUFFSIZE \
- -x NCCL_P2P_NET_CHUNKSIZE \
- -x NCCL_TUNER_PLUGIN \
--mca pml ^cm,ucx \
--mca btl tcp,self \
--mca btl_tcp_if_exclude lo,docker0,veth_def_agent \
@@ -72,14 +49,3 @@ run: |
else
echo "Worker nodes"
fi
-
-config:
- kubernetes:
- pod_config:
- spec:
- containers:
- - resources:
- limits:
- vpc.amazonaws.com/efa: 4
- requests:
- vpc.amazonaws.com/efa: 4
diff --git a/examples/distributed_ray_train/ray_train.yaml b/examples/distributed_ray_train/ray_train.yaml
index 0ba202b884d..9a9a9314bff 100644
--- a/examples/distributed_ray_train/ray_train.yaml
+++ b/examples/distributed_ray_train/ray_train.yaml
@@ -5,6 +5,9 @@
resources:
accelerators: L4:2
memory: 64+
+ # On SLURM, it is recommended to use a Docker image to avoid permission
+ # issues with /tmp: https://github.com/ray-project/ray/issues/3899
+ # image_id: docker:rayproject/ray:nightly-py39-gpu
num_nodes: 2
diff --git a/examples/hyperpod-eks/README.md b/examples/hyperpod-eks/README.md
index 8c11be05272..951b97fe32d 100644
--- a/examples/hyperpod-eks/README.md
+++ b/examples/hyperpod-eks/README.md
@@ -5,7 +5,7 @@ This example shows how to run SkyPilot on AWS SageMaker HyperPod with EKS.
## Prerequisites
- An existing SageMaker HyperPod with EKS (or you can create one with AWS [doc](https://catalog.workshops.aws/sagemaker-hyperpod-eks/en-US/00-setup/own-account/01-workshop-infra-script))
-- SkyPilot installed: [installation doc](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html)
+- SkyPilot installed: [installation doc](https://docs.skypilot.co/en/latest/getting-started/installation.html)
```bash
pip install skypilot-nightly[kubernetes]
```
diff --git a/examples/managed_spot_queued_resource.yaml b/examples/managed_spot_queued_resource.yaml
new file mode 100644
index 00000000000..9a6347a5872
--- /dev/null
+++ b/examples/managed_spot_queued_resource.yaml
@@ -0,0 +1,27 @@
+name: minimal
+
+resources:
+ use_spot: true
+ infra: gcp/us-central1/us-central1-a
+ accelerators: tpu-v5litepod-16:1
+ accelerator_args:
+ runtime_version: v2-alpha-tpuv5-lite
+ gcp_queued_resource: true
+
+setup: |
+ echo "running setup"
+ pip install tqdm
+ pip install jax[tpu]
+
+run: |
+ conda env list
+ python -u - << EOF
+ import time
+ import tqdm
+ import jax
+ print(jax.devices())
+
+ for i in tqdm.trange(240):
+ time.sleep(1)
+
+ EOF
diff --git a/examples/metrics/kube_prometheus_node_exporter_service_monitor.yaml b/examples/metrics/kube_prometheus_node_exporter_service_monitor.yaml
deleted file mode 100644
index 2298eb24440..00000000000
--- a/examples/metrics/kube_prometheus_node_exporter_service_monitor.yaml
+++ /dev/null
@@ -1,26 +0,0 @@
-apiVersion: monitoring.coreos.com/v1
-kind: ServiceMonitor
-metadata:
- annotations:
- meta.helm.sh/release-name: kube-prometheus
- meta.helm.sh/release-namespace: skypilot
- labels:
- app.kubernetes.io/instance: kube-prometheus
- app.kubernetes.io/managed-by: Helm
- app.kubernetes.io/name: node-exporter
- name: kube-prometheus-node-exporter
- namespace: skypilot
-spec:
- endpoints:
- - port: metrics
- relabelings:
- - sourceLabels: [__meta_kubernetes_pod_node_name]
- targetLabel: node
- jobLabel: jobLabel
- namespaceSelector:
- matchNames:
- - skypilot
- selector:
- matchLabels:
- app.kubernetes.io/instance: kube-prometheus
- app.kubernetes.io/name: node-exporter
diff --git a/examples/metrics/prometheus-values.yaml b/examples/metrics/prometheus-values.yaml
new file mode 100644
index 00000000000..f2299785f0e
--- /dev/null
+++ b/examples/metrics/prometheus-values.yaml
@@ -0,0 +1,16 @@
+server:
+ persistentVolume:
+ enabled: true
+ size: 50Gi
+ retention: "1000d"
+ retentionSize: "43GB"
+kube-state-metrics:
+ enabled: true
+ metricLabelsAllowlist:
+ - pods=[skypilot-cluster,skypilot-cluster-name]
+prometheus-node-exporter:
+ enabled: false
+prometheus-pushgateway:
+ enabled: false
+alertmanager:
+ enabled: false
diff --git a/examples/metrics/skypilot_prometheus_server_service.yaml b/examples/metrics/skypilot_prometheus_server_service.yaml
deleted file mode 100644
index 1af9d7712a8..00000000000
--- a/examples/metrics/skypilot_prometheus_server_service.yaml
+++ /dev/null
@@ -1,22 +0,0 @@
-apiVersion: v1
-kind: Service
-metadata:
- labels:
- app.kubernetes.io/component: prometheus
- name: skypilot-prometheus-server
- namespace: skypilot
-spec:
- internalTrafficPolicy: Cluster
- ipFamilies:
- - IPv4
- ipFamilyPolicy: SingleStack
- ports:
- - name: http
- port: 80
- protocol: TCP
- targetPort: 9090
- selector:
- app.kubernetes.io/component: prometheus
- app.kubernetes.io/name: prometheus
- sessionAffinity: None
- type: ClusterIP
diff --git a/examples/plugin/README.md b/examples/plugin/README.md
index f1b1e5c0f4d..dbe1d655673 100644
--- a/examples/plugin/README.md
+++ b/examples/plugin/README.md
@@ -1,9 +1,69 @@
# Example Plugins for SkyPilot API Server
-Usage:
+## Usage
```bash
$ pip install .
$ cp plugins.yaml ~/.sky/plugins.yaml
$ sky api stop; sky api start
```
+
+## Remote Controller Support
+
+Plugins can be automatically deployed to remote controllers (jobs controller, serve
+controller) by creating a separate `remote_plugins.yaml` file that specifies which
+plugins should be uploaded to controllers.
+
+### Setup
+
+1. Create `~/.sky/plugins.yaml` for API server plugins with `controller_wheel_path`:
+
+```yaml
+controller_wheel_path: dist
+
+plugins:
+- class: example_plugin.ExamplePlugin
+```
+
+2. Create `~/.sky/remote_plugins.yaml` for remote controller plugins:
+
+```yaml
+plugins:
+- class: example_plugin.ExamplePatchPlugin
+```
+
+When `remote_plugins.yaml` exists and contains plugins:
+1. All `.whl` files found in the directory specified in `controller_wheel_path` (in `plugins.yaml`) are uploaded to remote clusters via file mounts
+2. The wheels are installed in the SkyPilot runtime environment on the cluster
+3. The `remote_plugins.yaml` config is uploaded to the cluster (as `plugins.yaml`)
+
+This allows your plugins to run on both the API server (if specified in `plugins.yaml`) and on job/serve controllers (if specified in `remote_plugins.yaml`).
+
+**Note:** You must build the wheel files yourself before configuring them in `plugins.yaml`. All `.whl` files in the specified directory will be uploaded. For example:
+```bash
+python -m build # or python setup.py bdist_wheel
+# This typically creates wheel files in the dist/ directory
+```
+
+### Configuration
+
+The `plugins.yaml` schema supports the following top-level fields:
+
+- `controller_wheel_path` (optional): Path to a directory containing prebuilt plugin wheel files (.whl). All `.whl` files in this directory will be uploaded to controllers. If no `.whl` files are found in the directory, nothing will be uploaded.
+
+The `plugins.yaml` schema supports the following fields per plugin:
+
+- `class` (required): The Python class path of the plugin (e.g., `module.ClassName`)
+- `parameters` (optional): Dictionary of parameters to pass to the plugin constructor
+
+The `remote_plugins.yaml` schema supports the following fields per plugin:
+
+- `class` (required): The Python class path of the plugin (e.g., `module.ClassName`)
+- `parameters` (optional): Dictionary of parameters to pass to the plugin constructor
+
+### Environment Variables
+
+You can customize the paths to these configuration files using environment variables:
+
+- `SKYPILOT_SERVER_PLUGINS_CONFIG`: Path to `plugins.yaml` (default: `~/.sky/plugins.yaml`)
+- `SKYPILOT_SERVER_REMOTE_PLUGINS_CONFIG`: Path to `remote_plugins.yaml` (default: `~/.sky/remote_plugins.yaml`)
diff --git a/examples/plugin/plugins.yaml b/examples/plugin/plugins.yaml
index 31b807daf29..2b0a58923e2 100644
--- a/examples/plugin/plugins.yaml
+++ b/examples/plugin/plugins.yaml
@@ -1,3 +1,8 @@
+# Path to a directory containing prebuilt wheel files (.whl) that will be uploaded
+# to remote clusters (jobs controller, serve controller).
+# All .whl files in this directory will be uploaded and installed.
+controller_wheel_path: dist
+
plugins:
- class: example_plugin.ExamplePlugin
- class: example_plugin.ExampleParameterizedPlugin
diff --git a/examples/plugin/remote_plugin.yaml b/examples/plugin/remote_plugin.yaml
new file mode 100644
index 00000000000..a072d06c14d
--- /dev/null
+++ b/examples/plugin/remote_plugin.yaml
@@ -0,0 +1,5 @@
+# Plugins specified here will be uploaded to remote controllers.
+# These plugins will be available on both the API server (if also in plugins.yaml)
+# and on remote controllers (jobs controller, serve controller).
+plugins:
+- class: example_plugin.ExamplePatchPlugin
diff --git a/examples/ray_basic/ray.yaml b/examples/ray_basic/ray.yaml
index 4d5041c7468..0c4fd6c07fc 100644
--- a/examples/ray_basic/ray.yaml
+++ b/examples/ray_basic/ray.yaml
@@ -6,6 +6,9 @@
resources:
cpus: 2+
+ # On SLURM, it is recommended to use a Docker image to avoid permission
+ # issues with /tmp: https://github.com/ray-project/ray/issues/3899
+ # image_id: docker:rayproject/ray:nightly-py39-cpu
num_nodes: 2
diff --git a/examples/redisvl-vector-search/README.md b/examples/redisvl-vector-search/README.md
index f5361613723..a7d1c8ab900 100644
--- a/examples/redisvl-vector-search/README.md
+++ b/examples/redisvl-vector-search/README.md
@@ -1,6 +1,6 @@
# RedisVL + SkyPilot: Vector Search at Scale
-Distributed vector search over [1M research papers](https://www.kaggle.com/datasets/nechbamohammed/research-papers-dataset) using [RedisVL](https://docs.redisvl.com/en/latest/) and [SkyPilot](https://skypilot.readthedocs.io/en/latest/).
+Distributed vector search over [1M research papers](https://www.kaggle.com/datasets/nechbamohammed/research-papers-dataset) using [RedisVL](https://docs.redisvl.com/en/latest/) and [SkyPilot](https://docs.skypilot.co/en/latest/).
📖 [Read the full blog post](https://blog.skypilot.co/redisvl-skypilot/).
diff --git a/examples/serve/nvidia-dynamo/README.md b/examples/serve/nvidia-dynamo/README.md
index ddf5b5e5cab..933025c503e 100644
--- a/examples/serve/nvidia-dynamo/README.md
+++ b/examples/serve/nvidia-dynamo/README.md
@@ -19,9 +19,17 @@ NVIDIA Dynamo is a high-performance inference framework designed for serving gen
- **Disaggregated Prefill & Decode**: Separates inference phases for optimal resource utilization
- **Dynamic GPU Scheduling**: Intelligent workload distribution across available GPUs
- **LLM-Aware Request Routing**: Smart routing based on model characteristics and cache states
-- **Accelerated Data Transfer**: High-performance data movement between nodes
+- **Accelerated Data Transfer**: High-performance data movement between nodes via NIXL
- **KV Cache Offloading**: Multi-tiered memory management for efficient cache utilization
+## Container Image
+
+These examples use the official NVIDIA Dynamo container images from NGC:
+- `nvcr.io/nvidia/ai-dynamo/sglang-runtime:0.7.1` - SGLang backend (used in these examples)
+- `nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.7.1` - vLLM backend (alternative)
+
+The NGC images are freely accessible and include all necessary dependencies (NATS, etcd, NIXL, etc.).
+
## Launching Nvidia Dynamo with SkyPilot
### Single-Node Example (`nvidia-dynamo.sky.yaml`)
@@ -29,6 +37,7 @@ NVIDIA Dynamo is a high-performance inference framework designed for serving gen
- ✅ **OpenAI-Compatible API**: Drop-in replacement for OpenAI endpoints
- ✅ **Basic Load Balancing**: Round-robin request distribution
- ✅ **Auto-Discovery**: Dynamic worker registration
+- ✅ **No etcd Required**: Uses file-based KV store for single-node simplicity
### Multi-Node Example (`nvidia-dynamo-multinode.sky.yaml`)
- ✅ **KV-Aware Routing**: Intelligent cache-based request routing (`--router-mode kv`)
@@ -36,6 +45,7 @@ NVIDIA Dynamo is a high-performance inference framework designed for serving gen
- ✅ **Data Parallel Attention**: DP=2 across nodes (`--enable-dp-attention`)
- ✅ **Tensor Parallelism**: TP=8 per node for large model support
- ✅ **Disaggregated Transfer**: NIXL backend for KV cache transfers
+- ✅ **Centralized Services**: NATS and etcd run on head node, workers connect automatically
**Model**: `Qwen/Qwen3-8B` (8B parameter reasoning model)
diff --git a/examples/serve/nvidia-dynamo/nvidia-dynamo-multinode.sky.yaml b/examples/serve/nvidia-dynamo/nvidia-dynamo-multinode.sky.yaml
index 9d10ebf6d92..d07c5ab5063 100644
--- a/examples/serve/nvidia-dynamo/nvidia-dynamo-multinode.sky.yaml
+++ b/examples/serve/nvidia-dynamo/nvidia-dynamo-multinode.sky.yaml
@@ -2,7 +2,7 @@
#
# Usage:
#
-# sky launch -c dynamo-multi nvidia-dynamo-multinode.sky.yaml
+# sky launch -c dynamo-multi nvidia-dynamo-multinode.sky.yaml
#
# This config uses 2 nodes with 8x H100 GPUs each for disaggregated serving.
# Optionally override the model:
@@ -10,8 +10,10 @@
# sky launch -c dynamo-multi nvidia-dynamo-multinode.sky.yaml --env MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct --env HF_TOKEN
resources:
- accelerators: H100:8
+ accelerators: {H100:8, H200:8}
ports: 8080
+ # Use the official NVIDIA Dynamo SGLang runtime image from NGC
+ image_id: docker:nvcr.io/nvidia/ai-dynamo/sglang-runtime:0.7.1
num_nodes: 2
@@ -20,16 +22,7 @@ envs:
DIST_INIT_PORT: 29500
HF_TOKEN: "" # needed if a model is gated in HF Hub. Pass the value with `--env HF_TOKEN`
-setup: |
- sudo usermod -aG docker $USER
- sudo chmod 666 /var/run/docker.sock
- uv pip install "ai-dynamo[sglang]==0.5.0" accelerate --system --prerelease=allow
- uv pip install "sglang[all]==0.5.2" --system --prerelease=allow
- curl -fsSL -o docker-compose.yml https://raw.githubusercontent.com/ai-dynamo/dynamo/v0.5.0/deploy/docker-compose.yml
- docker compose -f docker-compose.yml up -d
-
run: |
- export GLOO_SOCKET_IFNAME=$(ip -o -4 route show to default | awk '{print $5}')
HEAD_IP=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
TOTAL_GPUS=$((SKYPILOT_NUM_NODES * SKYPILOT_NUM_GPUS_PER_NODE))
@@ -38,11 +31,31 @@ run: |
TP_SIZE=$((TOTAL_GPUS / 2))
DP_SIZE=2
+ # Get the network interface for GLOO
+ export GLOO_SOCKET_IFNAME=$(ip -o -4 route show to default | awk '{print $5}')
+
if [ "${SKYPILOT_NODE_RANK}" == "0" ]; then
+ # Head node: Start NATS and etcd services
+ echo "Starting NATS and etcd on head node..."
+ nats-server -js &
+ etcd --listen-client-urls http://0.0.0.0:2379 \
+ --advertise-client-urls http://${HEAD_IP}:2379 \
+ --data-dir /tmp/etcd &
+ sleep 3
+
# Start frontend with KV-aware routing enabled
python -m dynamo.frontend --router-mode kv --http-port 8080 &
+ else
+ # Worker nodes: Wait for head node services to be ready
+ echo "Waiting for head node services..."
+ sleep 5
fi
+ # Set connection endpoints for NATS and etcd (all nodes connect to head)
+ export NATS_SERVER=nats://${HEAD_IP}:4222
+ export ETCD_ENDPOINTS=http://${HEAD_IP}:2379
+
+ # All nodes run SGLang workers
python -m dynamo.sglang \
--model-path $MODEL_NAME \
--tp $TP_SIZE \
@@ -57,4 +70,15 @@ run: |
--mem-fraction-static 0.82 \
--disaggregation-transfer-backend nixl \
--disaggregation-bootstrap-port 30001 \
- --page-size 16
\ No newline at end of file
+ --page-size 16
+
+# Kubernetes-specific configuration
+config:
+ kubernetes:
+ pod_config:
+ spec:
+ containers:
+ - securityContext:
+ # Run as root to allow SkyPilot to install necessary packages
+ runAsUser: 0
+ runAsGroup: 0
diff --git a/examples/serve/nvidia-dynamo/nvidia-dynamo.sky.yaml b/examples/serve/nvidia-dynamo/nvidia-dynamo.sky.yaml
index 44c0bb425d3..3bb455f581a 100644
--- a/examples/serve/nvidia-dynamo/nvidia-dynamo.sky.yaml
+++ b/examples/serve/nvidia-dynamo/nvidia-dynamo.sky.yaml
@@ -2,28 +2,40 @@
#
# Usage:
#
-# sky launch -c dynamo nvidia-dynamo.sky.yaml
+# sky launch -c dynamo nvidia-dynamo.sky.yaml
#
# Optionally override the model:
#
-# sky launch -c dynamo nvidia-dynamo.sky.yaml --env MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct --env HF_TOKEN
+# sky launch -c dynamo nvidia-dynamo.sky.yaml --env MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct --env HF_TOKEN
resources:
- accelerators: H100:1
+ accelerators: {H100:1, H200:1}
ports: 8080
+ # Use the official NVIDIA Dynamo SGLang runtime image from NGC
+ image_id: docker:nvcr.io/nvidia/ai-dynamo/sglang-runtime:0.7.1
envs:
MODEL_NAME: Qwen/Qwen3-8B
HF_TOKEN: "" # needed if a model is gated in HF Hub. Pass the value with `--env HF_TOKEN`
-setup: |
- sudo usermod -aG docker $USER
- sudo chmod 666 /var/run/docker.sock
+run: |
+ # Start NATS server with JetStream enabled (required for Dynamo messaging)
+ nats-server -js &
+ sleep 2
- uv pip install "ai-dynamo[sglang]==0.4.1" accelerate --system --prerelease=allow
- curl -fsSL -o docker-compose.yml https://raw.githubusercontent.com/ai-dynamo/dynamo/release/0.4.1/deploy/docker-compose.yml
- docker compose -f docker-compose.yml up -d
+ # Start the Dynamo frontend (HTTP server + router)
+ python -m dynamo.frontend --http-port 8080 --store-kv file &
-run: |
- python -m dynamo.frontend &
- python -m dynamo.sglang --model $MODEL_NAME
\ No newline at end of file
+ # Start the SGLang worker
+ python -m dynamo.sglang --model $MODEL_NAME --store-kv file
+
+# Kubernetes-specific configuration
+config:
+ kubernetes:
+ pod_config:
+ spec:
+ containers:
+ - securityContext:
+ # Run as root to allow SkyPilot to install necessary packages
+ runAsUser: 0
+ runAsGroup: 0
diff --git a/examples/streamlit/README.md b/examples/streamlit/README.md
index 29fbf080144..29a9885a037 100644
--- a/examples/streamlit/README.md
+++ b/examples/streamlit/README.md
@@ -76,5 +76,5 @@ resources:
## Learn more
-- [SkyPilot Documentation](https://skypilot.readthedocs.io/)
+- [SkyPilot Documentation](https://docs.skypilot.co/)
- [Streamlit Documentation](https://docs.streamlit.io/)
diff --git a/examples/together_infiniband/README.md b/examples/together_infiniband/README.md
new file mode 100644
index 00000000000..2d60464cfc7
--- /dev/null
+++ b/examples/together_infiniband/README.md
@@ -0,0 +1,54 @@
+# Using InfiniBand in Together AI with SkyPilot
+
+SkyPilot provides the `network_tier: best` configuration option that automatically enables InfiniBand support on Together AI Kubernetes clusters. This eliminates the need for manual configuration of security contexts and environment variables.
+
+## InfiniBand on Together AI Kubernetes clusters
+
+Simply add ``network_tier: best`` to your resources specification:
+
+```yaml
+resources:
+ infra: k8s
+ accelerators: H100:8
+ network_tier: best
+```
+
+This enables the InfiniBand for inter-GPU communication, and SkyPilot will automatically setup the environment variables for you.
+
+## Running NCCL test using SkyPilot
+
+Check the [`nccl_network_tier.yaml`](https://github.com/skypilot-org/skypilot/blob/master/examples/together_infiniband/nccl_network_tier.yaml) for the complete SkyPilot cluster yaml configurations.
+
+The `image_id` provides the environment setup for [NCCL](https://developer.nvidia.com/nccl) (NVIDIA Collective Communications Library).
+
+To run the NCCL test with InfiniBand support:
+
+```bash
+sky launch -c infiniband nccl_network_tier.yaml
+```
+
+SkyPilot will:
+1. Schedule the job on the Kubernetes cluster with required GPU nodes
+2. Launch Pods and execute the NCCL performance test
+3. Output performance metrics showing the benefits of InfiniBand for distributed training
+
+The example result is as below:
+
+```
+# out-of-place in-place
+# size count type redop root time algbw busbw #wrong time algbw busbw #wrong
+# (B) (elements) (us) (GB/s) (GB/s) (us) (GB/s) (GB/s)
+ 536870912 134217728 float sum -1 2407.5 222.99 418.12 0 2380.3 225.55 422.90 0
+ 1073741824 268435456 float sum -1 4524.3 237.33 444.99 0 4531.6 236.95 444.28 0
+ 2147483648 536870912 float sum -1 8787.5 244.38 458.21 0 8780.7 244.57 458.56 0
+ 4294967296 1073741824 float sum -1 17327 247.88 464.77 0 17328 247.86 464.74 0
+ 8589934592 2147483648 float sum -1 34462 249.26 467.36 0 34482 249.11 467.08 0
+# Out of bounds values : 0 OK
+# Avg bus bandwidth : 451.101
+```
+
+> **NOTE:** To run NCCL tests without InfiniBand, you can launch a cluster with `nccl_no_ib.yaml`:
+>
+> ```bash
+> sky launch -c no_infiniband nccl_no_ib.yaml
+> ```
diff --git a/examples/together_infiniband/nccl_network_tier.yaml b/examples/together_infiniband/nccl_network_tier.yaml
new file mode 100644
index 00000000000..075c7673edb
--- /dev/null
+++ b/examples/together_infiniband/nccl_network_tier.yaml
@@ -0,0 +1,52 @@
+# This example is used to test the NCCL performance with
+# InfiniBand on Together AI Kubernetes cluster.
+name: nccl-network-tier
+
+resources:
+ infra: k8s
+ accelerators: H100:8
+ image_id: docker:nvcr.io/nvidia/pytorch:24.07-py3
+ network_tier: best
+
+num_nodes: 2
+
+run: |
+ if [ "${SKYPILOT_NODE_RANK}" == "0" ]; then
+ echo "Head node"
+
+ # Total number of processes, NP should be the total number of GPUs in the cluster
+ NP=$(($SKYPILOT_NUM_GPUS_PER_NODE * $SKYPILOT_NUM_NODES))
+
+ # Append :${SKYPILOT_NUM_GPUS_PER_NODE} to each IP as slots
+ nodes=""
+ for ip in $SKYPILOT_NODE_IPS; do
+ nodes="${nodes}${ip}:${SKYPILOT_NUM_GPUS_PER_NODE},"
+ done
+ nodes=${nodes::-1}
+ echo "All nodes: ${nodes}"
+
+ mpirun \
+ --allow-run-as-root \
+ --tag-output \
+ -H $nodes \
+ -np $NP \
+ -N $SKYPILOT_NUM_GPUS_PER_NODE \
+ --bind-to none \
+ -x PATH \
+ -x LD_LIBRARY_PATH \
+ -x NCCL_DEBUG=INFO \
+ -x NCCL_IB_HCA \
+ -x UCX_NET_DEVICES \
+ -x SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \
+ -x NCCL_COLLNET_ENABLE=0 \
+ /usr/local/bin/all_reduce_perf_mpi \
+ -b 512M \
+ -e 8G \
+ -f 2 \
+ -g 1 \
+ -c 1 \
+ -w 5 \
+ -n 10
+ else
+ echo "Worker nodes"
+ fi
diff --git a/examples/together_infiniband/nccl_no_ib.yaml b/examples/together_infiniband/nccl_no_ib.yaml
new file mode 100644
index 00000000000..9f0bc40d773
--- /dev/null
+++ b/examples/together_infiniband/nccl_no_ib.yaml
@@ -0,0 +1,54 @@
+# This example is used to test the NCCL performance without
+# InfiniBand on Together AI Kubernetes cluster.
+name: nccl-no-ib
+
+resources:
+ infra: k8s
+ accelerators: H100:8
+ image_id: docker:nvcr.io/nvidia/pytorch:24.07-py3
+
+num_nodes: 2
+
+run: |
+ if [ "${SKYPILOT_NODE_RANK}" == "0" ]; then
+ echo "Head node"
+
+ # Total number of processes, NP should be the total number of GPUs in the cluster
+ NP=$(($SKYPILOT_NUM_GPUS_PER_NODE * $SKYPILOT_NUM_NODES))
+
+ # Append :${SKYPILOT_NUM_GPUS_PER_NODE} to each IP as slots
+ nodes=""
+ for ip in $SKYPILOT_NODE_IPS; do
+ nodes="${nodes}${ip}:${SKYPILOT_NUM_GPUS_PER_NODE},"
+ done
+ nodes=${nodes::-1}
+ echo "All nodes: ${nodes}"
+
+ export NCCL_IB_HCA=""
+ export UCX_NET_DEVICES="eth0"
+
+ mpirun \
+ --allow-run-as-root \
+ --tag-output \
+ -H $nodes \
+ -np $NP \
+ -N $SKYPILOT_NUM_GPUS_PER_NODE \
+ --bind-to none \
+ -x PATH \
+ -x LD_LIBRARY_PATH \
+ -x NCCL_DEBUG=INFO \
+ -x NCCL_IB_HCA \
+ -x UCX_NET_DEVICES \
+ -x SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \
+ -x NCCL_COLLNET_ENABLE=0 \
+ /usr/local/bin/all_reduce_perf_mpi \
+ -b 512M \
+ -e 8G \
+ -f 2 \
+ -g 1 \
+ -c 1 \
+ -w 5 \
+ -n 10
+ else
+ echo "Worker nodes"
+ fi
diff --git a/format.sh b/format.sh
index 62ac9889900..ee3f67f45fe 100755
--- a/format.sh
+++ b/format.sh
@@ -129,10 +129,7 @@ isort --profile black -l 88 -m 3 "sky/skylet/providers/ibm"
# TODO(zhwu): When more of the codebase is typed properly, the mypy flags
# should be set to do a more stringent check.
echo 'SkyPilot mypy:'
-# Workaround for mypy 1.14.1 cache serialization bug that causes
-# "AssertionError: Internal error: unresolved placeholder type None"
-# Using --cache-dir=/dev/null disables cache writing to avoid the error
-mypy $(cat tests/mypy_files.txt) --cache-dir=/dev/null
+mypy $(cat tests/mypy_files.txt)
# Run Pylint
echo 'Sky Pylint:'
@@ -159,9 +156,9 @@ if ! npm -v || ! node -v; then
# Don't fail the script if npm or node is not installed
# because it's not required for all users
else
- npm --prefix sky/dashboard install
+ output=$(npm --prefix sky/dashboard install 2>&1) || { echo "$output"; exit 1; }
npm --prefix sky/dashboard run lint
- npm --prefix sky/dashboard run format
+ npm --prefix sky/dashboard run format -- --log-level warn
echo "SkyPilot Dashboard linting and formatting: Done"
echo
fi
diff --git a/llm/rl-post-training-jobgroup/README.md b/llm/rl-post-training-jobgroup/README.md
new file mode 100644
index 00000000000..eba5c84e7b8
--- /dev/null
+++ b/llm/rl-post-training-jobgroup/README.md
@@ -0,0 +1,175 @@
+# RL Post-Training with Job Groups
+
+This example demonstrates a distributed RL post-training architecture using SkyPilot job groups. It trains an LLM on mathematical reasoning tasks using GRPO (Group Relative Policy Optimization) with verifiable rewards.
+
+## Architecture
+
+The example consists of 5 task types that communicate over HTTP, with built-in load balancing for scaling inference:
+
+
+
+
+
+### Components
+
+1. **data-server** (auxiliary): FastAPI server that serves math prompts from the GSM8K dataset. Provides batches of problems with ground truth answers.
+
+2. **rollout-server** (auxiliary, x2): SGLang inference servers with native load balancing:
+ - Using `num_nodes: 2` creates two GPU instances for higher throughput
+ - Head node (rank 0) runs both SGLang server and SGLang router on port 30000
+ - SGLang router provides cache-aware load balancing for optimal KV cache reuse
+
+3. **reward-server** (auxiliary): Verifies mathematical answers by comparing model outputs against ground truth. Returns binary rewards (1.0 for correct, 0.0 for incorrect).
+
+4. **replay-buffer** (auxiliary): Stores experience tuples (prompt, response, reward) for sampling during training. Supports priority-based sampling where high-reward experiences are sampled more frequently.
+
+5. **ppo-trainer** (primary): Multi-node training orchestrator that implements GRPO. Coordinates with all other services to fetch prompts, generate responses, compute rewards, store experiences, and update the policy.
+
+### Primary/Auxiliary Tasks
+
+The ppo-trainer is designated as the **primary task**. When training completes:
+- All auxiliary services (data-server, rollout-server, reward-server, replay-buffer) are automatically terminated after a 10-second grace period (`termination_delay: 10s`)
+- This ensures GPU and CPU resources are released promptly once training finishes
+- Without this feature, auxiliary services would run indefinitely
+
+## Usage
+
+### Prerequisites
+
+- SkyPilot configured with a Kubernetes cluster
+- GPU nodes available (H100 recommended for optimal performance)
+
+### Launch Training
+
+```bash
+sky jobs launch llm/rl-post-training-jobgroup/rlhf-math-jobgroup.yaml
+```
+
+### Monitor Training
+
+```bash
+# Check job status
+sky jobs queue
+
+# View logs for specific components
+sky jobs logs data-server
+sky jobs logs rollout-server
+sky jobs logs reward-server
+sky jobs logs replay-buffer
+sky jobs logs ppo-trainer
+```
+
+Or use the SkyPilot dashboard to monitor jobs.
+
+## Configuration
+
+### Environment Variables
+
+| Variable | Default | Description |
+|----------|---------|-------------|
+| `MODEL_NAME` | `Qwen/Qwen2.5-0.5B-Instruct` | Model to train |
+| `NUM_EPOCHS` | `3` | Number of training epochs |
+| `BATCH_SIZE` | `4` | Training batch size |
+
+### Customizing Resources
+
+Edit the YAML to adjust resources per component:
+
+```yaml
+# For larger models, increase GPU memory
+resources:
+ accelerators: H100:1 # or A100:1
+ memory: 64+
+```
+
+## Service Discovery
+
+Components discover each other using job group DNS names:
+
+- `data-server-0.${SKYPILOT_JOBGROUP_NAME}:8000`
+- `rollout-server-0.${SKYPILOT_JOBGROUP_NAME}:30000` (SGLang router endpoint)
+- `rollout-server-0.${SKYPILOT_JOBGROUP_NAME}:30001` (SGLang backend 1)
+- `rollout-server-1.${SKYPILOT_JOBGROUP_NAME}:30001` (SGLang backend 2)
+- `reward-server-0.${SKYPILOT_JOBGROUP_NAME}:8002`
+- `replay-buffer-0.${SKYPILOT_JOBGROUP_NAME}:8003`
+
+This allows components to communicate without hardcoded IP addresses.
+
+## Load Balancing
+
+The example uses [SGLang's native router](https://docs.sglang.ai/advanced_features/router.html) for load balancing:
+
+1. **Multiple rollout servers**: Using `num_nodes: 2` creates two SGLang instances
+2. **Head node router**: The head node (rank 0) runs both SGLang server and SGLang router on port 30000
+3. **Automatic discovery**: The router is configured with worker URLs from `SKYPILOT_NUM_NODES`
+4. **Transparent to clients**: The trainer only needs to know the head node endpoint
+
+### Scaling to More Servers
+
+To scale up, simply increase `num_nodes`:
+
+```yaml
+name: rollout-server
+num_nodes: 4 # Scale to 4 servers
+```
+
+The router on the head node automatically discovers all workers using:
+```bash
+for i in $(seq 0 $((SKYPILOT_NUM_NODES - 1))); do
+ WORKER_URLS="${WORKER_URLS} http://rollout-server-${i}.${SKYPILOT_JOBGROUP_NAME}:30001"
+done
+```
+
+### SGLang Router Features
+
+SGLang's native router (`sglang_router`) provides:
+- **Cache-aware routing**: Routes requests to maximize KV cache reuse
+- Health checking with automatic failover
+- OpenAI-compatible API passthrough
+- Built-in Rust implementation for high performance
+
+## GRPO Algorithm
+
+GRPO (Group Relative Policy Optimization) is a simplified variant of PPO that:
+- Doesn't require a critic/value model
+- Uses group-relative advantages (compares rewards within a batch)
+- Works well with verifiable rewards (math, code)
+
+The training loop:
+1. Fetch batch of prompts from data-server
+2. Generate responses using rollout-server
+3. Compute rewards using reward-server
+4. Store experiences in replay-buffer
+5. Calculate group-relative advantages
+6. Update policy with clipped surrogate loss
+7. Sample from replay-buffer for additional updates (experience replay)
+
+## Extending This Example
+
+### Using a Reward Model
+
+Replace the reward-server with a neural reward model:
+
+```python
+# In reward_server.py, load a reward model
+from transformers import AutoModelForSequenceClassification
+model = AutoModelForSequenceClassification.from_pretrained("Skywork/Skywork-Reward-Llama-3.1-8B-v0.2")
+```
+
+### Scaling Up
+
+For larger models:
+1. Increase SGLang tensor parallelism
+2. Use multiple GPUs per trainer node
+3. Enable gradient checkpointing
+
+### Adding a Critic
+
+For full PPO, add a critic-server component that estimates value functions.
+
+## References
+
+- [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF) - Distributed RLHF framework
+- [VeRL](https://github.com/volcengine/verl) - Hybrid flow RLHF framework
+- [GRPO Paper](https://arxiv.org/abs/2402.03300) - Group Relative Policy Optimization
+- [GSM8K Dataset](https://huggingface.co/datasets/openai/gsm8k) - Math reasoning benchmark
diff --git a/llm/rl-post-training-jobgroup/code/data_server.py b/llm/rl-post-training-jobgroup/code/data_server.py
new file mode 100644
index 00000000000..a6d58ae876c
--- /dev/null
+++ b/llm/rl-post-training-jobgroup/code/data_server.py
@@ -0,0 +1,185 @@
+#!/usr/bin/env python3
+"""Data server for RLHF training - serves math prompts from GSM8K dataset.
+
+This server provides batches of math problems with their ground truth answers
+for training LLMs on mathematical reasoning tasks.
+
+Usage:
+ python data_server.py --port 8000
+"""
+
+import argparse
+import random
+from typing import List, Optional
+
+from fastapi import FastAPI
+from fastapi import Query
+from pydantic import BaseModel
+import uvicorn
+
+app = FastAPI(title="RLHF Data Server",
+ description="Serves math prompts for training")
+
+# Global state
+prompts_data: List[dict] = []
+current_index: int = 0
+
+
+class Prompt(BaseModel):
+ """A single prompt with its ground truth answer."""
+ id: int
+ prompt: str
+ ground_truth: str
+
+
+class PromptBatch(BaseModel):
+ """A batch of prompts."""
+ prompts: List[Prompt]
+ total_available: int
+
+
+def load_dataset():
+ """Load GSM8K dataset from HuggingFace."""
+ global prompts_data
+
+ try:
+ from datasets import load_dataset
+ print("Loading GSM8K dataset...")
+ dataset = load_dataset("openai/gsm8k", "main", split="train")
+
+ prompts_data = []
+ for i, item in enumerate(dataset):
+ # Extract the numerical answer from the solution
+ # GSM8K format: solution ends with "#### "
+ solution = item["answer"]
+ answer_marker = "####"
+ if answer_marker in solution:
+ ground_truth = solution.split(answer_marker)[-1].strip()
+ else:
+ ground_truth = solution.strip()
+
+ # Format prompt for instruction-following model
+ prompt = f"""Solve the following math problem step by step. End your solution with the final numerical answer.
+
+Problem: {item["question"]}
+
+Solution:"""
+
+ prompts_data.append({
+ "id": i,
+ "prompt": prompt,
+ "ground_truth": ground_truth
+ })
+
+ # Shuffle for training
+ random.shuffle(prompts_data)
+ print(f"Loaded {len(prompts_data)} prompts from GSM8K")
+
+ except Exception as e:
+ print(f"Error loading dataset: {e}")
+ # Fallback to simple math problems for testing
+ prompts_data = [
+ {
+ "id": 0,
+ "prompt": "What is 2 + 2?",
+ "ground_truth": "4"
+ },
+ {
+ "id": 1,
+ "prompt": "What is 10 * 5?",
+ "ground_truth": "50"
+ },
+ {
+ "id": 2,
+ "prompt": "What is 100 / 4?",
+ "ground_truth": "25"
+ },
+ {
+ "id": 3,
+ "prompt": "What is 7 + 8?",
+ "ground_truth": "15"
+ },
+ {
+ "id": 4,
+ "prompt": "What is 9 * 9?",
+ "ground_truth": "81"
+ },
+ ]
+ print(f"Using {len(prompts_data)} fallback prompts")
+
+
+@app.on_event("startup")
+async def startup_event():
+ """Load dataset on startup."""
+ load_dataset()
+
+
+@app.get("/health")
+async def health():
+ """Health check endpoint."""
+ return {"status": "healthy", "prompts_loaded": len(prompts_data)}
+
+
+@app.get("/prompts", response_model=PromptBatch)
+async def get_prompts(
+ batch_size: int = Query(default=8,
+ ge=1,
+ le=256,
+ description="Number of prompts to return"),
+ shuffle: bool = Query(default=True,
+ description="Whether to shuffle prompts")):
+ """Get a batch of prompts for training."""
+ global current_index
+
+ if not prompts_data:
+ return PromptBatch(prompts=[], total_available=0)
+
+ # Get batch of prompts
+ if shuffle:
+ batch = random.sample(prompts_data, min(batch_size, len(prompts_data)))
+ else:
+ # Sequential access with wraparound
+ batch = []
+ for _ in range(batch_size):
+ batch.append(prompts_data[current_index])
+ current_index = (current_index + 1) % len(prompts_data)
+
+ prompts = [Prompt(**p) for p in batch]
+ return PromptBatch(prompts=prompts, total_available=len(prompts_data))
+
+
+@app.get("/prompt/{prompt_id}", response_model=Optional[Prompt])
+async def get_prompt_by_id(prompt_id: int):
+ """Get a specific prompt by ID."""
+ for p in prompts_data:
+ if p["id"] == prompt_id:
+ return Prompt(**p)
+ return None
+
+
+@app.post("/reset")
+async def reset_index():
+ """Reset the sequential index to the beginning."""
+ global current_index
+ current_index = 0
+ return {"status": "reset", "index": current_index}
+
+
+def main():
+ parser = argparse.ArgumentParser(description="RLHF Data Server")
+ parser.add_argument("--port",
+ type=int,
+ default=8000,
+ help="Port to run server on")
+ parser.add_argument("--host",
+ type=str,
+ default="0.0.0.0",
+ help="Host to bind to")
+ args = parser.parse_args()
+
+ print(f"Starting data server on {args.host}:{args.port}")
+ uvicorn.run(app, host=args.host, port=args.port)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llm/rl-post-training-jobgroup/code/replay_buffer.py b/llm/rl-post-training-jobgroup/code/replay_buffer.py
new file mode 100644
index 00000000000..d4c1422e60b
--- /dev/null
+++ b/llm/rl-post-training-jobgroup/code/replay_buffer.py
@@ -0,0 +1,254 @@
+"""Replay Buffer Server for RLHF Training.
+
+This server provides a centralized experience replay buffer that stores
+(prompt, response, reward) tuples and allows sampling for training.
+
+Features:
+- Thread-safe storage with configurable capacity
+- Uniform random sampling
+- Priority-based sampling (optional)
+- Statistics tracking
+
+API Endpoints:
+- POST /add: Add experiences to the buffer
+- POST /sample: Sample a batch of experiences
+- GET /stats: Get buffer statistics
+- POST /clear: Clear the buffer
+- GET /health: Health check
+"""
+
+import argparse
+from collections import deque
+from dataclasses import dataclass
+from dataclasses import field
+import random
+import threading
+import time
+from typing import List, Optional
+
+from fastapi import FastAPI
+from fastapi import HTTPException
+from pydantic import BaseModel
+import uvicorn
+
+
+@dataclass
+class Experience:
+ """A single experience tuple."""
+ prompt: str
+ response: str
+ reward: float
+ ground_truth: Optional[str] = None
+ timestamp: float = field(default_factory=time.time)
+ priority: float = 1.0
+
+
+class AddExperienceRequest(BaseModel):
+ """Request to add experiences to the buffer."""
+ experiences: List[dict]
+
+
+class SampleRequest(BaseModel):
+ """Request to sample from the buffer."""
+ batch_size: int = 4
+ prioritized: bool = False
+
+
+class ReplayBuffer:
+ """Thread-safe replay buffer with priority sampling support."""
+
+ def __init__(self, capacity: int = 10000):
+ self.capacity = capacity
+ self.buffer: deque = deque(maxlen=capacity)
+ self.lock = threading.Lock()
+ self.total_added = 0
+ self.total_sampled = 0
+
+ def add(self, experiences: List[Experience]) -> int:
+ """Add experiences to the buffer."""
+ with self.lock:
+ for exp in experiences:
+ self.buffer.append(exp)
+ self.total_added += 1
+ return len(experiences)
+
+ def sample(self,
+ batch_size: int,
+ prioritized: bool = False) -> List[Experience]:
+ """Sample a batch of experiences."""
+ with self.lock:
+ if len(self.buffer) == 0:
+ return []
+
+ actual_size = min(batch_size, len(self.buffer))
+
+ if prioritized:
+ # Priority-based sampling (higher reward = higher priority)
+ priorities = [exp.priority for exp in self.buffer]
+ total_priority = sum(priorities)
+ if total_priority > 0:
+ probs = [p / total_priority for p in priorities]
+ indices = random.choices(range(len(self.buffer)),
+ weights=probs,
+ k=actual_size)
+ else:
+ indices = random.sample(range(len(self.buffer)),
+ actual_size)
+ else:
+ # Uniform random sampling
+ indices = random.sample(range(len(self.buffer)), actual_size)
+
+ samples = [self.buffer[i] for i in indices]
+ self.total_sampled += len(samples)
+ return samples
+
+ def clear(self):
+ """Clear the buffer."""
+ with self.lock:
+ self.buffer.clear()
+
+ def stats(self) -> dict:
+ """Get buffer statistics."""
+ with self.lock:
+ rewards = [exp.reward for exp in self.buffer]
+ return {
+ "size": len(self.buffer),
+ "capacity": self.capacity,
+ "total_added": self.total_added,
+ "total_sampled": self.total_sampled,
+ "avg_reward": sum(rewards) / len(rewards) if rewards else 0,
+ "min_reward": min(rewards) if rewards else 0,
+ "max_reward": max(rewards) if rewards else 0,
+ "positive_ratio": sum(1 for r in rewards if r > 0) /
+ len(rewards) if rewards else 0,
+ }
+
+
+# Initialize FastAPI app
+app = FastAPI(title="Replay Buffer Server",
+ description="Experience replay buffer for RLHF training")
+
+# Global buffer instance
+buffer: Optional[ReplayBuffer] = None
+
+
+@app.on_event("startup")
+async def startup():
+ """Initialize the replay buffer on startup."""
+ global buffer
+ buffer = ReplayBuffer(capacity=10000)
+ print("Replay buffer initialized with capacity 10000")
+
+
+@app.get("/health")
+async def health():
+ """Health check endpoint."""
+ return {
+ "status": "healthy",
+ "buffer_size": len(buffer.buffer) if buffer else 0
+ }
+
+
+@app.post("/add")
+async def add_experiences(request: AddExperienceRequest):
+ """Add experiences to the replay buffer.
+
+ Each experience should have:
+ - prompt: The input prompt
+ - response: The model's response
+ - reward: The reward score
+ - ground_truth (optional): The correct answer
+ """
+ if buffer is None:
+ raise HTTPException(status_code=503, detail="Buffer not initialized")
+
+ experiences = []
+ for exp_dict in request.experiences:
+ exp = Experience(
+ prompt=exp_dict.get("prompt", ""),
+ response=exp_dict.get("response", ""),
+ reward=exp_dict.get("reward", 0.0),
+ ground_truth=exp_dict.get("ground_truth"),
+ priority=abs(exp_dict.get("reward", 0.0)) +
+ 0.1 # Higher reward = higher priority
+ )
+ experiences.append(exp)
+
+ added = buffer.add(experiences)
+
+ return {"added": added, "buffer_size": len(buffer.buffer)}
+
+
+@app.post("/sample")
+async def sample_experiences(request: SampleRequest):
+ """Sample a batch of experiences from the buffer.
+
+ Args:
+ batch_size: Number of experiences to sample
+ prioritized: If True, use priority-based sampling
+ """
+ if buffer is None:
+ raise HTTPException(status_code=503, detail="Buffer not initialized")
+
+ if len(buffer.buffer) == 0:
+ return {"experiences": [], "message": "Buffer is empty"}
+
+ samples = buffer.sample(request.batch_size, request.prioritized)
+
+ return {
+ "experiences": [{
+ "prompt": exp.prompt,
+ "response": exp.response,
+ "reward": exp.reward,
+ "ground_truth": exp.ground_truth,
+ "timestamp": exp.timestamp
+ } for exp in samples],
+ "sampled": len(samples),
+ "buffer_size": len(buffer.buffer)
+ }
+
+
+@app.get("/stats")
+async def get_stats():
+ """Get buffer statistics."""
+ if buffer is None:
+ raise HTTPException(status_code=503, detail="Buffer not initialized")
+
+ return buffer.stats()
+
+
+@app.post("/clear")
+async def clear_buffer():
+ """Clear the replay buffer."""
+ if buffer is None:
+ raise HTTPException(status_code=503, detail="Buffer not initialized")
+
+ buffer.clear()
+ return {"status": "cleared"}
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Replay Buffer Server")
+ parser.add_argument("--port", type=int, default=8003, help="Port to run on")
+ parser.add_argument("--host",
+ type=str,
+ default="0.0.0.0",
+ help="Host to bind to")
+ parser.add_argument("--capacity",
+ type=int,
+ default=10000,
+ help="Buffer capacity")
+ args = parser.parse_args()
+
+ # Update global capacity
+ global buffer
+ buffer = ReplayBuffer(capacity=args.capacity)
+
+ print(f"Starting Replay Buffer Server on {args.host}:{args.port}")
+ print(f"Buffer capacity: {args.capacity}")
+
+ uvicorn.run(app, host=args.host, port=args.port)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llm/rl-post-training-jobgroup/code/reward_server.py b/llm/rl-post-training-jobgroup/code/reward_server.py
new file mode 100644
index 00000000000..52cf93395a6
--- /dev/null
+++ b/llm/rl-post-training-jobgroup/code/reward_server.py
@@ -0,0 +1,170 @@
+#!/usr/bin/env python3
+"""Reward server for RLHF training - verifies math answers.
+
+This server computes rewards by comparing generated answers against ground truth.
+Uses simple string/numeric matching for math problems.
+
+Usage:
+ python reward_server.py --port 8002
+"""
+
+import argparse
+import re
+from typing import List, Optional
+
+from fastapi import FastAPI
+from pydantic import BaseModel
+import uvicorn
+
+app = FastAPI(title="RLHF Reward Server",
+ description="Computes rewards for math responses")
+
+
+class RewardRequest(BaseModel):
+ """Request for computing reward for a single response."""
+ prompt: str
+ response: str
+ ground_truth: str
+
+
+class RewardResponse(BaseModel):
+ """Reward computation result."""
+ reward: float
+ extracted_answer: Optional[str]
+ ground_truth: str
+ correct: bool
+
+
+class BatchRewardRequest(BaseModel):
+ """Request for computing rewards for multiple responses."""
+ items: List[RewardRequest]
+
+
+class BatchRewardResponse(BaseModel):
+ """Batch reward computation results."""
+ rewards: List[RewardResponse]
+ mean_reward: float
+ accuracy: float
+
+
+def extract_answer(response: str) -> Optional[str]:
+ """Extract the final numerical answer from a response.
+
+ Tries multiple patterns commonly used in math solutions:
+ 1. "#### " (GSM8K format)
+ 2. "The answer is "
+ 3. "= " at the end
+ 4. Last number in the response
+ """
+ response = response.strip()
+
+ # Pattern 1: GSM8K format "#### "
+ match = re.search(r'####\s*([+-]?\d+(?:,\d{3})*(?:\.\d+)?)', response)
+ if match:
+ return match.group(1).replace(',', '')
+
+ # Pattern 2: "The answer is "
+ match = re.search(
+ r'[Tt]he\s+(?:final\s+)?answer\s+is[:\s]*([+-]?\d+(?:,\d{3})*(?:\.\d+)?)',
+ response)
+ if match:
+ return match.group(1).replace(',', '')
+
+ # Pattern 3: "= " at the end of a line
+ match = re.search(r'=\s*([+-]?\d+(?:,\d{3})*(?:\.\d+)?)\s*$', response,
+ re.MULTILINE)
+ if match:
+ return match.group(1).replace(',', '')
+
+ # Pattern 4: Last number in the response
+ numbers = re.findall(r'([+-]?\d+(?:,\d{3})*(?:\.\d+)?)', response)
+ if numbers:
+ return numbers[-1].replace(',', '')
+
+ return None
+
+
+def normalize_answer(answer: str) -> str:
+ """Normalize an answer for comparison."""
+ if answer is None:
+ return ""
+ # Remove commas, whitespace, and convert to lowercase
+ answer = answer.replace(',', '').strip().lower()
+ # Try to parse as number and format consistently
+ try:
+ num = float(answer)
+ # If it's a whole number, return as int
+ if num == int(num):
+ return str(int(num))
+ return str(num)
+ except ValueError:
+ return answer
+
+
+def compute_reward(prompt: str, response: str,
+ ground_truth: str) -> RewardResponse:
+ """Compute reward by comparing extracted answer to ground truth."""
+ extracted = extract_answer(response)
+ normalized_extracted = normalize_answer(extracted)
+ normalized_truth = normalize_answer(ground_truth)
+
+ # Check if answers match
+ correct = normalized_extracted == normalized_truth
+
+ # Binary reward: 1.0 for correct, 0.0 for incorrect
+ reward = 1.0 if correct else 0.0
+
+ return RewardResponse(reward=reward,
+ extracted_answer=extracted,
+ ground_truth=ground_truth,
+ correct=correct)
+
+
+@app.get("/health")
+async def health():
+ """Health check endpoint."""
+ return {"status": "healthy"}
+
+
+@app.post("/reward", response_model=RewardResponse)
+async def get_reward(request: RewardRequest):
+ """Compute reward for a single response."""
+ return compute_reward(request.prompt, request.response,
+ request.ground_truth)
+
+
+@app.post("/batch_reward", response_model=BatchRewardResponse)
+async def get_batch_reward(request: BatchRewardRequest):
+ """Compute rewards for a batch of responses."""
+ rewards = [
+ compute_reward(item.prompt, item.response, item.ground_truth)
+ for item in request.items
+ ]
+
+ total_reward = sum(r.reward for r in rewards)
+ correct_count = sum(1 for r in rewards if r.correct)
+
+ return BatchRewardResponse(
+ rewards=rewards,
+ mean_reward=total_reward / len(rewards) if rewards else 0.0,
+ accuracy=correct_count / len(rewards) if rewards else 0.0)
+
+
+def main():
+ parser = argparse.ArgumentParser(description="RLHF Reward Server")
+ parser.add_argument("--port",
+ type=int,
+ default=8002,
+ help="Port to run server on")
+ parser.add_argument("--host",
+ type=str,
+ default="0.0.0.0",
+ help="Host to bind to")
+ args = parser.parse_args()
+
+ print(f"Starting reward server on {args.host}:{args.port}")
+ uvicorn.run(app, host=args.host, port=args.port)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llm/rl-post-training-jobgroup/code/trainer.py b/llm/rl-post-training-jobgroup/code/trainer.py
new file mode 100644
index 00000000000..66bf8f5d04a
--- /dev/null
+++ b/llm/rl-post-training-jobgroup/code/trainer.py
@@ -0,0 +1,454 @@
+#!/usr/bin/env python3
+"""GRPO Trainer for RLHF math training.
+
+This trainer orchestrates the RLHF pipeline by:
+1. Fetching prompts from data-server
+2. Generating responses via rollout-server (SGLang)
+3. Computing rewards via reward-server
+4. Storing experiences in replay-buffer
+5. Updating the policy using GRPO (Group Relative Policy Optimization)
+
+GRPO is a simplified variant of PPO that doesn't require a critic model,
+making it popular for math/code tasks with verifiable rewards.
+
+Usage:
+ python trainer.py \
+ --data-server localhost:8000 \
+ --rollout-server localhost:8001 \
+ --reward-server localhost:8002 \
+ --replay-buffer localhost:8003 \
+ --num-epochs 3
+"""
+
+import argparse
+from dataclasses import dataclass
+import os
+import time
+from typing import List, Optional
+
+from accelerate import Accelerator
+import httpx
+import torch
+from torch.utils.data import DataLoader
+from transformers import AutoModelForCausalLM
+from transformers import AutoTokenizer
+
+
+@dataclass
+class TrainingConfig:
+ """Training configuration."""
+ data_server: str
+ rollout_server: str
+ reward_server: str
+ replay_buffer: Optional[str] = None
+ model_name: str = "Qwen/Qwen2.5-0.5B-Instruct"
+ batch_size: int = 4
+ num_epochs: int = 3
+ learning_rate: float = 1e-6
+ max_new_tokens: int = 512
+ temperature: float = 0.7
+ num_samples_per_prompt: int = 4 # For GRPO, generate multiple samples
+ kl_coef: float = 0.01
+ clip_range: float = 0.2
+ use_replay_buffer: bool = True # Whether to use replay buffer for training
+
+
+class RLHFTrainer:
+ """GRPO trainer that coordinates with external services."""
+
+ def __init__(self, config: TrainingConfig):
+ self.config = config
+ self.accelerator = Accelerator()
+
+ # HTTP clients for services
+ self.http_client = httpx.Client(timeout=120.0)
+
+ # Load model and tokenizer
+ if self.accelerator.is_main_process:
+ print(f"Loading model: {config.model_name}")
+
+ self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
+ if self.tokenizer.pad_token is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ self.model = AutoModelForCausalLM.from_pretrained(
+ config.model_name, torch_dtype=torch.bfloat16, device_map="auto")
+
+ # Optimizer
+ self.optimizer = torch.optim.AdamW(self.model.parameters(),
+ lr=config.learning_rate)
+
+ # Prepare with accelerator
+ self.model, self.optimizer = self.accelerator.prepare(
+ self.model, self.optimizer)
+
+ # Statistics
+ self.total_steps = 0
+ self.total_rewards = 0.0
+
+ def wait_for_services(self,
+ max_retries: int = 30,
+ retry_interval: int = 10):
+ """Wait for all services to be available."""
+ services = [
+ ("data-server", f"http://{self.config.data_server}/health"),
+ ("rollout-server", f"http://{self.config.rollout_server}/health"),
+ ("reward-server", f"http://{self.config.reward_server}/health"),
+ ]
+ if self.config.replay_buffer:
+ services.append(
+ ("replay-buffer", f"http://{self.config.replay_buffer}/health"))
+
+ for name, url in services:
+ if self.accelerator.is_main_process:
+ print(f"Waiting for {name} at {url}...")
+
+ for attempt in range(max_retries):
+ try:
+ response = self.http_client.get(url)
+ if response.status_code == 200:
+ if self.accelerator.is_main_process:
+ print(f" {name} is ready!")
+ break
+ except Exception as e:
+ pass
+
+ if attempt < max_retries - 1:
+ time.sleep(retry_interval)
+ else:
+ raise RuntimeError(
+ f"Service {name} not available after {max_retries} retries")
+
+ def fetch_prompts(self, batch_size: int) -> List[dict]:
+ """Fetch a batch of prompts from data server."""
+ url = f"http://{self.config.data_server}/prompts"
+ response = self.http_client.get(url, params={"batch_size": batch_size})
+ response.raise_for_status()
+ data = response.json()
+ return data["prompts"]
+
+ def generate_responses(self, prompts: List[str]) -> List[str]:
+ """Generate responses using the rollout server (SGLang)."""
+ url = f"http://{self.config.rollout_server}/v1/completions"
+
+ responses = []
+ for prompt in prompts:
+ payload = {
+ "model": self.config.model_name,
+ "prompt": prompt,
+ "max_tokens": self.config.max_new_tokens,
+ "temperature": self.config.temperature,
+ "n": 1,
+ }
+ try:
+ response = self.http_client.post(url, json=payload)
+ response.raise_for_status()
+ data = response.json()
+ text = data["choices"][0]["text"]
+ responses.append(text)
+ except Exception as e:
+ print(f"Error generating response: {e}")
+ responses.append("")
+
+ return responses
+
+ def compute_rewards(self, prompts: List[str], responses: List[str],
+ ground_truths: List[str]) -> List[float]:
+ """Compute rewards using the reward server."""
+ url = f"http://{self.config.reward_server}/batch_reward"
+
+ items = [{
+ "prompt": p,
+ "response": r,
+ "ground_truth": gt
+ } for p, r, gt in zip(prompts, responses, ground_truths)]
+
+ response = self.http_client.post(url, json={"items": items})
+ response.raise_for_status()
+ data = response.json()
+
+ return [r["reward"] for r in data["rewards"]]
+
+ def store_experiences(self, prompts: List[str], responses: List[str],
+ rewards: List[float], ground_truths: List[str]):
+ """Store experiences in the replay buffer."""
+ if not self.config.replay_buffer:
+ return
+
+ url = f"http://{self.config.replay_buffer}/add"
+ experiences = [{
+ "prompt": p,
+ "response": r,
+ "reward": rw,
+ "ground_truth": gt
+ } for p, r, rw, gt in zip(prompts, responses, rewards, ground_truths)]
+
+ try:
+ response = self.http_client.post(url,
+ json={"experiences": experiences})
+ response.raise_for_status()
+ except Exception as e:
+ print(f"Warning: Failed to store experiences in replay buffer: {e}")
+
+ def sample_from_replay_buffer(self,
+ batch_size: int) -> Optional[List[dict]]:
+ """Sample experiences from the replay buffer."""
+ if not self.config.replay_buffer:
+ return None
+
+ url = f"http://{self.config.replay_buffer}/sample"
+ try:
+ response = self.http_client.post(url,
+ json={
+ "batch_size": batch_size,
+ "prioritized": True
+ })
+ response.raise_for_status()
+ data = response.json()
+ if data["experiences"]:
+ return data["experiences"]
+ except Exception as e:
+ print(f"Warning: Failed to sample from replay buffer: {e}")
+ return None
+
+ def get_replay_buffer_stats(self) -> Optional[dict]:
+ """Get replay buffer statistics."""
+ if not self.config.replay_buffer:
+ return None
+
+ url = f"http://{self.config.replay_buffer}/stats"
+ try:
+ response = self.http_client.get(url)
+ response.raise_for_status()
+ return response.json()
+ except Exception as e:
+ return None
+
+ def compute_grpo_loss(self, prompts: List[str], responses: List[str],
+ rewards: List[float]) -> torch.Tensor:
+ """Compute GRPO loss for policy update.
+
+ GRPO uses group-relative advantages: for each prompt, we compare
+ the reward of each response to the mean reward of all responses
+ for that prompt.
+ """
+ # Tokenize prompts and responses together
+ full_texts = [p + r for p, r in zip(prompts, responses)]
+ encodings = self.tokenizer(full_texts,
+ return_tensors="pt",
+ padding=True,
+ truncation=True,
+ max_length=1024).to(self.accelerator.device)
+
+ # Get prompt lengths for masking
+ prompt_encodings = self.tokenizer(prompts,
+ return_tensors="pt",
+ padding=True,
+ truncation=True,
+ max_length=512)
+ prompt_lengths = prompt_encodings.attention_mask.sum(dim=1)
+
+ # Forward pass
+ outputs = self.model(**encodings, labels=encodings.input_ids)
+
+ # Compute per-token log probabilities
+ logits = outputs.logits[:, :-1, :]
+ labels = encodings.input_ids[:, 1:]
+ log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
+ token_log_probs = torch.gather(log_probs, 2,
+ labels.unsqueeze(-1)).squeeze(-1)
+
+ # Mask out prompt tokens (only count response tokens)
+ response_mask = torch.zeros_like(token_log_probs)
+ for i, plen in enumerate(prompt_lengths):
+ response_mask[i, plen - 1:] = 1.0
+ response_mask = response_mask * encodings.attention_mask[:, 1:]
+
+ # Sum log probs for each response
+ response_log_probs = (token_log_probs * response_mask).sum(dim=1)
+
+ # Convert rewards to tensor and compute advantages
+ rewards_tensor = torch.tensor(rewards,
+ device=self.accelerator.device,
+ dtype=torch.float32)
+
+ # GRPO: normalize rewards within batch (group-relative)
+ mean_reward = rewards_tensor.mean()
+ std_reward = rewards_tensor.std() + 1e-8
+ advantages = (rewards_tensor - mean_reward) / std_reward
+
+ # Policy gradient loss (negative because we maximize reward)
+ loss = -(response_log_probs * advantages).mean()
+
+ return loss
+
+ def train_step(self) -> dict:
+ """Execute one training step."""
+ self.model.train()
+
+ # 1. Fetch prompts
+ prompt_data = self.fetch_prompts(self.config.batch_size)
+ prompts = [p["prompt"] for p in prompt_data]
+ ground_truths = [p["ground_truth"] for p in prompt_data]
+
+ # 2. Generate responses
+ responses = self.generate_responses(prompts)
+
+ # 3. Compute rewards
+ rewards = self.compute_rewards(prompts, responses, ground_truths)
+
+ # 4. Store experiences in replay buffer
+ self.store_experiences(prompts, responses, rewards, ground_truths)
+
+ # 5. Compute loss and update with fresh experiences
+ loss = self.compute_grpo_loss(prompts, responses, rewards)
+
+ self.optimizer.zero_grad()
+ self.accelerator.backward(loss)
+ self.optimizer.step()
+
+ # 6. Optionally do additional update with replay buffer samples
+ replay_loss = None
+ if self.config.use_replay_buffer and self.config.replay_buffer:
+ replay_experiences = self.sample_from_replay_buffer(
+ self.config.batch_size)
+ if replay_experiences and len(replay_experiences) >= 2:
+ replay_prompts = [e["prompt"] for e in replay_experiences]
+ replay_responses = [e["response"] for e in replay_experiences]
+ replay_rewards = [e["reward"] for e in replay_experiences]
+
+ replay_loss = self.compute_grpo_loss(replay_prompts,
+ replay_responses,
+ replay_rewards)
+ self.optimizer.zero_grad()
+ self.accelerator.backward(replay_loss)
+ self.optimizer.step()
+
+ # Update statistics
+ self.total_steps += 1
+ mean_reward = sum(rewards) / len(rewards)
+ self.total_rewards += mean_reward
+
+ result = {
+ "loss": loss.item(),
+ "mean_reward": mean_reward,
+ "accuracy": sum(1 for r in rewards if r > 0) / len(rewards),
+ "num_samples": len(prompts)
+ }
+ if replay_loss is not None:
+ result["replay_loss"] = replay_loss.item()
+ return result
+
+ def train(self):
+ """Run the full training loop."""
+ if self.accelerator.is_main_process:
+ print("=" * 60)
+ print("GRPO Training for Math")
+ print("=" * 60)
+ print(f"Model: {self.config.model_name}")
+ print(f"Batch size: {self.config.batch_size}")
+ print(f"Epochs: {self.config.num_epochs}")
+ print(f"Learning rate: {self.config.learning_rate}")
+ print("=" * 60)
+
+ # Wait for services
+ self.wait_for_services()
+
+ # Training loop
+ steps_per_epoch = 100 # Configurable
+ for epoch in range(self.config.num_epochs):
+ epoch_rewards = []
+ epoch_losses = []
+
+ for step in range(steps_per_epoch):
+ metrics = self.train_step()
+ epoch_rewards.append(metrics["mean_reward"])
+ epoch_losses.append(metrics["loss"])
+
+ if self.accelerator.is_main_process and step % 10 == 0:
+ print(f"Epoch {epoch+1}/{self.config.num_epochs} | "
+ f"Step {step+1}/{steps_per_epoch} | "
+ f"Loss: {metrics['loss']:.4f} | "
+ f"Reward: {metrics['mean_reward']:.4f} | "
+ f"Accuracy: {metrics['accuracy']:.2%}")
+
+ # Epoch summary
+ if self.accelerator.is_main_process:
+ mean_epoch_reward = sum(epoch_rewards) / len(epoch_rewards)
+ mean_epoch_loss = sum(epoch_losses) / len(epoch_losses)
+ print(f"\n=== Epoch {epoch+1} Complete ===")
+ print(f"Mean Reward: {mean_epoch_reward:.4f}")
+ print(f"Mean Loss: {mean_epoch_loss:.4f}")
+
+ # Print replay buffer stats
+ buffer_stats = self.get_replay_buffer_stats()
+ if buffer_stats:
+ print(
+ f"Replay Buffer: {buffer_stats['size']}/{buffer_stats['capacity']} "
+ f"(avg_reward: {buffer_stats['avg_reward']:.4f}, "
+ f"positive_ratio: {buffer_stats['positive_ratio']:.2%})"
+ )
+ print()
+
+ if self.accelerator.is_main_process:
+ print("=" * 60)
+ print("Training Complete!")
+ print(f"Total steps: {self.total_steps}")
+ print(
+ f"Average reward: {self.total_rewards / self.total_steps:.4f}")
+ print("=" * 60)
+
+
+def main():
+ parser = argparse.ArgumentParser(description="GRPO Trainer for RLHF")
+ parser.add_argument("--data-server",
+ type=str,
+ required=True,
+ help="Data server address (host:port)")
+ parser.add_argument("--rollout-server",
+ type=str,
+ required=True,
+ help="Rollout server address (host:port)")
+ parser.add_argument("--reward-server",
+ type=str,
+ required=True,
+ help="Reward server address (host:port)")
+ parser.add_argument("--replay-buffer",
+ type=str,
+ default=None,
+ help="Replay buffer address (host:port)")
+ parser.add_argument("--model",
+ type=str,
+ default="Qwen/Qwen2.5-0.5B-Instruct",
+ help="Model name or path")
+ parser.add_argument("--batch-size",
+ type=int,
+ default=4,
+ help="Training batch size")
+ parser.add_argument("--num-epochs",
+ type=int,
+ default=3,
+ help="Number of training epochs")
+ parser.add_argument("--learning-rate",
+ type=float,
+ default=1e-6,
+ help="Learning rate")
+ args = parser.parse_args()
+
+ config = TrainingConfig(
+ data_server=args.data_server,
+ rollout_server=args.rollout_server,
+ reward_server=args.reward_server,
+ replay_buffer=args.replay_buffer,
+ model_name=args.model,
+ batch_size=args.batch_size,
+ num_epochs=args.num_epochs,
+ learning_rate=args.learning_rate,
+ )
+
+ trainer = RLHFTrainer(config)
+ trainer.train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llm/rl-post-training-jobgroup/rlhf-math-jobgroup-cpu.yaml b/llm/rl-post-training-jobgroup/rlhf-math-jobgroup-cpu.yaml
new file mode 100644
index 00000000000..732b53a7dc0
--- /dev/null
+++ b/llm/rl-post-training-jobgroup/rlhf-math-jobgroup-cpu.yaml
@@ -0,0 +1,184 @@
+# RLHF Math Training with Job Groups - CPU Test Version
+#
+# This is a simplified CPU-only version for testing the job group functionality.
+# It demonstrates the service connectivity without requiring GPUs.
+#
+# Primary/Auxiliary Behavior:
+# The test-client is the primary task. When tests complete, all auxiliary
+# services are terminated after a 5-second grace period.
+#
+# Usage:
+# sky jobs launch llm/rl-post-training-jobgroup/rlhf-math-jobgroup-cpu.yaml
+---
+name: rlhf-math-cpu
+execution: parallel
+primary_tasks: [test-client]
+termination_delay: 5s
+
+---
+# Data Server: Serves math prompts from GSM8K dataset
+name: data-server
+resources:
+ cpus: 2
+ memory: 4+
+ infra: kubernetes
+
+file_mounts:
+ /code: llm/rl-post-training-jobgroup/code
+
+setup: |
+ pip install fastapi uvicorn datasets
+
+run: |
+ echo "Starting data server..."
+ echo "JobGroup: ${SKYPILOT_JOBGROUP_NAME}"
+ echo "This server provides math prompts at http://data-server-0.${SKYPILOT_JOBGROUP_NAME}:8000"
+
+ cd /code
+ python data_server.py --port 8000
+
+---
+# Reward Server: Verifies math answers against ground truth
+name: reward-server
+resources:
+ cpus: 2
+ memory: 4+
+ infra: kubernetes
+
+file_mounts:
+ /code: llm/rl-post-training-jobgroup/code
+
+setup: |
+ pip install fastapi uvicorn
+
+run: |
+ echo "Starting reward server..."
+ echo "JobGroup: ${SKYPILOT_JOBGROUP_NAME}"
+ echo "Reward API at http://reward-server-0.${SKYPILOT_JOBGROUP_NAME}:8002"
+
+ cd /code
+ python reward_server.py --port 8002
+
+---
+# Replay Buffer: Stores experience tuples for training
+name: replay-buffer
+resources:
+ cpus: 2
+ memory: 4+
+ infra: kubernetes
+
+file_mounts:
+ /code: llm/rl-post-training-jobgroup/code
+
+setup: |
+ pip install fastapi uvicorn
+
+run: |
+ echo "Starting replay buffer server..."
+ echo "JobGroup: ${SKYPILOT_JOBGROUP_NAME}"
+ echo "Replay Buffer API at http://replay-buffer-0.${SKYPILOT_JOBGROUP_NAME}:8003"
+
+ cd /code
+ python replay_buffer.py --port 8003 --capacity 1000
+
+---
+# Test Client: Verifies connectivity between services
+name: test-client
+resources:
+ cpus: 2
+ memory: 4+
+ infra: kubernetes
+
+setup: |
+ pip install httpx
+
+run: |
+ echo "Starting test client..."
+ echo "JobGroup: ${SKYPILOT_JOBGROUP_NAME}"
+
+ # Service discovery via job group DNS
+ DATA_SERVER="data-server-0.${SKYPILOT_JOBGROUP_NAME}:8000"
+ REWARD_SERVER="reward-server-0.${SKYPILOT_JOBGROUP_NAME}:8002"
+ REPLAY_BUFFER="replay-buffer-0.${SKYPILOT_JOBGROUP_NAME}:8003"
+
+ echo "Data server: ${DATA_SERVER}"
+ echo "Reward server: ${REWARD_SERVER}"
+ echo "Replay buffer: ${REPLAY_BUFFER}"
+
+ # Wait for services to be ready
+ echo "Waiting for services to be available..."
+ sleep 30
+
+ # Test data server
+ echo "Testing data server..."
+ for i in {1..5}; do
+ if curl -s "http://${DATA_SERVER}/health" | grep -q "healthy"; then
+ echo "Data server is healthy!"
+ break
+ fi
+ echo "Waiting for data server... attempt $i"
+ sleep 5
+ done
+
+ # Fetch some prompts
+ echo "Fetching prompts..."
+ PROMPTS=$(curl -s "http://${DATA_SERVER}/prompts?batch_size=2")
+ echo "Prompts: ${PROMPTS}"
+
+ # Test reward server
+ echo "Testing reward server..."
+ for i in {1..5}; do
+ if curl -s "http://${REWARD_SERVER}/health" | grep -q "healthy"; then
+ echo "Reward server is healthy!"
+ break
+ fi
+ echo "Waiting for reward server... attempt $i"
+ sleep 5
+ done
+
+ # Test reward computation
+ echo "Testing reward computation..."
+ REWARD=$(curl -s -X POST "http://${REWARD_SERVER}/reward" \
+ -H "Content-Type: application/json" \
+ -d '{"prompt": "What is 2+2?", "response": "The answer is 4", "ground_truth": "4"}')
+ echo "Reward response: ${REWARD}"
+
+ # Test replay buffer
+ echo "Testing replay buffer..."
+ for i in {1..5}; do
+ if curl -s "http://${REPLAY_BUFFER}/health" | grep -q "healthy"; then
+ echo "Replay buffer is healthy!"
+ break
+ fi
+ echo "Waiting for replay buffer... attempt $i"
+ sleep 5
+ done
+
+ # Add experience to replay buffer
+ echo "Adding experience to replay buffer..."
+ ADD_RESULT=$(curl -s -X POST "http://${REPLAY_BUFFER}/add" \
+ -H "Content-Type: application/json" \
+ -d '{"experiences": [{"prompt": "What is 2+2?", "response": "The answer is 4", "reward": 1.0, "ground_truth": "4"}]}')
+ echo "Add result: ${ADD_RESULT}"
+
+ # Get replay buffer stats
+ echo "Getting replay buffer stats..."
+ STATS=$(curl -s "http://${REPLAY_BUFFER}/stats")
+ echo "Stats: ${STATS}"
+
+ # Sample from replay buffer
+ echo "Sampling from replay buffer..."
+ SAMPLE=$(curl -s -X POST "http://${REPLAY_BUFFER}/sample" \
+ -H "Content-Type: application/json" \
+ -d '{"batch_size": 1}')
+ echo "Sample: ${SAMPLE}"
+
+ echo ""
+ echo "=========================================="
+ echo "All services are working correctly!"
+ echo "=========================================="
+ echo ""
+ echo "Job group connectivity test complete."
+
+ # Keep running to allow inspection
+ sleep 300
diff --git a/llm/rl-post-training-jobgroup/rlhf-math-jobgroup.yaml b/llm/rl-post-training-jobgroup/rlhf-math-jobgroup.yaml
new file mode 100644
index 00000000000..be405b409bd
--- /dev/null
+++ b/llm/rl-post-training-jobgroup/rlhf-math-jobgroup.yaml
@@ -0,0 +1,226 @@
+# RLHF Math Training with Job Groups
+#
+# This example demonstrates a distributed RLHF architecture using SkyPilot job groups.
+# It trains an LLM on mathematical reasoning using GRPO (Group Relative Policy Optimization)
+# with verifiable rewards.
+#
+# Architecture:
+# - data-server (auxiliary): Serves GSM8K math prompts
+# - rollout-server (auxiliary, x2): SGLang instances + SGLang router
+# - reward-server (auxiliary): Verifies math answers against ground truth
+# - replay-buffer (auxiliary): Stores experience tuples for sampling
+# - ppo-trainer (primary): Orchestrates GRPO training across multiple nodes
+#
+# Primary/Auxiliary Behavior:
+# The ppo-trainer is the primary task. When training completes, all auxiliary
+# services (data-server, rollout-server, reward-server, replay-buffer) are
+# terminated after a 10-second grace period to ensure clean shutdown.
+#
+# Load Balancing:
+# The head node runs SGLang's native router (sglang_router) which provides
+# cache-aware load balancing across all SGLang instances for optimal KV cache reuse.
+# The trainer connects to the router endpoint on port 30000.
+#
+# Usage:
+# sky jobs launch llm/rl-post-training-jobgroup/rlhf-math-jobgroup.yaml
+#
+# The components communicate over the job group network using DNS names:
+# - data-server-0.${SKYPILOT_JOBGROUP_NAME}:8000
+# - rollout-server-0.${SKYPILOT_JOBGROUP_NAME}:30000 (SGLang router endpoint)
+# - rollout-server-0.${SKYPILOT_JOBGROUP_NAME}:30001 (SGLang backend 1)
+# - rollout-server-1.${SKYPILOT_JOBGROUP_NAME}:30001 (SGLang backend 2)
+# - reward-server-0.${SKYPILOT_JOBGROUP_NAME}:8002
+# - replay-buffer-0.${SKYPILOT_JOBGROUP_NAME}:8003
+---
+name: rlhf-math
+execution: parallel
+primary_tasks: [ppo-trainer]
+termination_delay: 10s
+
+---
+# Data Server: Serves math prompts from GSM8K dataset
+name: data-server
+resources:
+ cpus: 4
+ memory: 16+
+ infra: kubernetes
+
+file_mounts:
+ /code: llm/rl-post-training-jobgroup/code
+
+setup: |
+ uv pip install fastapi uvicorn datasets --system
+
+run: |
+ echo "Starting data server..."
+ echo "JobGroup: ${SKYPILOT_JOBGROUP_NAME}"
+ echo "This server provides math prompts at http://data-server-0.${SKYPILOT_JOBGROUP_NAME}:8000"
+
+ cd /code
+ python data_server.py --port 8000
+
+---
+# Rollout Servers: Multiple SGLang instances with SGLang router on head node
+# Using num_nodes=2 to create rollout-server-0 and rollout-server-1
+# Head node (rank 0) runs both SGLang server and SGLang router for load balancing
+name: rollout-server
+num_nodes: 2
+resources:
+ accelerators: H100:1
+ memory: 32+
+ infra: kubernetes
+
+envs:
+ MODEL_NAME: Qwen/Qwen2.5-0.5B-Instruct
+
+setup: |
+ # Install system dependencies (libnuma is required by SGLang kernel)
+ sudo apt-get update && sudo apt-get install -y libnuma-dev
+ uv pip install "sglang[all]" sglang-router --system
+
+run: |
+ echo "Starting rollout server with SGLang..."
+ echo "JobGroup: ${SKYPILOT_JOBGROUP_NAME}"
+ echo "Node rank: ${SKYPILOT_NODE_RANK} / ${SKYPILOT_NUM_NODES}"
+ echo "Model: ${MODEL_NAME}"
+
+ # Start SGLang server in background
+ python -m sglang.launch_server \
+ --model ${MODEL_NAME} \
+ --host 0.0.0.0 \
+ --port 30001 &
+ SGLANG_PID=$!
+
+ # On head node, also run the SGLang router for load balancing
+ if [ "${SKYPILOT_NODE_RANK}" == "0" ]; then
+ echo "Head node: starting SGLang router..."
+
+ # Build worker URL list for all rollout servers
+ WORKER_URLS=""
+ for i in $(seq 0 $((SKYPILOT_NUM_NODES - 1))); do
+ WORKER_URLS="${WORKER_URLS} http://rollout-server-${i}.${SKYPILOT_JOBGROUP_NAME}:30001"
+ done
+
+ echo "Load balancing across:${WORKER_URLS}"
+ echo "Router API available at http://rollout-server-0.${SKYPILOT_JOBGROUP_NAME}:30000/v1"
+
+ # Wait for SGLang backends to start
+ sleep 60
+
+ python -m sglang_router.launch_router \
+ --worker-urls ${WORKER_URLS} \
+ --host 0.0.0.0 \
+ --port 30000 \
+ --policy cache_aware &
+ ROUTER_PID=$!
+
+ # Wait for both processes
+ wait $SGLANG_PID $ROUTER_PID
+ else
+ # Worker nodes just run SGLang server
+ wait $SGLANG_PID
+ fi
+
+---
+# Reward Server: Verifies math answers against ground truth
+name: reward-server
+resources:
+ cpus: 4
+ memory: 8+
+ infra: kubernetes
+
+file_mounts:
+ /code: llm/rl-post-training-jobgroup/code
+
+setup: |
+ uv pip install fastapi uvicorn --system
+
+run: |
+ echo "Starting reward server..."
+ echo "JobGroup: ${SKYPILOT_JOBGROUP_NAME}"
+ echo "Reward API at http://reward-server-0.${SKYPILOT_JOBGROUP_NAME}:8002"
+
+ cd /code
+ python reward_server.py --port 8002
+
+---
+# Replay Buffer: Stores experience tuples for training
+name: replay-buffer
+resources:
+ cpus: 4
+ memory: 16+
+ infra: kubernetes
+
+file_mounts:
+ /code: llm/rl-post-training-jobgroup/code
+
+setup: |
+ uv pip install fastapi uvicorn --system
+
+run: |
+ echo "Starting replay buffer server..."
+ echo "JobGroup: ${SKYPILOT_JOBGROUP_NAME}"
+ echo "Replay Buffer API at http://replay-buffer-0.${SKYPILOT_JOBGROUP_NAME}:8003"
+
+ cd /code
+ python replay_buffer.py --port 8003 --capacity 10000
+
+---
+# PPO Trainer: Multi-node GRPO training
+name: ppo-trainer
+resources:
+ accelerators: H100:1
+ memory: 32+
+ infra: kubernetes
+num_nodes: 2
+
+envs:
+ MODEL_NAME: Qwen/Qwen2.5-0.5B-Instruct
+ NUM_EPOCHS: 3
+ BATCH_SIZE: 4
+
+file_mounts:
+ /code: llm/rl-post-training-jobgroup/code
+
+setup: |
+ uv pip install torch transformers accelerate httpx --system
+
+run: |
+ echo "Starting GRPO trainer..."
+ echo "JobGroup: ${SKYPILOT_JOBGROUP_NAME}"
+ echo "Node rank: ${SKYPILOT_NODE_RANK} / ${SKYPILOT_NUM_NODES}"
+
+ # Service discovery via job group DNS
+ # The rollout head node provides load balancing across all SGLang instances
+ DATA_SERVER="data-server-0.${SKYPILOT_JOBGROUP_NAME}:8000"
+ ROLLOUT_SERVER="rollout-server-0.${SKYPILOT_JOBGROUP_NAME}:30000"
+ REWARD_SERVER="reward-server-0.${SKYPILOT_JOBGROUP_NAME}:8002"
+ REPLAY_BUFFER="replay-buffer-0.${SKYPILOT_JOBGROUP_NAME}:8003"
+
+ echo "Data server: ${DATA_SERVER}"
+ echo "Rollout server (load balanced): ${ROLLOUT_SERVER}"
+ echo "Reward server: ${REWARD_SERVER}"
+ echo "Replay buffer: ${REPLAY_BUFFER}"
+
+ # Wait for services to be ready
+ echo "Waiting for services to be available..."
+ sleep 30
+
+ # Only run training on rank 0 (coordinator)
+ if [ "${SKYPILOT_NODE_RANK}" == "0" ]; then
+ echo "Starting training on coordinator node..."
+ cd /code
+ python trainer.py \
+ --data-server ${DATA_SERVER} \
+ --rollout-server ${ROLLOUT_SERVER} \
+ --reward-server ${REWARD_SERVER} \
+ --replay-buffer ${REPLAY_BUFFER} \
+ --model ${MODEL_NAME} \
+ --batch-size ${BATCH_SIZE} \
+ --num-epochs ${NUM_EPOCHS}
+ else
+ echo "Worker node ${SKYPILOT_NODE_RANK} ready for distributed training"
+ # In a full implementation, worker nodes would join distributed training
+ # For this demo, they just wait
+ sleep infinity
+ fi
diff --git a/llm/train-eval-jobgroup/README.md b/llm/train-eval-jobgroup/README.md
new file mode 100644
index 00000000000..7d0cd9ffa16
--- /dev/null
+++ b/llm/train-eval-jobgroup/README.md
@@ -0,0 +1,172 @@
+# Parallel Training and Evaluation with Shared Volume
+
+This example demonstrates SkyPilot job groups with parallel training and evaluation tasks that share a Kubernetes PVC volume for checkpoints. The evaluator monitors the checkpoint directory and evaluates models "on the fly" as training produces them.
+
+## Architecture
+
+
+
+
+
+### Components
+
+1. **trainer**: Trains ResNet-18 on CIFAR-10, saves checkpoints every N epochs to shared storage
+2. **evaluator**: Watches the checkpoint directory, evaluates new checkpoints as they appear, reports test accuracy
+
+### Graceful Completion
+
+Both tasks complete naturally without forced termination:
+- When training finishes, the trainer writes a `training_complete` marker file to the shared volume
+- The evaluator detects this marker, finishes evaluating any remaining checkpoints, and exits gracefully
+- This pattern avoids the need for `primary_tasks` and `termination_delay` settings
+
+## Usage
+
+### Create the Shared Volume
+
+First, create the shared volume that both tasks will use:
+
+```bash
+sky volume apply llm/train-eval-jobgroup/train-eval-ckpts-volume.yaml
+```
+
+### Launch the Job Group
+
+```bash
+sky jobs launch llm/train-eval-jobgroup/train-eval-jobgroup.yaml
+```
+
+### Monitor Training
+
+```bash
+# Check job status
+sky jobs queue
+
+# View trainer logs (training progress)
+sky jobs logs --task trainer
+
+# View evaluator logs (accuracy reports)
+sky jobs logs --task evaluator
+```
+
+### Expected Output
+
+**Trainer logs:**
+```
+Starting trainer...
+Loading CIFAR-10 dataset...
+Epoch 1/10 | Loss: 1.8234 | LR: 0.099511 | Time: 45.2s
+Epoch 2/10 | Loss: 1.2456 | LR: 0.095106 | Time: 44.8s
+Saved checkpoint: /checkpoints/checkpoint_epoch_2.pt
+...
+```
+
+**Evaluator logs:**
+```
+Starting evaluator...
+Watching for checkpoints...
+Epoch 2 | Train Loss: 1.2456 | Test Accuracy: 52.34%
+Epoch 4 | Train Loss: 0.8123 | Test Accuracy: 68.91%
+Epoch 6 | Train Loss: 0.5234 | Test Accuracy: 75.23%
+...
+```
+
+## Configuration
+
+### Environment Variables
+
+| Variable | Default | Description |
+|----------|---------|-------------|
+| `NUM_EPOCHS` | `10` | Number of training epochs |
+| `SAVE_EVERY` | `2` | Save checkpoint every N epochs |
+| `CHECKPOINT_DIR` | `/checkpoints` | Shared checkpoint directory |
+
+### Customizing Resources
+
+Edit the YAML to adjust resources:
+
+```yaml
+resources:
+ accelerators: H100:1 # or A100:1 for faster training
+ memory: 32+
+```
+
+## How It Works
+
+### Shared Volume
+
+Both tasks mount the same SkyPilot volume at `/checkpoints`:
+
+```yaml
+volumes:
+ /checkpoints: train-eval-ckpts
+```
+
+This creates a shared Kubernetes PVC that both tasks can access. The volume must be created before launching the job group.
+
+### Checkpoint Format
+
+The trainer saves checkpoints with:
+- Model state dict
+- Optimizer state dict
+- Epoch number
+- Training loss
+- Timestamp
+
+Files are named `checkpoint_epoch_N.pt` and a `latest.json` file tracks the most recent checkpoint.
+
+### Evaluator Polling
+
+The evaluator uses simple filesystem polling to detect new checkpoints:
+1. Scans for `checkpoint_epoch_*.pt` files every 5 seconds
+2. Loads new checkpoints and evaluates on CIFAR-10 test set
+3. Reports accuracy and tracks results
+4. Exits when training completes
+
+## Key Features Demonstrated
+
+1. **Parallel Execution**: Training and evaluation run simultaneously
+2. **Shared Storage**: Tasks communicate through a shared filesystem
+3. **On-the-fly Evaluation**: No need to wait for training to finish
+4. **Simple Communication**: Filesystem-based, no network services needed
+
+## Comparison with RLHF Example
+
+| Feature | Train-Eval | RLHF |
+|---------|------------|------|
+| Communication | Shared filesystem | HTTP APIs |
+| Complexity | Simple | Complex |
+| Components | 2 tasks | 5 tasks |
+| Use case | Checkpointing | Service mesh |
+
+This example is intentionally simpler to demonstrate job groups without the complexity of network services.
+
+## Extending This Example
+
+### Adding More Evaluators
+
+You can run multiple evaluators for different metrics:
+
+```yaml
+---
+name: evaluator-accuracy
+# ... evaluates accuracy
+
+---
+name: evaluator-perplexity
+# ... evaluates perplexity
+```
+
+### Distributed Training
+
+Add `num_nodes` for multi-node training:
+
+```yaml
+name: trainer
+num_nodes: 2
+# ... use torch.distributed
+```
+
+### Early Stopping
+
+The evaluator could signal the trainer to stop early by writing a `stop.txt` file that the trainer checks.
diff --git a/llm/train-eval-jobgroup/code/evaluator.py b/llm/train-eval-jobgroup/code/evaluator.py
new file mode 100644
index 00000000000..732d8cb16da
--- /dev/null
+++ b/llm/train-eval-jobgroup/code/evaluator.py
@@ -0,0 +1,219 @@
+#!/usr/bin/env python3
+"""Evaluator script that watches for new checkpoints and evaluates them.
+
+This script monitors a checkpoint directory and evaluates new checkpoints
+as they appear, reporting accuracy on the CIFAR-10 test set.
+
+Usage:
+ python evaluator.py --checkpoint-dir /checkpoints
+"""
+
+import argparse
+import glob
+import os
+import time
+
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+
+
+def get_model():
+ """Create a ResNet-18 model for CIFAR-10."""
+ model = torchvision.models.resnet18(weights=None)
+ # Modify for CIFAR-10 (32x32 images, 10 classes)
+ model.conv1 = nn.Conv2d(3,
+ 64,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+ model.maxpool = nn.Identity()
+ model.fc = nn.Linear(model.fc.in_features, 10)
+ return model
+
+
+def get_test_dataloader(batch_size=128):
+ """Create test dataloader for CIFAR-10."""
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465),
+ (0.2023, 0.1994, 0.2010)),
+ ])
+
+ testset = torchvision.datasets.CIFAR10(root='./data',
+ train=False,
+ download=True,
+ transform=transform_test)
+ testloader = DataLoader(testset,
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=2)
+
+ return testloader
+
+
+def evaluate(model, testloader, device):
+ """Evaluate model on test set and return accuracy."""
+ model.eval()
+ correct = 0
+ total = 0
+
+ with torch.no_grad():
+ for inputs, targets in testloader:
+ inputs, targets = inputs.to(device), targets.to(device)
+ outputs = model(inputs)
+ _, predicted = outputs.max(1)
+ total += targets.size(0)
+ correct += predicted.eq(targets).sum().item()
+
+ accuracy = 100.0 * correct / total
+ return accuracy
+
+
+def get_checkpoint_files(checkpoint_dir):
+ """Get list of checkpoint files in directory."""
+ pattern = os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pt')
+ return set(glob.glob(pattern))
+
+
+def load_checkpoint(checkpoint_path, model, device):
+ """Load checkpoint and return metadata."""
+ checkpoint = torch.load(checkpoint_path,
+ map_location=device,
+ weights_only=False)
+ model.load_state_dict(checkpoint['model_state_dict'])
+ return {
+ 'epoch': checkpoint['epoch'],
+ 'train_loss': checkpoint['train_loss'],
+ 'timestamp': checkpoint.get('timestamp', 0),
+ }
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Evaluate checkpoints as they appear')
+ parser.add_argument('--checkpoint-dir',
+ type=str,
+ required=True,
+ help='Directory to watch for checkpoints')
+ parser.add_argument('--poll-interval',
+ type=int,
+ default=5,
+ help='Seconds between polling for new checkpoints')
+ parser.add_argument('--batch-size',
+ type=int,
+ default=128,
+ help='Evaluation batch size')
+ args = parser.parse_args()
+
+ # Setup device
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+
+ # Create model
+ model = get_model().to(device)
+
+ # Get test data
+ print("Loading CIFAR-10 test dataset...")
+ testloader = get_test_dataloader(args.batch_size)
+ print(f"Test samples: {len(testloader.dataset)}")
+
+ print("\n" + "=" * 60)
+ print("Checkpoint Evaluator")
+ print("=" * 60)
+ print(f"Watching directory: {args.checkpoint_dir}")
+ print(f"Poll interval: {args.poll_interval} seconds")
+ print("=" * 60 + "\n")
+
+ # Track evaluated checkpoints
+ evaluated_checkpoints = set()
+ results = []
+
+ print("Waiting for checkpoints...")
+ print("-" * 60)
+
+ training_complete = False
+ complete_marker_path = os.path.join(args.checkpoint_dir,
+ 'training_complete')
+
+ while True:
+ # Get current checkpoint files
+ current_checkpoints = get_checkpoint_files(args.checkpoint_dir)
+
+ # Find new checkpoints
+ new_checkpoints = current_checkpoints - evaluated_checkpoints
+
+ if new_checkpoints:
+ # Sort by epoch number
+ sorted_checkpoints = sorted(
+ new_checkpoints,
+ key=lambda x: int(
+ os.path.basename(x).split('_')[-1].replace('.pt', '')))
+
+ for checkpoint_path in sorted_checkpoints:
+ try:
+ # Load and evaluate
+ metadata = load_checkpoint(checkpoint_path, model, device)
+ accuracy = evaluate(model, testloader, device)
+
+ result = {
+ 'checkpoint': os.path.basename(checkpoint_path),
+ 'epoch': metadata['epoch'],
+ 'train_loss': metadata['train_loss'],
+ 'test_accuracy': accuracy,
+ }
+ results.append(result)
+
+ print(f"Epoch {metadata['epoch']:3d} | "
+ f"Train Loss: {metadata['train_loss']:.4f} | "
+ f"Test Accuracy: {accuracy:.2f}%")
+
+ evaluated_checkpoints.add(checkpoint_path)
+
+ except Exception as e:
+ print(f"Error evaluating {checkpoint_path}: {e}")
+ # Don't mark as evaluated, will retry next poll
+ continue
+
+ # Check if training is complete (look for training_complete marker)
+ if os.path.exists(complete_marker_path):
+ if not training_complete:
+ print("\nDetected training completion marker.")
+ training_complete = True
+
+ # Evaluate any remaining checkpoints
+ remaining = get_checkpoint_files(
+ args.checkpoint_dir) - evaluated_checkpoints
+ if not remaining:
+ print("All checkpoints evaluated. Exiting.")
+ break
+
+ time.sleep(args.poll_interval)
+
+ # Final summary
+ print("\n" + "=" * 60)
+ print("Evaluation Complete!")
+ print("=" * 60)
+
+ if results:
+ print("\nResults Summary:")
+ print("-" * 60)
+ print(f"{'Epoch':>6} | {'Train Loss':>12} | {'Test Accuracy':>14}")
+ print("-" * 60)
+ for r in results:
+ print(f"{r['epoch']:>6} | {r['train_loss']:>12.4f} | "
+ f"{r['test_accuracy']:>13.2f}%")
+ print("-" * 60)
+
+ best = max(results, key=lambda x: x['test_accuracy'])
+ print(f"\nBest: Epoch {best['epoch']} with "
+ f"{best['test_accuracy']:.2f}% accuracy")
+
+ print("=" * 60)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/llm/train-eval-jobgroup/code/trainer.py b/llm/train-eval-jobgroup/code/trainer.py
new file mode 100644
index 00000000000..77381e18051
--- /dev/null
+++ b/llm/train-eval-jobgroup/code/trainer.py
@@ -0,0 +1,216 @@
+#!/usr/bin/env python3
+"""Trainer script for ResNet-18 on CIFAR-10.
+
+This script trains a ResNet-18 model on CIFAR-10 and saves checkpoints
+periodically to a shared directory that the evaluator can access.
+
+Usage:
+ python trainer.py --checkpoint-dir /checkpoints --num-epochs 10 --save-every 2
+"""
+
+import argparse
+import json
+import os
+import time
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+
+
+def get_model():
+ """Create a ResNet-18 model for CIFAR-10."""
+ model = torchvision.models.resnet18(weights=None)
+ # Modify for CIFAR-10 (32x32 images, 10 classes)
+ model.conv1 = nn.Conv2d(3,
+ 64,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+ model.maxpool = nn.Identity()
+ model.fc = nn.Linear(model.fc.in_features, 10)
+ return model
+
+
+def get_dataloaders(batch_size=128):
+ """Create training and test dataloaders for CIFAR-10."""
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465),
+ (0.2023, 0.1994, 0.2010)),
+ ])
+
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465),
+ (0.2023, 0.1994, 0.2010)),
+ ])
+
+ trainset = torchvision.datasets.CIFAR10(root='./data',
+ train=True,
+ download=True,
+ transform=transform_train)
+ trainloader = DataLoader(trainset,
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=2)
+
+ testset = torchvision.datasets.CIFAR10(root='./data',
+ train=False,
+ download=True,
+ transform=transform_test)
+ testloader = DataLoader(testset,
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=2)
+
+ return trainloader, testloader
+
+
+def save_checkpoint(model, optimizer, epoch, train_loss, checkpoint_dir):
+ """Save a training checkpoint."""
+ os.makedirs(checkpoint_dir, exist_ok=True)
+
+ checkpoint_path = os.path.join(checkpoint_dir,
+ f'checkpoint_epoch_{epoch}.pt')
+ checkpoint = {
+ 'epoch': epoch,
+ 'model_state_dict': model.state_dict(),
+ 'optimizer_state_dict': optimizer.state_dict(),
+ 'train_loss': train_loss,
+ 'timestamp': time.time(),
+ }
+ torch.save(checkpoint, checkpoint_path)
+ print(f"Saved checkpoint: {checkpoint_path}")
+
+ # Update latest.json to point to this checkpoint
+ latest_path = os.path.join(checkpoint_dir, 'latest.json')
+ with open(latest_path, 'w') as f:
+ json.dump(
+ {
+ 'checkpoint': f'checkpoint_epoch_{epoch}.pt',
+ 'epoch': epoch,
+ 'train_loss': train_loss,
+ 'timestamp': time.time(),
+ },
+ f,
+ indent=2)
+
+
+def train_epoch(model, trainloader, criterion, optimizer, device):
+ """Train for one epoch and return average loss."""
+ model.train()
+ running_loss = 0.0
+ total_batches = 0
+
+ for inputs, targets in trainloader:
+ inputs, targets = inputs.to(device), targets.to(device)
+
+ optimizer.zero_grad()
+ outputs = model(inputs)
+ loss = criterion(outputs, targets)
+ loss.backward()
+ optimizer.step()
+
+ running_loss += loss.item()
+ total_batches += 1
+
+ return running_loss / total_batches
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Train ResNet-18 on CIFAR-10')
+ parser.add_argument('--checkpoint-dir',
+ type=str,
+ required=True,
+ help='Directory to save checkpoints')
+ parser.add_argument('--num-epochs',
+ type=int,
+ default=10,
+ help='Number of training epochs')
+ parser.add_argument('--save-every',
+ type=int,
+ default=2,
+ help='Save checkpoint every N epochs')
+ parser.add_argument('--batch-size',
+ type=int,
+ default=128,
+ help='Training batch size')
+ parser.add_argument('--learning-rate',
+ type=float,
+ default=0.1,
+ help='Initial learning rate')
+ args = parser.parse_args()
+
+ # Setup device
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+
+ # Create model, criterion, optimizer
+ model = get_model().to(device)
+ criterion = nn.CrossEntropyLoss()
+ optimizer = optim.SGD(model.parameters(),
+ lr=args.learning_rate,
+ momentum=0.9,
+ weight_decay=5e-4)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
+ T_max=args.num_epochs)
+
+ # Get data
+ print("Loading CIFAR-10 dataset...")
+ trainloader, _ = get_dataloaders(args.batch_size)
+ print(f"Training samples: {len(trainloader.dataset)}")
+
+ # Training loop
+ print("\n" + "=" * 60)
+ print("Starting Training")
+ print("=" * 60)
+ print(f"Epochs: {args.num_epochs}")
+ print(f"Batch size: {args.batch_size}")
+ print(f"Checkpoint directory: {args.checkpoint_dir}")
+ print(f"Saving every {args.save_every} epochs")
+ print("=" * 60 + "\n")
+
+ for epoch in range(1, args.num_epochs + 1):
+ start_time = time.time()
+ train_loss = train_epoch(model, trainloader, criterion, optimizer,
+ device)
+ scheduler.step()
+ epoch_time = time.time() - start_time
+
+ print(f"Epoch {epoch}/{args.num_epochs} | "
+ f"Loss: {train_loss:.4f} | "
+ f"LR: {scheduler.get_last_lr()[0]:.6f} | "
+ f"Time: {epoch_time:.1f}s")
+
+ # Save checkpoint
+ if epoch % args.save_every == 0 or epoch == args.num_epochs:
+ save_checkpoint(model, optimizer, epoch, train_loss,
+ args.checkpoint_dir)
+
+ # Write training complete marker for evaluator
+ complete_marker = os.path.join(args.checkpoint_dir, 'training_complete')
+ with open(complete_marker, 'w') as f:
+ json.dump(
+ {
+ 'final_epoch': args.num_epochs,
+ 'final_loss': train_loss,
+ 'timestamp': time.time(),
+ },
+ f,
+ indent=2)
+ print(f"Wrote completion marker: {complete_marker}")
+
+ print("\n" + "=" * 60)
+ print("Training Complete!")
+ print("=" * 60)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/llm/train-eval-jobgroup/train-eval-ckpts-volume.yaml b/llm/train-eval-jobgroup/train-eval-ckpts-volume.yaml
new file mode 100644
index 00000000000..c72fc577fdc
--- /dev/null
+++ b/llm/train-eval-jobgroup/train-eval-ckpts-volume.yaml
@@ -0,0 +1,16 @@
+# Volume definition for train-eval-jobgroup shared checkpoint storage
+#
+# This volume is used by both the trainer and evaluator tasks to share
+# checkpoint files. Create this volume before launching the job group:
+#
+# sky volume apply llm/train-eval-jobgroup/train-eval-ckpts-volume.yaml
+#
+# Then launch the job group:
+# sky jobs launch llm/train-eval-jobgroup/train-eval-jobgroup.yaml
+
+name: train-eval-ckpts
+type: k8s-pvc
+size: 10Gi
+infra: kubernetes
+config:
+ access_mode: ReadWriteMany
diff --git a/llm/train-eval-jobgroup/train-eval-jobgroup.yaml b/llm/train-eval-jobgroup/train-eval-jobgroup.yaml
new file mode 100644
index 00000000000..861f1ed1459
--- /dev/null
+++ b/llm/train-eval-jobgroup/train-eval-jobgroup.yaml
@@ -0,0 +1,90 @@
+# Parallel Training and Evaluation with Shared Volume
+#
+# This example demonstrates a job group with parallel training and evaluation
+# tasks that share a Kubernetes volume for checkpoints. The evaluator monitors
+# the checkpoint directory and evaluates models as training produces them.
+#
+# Architecture:
+# - trainer: Trains ResNet-18 on CIFAR-10, saves checkpoints to shared volume
+# - evaluator: Watches for checkpoints, evaluates and reports accuracy
+#
+# Completion Behavior:
+# When training completes, the trainer writes a "training_complete" marker
+# file to the shared volume. The evaluator detects this marker, finishes
+# evaluating any remaining checkpoints, and exits gracefully. Both tasks
+# complete naturally without forced termination.
+#
+# Usage:
+# # First, create the shared volume:
+# sky volume apply llm/train-eval-jobgroup/train-eval-ckpts-volume.yaml
+#
+# # Then launch the job group:
+# sky jobs launch llm/train-eval-jobgroup/train-eval-jobgroup.yaml
+#
+# The components share storage via a Kubernetes PVC:
+# /checkpoints - Shared volume for checkpoint files
+---
+name: train-eval
+execution: parallel
+
+---
+# Trainer: Trains ResNet-18 on CIFAR-10 and saves checkpoints
+name: trainer
+resources:
+ accelerators: H100:1
+ memory: 16+
+ infra: kubernetes
+
+file_mounts:
+ /code: llm/train-eval-jobgroup/code
+
+volumes:
+ /checkpoints: train-eval-ckpts
+
+envs:
+ CHECKPOINT_DIR: /checkpoints
+ NUM_EPOCHS: 10
+ SAVE_EVERY: 2
+
+setup: |
+ uv pip install torch torchvision --system
+
+run: |
+ echo "Starting trainer..."
+ echo "JobGroup: ${SKYPILOT_JOBGROUP_NAME}"
+ echo "Checkpoints will be saved to ${CHECKPOINT_DIR}"
+
+ cd /code
+ python trainer.py \
+ --checkpoint-dir ${CHECKPOINT_DIR} \
+ --num-epochs ${NUM_EPOCHS} \
+ --save-every ${SAVE_EVERY}
+
+---
+# Evaluator: Watches for checkpoints and evaluates them
+name: evaluator
+resources:
+ accelerators: H100:1
+ memory: 16+
+ infra: kubernetes
+
+file_mounts:
+ /code: llm/train-eval-jobgroup/code
+
+volumes:
+ /checkpoints: train-eval-ckpts
+
+envs:
+ CHECKPOINT_DIR: /checkpoints
+
+setup: |
+ uv pip install torch torchvision --system
+
+run: |
+ echo "Starting evaluator..."
+ echo "JobGroup: ${SKYPILOT_JOBGROUP_NAME}"
+ echo "Watching for checkpoints in ${CHECKPOINT_DIR}"
+
+ cd /code
+ python evaluator.py \
+ --checkpoint-dir ${CHECKPOINT_DIR}
diff --git a/llm/verl/README.md b/llm/verl/README.md
index 0d2ec660b97..e227368b038 100644
--- a/llm/verl/README.md
+++ b/llm/verl/README.md
@@ -3,6 +3,8 @@
[Verl](https://github.com/volcengine/verl) is the most popular open-source reinforcement learning framework for LLMs, supporting PPO, GRPO, and other algorithms.
+Also see [`search-tooling/`](https://github.com/skypilot-org/skypilot/tree/master/llm/verl/search-tooling) and this [blog](https://blog.skypilot.co/verl-tool-calling/) for tool-augmented “search” workflows (Search-R1 style), including Google Search–backed inference and a Wikipedia FAISS retrieval service used for inference and training.
+
## Why SkyPilot + Verl?
SkyPilot makes RL training **easy and cost-effective**:
@@ -47,81 +49,8 @@ sky status --endpoint 8280 verl
Ray dashboard showing real-time monitoring of distributed training across multiple nodes
-## Key Features
-
-The example trains Qwen2.5-0.5B-Instruct on the GSM8K dataset using PPO:
-- **Multi-node distributed training** with automatic Ray cluster setup
-- **Checkpoint persistence** to cloud storage for fault tolerance
-- **Customizable models and datasets** via environment variables
-
-## Optional: Enable W&B for Training Visualization
-
-To track training curves and metrics in Weights & Biases:
-```bash
-# 1. Set your W&B API key locally
-export WANDB_API_KEY=your-api-key
-
-# 2. Launch with the secret flag
-sky launch -c verl llm/verl/multinode.yaml --secret WANDB_API_KEY
-
-# 3. Edit multinode.yaml to enable W&B logger (see comments in the file)
-```
-
-## Advanced Usage
-
-### 💰 Use Spot Instances for 3x Cost Savings
-
-```bash
-sky jobs launch -n verl-job llm/verl/multinode.yaml
-```
-Training automatically resumes from checkpoints if preempted.
-
-### 🚀 Continue Experiments on the Same Cluster
-
-```bash
-# Run additional training epochs
-sky exec verl llm/verl/multinode.yaml --env TOTAL_EPOCHS=10
-
-# The YAML automatically detects and reuses the existing Ray cluster
-```
-
-### 📈 Scale to More Nodes
-
-```bash
-sky launch -c verl llm/verl/multinode.yaml --num-nodes 4
-```
-
-### 🔧 Customize Training Configuration
-
-Modify parameters directly:
-```bash
-sky launch -c verl llm/verl/multinode.yaml \
- --env MODEL_NAME=meta-llama/Llama-2-7b-hf \
- --env ACTOR_LR=5e-6 \
- --env CRITIC_LR=1e-5
-```
-
-Train a larger model:
-```bash
-sky launch -c verl llm/verl/multinode.yaml \
- --env MODEL_NAME=Qwen/Qwen2.5-7B-Instruct \
- --gpus A100-80GB:8 --num-nodes 4
-```
-
-## Understanding the Setup
-
-1. **Head node**: Prepares data, starts Ray head, submits training job
-2. **Worker nodes**: Join Ray cluster for distributed training
-3. **Smart resumption**: Ray cluster is reused if already running, avoiding restart overhead
-
-## Troubleshooting
-
-- **OOM errors**: Reduce batch sizes or `gpu_memory_utilization`
-- **Connection issues**: Ensure ports 6385 (Ray) and 8280 (dashboard) are not blocked
-- **First run is slow**: Model download happens once, subsequent runs are faster
-
## Learn More
- [Verl Documentation](https://verl.readthedocs.io/)
- [Verl GitHub Repository](https://github.com/volcengine/verl)
-- [SkyPilot Ray Setup Guide](https://docs.skypilot.co/en/latest/running-jobs/distributed-jobs.html#executing-a-distributed-ray-program)
\ No newline at end of file
+- [SkyPilot Ray Setup Guide](https://docs.skypilot.co/en/latest/running-jobs/distributed-jobs.html#executing-a-distributed-ray-program)
diff --git a/llm/verl/search-tooling/README.md b/llm/verl/search-tooling/README.md
new file mode 100644
index 00000000000..41f8b322e0e
--- /dev/null
+++ b/llm/verl/search-tooling/README.md
@@ -0,0 +1,37 @@
+# Search tooling for VERL
+
+This folder contains SkyPilot YAMLs for training and inference with tool-augmented “search” workflows (Search-R1 style), using either:
+- a **Google Search** backend, or
+- a **Wikipedia retrieval service** (FAISS index).
+
+See this [blog](https://blog.skypilot.co/verl-tool-calling/) for how the YAMLs are used for training a RL agent that can use Google search.
+
+## Inference (Google Search backend)
+
+```bash
+sky launch -c verl-infer-google llm/verl/search-tooling/verl-search-interaction-google-search.yaml \
+ --env MODEL_PATH=/checkpoints/hf_model \
+ --env GOOGLE_API_KEY=your_key_here \
+ --env GOOGLE_CSE_ID=your_cse_id_here \
+ -y
+```
+
+## Inference (local Wikipedia retrieval on the same node)
+
+```bash
+sky launch -c verl-infer llm/verl/search-tooling/verl-search-interaction-infer.yaml \
+ --env MODEL_PATH=/checkpoints/hf_model \
+ -y
+```
+
+## Retrieval service (CPU-only, for reuse across jobs)
+
+```bash
+sky serve up -n retrieval llm/verl/search-tooling/verl-search-interaction-retrieval.yaml --cpus 32+ --memory 256+ -y
+sky serve status retrieval --endpoint 8000
+```
+
+## Training
+
+- Single-node training with retrieval running on the same node: `llm/verl/search-tooling/verl-search-interaction.yaml`
+- Training that points to an external retrieval service: `llm/verl/search-tooling/verl-search-interaction-rl-trainer.yaml`
diff --git a/llm/verl/search-tooling/verl-search-interaction-google-search.yaml b/llm/verl/search-tooling/verl-search-interaction-google-search.yaml
new file mode 100644
index 00000000000..e51cc6db2eb
--- /dev/null
+++ b/llm/verl/search-tooling/verl-search-interaction-google-search.yaml
@@ -0,0 +1,140 @@
+# Search Tool Interaction Inference (Google Search backend)
+#
+# This example demonstrates inference using Search-R1 with a search/retrieval tool.
+# The model uses a Google Search–backed tool for answering questions that require external knowledge.
+# Both the Google search server and inference run on the same node.
+#
+# Usage:
+# sky launch -c verl-infer-google llm/verl/verl-search-interaction-google-infer.yaml \
+# --env MODEL_PATH=/checkpoints/hf_model \
+# --env GOOGLE_API_KEY=your_key_here \
+# --env GOOGLE_CSE_ID=your_cse_id_here \
+# -y
+#
+# Requirements:
+# - Single GPU for inference
+# - Valid Google Programmable Search Engine (CSE) + API key
+
+resources:
+ accelerators: H100:1
+ memory: 128+
+ ports:
+ - 8000 # Google search server
+
+num_nodes: 1
+
+envs:
+ MODEL_PATH: "" # Optional: Path to model checkpoint (defaults to base model)
+ GOOGLE_API_KEY: "" # Required: Google API key
+ GOOGLE_CSE_ID: "" # Required: Google Custom Search Engine ID
+ CHECKPOINT_BUCKET_NAME: verl-search-interaction-checkpoints
+
+file_mounts:
+ /checkpoints:
+ name: ${CHECKPOINT_BUCKET_NAME}
+ mode: MOUNT
+
+setup: |
+ set -e
+
+ echo "=== Search Tool Inference Setup (Google Search) ==="
+
+ # System dependencies
+ echo "Installing system dependencies..."
+ sudo apt update && sudo apt install -y iproute2 git
+
+ # Python environment
+ echo "Setting up Python virtual environment..."
+ uv venv --python 3.10 --seed
+ source .venv/bin/activate
+
+ echo "Installing PyTorch..."
+ uv pip install "torch==2.8.*" torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
+
+ # Clone VERL repository (if infer.py relies on its code / configs)
+ echo "Cloning VERL repository..."
+ rm -rf verl
+ git clone https://github.com/volcengine/verl.git
+ cd verl
+ git checkout v0.6.0
+
+ echo "Installing VERL + SGLang dependencies..."
+ uv pip install -v -e .
+ uv pip install wheel
+ uv pip install packaging
+ uv pip install -r ./requirements_sglang.txt
+ uv pip install "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"
+
+ cd ..
+
+ # Clone Search-R1 for inference
+ echo "Cloning Search-R1 repository..."
+ rm -rf Search-R1
+ git clone https://github.com/PeterGriffinJin/Search-R1.git
+
+ # Install additional inference dependencies
+ cd Search-R1
+ if [ -f requirements.txt ]; then
+ echo "Installing Search-R1 requirements..."
+ uv pip install -r requirements.txt
+ fi
+
+ # Ensure Google API client is available (if not already pulled in)
+ uv pip install google-api-python-client
+
+ cd ..
+
+ echo "✓ Inference setup complete!"
+
+run: |
+ set -e
+
+ echo "=== Search Tool Inference (Google Search backend) ==="
+
+ # Activate environment
+ source .venv/bin/activate
+
+ # Sanity check env vars
+ if [ -z "$GOOGLE_API_KEY" ] || [ -z "$GOOGLE_CSE_ID" ]; then
+ echo "ERROR: GOOGLE_API_KEY and GOOGLE_CSE_ID must be set via --env."
+ exit 1
+ fi
+
+ echo "Using GOOGLE_API_KEY: (set)"
+ echo "Using GOOGLE_CSE_ID: (set)"
+
+ # Start Google search server in background
+ cd ~/sky_workdir/Search-R1
+ echo "Starting Google search server on port 8000..."
+ python search_r1/search/google_search_server.py \
+ --api_key "$GOOGLE_API_KEY" \
+ --cse_id "$GOOGLE_CSE_ID" \
+ > google_search_server.log 2>&1 &
+
+ RETRIEVAL_PID=$!
+ echo "Google search server PID: $RETRIEVAL_PID"
+
+ # Give the server a moment to start
+ sleep 10
+
+ # (Optional) basic health check if the server exposes one
+ # curl -f http://127.0.0.1:8000/health || echo "Healthcheck failed (continuing anyway)"
+
+ # Run inference
+ echo "Running infer.py..."
+ if [ -n "$MODEL_PATH" ]; then
+ # If your infer.py supports a flag, use it; otherwise it may read MODEL_PATH from env.
+ python infer.py --model_path "$MODEL_PATH" || python infer.py
+ else
+ python infer.py
+ fi
+
+ echo "✓ Inference finished"
+
+ # Clean up search server (SkyPilot will tear down the node afterwards anyway)
+ if ps -p $RETRIEVAL_PID > /dev/null 2>&1; then
+ echo "Stopping Google search server..."
+ kill $RETRIEVAL_PID || true
+ fi
+
+ echo "=== Done ==="
diff --git a/llm/verl/search-tooling/verl-search-interaction-infer.yaml b/llm/verl/search-tooling/verl-search-interaction-infer.yaml
new file mode 100644
index 00000000000..cdc88566e60
--- /dev/null
+++ b/llm/verl/search-tooling/verl-search-interaction-infer.yaml
@@ -0,0 +1,122 @@
+# Search Tool Interaction Inference
+#
+# This example demonstrates inference using Search-R1 with a search/retrieval tool.
+# The model uses a search tool for answering questions that require external knowledge.
+# Both retrieval service and inference run on the same node.
+#
+# Usage:
+# sky launch -c verl-infer llm/verl/verl-search-interaction-infer.yaml --env MODEL_PATH=/checkpoints/hf_model -y
+#
+# Requirements:
+# - Single GPU for inference
+# - Sufficient memory for retrieval index
+
+resources:
+ accelerators: H100:1
+ memory: 128+
+ ports:
+ - 8000 # Retrieval service
+
+num_nodes: 1
+
+envs:
+ MODEL_PATH: "" # Optional: Path to model checkpoint (defaults to base model)
+ RETRIEVAL_TOPK: 3
+ RETRIEVER_NAME: e5
+ RETRIEVER_MODEL: intfloat/e5-base-v2
+ CHECKPOINT_BUCKET_NAME: verl-search-interaction-checkpoints
+
+file_mounts:
+ /checkpoints:
+ name: ${CHECKPOINT_BUCKET_NAME}
+ mode: MOUNT
+
+setup: |
+ set -e
+
+ echo "=== Search Tool Inference Setup ==="
+
+ # System dependencies
+ echo "Installing system dependencies..."
+ sudo apt update && sudo apt install -y iproute2
+
+ # Python environment
+ echo "Setting up Python virtual environment..."
+ uv venv --python 3.10 --seed
+ source .venv/bin/activate
+
+ # Install dependencies
+ echo "Installing PyTorch and dependencies..."
+ uv pip install "torch==2.8.*" torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
+ uv pip install -v -e .
+ uv pip install wheel
+ uv pip install packaging
+ uv pip install -r ./requirements_sglang.txt
+ uv pip install "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"
+
+ # Download Wikipedia corpus and FAISS index
+ echo "Downloading Wikipedia corpus and FAISS index..."
+ export save_path=~/dataset
+ mkdir -p $save_path
+
+ huggingface-cli download maknee/wiki-18-subsets wiki-18-100k.jsonl.gz --repo-type=dataset --local-dir $save_path
+ huggingface-cli download maknee/wiki-18-subsets e5_Flat-100k.index --repo-type=dataset --local-dir $save_path
+
+ # Move files to expected locations
+ mv $save_path/wiki-18-100k.jsonl.gz $save_path/wiki-18.jsonl.gz
+ mv $save_path/e5_Flat-100k.index $save_path/e5_Flat.index
+
+ # Decompress the JSONL file
+ gzip -d $save_path/wiki-18.jsonl.gz -f
+
+ # Clone VERL repository
+ echo "Cloning VERL repository..."
+ rm -rf verl
+ git clone https://github.com/volcengine/verl.git
+ cd verl
+ git checkout v0.6.0
+ cd ..
+
+ # Clone Search-R1 for inference
+ echo "Cloning Search-R1 repository..."
+ rm -rf Search-R1
+ git clone https://github.com/PeterGriffinJin/Search-R1/
+
+ # Install additional inference dependencies if needed
+ cd Search-R1
+ if [ -f requirements.txt ]; then
+ uv pip install -r requirements.txt
+ fi
+ cd ..
+
+ echo "✓ Inference setup complete!"
+
+run: |
+ set -e
+
+ echo "=== Search Tool Inference ==="
+
+ # Activate environment
+ source .venv/bin/activate
+
+ # Set up paths
+ save_path=~/dataset
+ index_file=$save_path/e5_Flat.index
+ corpus_file=$save_path/wiki-18.jsonl
+
+ # Start retrieval server in background
+ echo "Starting retrieval server on port 8000..."
+ cd verl
+ python examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py \
+ --index_path $index_file \
+ --corpus_path $corpus_file \
+ --topk $RETRIEVAL_TOPK \
+ --retriever_name $RETRIEVER_NAME \
+ --retriever_model $RETRIEVER_MODEL &
+
+ RETRIEVAL_PID=$!
+ sleep 10
+
+ # Run inference
+ cd ~/sky_workdir/Search-R1
+ python infer.py
diff --git a/llm/verl/search-tooling/verl-search-interaction-retrieval.yaml b/llm/verl/search-tooling/verl-search-interaction-retrieval.yaml
new file mode 100644
index 00000000000..dce01c9a268
--- /dev/null
+++ b/llm/verl/search-tooling/verl-search-interaction-retrieval.yaml
@@ -0,0 +1,112 @@
+# Search Tool Retrieval Service
+#
+# This service provides Wikipedia retrieval capabilities using FAISS indexing.
+# It runs on CPU nodes and exposes a retrieval API on port 8000.
+#
+# Usage:
+# sky launch -c retrieval llm/verl/verl-search-interaction-retrieval.yaml --cpus 32+ --memory 256+ -y
+#
+# Get endpoint:
+# sky status retrieval --endpoint 8000
+#
+# OR with sky serve
+# sky serve up -n retrieval llm/verl/verl-search-interaction-retrieval.yaml --cpus 32+ --memory 256+ -y
+#
+# Get endpoint:
+# sky serve status retrieval --endpoint 8000
+
+service:
+ readiness_probe: /
+ replicas: 3
+
+resources:
+ cpus: 32+
+ memory: 256+
+ use_spot: false
+ ports:
+ - 8000 # Retrieval service API
+
+num_nodes: 1
+
+envs:
+ RETRIEVAL_TOPK: 3
+ RETRIEVER_NAME: e5
+ RETRIEVER_MODEL: intfloat/e5-base-v2
+
+setup: |
+ set -e
+
+ echo "=== Retrieval Service Setup ==="
+
+ # System dependencies
+ echo "Installing system dependencies..."
+ sudo apt update && sudo apt install -y iproute2
+
+ # Python environment
+ echo "Setting up Python virtual environment..."
+ uv venv --python 3.10 --seed
+ source .venv/bin/activate
+
+ # Install retrieval service dependencies
+ echo "Installing retrieval service dependencies..."
+ uv pip install "torch==2.8.*" torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
+ uv pip install transformers datasets huggingface_hub
+ uv pip install faiss-cpu
+ uv pip install uvicorn fastapi uvloop==0.21.0
+
+ # Download Wikipedia corpus and FAISS index
+ echo "Downloading Wikipedia corpus and FAISS index..."
+ export save_path=~/dataset
+ mkdir -p $save_path
+
+ huggingface-cli download maknee/wiki-18-subsets wiki-18-100k.jsonl.gz --repo-type=dataset --local-dir $save_path
+ huggingface-cli download maknee/wiki-18-subsets e5_Flat-100k.index --repo-type=dataset --local-dir $save_path
+
+ # Move files to expected locations
+ mv $save_path/wiki-18-100k.jsonl.gz $save_path/wiki-18.jsonl.gz
+ mv $save_path/e5_Flat-100k.index $save_path/e5_Flat.index
+
+ # Decompress the JSONL file
+ gzip -d $save_path/wiki-18.jsonl.gz -f
+
+ # Clone VERL repository for retrieval server code
+ echo "Cloning repositories..."
+ git clone https://github.com/volcengine/verl.git
+ cd verl
+ git checkout v0.6.0
+
+ # Patch retrieval server for CPU-only usage (comment out CUDA calls)
+ echo "Patching retrieval server for CPU-only usage..."
+ sed -i 's/^\(\s*\)\(model\.cuda()\)/\1# \2 # Commented out for CPU-only deployment/' \
+ examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py
+ sed -i 's/^\(\s*\)\(inputs = {k: v\.cuda() for k, v in inputs\.items()}\)/\1# \2 # Commented out for CPU-only deployment/' \
+ examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py
+
+ cd ..
+
+ echo "✓ Retrieval service setup complete!"
+
+run: |
+ set -e
+
+ echo "=== Starting Retrieval Service ==="
+
+ # Activate environment
+ source .venv/bin/activate
+
+ # Set up paths
+ save_path=~/dataset
+ index_file=$save_path/e5_Flat.index
+ corpus_file=$save_path/wiki-18.jsonl
+
+ # Start retrieval server
+ echo "Starting retrieval server on port 8000..."
+ cd verl
+ python examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py \
+ --index_path $index_file \
+ --corpus_path $corpus_file \
+ --topk $RETRIEVAL_TOPK \
+ --retriever_name $RETRIEVER_NAME \
+ --retriever_model $RETRIEVER_MODEL &
+
+ echo "✓ Retrieval service running on port 8000"
diff --git a/llm/verl/search-tooling/verl-search-interaction-rl-trainer.yaml b/llm/verl/search-tooling/verl-search-interaction-rl-trainer.yaml
new file mode 100644
index 00000000000..574976b66e7
--- /dev/null
+++ b/llm/verl/search-tooling/verl-search-interaction-rl-trainer.yaml
@@ -0,0 +1,314 @@
+# Search Tool Interaction Training with VERL (RL Trainer)
+#
+# This example demonstrates multi-turn tool interaction training using VERL with a search/retrieval tool.
+# The model learns to use a search tool for answering questions that require external knowledge.
+#
+# Requires a separate retrieval service running (see verl-search-interaction-retrieval.yaml)
+#
+# Based on: https://verl.readthedocs.io/en/v0.5.x/sglang_multiturn/search_tool_example.html
+#
+# Usage:
+# # 1. Launch retrieval service first
+# sky launch -c retrieval llm/verl/verl-search-interaction-retrieval.yaml --cpus 32+ --memory 256+ -y
+#
+# # 2. Get retrieval service endpoint
+# RETRIEVAL_IP=$(sky status retrieval --endpoint 8000)
+#
+# # 3. Launch training (without WandB)
+# sky launch -c verl-train llm/verl/verl-search-interaction-rl-trainer.yaml --env RETRIEVAL_SERVICE_URL=http://$RETRIEVAL_IP --env DATASET_SIZE=small --env TOTAL_EPOCHS=1 -y
+#
+# # Or with WandB logging (optional)
+# sky launch -c verl-train llm/verl/verl-search-interaction-rl-trainer.yaml --env RETRIEVAL_SERVICE_URL=http://$RETRIEVAL_IP --env DATASET_SIZE=small --env TOTAL_EPOCHS=1 --secret WANDB_API_KEY -y
+#
+# Requirements:
+# - Docker with SYS_PTRACE capability (for PyTorch multiprocessing CUDA tensor sharing)
+# - H100 GPUs (can be adjusted for other accelerators)
+# - Running retrieval service at RETRIEVAL_SERVICE_URL
+
+resources:
+ accelerators: H100:1
+ memory: 128+
+ image_id: docker:verlai/verl:app-verl0.6-transformers4.56.1-sglang0.5.2-mcore0.13.0-te2.2
+ ports:
+ - 8265 # Ray dashboard
+ - 9090 # vLLM model serving
+
+num_nodes: 1
+
+config:
+ docker:
+ run_options:
+ - --cap-add=SYS_PTRACE # Required for PyTorch CUDA tensor sharing between Ray workers
+ - --ipc=host
+ - --shm-size=16g
+
+envs:
+ RETRIEVAL_SERVICE_URL: "" # Required: URL of the retrieval service (e.g., http://retrieval-ip:8000)
+ DATASET_SIZE: small # Options: small (1000 train, 200 test), medium (10k train, 2k test), full
+ TOTAL_EPOCHS: 1
+ TOTAL_STEPS: 10
+ TRAIN_BATCH_SIZE: 512
+ VAL_BATCH_SIZE: 256
+ SAVE_FREQ: 5 # Save checkpoints every 5 steps
+ TEST_FREQ: 5 # Test every 5 steps
+ MODEL_NAME: Qwen/Qwen2.5-3B-Instruct
+ WANDB_PROJECT_NAME: search_r1_like_async_rl
+ WANDB_EXPERIMENT_NAME: qwen2.5-3b-it_rm-searchR1-like-sgl-multiturn
+ CHECKPOINT_BUCKET_NAME: nebius://verl-search-interaction-checkpoints
+
+file_mounts:
+ /checkpoints:
+ source: ${CHECKPOINT_BUCKET_NAME}
+ mode: MOUNT_CACHED
+
+secrets:
+ WANDB_API_KEY: "" # Optional: Set to enable WandB logging. If not set, only console logging will be used.
+
+setup: |
+ rm -f ~/.pip/pip.conf
+ rm -f ~/.config/pip/pip.conf
+
+ set -e
+
+ echo "=== VERL Search Tool Interaction Training Setup ==="
+
+ # Validate required environment variables
+ if [ -z "$RETRIEVAL_SERVICE_URL" ]; then
+ echo "ERROR: RETRIEVAL_SERVICE_URL environment variable is required"
+ echo "Example: --env RETRIEVAL_SERVICE_URL=http://retrieval-ip:8000"
+ exit 1
+ fi
+
+ # Python environment
+ echo "Setting up Python virtual environment..."
+ uv venv --python 3.10 --seed
+ source .venv/bin/activate
+
+ # Clone VERL repository
+ echo "Cloning VERL repository..."
+ rm -rf verl
+ git clone https://github.com/volcengine/verl.git
+ cd verl
+ git checkout v0.6.0
+
+ # Core dependencies
+ echo "Installing PyTorch and VERL..."
+ uv pip install "torch==2.8.*" torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
+ uv pip install -v -e .
+ uv pip install "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"
+ uv pip install wheel
+ uv pip install packaging
+ uv pip install -r ./requirements_sglang.txt
+
+ # Install uvloop (required version)
+ uv pip install uvloop==0.21.0
+
+ # Data preparation
+ echo "Preparing search R1 dataset..."
+ python3 examples/data_preprocess/preprocess_search_r1_dataset.py
+
+ # Clone Search-R1 for additional utilities
+ git clone https://github.com/PeterGriffinJin/Search-R1/
+
+ # Update tool config to use external retrieval service
+ echo "Configuring external retrieval service..."
+ TOOL_CONFIG="examples/sglang_multiturn/config/tool_config/search_tool_config.yaml"
+
+ # Backup original config
+ cp $TOOL_CONFIG ${TOOL_CONFIG}.bak
+
+ # Update retrieval URL and num_workers in the config
+ sed -i 's/num_workers: *120/num_workers: 8/' $TOOL_CONFIG
+ sed -i "s|http://127\.0\.0\.1:8000/retrieve|$RETRIEVAL_SERVICE_URL/retrieve|g" $TOOL_CONFIG
+ sed -i "s|http://localhost:8000|$RETRIEVAL_SERVICE_URL|g" $TOOL_CONFIG
+
+ echo "✓ Setup complete!"
+ echo "Dataset location: ~/data/searchR1_processed_direct/"
+ echo "VERL repository: $(pwd)"
+ echo "Retrieval service: $RETRIEVAL_SERVICE_URL"
+
+run: |
+ set -e
+
+ echo "=== VERL Search Tool Interaction Training ==="
+ sudo apt update && sudo apt install -y iproute2 npm
+
+ # Validate retrieval service
+ if [ -z "$RETRIEVAL_SERVICE_URL" ]; then
+ echo "ERROR: RETRIEVAL_SERVICE_URL environment variable is required"
+ exit 1
+ fi
+
+ echo "Testing connection to retrieval service at $RETRIEVAL_SERVICE_URL..."
+ # Give it a few retries in case the service is still starting
+ max_retries=30
+ retry_count=0
+ while [ $retry_count -lt $max_retries ]; do
+ # Test the /retrieve endpoint with a sample query
+ test_response=$(curl -s -X POST "${RETRIEVAL_SERVICE_URL}/retrieve" \
+ -H "Content-Type: application/json" \
+ -d '{"queries": ["test query"], "topk": 1, "return_scores": false}' \
+ -w "\n%{http_code}" 2>&1)
+
+ http_code=$(echo "$test_response" | tail -n1)
+
+ if [ "$http_code" = "200" ]; then
+ echo "✓ Successfully connected to retrieval service"
+ echo "✓ /retrieve endpoint is responding correctly"
+ break
+ fi
+ retry_count=$((retry_count+1))
+ if [ $retry_count -eq $max_retries ]; then
+ echo "WARNING: Could not connect to retrieval service at $RETRIEVAL_SERVICE_URL"
+ echo "Make sure the retrieval service is running and accessible"
+ echo "Last response code: $http_code"
+ fi
+ sleep 5
+ done
+
+ # Multi-node setup
+ HEAD_IP=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
+ NUM_NODES=$SKYPILOT_NUM_NODES
+ NUM_GPUS_PER_NODE=$SKYPILOT_NUM_GPUS_PER_NODE
+
+ # Network configuration for distributed training
+ NETWORK_INTERFACE=$(ip route get 8.8.8.8 | grep -oP 'dev \K\S+')
+ export GLOO_SOCKET_IFNAME=$NETWORK_INTERFACE
+ export NCCL_SOCKET_IFNAME=$NETWORK_INTERFACE
+
+ # PyTorch multiprocessing configuration
+ export TORCH_MULTIPROCESSING_SHARING_STRATEGY=file_system
+
+ # Activate environment
+ source .venv/bin/activate
+
+ # Set up paths
+ cd verl
+ PROJECT_DIR="$(pwd)"
+ export PYTHONPATH="$PROJECT_DIR:$PYTHONPATH"
+
+ # WandB login (optional)
+ if [ -n "$WANDB_API_KEY" ]; then
+ echo "Logging into Weights & Biases..."
+ python3 -c "import wandb; wandb.login(relogin=True, key='$WANDB_API_KEY')"
+ fi
+
+ if [ "$SKYPILOT_NODE_RANK" == "0" ]; then
+ echo "Starting Ray head node on port 6379..."
+ ps aux | grep ray | grep 6379 &> /dev/null || ray start --head --disable-usage-stats --port=6379 --dashboard-host=0.0.0.0 --dashboard-port=8265
+
+ # Wait for all nodes to connect
+ echo "Waiting for $NUM_NODES nodes to connect..."
+ retry_count=0
+ max_retries=30
+ while [ $retry_count -lt $max_retries ]; do
+ connected_nodes=$(ray status 2>/dev/null | grep -c "node_" || echo "0")
+ if [ "$connected_nodes" -ge "$NUM_NODES" ]; then
+ echo "✓ All $NUM_NODES nodes connected"
+ break
+ fi
+ retry_count=$((retry_count+1))
+ sleep 10
+ done
+
+ # Display Ray cluster status
+ echo "Ray cluster status:"
+ ray status
+
+ echo "Starting search tool interaction training..."
+ cd $PROJECT_DIR
+
+ # Increase file descriptor limit
+ ulimit -n 65535
+
+ # Set up configuration paths
+ CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config"
+ TRAIN_DATA="$HOME/data/searchR1_processed_direct/train.parquet"
+ VAL_DATA="$HOME/data/searchR1_processed_direct/test.parquet"
+ TOOL_CONFIG="$CONFIG_PATH/tool_config/search_tool_config.yaml"
+
+ # Configure logging based on WANDB_API_KEY availability
+ if [ -n "$WANDB_API_KEY" ]; then
+ LOGGER_CONFIG='["console","wandb"]'
+ WANDB_ARGS="trainer.project_name=$WANDB_PROJECT_NAME trainer.experiment_name=$WANDB_EXPERIMENT_NAME"
+ echo "✓ WandB logging enabled"
+ else
+ LOGGER_CONFIG='["console"]'
+ WANDB_ARGS=""
+ echo "ℹ WandB logging disabled (no API key provided)"
+ fi
+
+ # Training with search tool
+ python3 -m verl.trainer.main_ppo \
+ --config-path="$CONFIG_PATH" \
+ --config-name='search_multiturn_grpo' \
+ algorithm.adv_estimator=grpo \
+ data.train_batch_size=$TRAIN_BATCH_SIZE \
+ data.val_batch_size=$VAL_BATCH_SIZE \
+ data.max_prompt_length=4096 \
+ data.max_response_length=3000 \
+ data.filter_overlong_prompts=True \
+ data.truncation='error' \
+ data.return_raw_chat=True \
+ actor_rollout_ref.model.path=$MODEL_NAME \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.285 \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.actor.ppo_mini_batch_size=16 \
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
+ actor_rollout_ref.actor.use_kl_loss=True \
+ actor_rollout_ref.actor.kl_loss_coef=0.001 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=True \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
+ actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
+ actor_rollout_ref.rollout.max_model_len=15000 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
+ actor_rollout_ref.rollout.name=sglang \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
+ actor_rollout_ref.rollout.n=5 \
+ actor_rollout_ref.rollout.multi_turn.max_assistant_turns=2 \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
+ algorithm.use_kl_in_reward=False \
+ trainer.critic_warmup=0 \
+ trainer.val_before_train=False \
+ trainer.logger="$LOGGER_CONFIG" \
+ $WANDB_ARGS \
+ trainer.n_gpus_per_node=$NUM_GPUS_PER_NODE \
+ trainer.nnodes=$NUM_NODES \
+ trainer.save_freq=$SAVE_FREQ \
+ trainer.test_freq=$TEST_FREQ \
+ data.train_files="$TRAIN_DATA" \
+ data.val_files="$VAL_DATA" \
+ actor_rollout_ref.rollout.multi_turn.tool_config_path="$TOOL_CONFIG" \
+ trainer.total_epochs=$TOTAL_EPOCHS \
+ trainer.total_training_steps=$TOTAL_STEPS \
+ trainer.default_local_dir=/checkpoints
+
+ echo "✓ Training complete!"
+
+ # Model checkpoint merging
+ echo "Merging model checkpoints..."
+ LATEST_STEP=$(cat /checkpoints/latest_checkpointed_iteration.txt)
+ CHECKPOINT_DIR="/checkpoints/global_step_${LATEST_STEP}/actor"
+
+ python -m verl.model_merger merge \
+ --backend fsdp \
+ --tie-word-embedding \
+ --local_dir ${CHECKPOINT_DIR} \
+ --target_dir /checkpoints/hf_model
+
+ echo "✓ Model saved to /checkpoints/hf_model"
+ echo "Training artifacts saved to cloud bucket: ${CHECKPOINT_BUCKET_NAME}"
+
+ else
+ # Worker node setup
+ echo "Worker node (rank $SKYPILOT_NODE_RANK) connecting to head at $HEAD_IP:6379..."
+ sleep 15
+ ps aux | grep ray | grep $HEAD_IP:6379 &> /dev/null || ray start --address $HEAD_IP:6379 --disable-usage-stats
+ echo "✓ Worker node connected"
+ sleep infinity
+ fi
diff --git a/llm/verl/search-tooling/verl-search-interaction.yaml b/llm/verl/search-tooling/verl-search-interaction.yaml
new file mode 100644
index 00000000000..44b10c5b1d0
--- /dev/null
+++ b/llm/verl/search-tooling/verl-search-interaction.yaml
@@ -0,0 +1,351 @@
+# Search Tool Interaction Training with VERL
+#
+# This example demonstrates multi-turn tool interaction training using VERL with a search/retrieval tool.
+# The model learns to use a search tool for answering questions that require external knowledge.
+#
+# Based on: https://verl.readthedocs.io/en/v0.5.x/sglang_multiturn/search_tool_example.html
+#
+# Usage:
+# # Without WandB logging
+# sky launch -c verl-search llm/verl/verl-search-interaction.yaml --env DATASET_SIZE=small --env TOTAL_EPOCHS=1 -y
+#
+# # Or with WandB logging (optional)
+# sky launch -c verl-search llm/verl/verl-search-interaction.yaml --secret WANDB_API_KEY --env DATASET_SIZE=small --env TOTAL_EPOCHS=1 -y
+#
+# Requirements:
+# - Docker with SYS_PTRACE capability (for PyTorch multiprocessing CUDA tensor sharing)
+# - Single H100 or equivalent GPU (can be adjusted for other accelerators)
+
+resources:
+ accelerators: H100:1
+ memory: 128+
+ image_id: docker:verlai/verl:app-verl0.6-transformers4.56.1-sglang0.5.2-mcore0.13.0-te2.2
+ ports:
+ - 8265 # Ray dashboard
+ - 8000 # Retrieval service
+
+num_nodes: 1
+
+config:
+ docker:
+ run_options:
+ - --cap-add=SYS_PTRACE # Required for PyTorch CUDA tensor sharing between Ray workers
+ - --ipc=host
+ - --shm-size=16g
+
+envs:
+ DATASET_SIZE: small # Options: small (1000 train, 200 test), medium (10k train, 2k test), full
+ TOTAL_EPOCHS: 1
+ TOTAL_STEPS: 10
+ TRAIN_BATCH_SIZE: 512 # Reduced from 512 for smaller steps
+ VAL_BATCH_SIZE: 256 # Reduced from 256 for smaller steps
+ SAVE_FREQ: 5 # Save checkpoints every 10 steps (reduced from 100)
+ TEST_FREQ: 5 # Test every 5 steps (reduced from 50)
+ MODEL_NAME: Qwen/Qwen2.5-3B-Instruct
+ WANDB_PROJECT_NAME: search_r1_like_async_rl
+ WANDB_EXPERIMENT_NAME: qwen2.5-3b-it_rm-searchR1-like-sgl-multiturn
+ CHECKPOINT_BUCKET_NAME: verl-search-interaction-checkpoints
+
+file_mounts:
+ /checkpoints:
+ name: ${CHECKPOINT_BUCKET_NAME}
+ mode: MOUNT
+
+secrets:
+ WANDB_API_KEY: "" # Optional: Set to enable WandB logging. If not set, only console logging will be used.
+
+setup: |
+ rm -f ~/.pip/pip.conf
+ rm -f ~/.config/pip/pip.conf
+
+ set -e
+
+ echo "=== VERL Search Tool Interaction Setup ==="
+
+ # System dependencies
+ echo "Installing system dependencies..."
+ sudo apt update && sudo apt install -y iproute2 npm
+
+ # Optional: Install AI CLI tools
+ npm i -g @anthropic-ai/claude-code -y
+ npm i -g @openai/codex -y
+ npm i -g @google/gemini-cli -y
+
+ # export IS_SANDBOX=1
+ # echo 'alias cx="codex --dangerously-bypass-approvals-and-sandbox --enable web_search_request"' >> ~/.bashrc
+ # echo 'alias ccd="claude --dangerously-skip-permissions"' >> ~/.bashrc
+
+ # echo 'alias cxh="codex -m gpt-5 -c model_reasoning_effort="high" --dangerously-bypass-approvals-and-sandbox --enable web_search_request"' >> ~/.bashrc
+ # echo 'alias gmi="gemini --telemetry false --yolo"' >> ~/.bashrc
+
+
+ # claude mcp add codex -s user -- codex -m gpt-5-codex -c model_reasoning_effort="high" --enable web_search_request mcp-server
+ # claude mcp add gpt -s user -- codex -m gpt-5 -c model_reasoning_effort="high" --enable web_search_request mcp-server
+
+ # claude mcp add gemini -- npx -y gemini-mcp-tool
+
+
+ # Python environment
+ echo "Setting up Python virtual environment..."
+ uv venv --python 3.10 --seed
+ source .venv/bin/activate
+
+ # Clone VERL repository
+ echo "Cloning VERL repository..."
+ rm -rf verl
+ git clone https://github.com/volcengine/verl.git
+ cd verl
+ git checkout v0.6.0
+
+ # Core dependencies
+ echo "Installing PyTorch and VERL..."
+ uv pip install "torch==2.8.*" torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
+ uv pip install "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"
+ uv pip install -v -e .
+ uv pip install wheel
+ uv pip install packaging
+ uv pip install -r ./requirements_sglang.txt
+
+ # Search/retrieval specific dependencies
+ echo "Installing retrieval service dependencies..."
+ uv pip install faiss-gpu-cu12
+ # issue with uvloop version https://github.com/volcengine/verl/issues/3806
+ uv pip install uvloop==0.21.0
+
+ # Download Wikipedia corpus and FAISS index
+ echo "Downloading Wikipedia corpus and FAISS index..."
+ export save_path=~/dataset
+ mkdir -p $save_path
+
+ huggingface-cli download maknee/wiki-18-subsets wiki-18-100k.jsonl.gz --repo-type=dataset --local-dir $save_path
+ huggingface-cli download maknee/wiki-18-subsets e5_Flat-100k.index --repo-type=dataset --local-dir $save_path
+
+ # Move files to expected locations
+ mv $save_path/wiki-18-100k.jsonl.gz $save_path/wiki-18.jsonl.gz
+ mv $save_path/e5_Flat-100k.index $save_path/e5_Flat.index
+
+ # Decompress the JSONL file
+ gzip -d $save_path/wiki-18.jsonl.gz -f
+
+ # Data preparation
+ echo "Preparing search R1 dataset..."
+ python3 examples/data_preprocess/preprocess_search_r1_dataset.py
+
+ # sed -i 's/num_workers: *120/num_workers: 8/' examples/sglang_multiturn/config/tool_config/search_tool_config.yaml
+
+ # # Setup faiss
+ # # Activate conda (only in the current shell)
+ # eval "$($HOME/miniconda3/bin/conda shell.bash hook)"
+
+ # # (Optional) Add conda to your default shell startup
+ # conda init
+
+ # # Reload shell config
+ # source ~/.bashrc
+
+ # # Create and activate the retriever environment with Python 3.10
+ # conda create -n retriever python=3.10 -y
+ # conda activate retriever
+
+ # # Install PyTorch (with GPU support) and related libraries
+ # conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.1 -c pytorch -c nvidia -y
+
+ # # Install other Python packages
+ # pip install transformers datasets pyserini huggingface_hub
+
+ # # Install the GPU version of faiss
+ # conda install faiss-gpu=1.9.0 -c pytorch -c nvidia -y
+
+ # # Install the API service framework
+ # pip install uvicorn fastapi hf_transfer
+
+ # echo "✓ Setup complete!"
+ # echo "Dataset location: ~/data/searchR1_processed_direct/"
+ # echo "VERL repository: $(pwd)"
+
+ git clone https://github.com/PeterGriffinJin/Search-R1/
+
+run: |
+ set -e
+
+ echo "=== VERL Search Tool Interaction Training ==="
+
+ # Multi-node setup
+ HEAD_IP=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
+ NUM_NODES=$SKYPILOT_NUM_NODES
+ NUM_GPUS_PER_NODE=$SKYPILOT_NUM_GPUS_PER_NODE
+
+ # Network configuration for distributed training
+ NETWORK_INTERFACE=$(ip route get 8.8.8.8 | grep -oP 'dev \K\S+')
+ export GLOO_SOCKET_IFNAME=$NETWORK_INTERFACE
+ export NCCL_SOCKET_IFNAME=$NETWORK_INTERFACE
+
+ # PyTorch multiprocessing configuration
+ export TORCH_MULTIPROCESSING_SHARING_STRATEGY=file_system
+
+ # Activate environment
+ source .venv/bin/activate
+
+ # Set up paths
+ cd verl
+ PROJECT_DIR="$(pwd)"
+ export PYTHONPATH="$PROJECT_DIR:$PYTHONPATH"
+
+ # Start retrieval service
+ echo "Starting retrieval server..."
+ # conda activate retriever
+ save_path=~/dataset
+ index_file=$save_path/e5_Flat.index
+ corpus_file=$save_path/wiki-18.jsonl
+ retriever_name=e5
+ retriever_path=intfloat/e5-base-v2
+
+ python examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py \
+ --index_path $index_file \
+ --corpus_path $corpus_file \
+ --topk 3 \
+ --retriever_name $retriever_name \
+ --retriever_model $retriever_path &
+
+ RETRIEVAL_PID=$!
+ sleep 10
+ conda deactivate
+
+ save_path=~/dataset
+ index_file=$save_path/e5_Flat.index
+ corpus_file=$save_path/wiki-18.jsonl
+ retriever_name=e5
+ retriever_path=intfloat/e5-base-v2
+
+ # WandB login (optional)
+ if [ -n "$WANDB_API_KEY" ]; then
+ echo "Logging into Weights & Biases..."
+ python3 -c "import wandb; wandb.login(relogin=True, key='$WANDB_API_KEY')"
+ fi
+
+ if [ "$SKYPILOT_NODE_RANK" == "0" ]; then
+ echo "Starting Ray head node on port 6379..."
+ ps aux | grep ray | grep 6379 &> /dev/null || ray start --head --disable-usage-stats --port=6379 --dashboard-host=0.0.0.0 --dashboard-port=8265
+
+ # Wait for all nodes to connect
+ echo "Waiting for $NUM_NODES nodes to connect..."
+ retry_count=0
+ max_retries=30
+ while [ $retry_count -lt $max_retries ]; do
+ connected_nodes=$(ray status 2>/dev/null | grep -c "node_" || echo "0")
+ if [ "$connected_nodes" -ge "$NUM_NODES" ]; then
+ echo "✓ All $NUM_NODES nodes connected"
+ break
+ fi
+ retry_count=$((retry_count+1))
+ sleep 10
+ done
+
+ # Display Ray cluster status
+ echo "Ray cluster status:"
+ ray status
+
+ echo "Starting search tool interaction training..."
+ cd $PROJECT_DIR
+
+ # Increase file descriptor limit
+ ulimit -n 65535
+
+ # Set up configuration paths
+ CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config"
+ TRAIN_DATA="$HOME/data/searchR1_processed_direct/train.parquet"
+ VAL_DATA="$HOME/data/searchR1_processed_direct/test.parquet"
+ TOOL_CONFIG="$CONFIG_PATH/tool_config/search_tool_config.yaml"
+
+ # Configure logging based on WANDB_API_KEY availability
+ if [ -n "$WANDB_API_KEY" ]; then
+ LOGGER_CONFIG='["console","wandb"]'
+ WANDB_ARGS="trainer.project_name=$WANDB_PROJECT_NAME trainer.experiment_name=$WANDB_EXPERIMENT_NAME"
+ echo "✓ WandB logging enabled"
+ else
+ LOGGER_CONFIG='["console"]'
+ WANDB_ARGS=""
+ echo "ℹ WandB logging disabled (no API key provided)"
+ fi
+
+ # Training with search tool
+ python3 -m verl.trainer.main_ppo \
+ --config-path="$CONFIG_PATH" \
+ --config-name='search_multiturn_grpo' \
+ algorithm.adv_estimator=grpo \
+ data.train_batch_size=$TRAIN_BATCH_SIZE \
+ data.val_batch_size=$VAL_BATCH_SIZE \
+ data.max_prompt_length=4096 \
+ data.max_response_length=3000 \
+ data.filter_overlong_prompts=True \
+ data.truncation='error' \
+ data.return_raw_chat=True \
+ actor_rollout_ref.model.path=$MODEL_NAME \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.285 \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.actor.ppo_mini_batch_size=16 \
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
+ actor_rollout_ref.actor.use_kl_loss=True \
+ actor_rollout_ref.actor.kl_loss_coef=0.001 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=True \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
+ actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
+ actor_rollout_ref.rollout.max_model_len=15000 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
+ actor_rollout_ref.rollout.name=sglang \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
+ actor_rollout_ref.rollout.n=5 \
+ actor_rollout_ref.rollout.multi_turn.max_assistant_turns=2 \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
+ algorithm.use_kl_in_reward=False \
+ trainer.critic_warmup=0 \
+ trainer.val_before_train=False \
+ trainer.logger="$LOGGER_CONFIG" \
+ $WANDB_ARGS \
+ trainer.n_gpus_per_node=$NUM_GPUS_PER_NODE \
+ trainer.nnodes=$NUM_NODES \
+ trainer.save_freq=$SAVE_FREQ \
+ trainer.test_freq=$TEST_FREQ \
+ data.train_files="$TRAIN_DATA" \
+ data.val_files="$VAL_DATA" \
+ actor_rollout_ref.rollout.multi_turn.tool_config_path="$TOOL_CONFIG" \
+ trainer.total_epochs=$TOTAL_EPOCHS \
+ trainer.total_training_steps=$TOTAL_STEPS \
+ trainer.default_local_dir=/checkpoints
+
+ echo "✓ Training complete!"
+
+ # Model checkpoint merging
+ echo "Merging model checkpoints..."
+ LATEST_STEP=$(cat /checkpoints/latest_checkpointed_iteration.txt)
+ CHECKPOINT_DIR="/checkpoints/global_step_${LATEST_STEP}/actor"
+
+ python -m verl.model_merger merge \
+ --backend fsdp \
+ --tie-word-embedding \
+ --local_dir ${CHECKPOINT_DIR} \
+ --target_dir /checkpoints/hf_model
+
+ echo "✓ Model saved to /checkpoints/hf_model"
+ echo "Training artifacts saved to cloud bucket: ${CHECKPOINT_BUCKET_NAME}"
+
+ # Cleanup retrieval service before starting vLLM
+ if [ -n "$RETRIEVAL_PID" ]; then
+ echo "Stopping retrieval service..."
+ kill $RETRIEVAL_PID 2>/dev/null || true
+ sleep 5
+ fi
+
+ else
+ # Worker node setup
+ echo "Worker node (rank $SKYPILOT_NODE_RANK) connecting to head at $HEAD_IP:6379..."
+ sleep 15
+ ps aux | grep ray | grep $HEAD_IP:6379 &> /dev/null || ray start --address $HEAD_IP:6379 --disable-usage-stats
+ echo "✓ Worker node connected"
+ sleep infinity
+ fi
diff --git a/pyproject.toml b/pyproject.toml
index 9a509234dc4..704fa38cac1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -20,7 +20,7 @@ addopts = "-s -n 16 -q --tb=short --dist loadgroup --disable-warnings"
asyncio_default_fixture_loop_scope = "function"
[tool.mypy]
-python_version = "3.8"
+python_version = "3.9"
follow_imports = "skip"
ignore_missing_imports = true
allow_redefinition = true
diff --git a/requirements-dev.txt b/requirements-dev.txt
index e0dd4be1763..90c8239537e 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -16,7 +16,7 @@ isort==5.12.0
# type checking
# match the version with .pre-commit-config.yaml
-mypy==1.14.1
+mypy==1.19.1
types-PyYAML
types-paramiko
# 2.31 requires urlib3>2, which is incompatible with IBM and
diff --git a/sky/__init__.py b/sky/__init__.py
index 2aec6b62c86..5faff53c6e5 100644
--- a/sky/__init__.py
+++ b/sky/__init__.py
@@ -155,6 +155,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]):
Hyperbolic = clouds.Hyperbolic
Shadeform = clouds.Shadeform
Seeweb = clouds.Seeweb
+Yotta = clouds.Yotta
__all__ = [
'__version__',
@@ -180,6 +181,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]):
'Hyperbolic',
'Shadeform',
'Seeweb',
+ 'Yotta',
'Optimizer',
'OptimizeTarget',
'backends',
diff --git a/sky/adaptors/kubernetes.py b/sky/adaptors/kubernetes.py
index 2ecae9e26fb..a4a13f834ea 100644
--- a/sky/adaptors/kubernetes.py
+++ b/sky/adaptors/kubernetes.py
@@ -1,8 +1,15 @@
-"""Kubernetes adaptors"""
+"""Kubernetes adaptors
+
+Thread safety notes:
+
+The API functions (core_api, batch_api, etc.) return cached clients that are
+created with context-specific ApiClient instances.
+"""
import functools
import logging
import os
import platform
+import typing
from typing import Any, Callable, Optional, Set
from sky import sky_logging
@@ -11,16 +18,17 @@
from sky.utils import common_utils
from sky.utils import ux_utils
-_IMPORT_ERROR_MESSAGE = ('Failed to import dependencies for Kubernetes. '
- 'Try running: pip install "skypilot[kubernetes]"')
-kubernetes = common.LazyImport('kubernetes',
- import_error_message=_IMPORT_ERROR_MESSAGE)
-models = common.LazyImport('kubernetes.client.models',
- import_error_message=_IMPORT_ERROR_MESSAGE)
-urllib3 = common.LazyImport('urllib3',
- import_error_message=_IMPORT_ERROR_MESSAGE)
-dateutil_parser = common.LazyImport('dateutil.parser',
- import_error_message=_IMPORT_ERROR_MESSAGE)
+if typing.TYPE_CHECKING:
+ import kubernetes
+ import urllib3
+ import urllib3.exceptions
+else:
+ _IMPORT_ERROR_MESSAGE = ('Failed to import dependencies for Kubernetes. '
+ 'Try running: pip install "skypilot[kubernetes]"')
+ kubernetes = common.LazyImport('kubernetes',
+ import_error_message=_IMPORT_ERROR_MESSAGE)
+ urllib3 = common.LazyImport('urllib3',
+ import_error_message=_IMPORT_ERROR_MESSAGE)
# Timeout to use for API calls
API_TIMEOUT = 5
@@ -86,13 +94,33 @@ def _get_config_file() -> str:
return os.environ.get('KUBECONFIG', '~/.kube/config')
-def _load_config(context: Optional[str] = None):
+def _get_api_client(context: Optional[str] = None) -> Any:
+ """Get an ApiClient for the given context without modifying global config.
+
+ This is fully thread-safe because it creates isolated Configuration
+ objects for each client rather than modifying the global
+ kubernetes.client.configuration.
+
+ Args:
+ context: The Kubernetes context to use. If None, tries in-cluster config
+ first, then falls back to kubeconfig current-context.
+
+ Returns:
+ A kubernetes.client.ApiClient configured for the specified context.
+
+ Raises:
+ ValueError: If the configuration cannot be loaded.
+ """
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
- def _load_config_from_kubeconfig(context: Optional[str] = None):
+ def _get_api_client_from_kubeconfig(context: Optional[str] = None) -> Any:
+ """Load kubeconfig, return ApiClient without modifying global state."""
try:
- kubernetes.config.load_kube_config(config_file=_get_config_file(),
- context=context)
+ # new_client_from_config returns an ApiClient configured for the
+ # specified context WITHOUT modifying the global configuration.
+ # This is the key to thread-safety.
+ return kubernetes.config.new_client_from_config(
+ config_file=_get_config_file(), context=context)
except kubernetes.config.config_exception.ConfigException as e:
suffix = common_utils.format_exception(e, use_bracket=True)
context_name = '(current-context)' if context is None else context
@@ -143,20 +171,27 @@ def _load_config_from_kubeconfig(context: Optional[str] = None):
if context == in_cluster_context_name() or context is None:
try:
# Load in-cluster config if running in a pod and context is None.
- # Kubernetes set environment variables for service discovery do not
- # show up in SkyPilot tasks. For now, we work around by using
- # DNS name instead of environment variables.
- # See issue: https://github.com/skypilot-org/skypilot/issues/2287
- # Only set if not already present (preserving existing values)
+ # Use InClusterConfigLoader with an explicit Configuration object
+ # to avoid modifying global state (thread-safe).
+ #
+ # Workaround: Kubernetes service discovery environment variables
+ # may not show up in SkyPilot tasks. We set them to DNS names as
+ # a fallback. See: github.com/skypilot-org/skypilot/issues/2287
if 'KUBERNETES_SERVICE_HOST' not in os.environ:
os.environ['KUBERNETES_SERVICE_HOST'] = 'kubernetes.default.svc'
if 'KUBERNETES_SERVICE_PORT' not in os.environ:
os.environ['KUBERNETES_SERVICE_PORT'] = '443'
- kubernetes.config.load_incluster_config()
+
+ config = kubernetes.client.Configuration()
+ kubernetes.config.load_incluster_config(config)
+ return kubernetes.client.ApiClient(configuration=config)
except kubernetes.config.config_exception.ConfigException:
- _load_config_from_kubeconfig()
- else:
- _load_config_from_kubeconfig(context)
+ if context == in_cluster_context_name():
+ # Explicitly requested in-cluster context but not in a cluster
+ raise
+ # Otherwise, if context is None, fall through to kubeconfig
+
+ return _get_api_client_from_kubeconfig(context)
def list_kube_config_contexts():
@@ -219,88 +254,83 @@ def wrapper(*args, **kwargs):
@annotations.lru_cache(scope='request')
@wrap_kubernetes_client
def core_api(context: Optional[str] = None):
- _load_config(context)
- return kubernetes.client.CoreV1Api()
+ return kubernetes.client.CoreV1Api(api_client=_get_api_client(context))
@_api_logging_decorator('urllib3', logging.ERROR)
@annotations.lru_cache(scope='request')
@wrap_kubernetes_client
def storage_api(context: Optional[str] = None):
- _load_config(context)
- return kubernetes.client.StorageV1Api()
+ return kubernetes.client.StorageV1Api(api_client=_get_api_client(context))
@_api_logging_decorator('urllib3', logging.ERROR)
@annotations.lru_cache(scope='request')
@wrap_kubernetes_client
def auth_api(context: Optional[str] = None):
- _load_config(context)
- return kubernetes.client.RbacAuthorizationV1Api()
+ return kubernetes.client.RbacAuthorizationV1Api(
+ api_client=_get_api_client(context))
@_api_logging_decorator('urllib3', logging.ERROR)
@annotations.lru_cache(scope='request')
@wrap_kubernetes_client
def networking_api(context: Optional[str] = None):
- _load_config(context)
- return kubernetes.client.NetworkingV1Api()
+ return kubernetes.client.NetworkingV1Api(
+ api_client=_get_api_client(context))
@_api_logging_decorator('urllib3', logging.ERROR)
@annotations.lru_cache(scope='request')
@wrap_kubernetes_client
def custom_objects_api(context: Optional[str] = None):
- _load_config(context)
- return kubernetes.client.CustomObjectsApi()
+ return kubernetes.client.CustomObjectsApi(
+ api_client=_get_api_client(context))
@_api_logging_decorator('urllib3', logging.ERROR)
@annotations.lru_cache(scope='global')
@wrap_kubernetes_client
def node_api(context: Optional[str] = None):
- _load_config(context)
- return kubernetes.client.NodeV1Api()
+ return kubernetes.client.NodeV1Api(api_client=_get_api_client(context))
@_api_logging_decorator('urllib3', logging.ERROR)
@annotations.lru_cache(scope='request')
@wrap_kubernetes_client
def apps_api(context: Optional[str] = None):
- _load_config(context)
- return kubernetes.client.AppsV1Api()
+ return kubernetes.client.AppsV1Api(api_client=_get_api_client(context))
@_api_logging_decorator('urllib3', logging.ERROR)
@annotations.lru_cache(scope='request')
@wrap_kubernetes_client
def batch_api(context: Optional[str] = None):
- _load_config(context)
- return kubernetes.client.BatchV1Api()
+ return kubernetes.client.BatchV1Api(api_client=_get_api_client(context))
@_api_logging_decorator('urllib3', logging.ERROR)
@annotations.lru_cache(scope='request')
@wrap_kubernetes_client
def api_client(context: Optional[str] = None):
- _load_config(context)
- return kubernetes.client.ApiClient()
+ return _get_api_client(context)
@_api_logging_decorator('urllib3', logging.ERROR)
@annotations.lru_cache(scope='request')
@wrap_kubernetes_client
def custom_resources_api(context: Optional[str] = None):
- _load_config(context)
- return kubernetes.client.CustomObjectsApi()
+ return kubernetes.client.CustomObjectsApi(
+ api_client=_get_api_client(context))
@_api_logging_decorator('urllib3', logging.ERROR)
@annotations.lru_cache(scope='request')
@wrap_kubernetes_client
def watch(context: Optional[str] = None):
- _load_config(context)
- return kubernetes.watch.Watch()
+ w = kubernetes.watch.Watch()
+ w._api_client = _get_api_client(context) # pylint: disable=protected-access
+ return w
def api_exception():
diff --git a/sky/adaptors/oci.py b/sky/adaptors/oci.py
index cbf4f9354b3..8fa6539b066 100644
--- a/sky/adaptors/oci.py
+++ b/sky/adaptors/oci.py
@@ -73,18 +73,24 @@ def service_exception():
def with_oci_env(f):
+ """Wraps a function to return a single shell command string (joined by '&&')
+ that ensures OCI CLI is available before running the actual OCI
+ command returned by `f`.
+ """
@functools.wraps(f)
def wrapper(*args, **kwargs):
- # pylint: disable=line-too-long
+ oci_venv_dir = '"$HOME/sky-oci-cli-env"'
enter_env_cmds = [
- 'conda info --envs | grep "sky-oci-cli-env" || conda create -n sky-oci-cli-env python=3.10 -y',
- '. $(conda info --base 2> /dev/null)/etc/profile.d/conda.sh > /dev/null 2>&1 || true',
- 'conda activate sky-oci-cli-env', 'pip install oci-cli',
- 'export OCI_CLI_SUPPRESS_FILE_PERMISSIONS_WARNING=True'
+ # Create the venv if missing
+ (f'[ -d {oci_venv_dir} ] || '
+ f'uv venv --seed {oci_venv_dir} --python 3.10'),
+ f'source {oci_venv_dir}/bin/activate',
+ 'uv pip install oci-cli',
+ 'export OCI_CLI_SUPPRESS_FILE_PERMISSIONS_WARNING=True',
]
operation_cmd = [f(*args, **kwargs)]
- leave_env_cmds = ['conda deactivate']
+ leave_env_cmds = ['deactivate']
return ' && '.join(enter_env_cmds + operation_cmd + leave_env_cmds)
return wrapper
diff --git a/sky/adaptors/slurm.py b/sky/adaptors/slurm.py
index 69b05c3bcfe..e937b101b63 100644
--- a/sky/adaptors/slurm.py
+++ b/sky/adaptors/slurm.py
@@ -129,6 +129,8 @@ def __init__(
ssh_proxy_command=ssh_proxy_command,
ssh_proxy_jump=ssh_proxy_jump,
enable_interactive_auth=True,
+ # Allow ssh-agent and default key fallback for Slurm.
+ disable_identities_only=True,
)
def _run_slurm_cmd(self, cmd: str) -> Tuple[int, str, str]:
@@ -625,3 +627,22 @@ def get_partitions(self) -> List[str]:
at the end of the name.
"""
return [partition.name for partition in self.get_partitions_info()]
+
+ def get_proctrack_type(self) -> Optional[str]:
+ """Get the ProctrackType from Slurm configuration.
+
+ Returns:
+ The proctrack type (e.g., 'cgroup', 'linuxproc', 'pgid'),
+ or None if it cannot be determined.
+ """
+ cmd = 'scontrol show config | grep -i "^ProctrackType"'
+ rc, stdout, stderr = self._run_slurm_cmd(cmd)
+ if rc != 0:
+ logger.warning(f'Failed to get ProctrackType: {stderr}')
+ return None
+
+ # Parse output like "ProctrackType = proctrack/cgroup"
+ match = re.search(r'ProctrackType\s*=\s*proctrack/(\w+)', stdout)
+ if match:
+ return match.group(1)
+ return None
diff --git a/sky/adaptors/yotta.py b/sky/adaptors/yotta.py
new file mode 100644
index 00000000000..59c3808e58c
--- /dev/null
+++ b/sky/adaptors/yotta.py
@@ -0,0 +1 @@
+"""Yotta cloud adaptor."""
diff --git a/sky/authentication.py b/sky/authentication.py
index a2e14947a12..fac32878eef 100644
--- a/sky/authentication.py
+++ b/sky/authentication.py
@@ -28,6 +28,7 @@
import uuid
import colorama
+import filelock
from sky import clouds
from sky import exceptions
@@ -228,9 +229,14 @@ def setup_lambda_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
with open(public_key_path, 'r', encoding='utf-8') as f:
public_key = f.read().strip()
prefix = f'sky-key-{common_utils.get_user_hash()}'
- name, exists = lambda_client.get_unique_ssh_key_name(prefix, public_key)
- if not exists:
- lambda_client.register_ssh_key(name, public_key)
+
+ lock_path = os.path.expanduser(
+ '~/.sky/locks/lambda-cloud-ssh-key-registration.lock')
+ os.makedirs(os.path.dirname(lock_path), exist_ok=True)
+ with filelock.FileLock(lock_path):
+ name, exists = lambda_client.get_unique_ssh_key_name(prefix, public_key)
+ if not exists:
+ lambda_client.register_ssh_key(name, public_key)
config['auth']['remote_key_name'] = name
return config
diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py
index 97abbd4d04c..b6fad8a9f79 100644
--- a/sky/backends/backend_utils.py
+++ b/sky/backends/backend_utils.py
@@ -139,7 +139,7 @@
# Time that must elapse since the last status check before we should re-check if
# the cluster has been terminated or autostopped.
-_CLUSTER_STATUS_CACHE_DURATION_SECONDS = 2
+CLUSTER_STATUS_CACHE_DURATION_SECONDS = 2
CLUSTER_FILE_MOUNTS_LOCK_TIMEOUT_SECONDS = 10
WORKSPACE_LOCK_TIMEOUT_SECONDS = 10
@@ -653,6 +653,7 @@ def write_cluster_config(
dryrun: bool = False,
keep_launch_fields_in_existing_config: bool = True,
volume_mounts: Optional[List['volume_utils.VolumeMount']] = None,
+ cloud_specific_failover_overrides: Optional[Dict[str, Any]] = None,
) -> Dict[str, str]:
"""Fills in cluster configuration templates and writes them out.
@@ -726,7 +727,8 @@ def write_cluster_config(
cloud=str(cloud).lower(),
region=region.name,
keys=('remote_identity',),
- default_value=None)
+ default_value=None,
+ override_configs=to_provision.cluster_config_overrides)
remote_identity = schemas.get_default_remote_identity(str(cloud).lower())
if isinstance(remote_identity_config, str):
remote_identity = remote_identity_config
@@ -899,6 +901,9 @@ def write_cluster_config(
if to_provision.labels:
labels.update(to_provision.labels)
+ install_conda = skypilot_config.get_nested(('provision', 'install_conda'),
+ True)
+
# We disable conda auto-activation if the user has specified a docker image
# to use, which is likely to already have a conda environment activated.
conda_auto_activate = ('true' if to_provision.extract_docker_image() is None
@@ -949,114 +954,122 @@ def write_cluster_config(
# Use a tmp file path to avoid incomplete YAML file being re-used in the
# future.
tmp_yaml_path = yaml_path + '.tmp'
- common_utils.fill_template(
- cluster_config_template,
- dict(
- resources_vars,
- **{
- 'cluster_name_on_cloud': cluster_name_on_cloud,
- 'num_nodes': num_nodes,
- 'disk_size': to_provision.disk_size,
- # If the current code is run by controller, propagate the real
- # calling user which should've been passed in as the
- # SKYPILOT_USER env var (see
- # controller_utils.shared_controller_vars_to_fill().
- 'user': common_utils.get_cleaned_username(
- os.environ.get(constants.USER_ENV_VAR, '')),
-
- # Networking configs
- 'use_internal_ips': skypilot_config.get_effective_region_config(
- cloud=str(cloud).lower(),
- region=region.name,
- keys=('use_internal_ips',),
- default_value=False),
- 'ssh_proxy_command': ssh_proxy_command,
- 'vpc_name': skypilot_config.get_effective_region_config(
- cloud=str(cloud).lower(),
- region=region.name,
- keys=('vpc_name',),
- default_value=None),
- # User-supplied labels.
- 'labels': labels,
- # User-supplied remote_identity
- 'remote_identity': remote_identity,
- # The reservation pools that specified by the user. This is
- # currently only used by AWS and GCP.
- 'specific_reservations': specific_reservations,
-
- # Conda setup
- # We should not use `.format`, as it contains '{}' as the bash
- # syntax.
- 'conda_installation_commands':
- constants.CONDA_INSTALLATION_COMMANDS.replace(
- '{conda_auto_activate}',
- conda_auto_activate).replace('{is_custom_docker}',
- is_custom_docker),
- # Currently only used by Slurm. For other clouds, it is
- # already part of ray_skypilot_installation_commands
- 'setup_sky_dirs_commands': constants.SETUP_SKY_DIRS_COMMANDS,
- 'ray_skypilot_installation_commands':
- (constants.RAY_SKYPILOT_INSTALLATION_COMMANDS.replace(
- '{sky_wheel_hash}',
- wheel_hash).replace('{cloud}',
- str(cloud).lower())),
- 'skypilot_wheel_installation_commands':
- constants.SKYPILOT_WHEEL_INSTALLATION_COMMANDS.replace(
- '{sky_wheel_hash}',
- wheel_hash).replace('{cloud}',
- str(cloud).lower()),
- 'copy_skypilot_templates_commands':
- constants.COPY_SKYPILOT_TEMPLATES_COMMANDS,
- # Port of Ray (GCS server).
- # Ray's default port 6379 is conflicted with Redis.
- 'ray_port': constants.SKY_REMOTE_RAY_PORT,
- 'ray_dashboard_port': constants.SKY_REMOTE_RAY_DASHBOARD_PORT,
- 'ray_temp_dir': constants.SKY_REMOTE_RAY_TEMPDIR,
- 'dump_port_command': instance_setup.DUMP_RAY_PORTS,
- # Sky-internal constants.
- 'sky_ray_cmd': constants.SKY_RAY_CMD,
- # pip install needs to have python env activated to make sure
- # installed packages are within the env path.
- 'sky_pip_cmd': f'{constants.SKY_PIP_CMD}',
- # Activate the SkyPilot runtime environment when starting ray
- # cluster, so that ray autoscaler can access cloud SDK and CLIs
- # on remote
- 'sky_activate_python_env':
- constants.ACTIVATE_SKY_REMOTE_PYTHON_ENV,
- 'ray_version': constants.SKY_REMOTE_RAY_VERSION,
- # Command for waiting ray cluster to be ready on head.
- 'ray_head_wait_initialized_command':
- instance_setup.RAY_HEAD_WAIT_INITIALIZED_COMMAND,
-
- # Cloud credentials for cloud storage.
- 'credentials': credentials,
- # Sky remote utils.
- 'sky_remote_path': SKY_REMOTE_PATH,
- 'sky_local_path': str(local_wheel_path),
- # Add yaml file path to the template variables.
- 'sky_ray_yaml_remote_path':
- cluster_utils.SKY_CLUSTER_YAML_REMOTE_PATH,
- 'sky_ray_yaml_local_path': tmp_yaml_path,
- 'sky_version': str(version.parse(sky.__version__)),
- 'sky_wheel_hash': wheel_hash,
- 'ssh_max_sessions_config':
- constants.SET_SSH_MAX_SESSIONS_CONFIG_CMD,
- # Authentication (optional).
- **auth_config,
-
- # Controller specific configs
- 'is_remote_controller': is_remote_controller,
- 'high_availability': high_availability_specified,
-
- # Volume mounts
- 'volume_mounts': volume_mount_vars,
- 'ephemeral_volume_mounts': ephemeral_volume_mount_vars,
-
- # runcmd to run before any of the SkyPilot runtime setup commands.
- # This is currently only used by AWS and Kubernetes.
- 'runcmd': runcmd,
- }),
- output_path=tmp_yaml_path)
+ variables = dict(
+ resources_vars,
+ **{
+ 'cluster_name_on_cloud': cluster_name_on_cloud,
+ 'num_nodes': num_nodes,
+ 'disk_size': to_provision.disk_size,
+ # If the current code is run by controller, propagate the real
+ # calling user which should've been passed in as the
+ # SKYPILOT_USER env var (see
+ # controller_utils.shared_controller_vars_to_fill().
+ 'user': common_utils.get_cleaned_username(
+ os.environ.get(constants.USER_ENV_VAR, '')),
+
+ # Networking configs
+ 'use_internal_ips': skypilot_config.get_effective_region_config(
+ cloud=str(cloud).lower(),
+ region=region.name,
+ keys=('use_internal_ips',),
+ default_value=False),
+ 'ssh_proxy_command': ssh_proxy_command,
+ # TODO (kyuds): for backwards compatibility. If `vpc_names`
+ # is set, this will be overridden. We can remove this after
+ # v0.13.0 if all clouds that currently support `vpc_name`
+ # migrates to `vpc_names` (ie: gcp)
+ 'vpc_name': skypilot_config.get_effective_region_config(
+ cloud=str(cloud).lower(),
+ region=region.name,
+ keys=('vpc_name',),
+ default_value=None),
+ # User-supplied labels.
+ 'labels': labels,
+ # User-supplied remote_identity
+ 'remote_identity': remote_identity,
+ # The reservation pools that specified by the user. This is
+ # currently only used by AWS and GCP.
+ 'specific_reservations': specific_reservations,
+
+ # Conda setup
+ # We should not use `.format`, as it contains '{}' as the bash
+ # syntax.
+ 'conda_installation_commands':
+ constants.CONDA_INSTALLATION_COMMANDS.replace(
+ '{conda_auto_activate}', conda_auto_activate).replace(
+ '{is_custom_docker}', is_custom_docker)
+ if install_conda else '',
+ # UV setup
+ 'uv_installation_commands': constants.UV_INSTALLATION_COMMANDS,
+ # Currently only used by Slurm. For other clouds, it is
+ # already part of ray_skypilot_installation_commands
+ 'setup_sky_dirs_commands': constants.SETUP_SKY_DIRS_COMMANDS,
+ 'ray_skypilot_installation_commands':
+ (constants.RAY_SKYPILOT_INSTALLATION_COMMANDS.replace(
+ '{sky_wheel_hash}',
+ wheel_hash).replace('{cloud}',
+ str(cloud).lower())),
+ 'skypilot_wheel_installation_commands':
+ constants.SKYPILOT_WHEEL_INSTALLATION_COMMANDS.replace(
+ '{sky_wheel_hash}',
+ wheel_hash).replace('{cloud}',
+ str(cloud).lower()),
+ 'copy_skypilot_templates_commands':
+ constants.COPY_SKYPILOT_TEMPLATES_COMMANDS,
+ # Port of Ray (GCS server).
+ # Ray's default port 6379 is conflicted with Redis.
+ 'ray_port': constants.SKY_REMOTE_RAY_PORT,
+ 'ray_dashboard_port': constants.SKY_REMOTE_RAY_DASHBOARD_PORT,
+ 'ray_temp_dir': constants.SKY_REMOTE_RAY_TEMPDIR,
+ 'dump_port_command': instance_setup.DUMP_RAY_PORTS,
+ # Sky-internal constants.
+ 'sky_ray_cmd': constants.SKY_RAY_CMD,
+ # pip install needs to have python env activated to make sure
+ # installed packages are within the env path.
+ 'sky_pip_cmd': f'{constants.SKY_PIP_CMD}',
+ # Activate the SkyPilot runtime environment when starting ray
+ # cluster, so that ray autoscaler can access cloud SDK and CLIs
+ # on remote
+ 'sky_activate_python_env': constants.ACTIVATE_SKY_REMOTE_PYTHON_ENV,
+ 'ray_version': constants.SKY_REMOTE_RAY_VERSION,
+ # Command for waiting ray cluster to be ready on head.
+ 'ray_head_wait_initialized_command':
+ instance_setup.RAY_HEAD_WAIT_INITIALIZED_COMMAND,
+
+ # Cloud credentials for cloud storage.
+ 'credentials': credentials,
+ # Sky remote utils.
+ 'sky_remote_path': SKY_REMOTE_PATH,
+ 'sky_local_path': str(local_wheel_path),
+ # Add yaml file path to the template variables.
+ 'sky_ray_yaml_remote_path':
+ cluster_utils.SKY_CLUSTER_YAML_REMOTE_PATH,
+ 'sky_ray_yaml_local_path': tmp_yaml_path,
+ 'sky_version': str(version.parse(sky.__version__)),
+ 'sky_wheel_hash': wheel_hash,
+ 'ssh_max_sessions_config':
+ constants.SET_SSH_MAX_SESSIONS_CONFIG_CMD,
+ # Authentication (optional).
+ **auth_config,
+
+ # Controller specific configs
+ 'is_remote_controller': is_remote_controller,
+ 'high_availability': high_availability_specified,
+
+ # Volume mounts
+ 'volume_mounts': volume_mount_vars,
+ 'ephemeral_volume_mounts': ephemeral_volume_mount_vars,
+
+ # runcmd to run before any of the SkyPilot runtime setup commands.
+ # This is currently only used by AWS and Kubernetes.
+ 'runcmd': runcmd,
+ },
+ )
+ if cloud_specific_failover_overrides is not None:
+ variables.update(cloud_specific_failover_overrides)
+ common_utils.fill_template(cluster_config_template,
+ variables,
+ output_path=tmp_yaml_path)
config_dict['cluster_name'] = cluster_name
config_dict['ray'] = yaml_path
@@ -1173,6 +1186,7 @@ def _add_auth_to_cluster_config(cloud: clouds.Cloud, tmp_yaml_path: str):
clouds.Azure,
clouds.DO,
clouds.Nebius,
+ clouds.Yotta,
)):
config = auth.configure_ssh_info(config)
elif isinstance(cloud, clouds.GCP):
@@ -1509,23 +1523,6 @@ def wait_until_ray_cluster_ready(
return True, docker_user # success
-def _get_ssh_control_name(config: Dict[str, Any]) -> str:
- ssh_provider_module = config['provider']['module']
- ssh_control_name = config.get('cluster_name',
- command_runner.DEFAULT_SSH_CONTROL_NAME)
- if 'slurm' in ssh_provider_module:
- # For Slurm, multiple SkyPilot clusters may share the same underlying
- # Slurm login node. By using a fixed ssh_control_name ('__default__'),
- # we ensure that all connections to the same login node reuse the same
- # SSH ControlMaster process, avoiding repeated SSH handshakes.
- #
- # The %C token in ControlPath (see ssh_options_list) ensures that
- # connections to different login nodes use different sockets, avoiding
- # collisions between different Slurm clusters.
- ssh_control_name = command_runner.DEFAULT_SSH_CONTROL_NAME
- return ssh_control_name
-
-
def ssh_credential_from_yaml(
cluster_yaml: Optional[str],
docker_user: Optional[str] = None,
@@ -1546,7 +1543,7 @@ def ssh_credential_from_yaml(
if ssh_user is None:
ssh_user = auth_section['ssh_user'].strip()
ssh_private_key_path = auth_section.get('ssh_private_key')
- ssh_control_name = _get_ssh_control_name(config)
+ ssh_control_name = config.get('cluster_name', '__default__')
ssh_proxy_command = auth_section.get('ssh_proxy_command')
# Update the ssh_user placeholder in proxy command, if required
@@ -1600,7 +1597,7 @@ def ssh_credentials_from_handles(
if ssh_user is None:
ssh_user = auth_section['ssh_user'].strip()
ssh_private_key_path = auth_section.get('ssh_private_key')
- ssh_control_name = _get_ssh_control_name(config)
+ ssh_control_name = config.get('cluster_name', '__default__')
ssh_proxy_command = auth_section.get('ssh_proxy_command')
# Update the ssh_user placeholder in proxy command, if required
@@ -2431,6 +2428,42 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool:
exc_info=e)
return False
+ def _handle_autostopping_cluster(
+ print_newline: bool = False) -> Optional[Dict[str, Any]]:
+ """Handle cluster that is autostopping/autodowning.
+
+ Sets the cluster status to AUTOSTOPPING and returns the cluster record.
+
+ Args:
+ print_newline: Whether to print a newline before logging (for UX).
+
+ Returns:
+ Cluster record if autostopping, None otherwise.
+ """
+ # The cluster is autostopping - set to AUTOSTOPPING status
+ if print_newline:
+ ux_utils.console_newline()
+ operation_str = 'autodowning' if record.get('to_down',
+ False) else 'autostopping'
+ logger.info(f'Cluster {cluster_name!r} is {operation_str}.')
+
+ # Set cluster to AUTOSTOPPING status
+ record['status'] = status_lib.ClusterStatus.AUTOSTOPPING
+ global_user_state.add_cluster_event(
+ cluster_name,
+ status_lib.ClusterStatus.AUTOSTOPPING,
+ f'Cluster is {operation_str}.',
+ global_user_state.ClusterEventType.STATUS_CHANGE,
+ nop_if_duplicate=True)
+ # Use set_cluster_status() to directly update the status in DB
+ # instead of add_or_update_cluster() which only supports INIT/UP
+ global_user_state.set_cluster_status(
+ cluster_name, status_lib.ClusterStatus.AUTOSTOPPING)
+ return global_user_state.get_cluster_from_name(
+ cluster_name,
+ include_user_info=include_user_info,
+ summary_response=summary_response)
+
# Determining if the cluster is healthy (UP):
#
# For non-spot clusters: If ray status shows all nodes are healthy, it is
@@ -2452,6 +2485,13 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool:
# NOTE: all_nodes_up calculation is fast due to calling cloud CLI;
# run_ray_status_to_check_all_nodes_up() is slow due to calling `ray get
# head-ip/worker-ips`.
+
+ # Check if the cluster is in the process of autostopping
+ backend = get_backend_from_handle(handle)
+ if isinstance(backend, backends.CloudVmRayBackend):
+ if backend.is_definitely_autostopping(handle, stream_logs=False):
+ return _handle_autostopping_cluster(print_newline=False)
+
record['status'] = status_lib.ClusterStatus.UP
# Add cluster event for instance status check.
global_user_state.add_cluster_event(
@@ -2586,12 +2626,24 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool:
backend = get_backend_from_handle(handle)
if isinstance(backend, backends.CloudVmRayBackend):
- if is_head_node_alive:
+ # Check autostopping first, before head_node_alive check
+ # This ensures we detect AUTOSTOPPING even when Ray becomes
+ # unhealthy during hook execution, or if the actual nodes are
+ # partially autostopped but not completely yet.
+ is_autostopping = backend.is_definitely_autostopping(
+ handle, stream_logs=False)
+
+ if is_autostopping:
+ logger.debug(
+ f'The cluster {cluster_name!r} is abnormal '
+ f'({init_reason}) but is definitely autostopping. '
+ 'Returning AUTOSTOPPING status.')
+ return _handle_autostopping_cluster(print_newline=True)
+ elif is_head_node_alive:
logger.debug(
f'Skipping autostop reset for cluster {cluster_name!r} '
'because the head node is alive.')
- elif not backend.is_definitely_autostopping(handle,
- stream_logs=False):
+ elif not is_autostopping:
# Friendly hint.
autostop = record['autostop']
maybe_down_str = ' --down' if record['to_down'] else ''
@@ -2642,13 +2694,6 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool:
f'abnormal state. To fix, try running: {reset}{bright}sky '
f'start -f -i {autostop}{maybe_down_str} {cluster_name}'
f'{reset}')
- else:
- ux_utils.console_newline()
- operation_str = 'autodowning' if record[
- 'to_down'] else 'autostopping'
- logger.info(
- f'Cluster {cluster_name!r} is {operation_str}. Setting to '
- 'INIT status; try refresh again in a while.')
# If the user starts part of a STOPPED cluster, we still need a status
# to represent the abnormal status. For spot cluster, it can also
@@ -2734,10 +2779,13 @@ def _must_refresh_cluster_status(
use_spot = record['handle'].launched_resources.use_spot
has_autostop = (record['status'] != status_lib.ClusterStatus.STOPPED and
record['autostop'] >= 0)
+ # If cluster is AUTOSTOPPING, always refresh to check if it transitioned to STOPPED
+ is_autostopping = record['status'] == status_lib.ClusterStatus.AUTOSTOPPING
recently_refreshed = (record['status_updated_at'] is not None and
time.time() - record['status_updated_at'] <
- _CLUSTER_STATUS_CACHE_DURATION_SECONDS)
- is_stale = (use_spot or has_autostop) and not recently_refreshed
+ CLUSTER_STATUS_CACHE_DURATION_SECONDS)
+ is_stale = (use_spot or has_autostop or
+ is_autostopping) and not recently_refreshed
return force_refresh_for_cluster or is_stale
@@ -2764,7 +2812,7 @@ def refresh_cluster_record(
following conditions will be refreshed no matter the argument is
specified or not:
- the most latest available status update is more than
- _CLUSTER_STATUS_CACHE_DURATION_SECONDS old, and one of:
+ CLUSTER_STATUS_CACHE_DURATION_SECONDS old, and one of:
1. the cluster is a spot cluster, or
2. cluster autostop is set and the cluster is not STOPPED.
cluster_lock_already_held: Whether the caller is already holding the
@@ -3021,7 +3069,8 @@ def check_cluster_available(
f'cluster {cluster_name!r}. It is only supported by backend: '
f'{backends.CloudVmRayBackend.NAME}.'
f'{reset}')
- if cluster_status != status_lib.ClusterStatus.UP:
+ if cluster_status not in (status_lib.ClusterStatus.UP,
+ status_lib.ClusterStatus.AUTOSTOPPING):
with ux_utils.print_exception_no_traceback():
hint_for_init = ''
if cluster_status == status_lib.ClusterStatus.INIT:
@@ -3033,7 +3082,8 @@ def check_cluster_available(
f'{colorama.Fore.YELLOW}{operation.capitalize()}: skipped for '
f'cluster {cluster_name!r} (status: {cluster_status.value}). '
'It is only allowed for '
- f'{status_lib.ClusterStatus.UP.value} clusters.'
+ f'{status_lib.ClusterStatus.UP.value} and '
+ f'{status_lib.ClusterStatus.AUTOSTOPPING.value} clusters.'
f'{hint_for_init}'
f'{reset}',
cluster_status=cluster_status,
@@ -3174,7 +3224,9 @@ def is_controller_accessible(
if not runner.check_connection():
error_msg = controller.value.connection_error_hint
else:
- assert controller_status == status_lib.ClusterStatus.UP, handle
+ assert controller_status in (
+ status_lib.ClusterStatus.UP,
+ status_lib.ClusterStatus.AUTOSTOPPING), handle
if error_msg is not None:
if exit_if_not_accessible:
@@ -3802,8 +3854,9 @@ def get_endpoints(cluster: str,
f'Cluster {cluster!r} not found.', cluster_status=None)
assert len(cluster_records) == 1, cluster_records
cluster_record = cluster_records[0]
- if (not skip_status_check and
- cluster_record['status'] != status_lib.ClusterStatus.UP):
+ if (not skip_status_check and cluster_record['status']
+ not in (status_lib.ClusterStatus.UP,
+ status_lib.ClusterStatus.AUTOSTOPPING)):
with ux_utils.print_exception_no_traceback():
raise exceptions.ClusterNotUpError(
f'Cluster {cluster_record["name"]!r} '
diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py
index 3eeb16e70c4..927afb4331e 100644
--- a/sky/backends/cloud_vm_ray_backend.py
+++ b/sky/backends/cloud_vm_ray_backend.py
@@ -44,14 +44,17 @@
from sky.backends import wheel_utils
from sky.clouds import cloud as sky_cloud
from sky.clouds.utils import gcp_utils
+from sky.dag import DEFAULT_EXECUTION
from sky.data import data_utils
from sky.data import storage as storage_lib
from sky.provision import common as provision_common
+from sky.provision import constants as provision_constants
from sky.provision import instance_setup
from sky.provision import metadata_utils
from sky.provision import provisioner
from sky.provision.kubernetes import config as config_lib
from sky.provision.kubernetes import utils as kubernetes_utils
+from sky.provision.slurm import utils as slurm_utils
from sky.serve import constants as serve_constants
from sky.server.requests import requests as requests_lib
from sky.skylet import autostop_lib
@@ -184,6 +187,7 @@
_MAX_GET_ZONE_RETRY = 3
_JOB_ID_PATTERN = re.compile(r'Job ID: ([0-9]+)')
+_JOB_IDS_PATTERN = re.compile(r'Job IDs: ([0-9,]+)')
_LOG_DIR_PATTERN = re.compile(r'Log Dir: ([^ ]+)')
# Path to the monkey-patched ray up script.
@@ -285,7 +289,8 @@ def _get_cluster_config_template(cloud):
clouds.Fluidstack: 'fluidstack-ray.yml.j2',
clouds.Nebius: 'nebius-ray.yml.j2',
clouds.Hyperbolic: 'hyperbolic-ray.yml.j2',
- clouds.Seeweb: 'seeweb-ray.yml.j2'
+ clouds.Seeweb: 'seeweb-ray.yml.j2',
+ clouds.Yotta: 'yotta-ray.yml.j2',
}
return cloud_to_template[type(cloud)]
@@ -932,7 +937,7 @@ def _insufficient_resources_msg(
message += (f'{to_provision.cloud} for {requested_resources}. ')
return message
- def _retry_zones(
+ def _retry_zones( # pylint: disable=line-too-long
self,
to_provision: resources_lib.Resources,
num_nodes: int,
@@ -1044,334 +1049,346 @@ def _retry_zones(
else:
zone_str = ','.join(z.name for z in zones)
zone_str = f' ({zone_str})'
- try:
- config_dict = backend_utils.write_cluster_config(
- to_provision,
- num_nodes,
- _get_cluster_config_template(to_provision.cloud),
- cluster_name,
- self._local_wheel_path,
- self._wheel_hash,
- region=region,
- zones=zones,
- dryrun=dryrun,
- keep_launch_fields_in_existing_config=cluster_exists,
- volume_mounts=volume_mounts,
- )
- except exceptions.ResourcesUnavailableError as e:
- # Failed due to catalog issue, e.g. image not found, or
- # GPUs are requested in a Kubernetes cluster but the cluster
- # does not have nodes labeled with GPU types.
- logger.info(f'{e}')
- continue
- except exceptions.InvalidCloudCredentials as e:
- # Failed due to invalid cloud credentials.
- logger.warning(f'{common_utils.format_exception(e)}')
- # We should block the entire cloud for invalid cloud credentials
- _add_to_blocked_resources(
- self._blocked_resources,
- to_provision.copy(region=None, zone=None))
- raise exceptions.ResourcesUnavailableError(
- f'Failed to provision on cloud {to_provision.cloud} due to '
- f'invalid cloud credentials: '
- f'{common_utils.format_exception(e)}')
- except exceptions.InvalidCloudConfigs as e:
- # Failed due to invalid user configs in ~/.sky/config.yaml.
- logger.warning(f'{common_utils.format_exception(e)}')
- # We should block the entire cloud if the user config is
- # invalid.
- _add_to_blocked_resources(
- self._blocked_resources,
- to_provision.copy(region=None, zone=None))
- raise exceptions.ResourcesUnavailableError(
- f'Failed to provision on cloud {to_provision.cloud} due to '
- f'invalid cloud config: {common_utils.format_exception(e)}')
- if ('config_hash' in config_dict and
- skip_if_config_hash_matches == config_dict['config_hash']):
- logger.debug('Skipping provisioning of cluster with matching '
- 'config hash.')
- config_dict['provisioning_skipped'] = True
- return config_dict
- config_dict['provisioning_skipped'] = False
+ for failover_overrides in to_provision.cloud.yield_cloud_specific_failover_overrides(
+ region=to_provision.region):
+ try:
+ config_dict = backend_utils.write_cluster_config(
+ to_provision,
+ num_nodes,
+ _get_cluster_config_template(to_provision.cloud),
+ cluster_name,
+ self._local_wheel_path,
+ self._wheel_hash,
+ region=region,
+ zones=zones,
+ dryrun=dryrun,
+ keep_launch_fields_in_existing_config=cluster_exists,
+ volume_mounts=volume_mounts,
+ cloud_specific_failover_overrides=failover_overrides,
+ )
+ except exceptions.ResourcesUnavailableError as e:
+ # Failed due to catalog issue, e.g. image not found, or
+ # GPUs are requested in a Kubernetes cluster but the cluster
+ # does not have nodes labeled with GPU types.
+ logger.info(f'{e}')
+ continue
+ except exceptions.InvalidCloudCredentials as e:
+ # Failed due to invalid cloud credentials.
+ logger.warning(f'{common_utils.format_exception(e)}')
+ # We should block the entire cloud for invalid cloud credentials
+ _add_to_blocked_resources(
+ self._blocked_resources,
+ to_provision.copy(region=None, zone=None))
+ raise exceptions.ResourcesUnavailableError(
+ f'Failed to provision on cloud {to_provision.cloud} due to '
+ f'invalid cloud credentials: '
+ f'{common_utils.format_exception(e)}')
+ except exceptions.InvalidCloudConfigs as e:
+ # Failed due to invalid user configs in ~/.sky/config.yaml.
+ logger.warning(f'{common_utils.format_exception(e)}')
+ # We should block the entire cloud if the user config is
+ # invalid.
+ _add_to_blocked_resources(
+ self._blocked_resources,
+ to_provision.copy(region=None, zone=None))
+ raise exceptions.ResourcesUnavailableError(
+ f'Failed to provision on cloud {to_provision.cloud} due to '
+ f'invalid cloud config: {common_utils.format_exception(e)}'
+ )
- if dryrun:
- return config_dict
+ if ('config_hash' in config_dict and skip_if_config_hash_matches
+ == config_dict['config_hash']):
+ logger.debug(
+ 'Skipping provisioning of cluster with matching '
+ 'config hash.')
+ config_dict['provisioning_skipped'] = True
+ return config_dict
+ config_dict['provisioning_skipped'] = False
- cluster_config_file = config_dict['ray']
+ if dryrun:
+ return config_dict
- launched_resources = to_provision.copy(region=region.name)
- if zones and len(zones) == 1:
- launched_resources = launched_resources.copy(zone=zones[0].name)
-
- prev_cluster_ips, prev_ssh_ports, prev_cluster_info = (None, None,
- None)
- if prev_handle is not None:
- prev_cluster_ips = prev_handle.stable_internal_external_ips
- prev_ssh_ports = prev_handle.stable_ssh_ports
- prev_cluster_info = prev_handle.cached_cluster_info
- # Record early, so if anything goes wrong, 'sky status' will show
- # the cluster name and users can appropriately 'sky down'. It also
- # means a second 'sky launch -c ' will attempt to reuse.
- handle = CloudVmRayResourceHandle(
- cluster_name=cluster_name,
- # Backward compatibility will be guaranteed by the underlying
- # backend_utils.write_cluster_config, which gets the cluster
- # name on cloud from the ray yaml file, if the previous cluster
- # exists.
- cluster_name_on_cloud=config_dict['cluster_name_on_cloud'],
- cluster_yaml=cluster_config_file,
- launched_nodes=num_nodes,
- # OK for this to be shown in CLI as status == INIT.
- launched_resources=launched_resources,
- # Use the previous cluster's IPs and ports if available to
- # optimize the case where the cluster is restarted, i.e., no
- # need to query IPs and ports from the cloud provider.
- stable_internal_external_ips=prev_cluster_ips,
- stable_ssh_ports=prev_ssh_ports,
- cluster_info=prev_cluster_info,
- )
- usage_lib.messages.usage.update_final_cluster_status(
- status_lib.ClusterStatus.INIT)
+ cluster_config_file = config_dict['ray']
+
+ launched_resources = to_provision.copy(region=region.name)
+ if zones and len(zones) == 1:
+ launched_resources = launched_resources.copy(
+ zone=zones[0].name)
+
+ prev_cluster_ips, prev_ssh_ports, prev_cluster_info = (None,
+ None,
+ None)
+ if prev_handle is not None:
+ prev_cluster_ips = prev_handle.stable_internal_external_ips
+ prev_ssh_ports = prev_handle.stable_ssh_ports
+ prev_cluster_info = prev_handle.cached_cluster_info
+ # Record early, so if anything goes wrong, 'sky status' will show
+ # the cluster name and users can appropriately 'sky down'. It also
+ # means a second 'sky launch -c ' will attempt to reuse.
+ handle = CloudVmRayResourceHandle(
+ cluster_name=cluster_name,
+ # Backward compatibility will be guaranteed by the underlying
+ # backend_utils.write_cluster_config, which gets the cluster
+ # name on cloud from the ray yaml file, if the previous cluster
+ # exists.
+ cluster_name_on_cloud=config_dict['cluster_name_on_cloud'],
+ cluster_yaml=cluster_config_file,
+ launched_nodes=num_nodes,
+ # OK for this to be shown in CLI as status == INIT.
+ launched_resources=launched_resources,
+ # Use the previous cluster's IPs and ports if available to
+ # optimize the case where the cluster is restarted, i.e., no
+ # need to query IPs and ports from the cloud provider.
+ stable_internal_external_ips=prev_cluster_ips,
+ stable_ssh_ports=prev_ssh_ports,
+ cluster_info=prev_cluster_info,
+ )
+ usage_lib.messages.usage.update_final_cluster_status(
+ status_lib.ClusterStatus.INIT)
- # This sets the status to INIT (even for a normal, UP cluster).
- global_user_state.add_or_update_cluster(
- cluster_name,
- cluster_handle=handle,
- requested_resources=requested_resources,
- ready=False,
- is_managed=self._is_managed,
- provision_log_path=log_abs_path,
- )
+ # This sets the status to INIT (even for a normal, UP cluster).
+ global_user_state.add_or_update_cluster(
+ cluster_name,
+ cluster_handle=handle,
+ requested_resources=requested_resources,
+ ready=False,
+ is_managed=self._is_managed,
+ provision_log_path=log_abs_path,
+ )
- # Add cluster event for actual provisioning start.
- global_user_state.add_cluster_event(
- cluster_name, status_lib.ClusterStatus.INIT,
- f'Provisioning on {to_provision.cloud.display_name()} ' +
- f'in {to_provision.region}',
- global_user_state.ClusterEventType.STATUS_CHANGE)
+ # Add cluster event for actual provisioning start.
+ global_user_state.add_cluster_event(
+ cluster_name, status_lib.ClusterStatus.INIT,
+ f'Provisioning on {to_provision.cloud.display_name()} ' +
+ f'in {to_provision.region}',
+ global_user_state.ClusterEventType.STATUS_CHANGE)
- global_user_state.set_owner_identity_for_cluster(
- cluster_name, cloud_user_identity)
-
- if (to_provision.cloud.PROVISIONER_VERSION ==
- clouds.ProvisionerVersion.SKYPILOT):
- # TODO (suquark): Gradually move the other clouds to
- # the new provisioner once they are ready.
- assert to_provision.region == region.name, (to_provision,
- region)
- num_nodes = handle.launched_nodes
- # Some clouds, like RunPod, only support exposing ports during
- # launch. For those clouds, we pass the ports to open in the
- # `bulk_provision` to expose the ports during provisioning.
- # If the `bulk_provision` is to apply on an existing cluster,
- # it should be ignored by the underlying provisioner impl
- # as it will only apply to newly-created instances.
- ports_to_open_on_launch = (
- list(resources_utils.port_ranges_to_set(to_provision.ports))
- if to_provision.cloud.OPEN_PORTS_VERSION <=
- clouds.OpenPortsVersion.LAUNCH_ONLY else None)
- try:
- controller = controller_utils.Controllers.from_name(
- cluster_name)
- controller_str = ('' if controller is None else
- f' {controller.value.name}')
- if isinstance(to_provision.cloud, clouds.Kubernetes):
- suffix = '.'
- if region.name.startswith('ssh-'):
- ssh_node_pool_name = common_utils.removeprefix(
- region.name, 'ssh-')
- suffix = f' ({ssh_node_pool_name})'
- logger.info(
- ux_utils.starting_message(
- f'Launching{controller_str} on '
- f'{to_provision.cloud}{suffix}'))
- else:
- logger.info(
- ux_utils.starting_message(
- f'Launching{controller_str} on '
- f'{to_provision.cloud} '
- f'{region.name}{colorama.Style.RESET_ALL}'
- f'{zone_str}.'))
- assert handle.cluster_yaml is not None
- provision_record = provisioner.bulk_provision(
- to_provision.cloud,
- region,
- zones,
- resources_utils.ClusterName(
- cluster_name, handle.cluster_name_on_cloud),
- num_nodes=num_nodes,
- cluster_yaml=handle.cluster_yaml,
- prev_cluster_ever_up=prev_cluster_ever_up,
- log_dir=self.log_dir,
- ports_to_open_on_launch=ports_to_open_on_launch)
- # NOTE: We will handle the logic of '_ensure_cluster_ray_started' #pylint: disable=line-too-long
- # in 'provision_utils.post_provision_runtime_setup()' in the
- # caller.
- resources_vars = (
- to_provision.cloud.make_deploy_resources_variables(
- to_provision,
+ global_user_state.set_owner_identity_for_cluster(
+ cluster_name, cloud_user_identity)
+
+ if (to_provision.cloud.PROVISIONER_VERSION ==
+ clouds.ProvisionerVersion.SKYPILOT):
+ # TODO (suquark): Gradually move the other clouds to
+ # the new provisioner once they are ready.
+ assert to_provision.region == region.name, (to_provision,
+ region)
+ num_nodes = handle.launched_nodes
+ # Some clouds, like RunPod, only support exposing ports during
+ # launch. For those clouds, we pass the ports to open in the
+ # `bulk_provision` to expose the ports during provisioning.
+ # If the `bulk_provision` is to apply on an existing cluster,
+ # it should be ignored by the underlying provisioner impl
+ # as it will only apply to newly-created instances.
+ ports_to_open_on_launch = (
+ list(
+ resources_utils.port_ranges_to_set(
+ to_provision.ports))
+ if to_provision.cloud.OPEN_PORTS_VERSION <=
+ clouds.OpenPortsVersion.LAUNCH_ONLY else None)
+ try:
+ controller = controller_utils.Controllers.from_name(
+ cluster_name)
+ controller_str = ('' if controller is None else
+ f' {controller.value.name}')
+ if isinstance(to_provision.cloud, clouds.Kubernetes):
+ suffix = '.'
+ if region.name.startswith('ssh-'):
+ ssh_node_pool_name = common_utils.removeprefix(
+ region.name, 'ssh-')
+ suffix = f' ({ssh_node_pool_name})'
+ logger.info(
+ ux_utils.starting_message(
+ f'Launching{controller_str} on '
+ f'{to_provision.cloud}{suffix}'))
+ else:
+ logger.info(
+ ux_utils.starting_message(
+ f'Launching{controller_str} on '
+ f'{to_provision.cloud} '
+ f'{region.name}{colorama.Style.RESET_ALL}'
+ f'{zone_str}.'))
+ assert handle.cluster_yaml is not None
+ provision_record = provisioner.bulk_provision(
+ to_provision.cloud,
+ region,
+ zones,
resources_utils.ClusterName(
cluster_name, handle.cluster_name_on_cloud),
- region, zones, num_nodes))
- config_dict['provision_record'] = provision_record
- config_dict['resources_vars'] = resources_vars
- config_dict['handle'] = handle
- return config_dict
- except provision_common.StopFailoverError:
- with ux_utils.print_exception_no_traceback():
- raise
- except exceptions.InconsistentHighAvailabilityError:
- # No teardown happens for this error.
- with ux_utils.print_exception_no_traceback():
- raise
- except config_lib.KubernetesError as e:
- if e.insufficent_resources:
- insufficient_resources = e.insufficent_resources
- # NOTE: We try to cleanup the cluster even if the previous
- # cluster does not exist. Also we are fast at
- # cleaning up clusters now if there is no existing node.
- CloudVmRayBackend().post_teardown_cleanup(
- handle,
- terminate=not prev_cluster_ever_up,
- remove_from_db=False,
- failover=True,
- )
- # TODO(suquark): other clouds may have different zone
- # blocking strategy. See '_update_blocklist_on_error'
- # for details.
- FailoverCloudErrorHandlerV2.update_blocklist_on_error(
- self._blocked_resources, to_provision, region, zones, e)
- continue
- except Exception as e: # pylint: disable=broad-except
- # NOTE: We try to cleanup the cluster even if the previous
- # cluster does not exist. Also we are fast at
- # cleaning up clusters now if there is no existing node..
- CloudVmRayBackend().post_teardown_cleanup(
- handle,
- terminate=not prev_cluster_ever_up,
- remove_from_db=False,
- failover=True)
- # TODO(suquark): other clouds may have different zone
- # blocking strategy. See '_update_blocklist_on_error'
- # for details.
- FailoverCloudErrorHandlerV2.update_blocklist_on_error(
- self._blocked_resources, to_provision, region, zones, e)
- continue
- # NOTE: The code below in the loop should not be reachable
- # with the new provisioner.
+ num_nodes=num_nodes,
+ cluster_yaml=handle.cluster_yaml,
+ prev_cluster_ever_up=prev_cluster_ever_up,
+ log_dir=self.log_dir,
+ ports_to_open_on_launch=ports_to_open_on_launch)
+ # NOTE: We will handle the logic of '_ensure_cluster_ray_started'
+ # in 'provision_utils.post_provision_runtime_setup()' in the
+ # caller.
+ resources_vars = (
+ to_provision.cloud.make_deploy_resources_variables(
+ to_provision,
+ resources_utils.ClusterName(
+ cluster_name, handle.cluster_name_on_cloud),
+ region, zones, num_nodes))
+ config_dict['provision_record'] = provision_record
+ config_dict['resources_vars'] = resources_vars
+ config_dict['handle'] = handle
+ return config_dict
+ except provision_common.StopFailoverError:
+ with ux_utils.print_exception_no_traceback():
+ raise
+ except exceptions.InconsistentHighAvailabilityError:
+ # No teardown happens for this error.
+ with ux_utils.print_exception_no_traceback():
+ raise
+ except config_lib.KubernetesError as e:
+ if e.insufficent_resources:
+ insufficient_resources = e.insufficent_resources
+ # NOTE: We try to cleanup the cluster even if the previous
+ # cluster does not exist. Also we are fast at
+ # cleaning up clusters now if there is no existing node.
+ CloudVmRayBackend().post_teardown_cleanup(
+ handle,
+ terminate=not prev_cluster_ever_up,
+ remove_from_db=False,
+ failover=True,
+ )
+ # TODO(suquark): other clouds may have different zone
+ # blocking strategy. See '_update_blocklist_on_error'
+ # for details.
+ FailoverCloudErrorHandlerV2.update_blocklist_on_error(
+ self._blocked_resources, to_provision, region,
+ zones, e)
+ continue
+ except Exception as e: # pylint: disable=broad-except
+ # NOTE: We try to cleanup the cluster even if the previous
+ # cluster does not exist. Also we are fast at
+ # cleaning up clusters now if there is no existing node..
+ CloudVmRayBackend().post_teardown_cleanup(
+ handle,
+ terminate=not prev_cluster_ever_up,
+ remove_from_db=False,
+ failover=True)
+ # TODO(suquark): other clouds may have different zone
+ # blocking strategy. See '_update_blocklist_on_error'
+ # for details.
+ FailoverCloudErrorHandlerV2.update_blocklist_on_error(
+ self._blocked_resources, to_provision, region,
+ zones, e)
+ continue
+ # NOTE: The code below in the loop should not be reachable
+ # with the new provisioner.
- logging_info = {
- 'cluster_name': cluster_name,
- 'region_name': region.name,
- 'zone_str': zone_str,
- }
+ logging_info = {
+ 'cluster_name': cluster_name,
+ 'region_name': region.name,
+ 'zone_str': zone_str,
+ }
- status, stdout, stderr, head_internal_ip, head_external_ip = (
- self._gang_schedule_ray_up(to_provision.cloud,
- cluster_config_file, handle,
- log_abs_path, stream_logs,
- logging_info, to_provision.use_spot))
+ status, stdout, stderr, head_internal_ip, head_external_ip = (
+ self._gang_schedule_ray_up(to_provision.cloud,
+ cluster_config_file, handle,
+ log_abs_path, stream_logs,
+ logging_info,
+ to_provision.use_spot))
+
+ if status == GangSchedulingStatus.CLUSTER_READY:
+ # We must query the IPs from the cloud provider, when the
+ # provisioning is done, to make sure the cluster IPs are
+ # up-to-date.
+ # The staled IPs may be caused by the node being restarted
+ # manually or by the cloud provider.
+ # Optimize the case where the cluster's head IPs can be parsed
+ # from the output of 'ray up'.
+ if handle.launched_nodes == 1:
+ handle.update_cluster_ips(
+ max_attempts=_FETCH_IP_MAX_ATTEMPTS,
+ internal_ips=[head_internal_ip],
+ external_ips=[head_external_ip])
+ else:
+ handle.update_cluster_ips(
+ max_attempts=_FETCH_IP_MAX_ATTEMPTS)
+ handle.update_ssh_ports(max_attempts=_FETCH_IP_MAX_ATTEMPTS)
+ if cluster_exists:
+ # Guard against the case where there's an existing cluster
+ # with ray runtime messed up (e.g., manually killed) by (1)
+ # querying ray status (2) restarting ray if needed.
+ #
+ # The above 'ray up' will not restart it automatically due
+ # to 'ray up # --no-restart' flag.
+ #
+ # NOTE: this is performance sensitive and has been observed
+ # to take 9s. Only do this for existing clusters, not
+ # freshly launched ones (which should have ray runtime
+ # started).
+ self._ensure_cluster_ray_started(handle, log_abs_path)
- if status == GangSchedulingStatus.CLUSTER_READY:
- # We must query the IPs from the cloud provider, when the
- # provisioning is done, to make sure the cluster IPs are
- # up-to-date.
- # The staled IPs may be caused by the node being restarted
- # manually or by the cloud provider.
- # Optimize the case where the cluster's head IPs can be parsed
- # from the output of 'ray up'.
- if handle.launched_nodes == 1:
- handle.update_cluster_ips(
- max_attempts=_FETCH_IP_MAX_ATTEMPTS,
- internal_ips=[head_internal_ip],
- external_ips=[head_external_ip])
+ config_dict['handle'] = handle
+ logger.info(
+ ux_utils.finishing_message(
+ f'Cluster launched: {cluster_name!r}.',
+ log_path,
+ cluster_name=cluster_name))
+ return config_dict
+
+ # The cluster is not ready. We must perform error recording and/or
+ # cleanup.
+
+ # If cluster was ever up, stop it; otherwise terminate.
+ terminate_or_stop = not prev_cluster_ever_up
+ definitely_no_nodes_launched = False
+ if status == GangSchedulingStatus.HEAD_FAILED:
+ # ray up failed for the head node.
+ definitely_no_nodes_launched = (
+ FailoverCloudErrorHandlerV1.update_blocklist_on_error(
+ self._blocked_resources, to_provision, region,
+ zones, stdout, stderr))
else:
- handle.update_cluster_ips(
- max_attempts=_FETCH_IP_MAX_ATTEMPTS)
- handle.update_ssh_ports(max_attempts=_FETCH_IP_MAX_ATTEMPTS)
- if cluster_exists:
- # Guard against the case where there's an existing cluster
- # with ray runtime messed up (e.g., manually killed) by (1)
- # querying ray status (2) restarting ray if needed.
- #
- # The above 'ray up' will not restart it automatically due
- # to 'ray up # --no-restart' flag.
- #
- # NOTE: this is performance sensitive and has been observed
- # to take 9s. Only do this for existing clusters, not
- # freshly launched ones (which should have ray runtime
- # started).
- self._ensure_cluster_ray_started(handle, log_abs_path)
-
- config_dict['handle'] = handle
- logger.info(
- ux_utils.finishing_message(
- f'Cluster launched: {cluster_name!r}.',
- log_path,
- cluster_name=cluster_name))
- return config_dict
-
- # The cluster is not ready. We must perform error recording and/or
- # cleanup.
-
- # If cluster was ever up, stop it; otherwise terminate.
- terminate_or_stop = not prev_cluster_ever_up
- definitely_no_nodes_launched = False
- if status == GangSchedulingStatus.HEAD_FAILED:
- # ray up failed for the head node.
- definitely_no_nodes_launched = (
- FailoverCloudErrorHandlerV1.update_blocklist_on_error(
- self._blocked_resources, to_provision, region, zones,
- stdout, stderr))
- else:
- # gang scheduling failed.
- assert status == GangSchedulingStatus.GANG_FAILED, status
- # The stdout/stderr of ray up is not useful here, since
- # head node is successfully provisioned.
- definitely_no_nodes_launched = (
- FailoverCloudErrorHandlerV1.update_blocklist_on_error(
- self._blocked_resources,
- to_provision,
- region,
- zones=zones,
- stdout=None,
- stderr=None))
- # GANG_FAILED means head is up, workers failed.
- assert definitely_no_nodes_launched is False, (
- definitely_no_nodes_launched)
-
- # Only log the errors for GANG_FAILED, since HEAD_FAILED may
- # not have created any resources (it can happen however) and
- # HEAD_FAILED can happen in "normal" failover cases.
- logger.error('*** Failed provisioning the cluster. ***')
- terminate_str = ('Terminating'
- if terminate_or_stop else 'Stopping')
- logger.error(f'*** {terminate_str} the failed cluster. ***')
-
- # If these conditions hold, it *should* be safe to skip the cleanup
- # action. This is a UX optimization.
- #
- # We want to skip mainly for VPC/subnets errors thrown during node
- # provider bootstrapping: if users encountered "No VPC with name
- # 'xxx' is found in .", then going ahead to down the
- # non-existent cluster will itself print out a (caught, harmless)
- # error with the same message. This was found to be
- # confusing. Thus we skip termination.
- skip_cleanup = not cluster_exists and definitely_no_nodes_launched
- if skip_cleanup:
- continue
+ # gang scheduling failed.
+ assert status == GangSchedulingStatus.GANG_FAILED, status
+ # The stdout/stderr of ray up is not useful here, since
+ # head node is successfully provisioned.
+ definitely_no_nodes_launched = (
+ FailoverCloudErrorHandlerV1.update_blocklist_on_error(
+ self._blocked_resources,
+ to_provision,
+ region,
+ zones=zones,
+ stdout=None,
+ stderr=None))
+ # GANG_FAILED means head is up, workers failed.
+ assert definitely_no_nodes_launched is False, (
+ definitely_no_nodes_launched)
+
+ # Only log the errors for GANG_FAILED, since HEAD_FAILED may
+ # not have created any resources (it can happen however) and
+ # HEAD_FAILED can happen in "normal" failover cases.
+ logger.error('*** Failed provisioning the cluster. ***')
+ terminate_str = ('Terminating'
+ if terminate_or_stop else 'Stopping')
+ logger.error(f'*** {terminate_str} the failed cluster. ***')
+
+ # If these conditions hold, it *should* be safe to skip the cleanup
+ # action. This is a UX optimization.
+ #
+ # We want to skip mainly for VPC/subnets errors thrown during node
+ # provider bootstrapping: if users encountered "No VPC with name
+ # 'xxx' is found in .", then going ahead to down the
+ # non-existent cluster will itself print out a (caught, harmless)
+ # error with the same message. This was found to be
+ # confusing. Thus we skip termination.
+ skip_cleanup = not cluster_exists and definitely_no_nodes_launched
+ if skip_cleanup:
+ continue
- # There may exist partial nodes (e.g., head node) so we must
- # terminate or stop before moving on to other regions.
- #
- # NOTE: even HEAD_FAILED could've left a live head node there,
- # so we must terminate/stop here too. E.g., node is up, and ray
- # autoscaler proceeds to setup commands, which may fail:
- # ERR updater.py:138 -- New status: update-failed
- CloudVmRayBackend().teardown_no_lock(handle,
- terminate=terminate_or_stop,
- remove_from_db=False)
+ # There may exist partial nodes (e.g., head node) so we must
+ # terminate or stop before moving on to other regions.
+ #
+ # NOTE: even HEAD_FAILED could've left a live head node there,
+ # so we must terminate/stop here too. E.g., node is up, and ray
+ # autoscaler proceeds to setup commands, which may fail:
+ # ERR updater.py:138 -- New status: update-failed
+ CloudVmRayBackend().teardown_no_lock(
+ handle, terminate=terminate_or_stop, remove_from_db=False)
message = self._insufficient_resources_msg(to_provision,
requested_resources,
@@ -2673,6 +2690,13 @@ def add_job(
) -> 'jobsv1_pb2.AddJobResponse':
return self._jobs_stub.AddJob(request, timeout=timeout)
+ def set_job_info_without_job_id(
+ self,
+ request: 'jobsv1_pb2.SetJobInfoWithoutJobIdRequest',
+ timeout: Optional[float] = constants.SKYLET_GRPC_TIMEOUT_SECONDS
+ ) -> 'jobsv1_pb2.SetJobInfoWithoutJobIdResponse':
+ return self._jobs_stub.SetJobInfoWithoutJobId(request, timeout=timeout)
+
def queue_job(
self,
request: 'jobsv1_pb2.QueueJobRequest',
@@ -3072,6 +3096,18 @@ def _maybe_clear_external_cluster_failures(
f'{cluster_name!r}: {", ".join(failure_details)}'
f'{colorama.Style.RESET_ALL}')
+ def check_skylet_running(self, handle: CloudVmRayResourceHandle):
+ # For backward compatibility and robustness of skylet, it is checked
+ # and restarted if necessary.
+ logger.debug('Checking if skylet is running on the head node.')
+ with rich_utils.safe_status(
+ ux_utils.spinner_message('Preparing SkyPilot runtime')):
+ # We need to source bashrc for skylet to make sure the autostop
+ # event can access the path to the cloud CLIs.
+ self.run_on_head(handle,
+ instance_setup.MAYBE_SKYLET_RESTART_CMD,
+ source_bashrc=True)
+
def _locked_provision(
self,
lock_id: str,
@@ -3326,14 +3362,7 @@ def _get_zone(runner):
# For backward compatibility and robustness of skylet, it is checked
# and restarted if necessary.
- logger.debug('Checking if skylet is running on the head node.')
- with rich_utils.safe_status(
- ux_utils.spinner_message('Preparing SkyPilot runtime')):
- # We need to source bashrc for skylet to make sure the autostop
- # event can access the path to the cloud CLIs.
- self.run_on_head(handle,
- instance_setup.MAYBE_SKYLET_RESTART_CMD,
- source_bashrc=True)
+ self.check_skylet_running(handle)
self._update_after_cluster_provisioned(
handle, to_provision_config.prev_handle, task,
@@ -3800,10 +3829,10 @@ def _dump_code_to_file(codegen: str,
# We choose to sync code + exec, because the alternative of
# 'ray submit' may not work as it may use system python
# (python2) to execute the script. Happens for AWS.
- head_runner.rsync(source=fp.name,
- target=script_path,
- up=True,
- stream_logs=False)
+ head_runner.rsync_driver(source=fp.name,
+ target=script_path,
+ up=True,
+ stream_logs=False)
mkdir_code = f'mkdir -p {remote_log_dir} && touch {remote_log_path}'
encoded_script = shlex.quote(codegen)
@@ -3852,20 +3881,33 @@ def _dump_code_to_file(codegen: str,
for task_id, task in enumerate(managed_job_dag.tasks):
resources_str = backend_utils.get_task_resources_str(
task, is_managed_job=True)
- managed_job_tasks.append(
- jobsv1_pb2.ManagedJobTask(
- task_id=task_id,
- name=task.name,
- resources_str=resources_str,
- metadata_json=task.metadata_json))
-
+ managed_job_task = jobsv1_pb2.ManagedJobTask(
+ task_id=task_id,
+ name=task.name,
+ resources_str=resources_str,
+ metadata_json=task.metadata_json)
+ # Only set is_primary_in_job_group for job groups
+ if managed_job_dag.is_job_group():
+ # If primary_task_names is None, all tasks are
+ # primary
+ managed_job_task.is_primary_in_job_group = (
+ managed_job_dag.primary_tasks is None or
+ task.name in managed_job_dag.primary_tasks)
+ managed_job_tasks.append(managed_job_task)
+
+ # Execution mode: 'parallel' for job groups, 'serial' for
+ # pipelines and single jobs
+ execution = (managed_job_dag.execution.value
+ if managed_job_dag.execution else
+ DEFAULT_EXECUTION.value)
managed_job_info = jobsv1_pb2.ManagedJobInfo(
name=managed_job_dag.name,
pool=managed_job_dag.pool,
workspace=workspace,
entrypoint=entrypoint,
tasks=managed_job_tasks,
- user_id=managed_job_user_id)
+ user_id=managed_job_user_id,
+ execution=execution)
if backend_utils.is_command_length_over_limit(codegen):
_dump_code_to_file(codegen)
@@ -3893,29 +3935,6 @@ def _dump_code_to_file(codegen: str,
_dump_code_to_file(codegen)
job_submit_cmd = f'{mkdir_code} && {code}'
- def _maybe_add_managed_job_code(job_submit_cmd: str) -> str:
- if managed_job_dag is not None:
- # Add the managed job to job queue database.
- managed_job_codegen = managed_jobs.ManagedJobCodeGen()
- managed_job_code = managed_job_codegen.set_pending(
- job_id,
- managed_job_dag,
- skypilot_config.get_active_workspace(
- force_user_workspace=True),
- entrypoint=common_utils.get_current_command(),
- user_hash=managed_job_user_id)
- # Set the managed job to PENDING state to make sure that
- # this managed job appears in the `sky jobs queue`, even
- # if it needs to wait to be submitted.
- # We cannot set the managed job to PENDING state in the
- # job template (jobs-controller.yaml.j2), as it may need
- # to wait for the run commands to be scheduled on the job
- # controller in high-load cases.
- job_submit_cmd += ' && ' + managed_job_code
- return job_submit_cmd
-
- job_submit_cmd = _maybe_add_managed_job_code(job_submit_cmd)
-
# For Slurm, run in background so that SSH returns immediately.
# This is needed because we add the wait_for_job code above which
# makes the command block until the job completes.
@@ -3940,7 +3959,6 @@ def _maybe_add_managed_job_code(job_submit_cmd: str) -> str:
f'Output: {output}')
_dump_code_to_file(codegen)
job_submit_cmd = f'{mkdir_code} && {code}'
- job_submit_cmd = _maybe_add_managed_job_code(job_submit_cmd)
# See comment above for why run_in_background=is_slurm.
returncode, stdout, stderr = self.run_on_head(
handle,
@@ -4026,6 +4044,97 @@ def _add_job(self, handle: CloudVmRayResourceHandle,
f'Returncode: {returncode}') from e
return job_id, log_dir
+ def set_job_info_without_job_id(
+ self,
+ handle: CloudVmRayResourceHandle,
+ name: str,
+ workspace: str,
+ entrypoint: str,
+ pool: Optional[str],
+ pool_hash: Optional[str],
+ user_hash: Optional[str],
+ task_ids: List[int],
+ task_names: List[str],
+ resources_str: str,
+ metadata_jsons: List[str],
+ is_primary_in_job_groups: List[bool],
+ num_jobs: int = 1,
+ execution: str = DEFAULT_EXECUTION.value) -> List[int]:
+ """Set job info without creating entries in the jobs table.
+
+ This creates entries in job_info_table and spot_table without creating
+ entries in the jobs table, which prevents autostop from being blocked
+ by jobs stuck in INIT status.
+ """
+ use_legacy = not handle.is_grpc_enabled_with_flag
+
+ if not use_legacy:
+ try:
+ request = jobsv1_pb2.SetJobInfoWithoutJobIdRequest(
+ name=name,
+ workspace=workspace,
+ entrypoint=entrypoint,
+ pool=pool,
+ pool_hash=pool_hash,
+ user_hash=user_hash,
+ task_ids=task_ids,
+ task_names=task_names,
+ resources_str=resources_str,
+ metadata_jsons=metadata_jsons,
+ num_jobs=num_jobs,
+ execution=execution,
+ is_primary_in_job_groups=is_primary_in_job_groups)
+ response = backend_utils.invoke_skylet_with_retries(
+ lambda: SkyletClient(handle.get_grpc_channel()
+ ).set_job_info_without_job_id(request))
+ return list(response.job_ids)
+ except exceptions.SkyletMethodNotImplementedError:
+ use_legacy = True
+
+ if use_legacy:
+ code = job_lib.JobLibCodeGen.set_job_info_without_job_id(
+ name=name,
+ workspace=workspace,
+ entrypoint=entrypoint,
+ pool=pool,
+ pool_hash=pool_hash,
+ user_hash=user_hash,
+ task_ids=task_ids,
+ task_names=task_names,
+ resources_str=resources_str,
+ metadata_jsons=metadata_jsons,
+ is_primary_in_job_groups=is_primary_in_job_groups,
+ num_jobs=num_jobs,
+ execution=execution)
+ returncode, result_str, stderr = self.run_on_head(
+ handle,
+ code,
+ stream_logs=False,
+ require_outputs=True,
+ separate_stderr=True)
+ backend_utils.check_stale_runtime_on_remote(returncode, stderr,
+ handle.cluster_name)
+ subprocess_utils.handle_returncode(returncode, code,
+ 'Failed to fetch job id.',
+ stderr)
+ try:
+ # Parse job IDs from output
+ job_ids_match = _JOB_IDS_PATTERN.search(result_str)
+ if job_ids_match:
+ job_ids = [
+ int(x.strip())
+ for x in job_ids_match.group(1).split(',')
+ ]
+ return job_ids
+ else:
+ raise ValueError(
+ f'Failed to parse job ids from: {result_str}')
+ except ValueError as e:
+ logger.error(stderr)
+ raise ValueError(f'Failed to parse job id: {result_str}; '
+ f'Returncode: {returncode}') from e
+ return []
+
def _execute(
self,
handle: CloudVmRayResourceHandle,
@@ -4356,7 +4465,7 @@ def _rsync_down(args) -> None:
(runner, local_log_dir, remote_log_dir) = args
try:
os.makedirs(os.path.expanduser(local_log_dir), exist_ok=True)
- runner.rsync(
+ runner.rsync_driver(
# Require a `/` at the end to make sure the parent dir
# are not created locally. We do not add additional '*' as
# kubernetes's rsync does not work with an ending '*'.
@@ -4457,18 +4566,67 @@ def tail_logs(
final = e.code
return final
+ def tail_autostop_logs(self,
+ handle: CloudVmRayResourceHandle,
+ follow: bool = True,
+ tail: int = 0) -> int:
+ """Tail the autostop hook logs.
+
+ Args:
+ handle: The handle to the cluster.
+ follow: Whether to follow the logs.
+ tail: The number of lines to display from the end of the
+ log file. If 0, print all lines.
+
+ Returns:
+ The exit code of the tail command.
+ """
+ # Construct tail command for the autostop hook log
+ log_path = f'~/{constants.AUTOSTOP_HOOK_LOG_FILE}'
+ tail_cmd_parts = ['tail']
+ if tail > 0:
+ tail_cmd_parts.extend(['-n', str(tail)])
+ if follow:
+ tail_cmd_parts.append('-f')
+ tail_cmd_parts.append(log_path)
+
+ # Add fallback to show helpful message if file doesn't exist
+ tail_cmd = ' '.join(tail_cmd_parts)
+ error_msg = (f'Autostop hook log file not found at {log_path}. '
+ f'The autostop hook may not have been executed yet.')
+ cmd = (f'if [ -f {log_path} ]; then {tail_cmd}; '
+ f'else echo "{error_msg}"; exit 1; fi')
+
+ # With the stdin=subprocess.DEVNULL, the ctrl-c will not directly
+ # kill the process, so we need to handle it manually here.
+ if threading.current_thread() is threading.main_thread():
+ signal.signal(signal.SIGINT, backend_utils.interrupt_handler)
+ signal.signal(signal.SIGTSTP, backend_utils.stop_handler)
+ try:
+ returncode = self.run_on_head(
+ handle,
+ cmd,
+ stream_logs=True,
+ # Allocate a pseudo-terminal to disable output buffering.
+ ssh_mode=command_runner.SshMode.INTERACTIVE,
+ )
+ except SystemExit as e:
+ returncode = e.code
+ return returncode
+
def tail_managed_job_logs(self,
handle: CloudVmRayResourceHandle,
job_id: Optional[int] = None,
job_name: Optional[str] = None,
controller: bool = False,
follow: bool = True,
- tail: Optional[int] = None) -> int:
+ tail: Optional[int] = None,
+ task: Optional[Union[str, int]] = None) -> int:
# if job_name is not None, job_id should be None
assert job_name is None or job_id is None, (job_name, job_id)
# TODO(kevin): Migrate stream_logs to gRPC
code = managed_jobs.ManagedJobCodeGen.stream_logs(
- job_name, job_id, follow, controller, tail)
+ job_name, job_id, follow, controller, tail, task)
# With the stdin=subprocess.DEVNULL, the ctrl-c will not directly
# kill the process, so we need to handle it manually here.
@@ -4925,7 +5083,7 @@ def teardown_no_lock(self,
# configurations (such as VPC not found). So it's safe & good UX
# to not print a failure message.
elif ('TPU must be specified.' not in stderr and
- 'SKYPILOT_ERROR_NO_NODES_LAUNCHED: ' not in stderr):
+ provision_constants.ERROR_NO_NODES_LAUNCHED not in stderr):
raise RuntimeError(
_TEARDOWN_FAILURE_MESSAGE.format(
extra_reason='',
@@ -5039,6 +5197,21 @@ def post_teardown_cleanup(self,
else:
raise
+ # Clean up all cluster resources (e.g., Kubernetes services).
+ # This is a no-op for most clouds, but Kubernetes needs it to
+ # clean up orphaned services when pods are deleted externally.
+ try:
+ provision_lib.cleanup_cluster_resources(
+ repr(cloud), cluster_name_on_cloud, config['provider'])
+ except Exception as e: # pylint: disable=broad-except
+ if purge:
+ msg = common_utils.format_exception(e, use_bracket=True)
+ logger.warning(
+ f'Failed to cleanup cluster resources. Skipping '
+ f'since purge is set. Details: {msg}')
+ else:
+ raise
+
if ports_cleaned_up and custom_multi_network_cleaned_up:
try:
self.remove_cluster_config(handle)
@@ -5144,7 +5317,9 @@ def set_autostop(self,
idle_minutes_to_autostop: Optional[int],
wait_for: Optional[autostop_lib.AutostopWaitFor],
down: bool = False,
- stream_logs: bool = True) -> None:
+ stream_logs: bool = True,
+ hook: Optional[str] = None,
+ hook_timeout: Optional[int] = None) -> None:
# The core.autostop() function should have already checked that the
# cloud and resources support requested autostop.
if idle_minutes_to_autostop is not None:
@@ -5197,11 +5372,17 @@ def set_autostop(self,
autostopv1_pb2.AUTOSTOP_WAIT_FOR_UNSPECIFIED,
down=down,
)
+ if hook:
+ request.hook = hook
+ if hook_timeout is not None:
+ request.hook_timeout = hook_timeout
+
backend_utils.invoke_skylet_with_retries(lambda: SkyletClient(
handle.get_grpc_channel()).set_autostop(request))
else:
code = autostop_lib.AutostopCodeGen.set_autostop(
- idle_minutes_to_autostop, self.NAME, wait_for, down)
+ idle_minutes_to_autostop, self.NAME, wait_for, down, hook,
+ hook_timeout)
returncode, _, stderr = self.run_on_head(
handle, code, require_outputs=True, stream_logs=stream_logs)
subprocess_utils.handle_returncode(returncode,
@@ -5233,13 +5414,16 @@ def is_definitely_autostopping(self,
# The head node of the cluster is not UP or in an abnormal state.
# We cannot check if the cluster is autostopping.
return False
+
+ is_autostopping = False
+
if handle.is_grpc_enabled_with_flag:
try:
request = autostopv1_pb2.IsAutostoppingRequest()
response = backend_utils.invoke_skylet_with_retries(
lambda: SkyletClient(handle.get_grpc_channel()
).is_autostopping(request))
- return response.is_autostopping
+ is_autostopping = response.is_autostopping
except Exception as e: # pylint: disable=broad-except
# The cluster may have been terminated, causing the gRPC call
# to timeout and fail.
@@ -5250,11 +5434,14 @@ def is_definitely_autostopping(self,
returncode, stdout, stderr = self.run_on_head(
handle, code, require_outputs=True, stream_logs=stream_logs)
if returncode == 0:
- return message_utils.decode_payload(stdout)
- logger.debug('Failed to check if cluster is autostopping with '
- f'{returncode}: {stdout+stderr}\n'
- f'Command: {code}')
- return False
+ is_autostopping = message_utils.decode_payload(stdout)
+ else:
+ logger.debug('Failed to check if cluster is autostopping with '
+ f'{returncode}: {stdout+stderr}\n'
+ f'Command: {code}')
+ return False
+
+ return is_autostopping
# TODO(zhwu): Refactor this to a CommandRunner class, so different backends
# can support its own command runner.
@@ -5320,7 +5507,7 @@ def run_on_head(
if under_remote_workdir:
cmd = f'cd {SKY_REMOTE_WORKDIR} && {cmd}'
- return head_runner.run(
+ return head_runner.run_driver(
cmd,
port_forward=port_forward,
log_path=log_path,
@@ -5966,7 +6153,8 @@ def _skypilot_predefined_env_vars(
'cloud': str(handle.launched_resources.cloud),
'region': handle.launched_resources.region,
'zone': handle.launched_resources.zone,
- })
+ }),
+ constants.USER_ENV_VAR: common_utils.get_current_user_name(),
}
def _get_task_env_vars(self, task: task_lib.Task, job_id: int,
@@ -6003,7 +6191,16 @@ def _get_task_codegen_class(
slurm_job_id = head_instance.tags.get('job_id')
assert (slurm_job_id
is not None), ('job_id tag not found in head instance')
- return task_codegen.SlurmCodeGen(slurm_job_id=slurm_job_id)
+ container_image = handle.launched_resources.extract_docker_image()
+ container_name = None
+ if container_image is not None:
+ container_name = slurm_utils.pyxis_container_name(
+ handle.cluster_name_on_cloud)
+
+ return task_codegen.SlurmCodeGen(
+ slurm_job_id,
+ container_name,
+ )
else:
return task_codegen.RayCodeGen()
diff --git a/sky/backends/docker_utils.py b/sky/backends/docker_utils.py
index 1da4fdbf873..4198568aa48 100644
--- a/sky/backends/docker_utils.py
+++ b/sky/backends/docker_utils.py
@@ -201,7 +201,7 @@ def push_dockerimage(local_tag, remote_name):
def make_bash_from_multiline(codegen: str) -> str:
"""Makes a bash script from a multi-line string of commands.
- Automatically includes conda setup prefixes.
+ Automatically includes conda setup prefixes if conda is installed.
Args:
codegen: str: multiline commands to be converted to a shell script
diff --git a/sky/backends/task_codegen.py b/sky/backends/task_codegen.py
index e188dc9bc17..01b2d1b2ab5 100644
--- a/sky/backends/task_codegen.py
+++ b/sky/backends/task_codegen.py
@@ -5,6 +5,7 @@
import json
import math
import os
+import shlex
import textwrap
from typing import Dict, List, Optional, Tuple
@@ -130,7 +131,8 @@ def _add_constants(self) -> None:
CANCELLED_RETURN_CODE = 137
"""))
- def _get_rclone_flush_script(self) -> str:
+ @staticmethod
+ def get_rclone_flush_script() -> str:
"""Generate rclone flush script for cached storage mounts.
This script blocks job completion until all storage mounted with
@@ -612,7 +614,7 @@ def _add_ray_task(self,
options_str = ', '.join(options)
logger.debug('Added Task with options: '
f'{options_str}')
- rclone_flush_script = self._get_rclone_flush_script()
+ rclone_flush_script = self.get_rclone_flush_script()
unset_ray_env_vars = ' && '.join(
[f'unset {var}' for var in UNSET_RAY_ENV_VARS])
self._code += [
@@ -664,14 +666,20 @@ def add_epilogue(self) -> None:
class SlurmCodeGen(TaskCodeGen):
"""Code generator for task execution on Slurm using native srun."""
- def __init__(self, slurm_job_id: str):
- """Initialize SlurmCodeGen
+ def __init__(
+ self,
+ slurm_job_id: str,
+ container_name: Optional[str],
+ ):
+ """Initialize SlurmCodeGen.
Args:
slurm_job_id: The Slurm job ID, i.e. SLURM_JOB_ID
+ container_name: pyxis container name, or None
"""
super().__init__()
self._slurm_job_id = slurm_job_id
+ self._container_name = container_name
def add_prologue(self, job_id: int) -> None:
assert not self._has_prologue, 'add_prologue() called twice?'
@@ -805,10 +813,18 @@ def add_task(
for k, v in env_vars.items())
sky_env_vars_dict_str = '\n'.join(sky_env_vars_dict_str)
- rclone_flush_script = self._get_rclone_flush_script()
+ rclone_flush_script = self.get_rclone_flush_script()
streaming_msg = self._get_job_started_msg()
has_setup_cmd = self._setup_cmd is not None
+ container_flags = ''
+ if self._container_name is not None:
+ # --container-remap-root must be passed on every srun to get
+ # correct $HOME
+ container_flags = (
+ ' --container-remap-root'
+ f' --container-name={shlex.quote(self._container_name)}:exec')
+
self._code += [
sky_env_vars_dict_str,
textwrap.dedent(f"""\
@@ -886,19 +902,36 @@ def build_task_runner_cmd(user_script, extra_flags, log_dir, env_vars_dict,
# allocation. See:
# https://support.schedmd.com/show_bug.cgi?id=14298
# https://github.com/huggingface/datatrove/issues/248
+ cmd_parts = []
+ # Only unset SKY_RUNTIME_DIR for container runs. For non-container
+ # runs, we want to inherit the node-local SKY_RUNTIME_DIR set by
+ # SlurmCommandRunner to avoid SQLite WAL issues on shared filesystems.
+ if {True if container_flags else False}:
+ cmd_parts.append('unset SKY_RUNTIME_DIR;')
+ cmd_parts.extend([
+ constants.SKY_SLURM_PYTHON_CMD,
+ '-m sky.skylet.executor.slurm',
+ runner_args,
+ ])
+ bash_cmd = shlex.quote(' '.join(cmd_parts))
srun_cmd = (
"unset $(env | awk -F= '/^SLURM_/ {{print $1}}') && "
f'srun --export=ALL --quiet --unbuffered --kill-on-bad-exit --jobid={self._slurm_job_id} '
- f'--job-name=sky-{self.job_id}{{job_suffix}} --ntasks-per-node=1 {{extra_flags}} '
- f'{{constants.SKY_SLURM_PYTHON_CMD}} -m sky.skylet.executor.slurm {{runner_args}}'
+ f'--job-name=sky-{self.job_id}{{job_suffix}} --ntasks-per-node=1{container_flags} {{extra_flags}} '
+ f'/bin/bash -c {{bash_cmd}}'
)
- return srun_cmd, script_path
+
+ def cleanup():
+ if script_path is not None:
+ os.remove(script_path)
+
+ return srun_cmd, cleanup
def run_thread_func():
# This blocks until Slurm allocates resources (--exclusive)
# --mem=0 to match RayCodeGen's behavior where we don't explicitly request memory.
run_flags = f'--nodes={num_nodes} --cpus-per-task={task_cpu_demand} --mem=0 {{gpu_arg}} --exclusive'
- srun_cmd, task_script_path = build_task_runner_cmd(
+ srun_cmd, cleanup = build_task_runner_cmd(
script, run_flags, {log_dir!r}, sky_env_vars_dict,
task_name={task_name!r},
alloc_signal=alloc_signal_file,
@@ -913,8 +946,7 @@ def run_thread_func():
print(line, end='', flush=True)
proc.wait()
- if task_script_path is not None:
- os.remove(task_script_path)
+ cleanup()
return {{'return_code': proc.returncode, 'pid': proc.pid}}
run_thread_result = {{'result': None}}
@@ -955,7 +987,7 @@ def run_thread_wrapper():
# --overlap as we have already secured allocation with the srun for the run section,
# and otherwise this srun would get blocked and deadlock.
setup_flags = f'--overlap --nodes={self._setup_num_nodes}'
- setup_srun, setup_script_path = build_task_runner_cmd(
+ setup_srun, setup_cleanup = build_task_runner_cmd(
{self._setup_cmd!r}, setup_flags, {self._setup_log_dir!r}, {self._setup_envs!r},
is_setup=True
)
@@ -969,8 +1001,7 @@ def run_thread_wrapper():
print(line, end='', flush=True)
setup_proc.wait()
- if setup_script_path is not None:
- os.remove(setup_script_path)
+ setup_cleanup()
setup_returncode = setup_proc.returncode
if setup_returncode != 0:
diff --git a/sky/catalog/__init__.py b/sky/catalog/__init__.py
index 4180bf057f4..038bd57e52d 100644
--- a/sky/catalog/__init__.py
+++ b/sky/catalog/__init__.py
@@ -335,6 +335,7 @@ def get_common_gpus() -> List[str]:
'H200',
'L4',
'L40S',
+ 'RTX5090',
'T4',
'V100',
'V100-32GB',
diff --git a/sky/catalog/common.py b/sky/catalog/common.py
index c284e72e3bc..9be2c23a124 100644
--- a/sky/catalog/common.py
+++ b/sky/catalog/common.py
@@ -3,6 +3,7 @@
import difflib
import hashlib
import os
+import tempfile
import time
import typing
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
@@ -243,9 +244,19 @@ def _update_catalog():
raise e
else:
# Download successful, save the catalog to a local file.
+ # Use atomic write (write to temp file, then rename) to
+ # avoid race conditions when multiple processes read/write
+ # the catalog file concurrently during parallel test
+ # execution.
os.makedirs(os.path.dirname(catalog_path), exist_ok=True)
- with open(catalog_path, 'w', encoding='utf-8') as f:
+ with tempfile.NamedTemporaryFile(
+ mode='w',
+ dir=os.path.dirname(catalog_path),
+ delete=False,
+ encoding='utf-8') as f:
f.write(r.text)
+ tmp_path = f.name
+ os.rename(tmp_path, catalog_path)
with open(meta_path + '.md5', 'w', encoding='utf-8') as f:
f.write(hashlib.md5(r.text.encode()).hexdigest())
logger.debug(f'Updated {cloud} catalog {filename}.')
diff --git a/sky/catalog/data_fetchers/fetch_aws.py b/sky/catalog/data_fetchers/fetch_aws.py
index 483639e717e..02da872c38b 100644
--- a/sky/catalog/data_fetchers/fetch_aws.py
+++ b/sky/catalog/data_fetchers/fetch_aws.py
@@ -13,13 +13,14 @@
import textwrap
import traceback
import typing
-from typing import List, Optional, Set, Tuple, Union
+from typing import Any, Dict, List, Optional, Set, Tuple, Union
import numpy as np
from sky import exceptions
from sky.adaptors import aws
from sky.adaptors import common as adaptors_common
+from sky.skylet import constants
from sky.utils import log_utils
from sky.utils import ux_utils
@@ -67,8 +68,21 @@
# The following columns will be included in the final catalog.
USEFUL_COLUMNS = [
- 'InstanceType', 'AcceleratorName', 'AcceleratorCount', 'vCPUs', 'MemoryGiB',
- 'GpuInfo', 'Price', 'SpotPrice', 'Region', 'AvailabilityZone', 'Arch'
+ 'InstanceType',
+ 'AcceleratorName',
+ 'AcceleratorCount',
+ 'vCPUs',
+ 'MemoryGiB',
+ 'GpuInfo',
+ 'Price',
+ 'SpotPrice',
+ 'Region',
+ 'AvailabilityZone',
+ 'Arch',
+ 'LocalDiskType',
+ 'NVMeSupported',
+ 'LocalDiskSize',
+ 'LocalDiskCount',
]
# NOTE: the hard-coded us-east-1 URL is not a typo. AWS pricing endpoint is
@@ -269,23 +283,49 @@ def get_memory_gib(row) -> float:
return row['MemoryInfo']['SizeInMiB'] / 1024
return float(row['Memory'].split(' GiB')[0])
+ def get_local_disk_info(row) -> Dict[str, Any]:
+ info: Dict[str, Any] = {}
+ local_disk_supported = row['InstanceStorageSupported']
+ info['LocalDiskType'] = None
+ info['NVMeSupported'] = False
+ info['LocalDiskSize'] = None
+ info['LocalDiskCount'] = None
+
+ if local_disk_supported:
+ raw_info = row['InstanceStorageInfo']
+ info['NVMeSupported'] = raw_info['NvmeSupport'] == 'required'
+ # This is always 1. AWS probably made this as a list
+ # with future changes in consideration.
+ assert len(raw_info['Disks']) == 1, (
+ f'Instance type {row["InstanceType"]} has '
+ f'{len(raw_info["Disks"])} disk entries, expected 1.')
+ disk_info = raw_info['Disks'][0]
+ assert disk_info['Type'] in constants.LOCAL_DISK_TYPES, (
+ f'Instance type {row["InstanceType"]} has unknown '
+ f'disk type {disk_info["Type"]}.')
+ info['LocalDiskType'] = disk_info['Type']
+ info['LocalDiskSize'] = disk_info['SizeInGB']
+ info['LocalDiskCount'] = disk_info['Count']
+ return info
+
def get_additional_columns(row) -> pd.Series:
acc_name, acc_count = get_acc_info(row)
- # AWS p3dn.24xlarge offers a different V100 GPU.
+ # AWS instance type workarounds for incorrect/missing GPU info.
# See https://aws.amazon.com/blogs/compute/optimizing-deep-learning-on-p3-and-p3dn-with-efa/ # pylint: disable=line-too-long
if row['InstanceType'] == 'p3dn.24xlarge':
acc_name = 'V100-32GB'
- if row['InstanceType'] == 'p4de.24xlarge':
+ elif row['InstanceType'] == 'p4de.24xlarge':
acc_name = 'A100-80GB'
acc_count = 8
- if row['InstanceType'] == 'p5en.48xlarge':
+ elif row['InstanceType'] in ('p5e.48xlarge', 'p5en.48xlarge'):
# TODO(andyl): Check if this workaround still needed after
# v0.10.0 released. Currently, the acc_name returned by the
# AWS API is 'NVIDIA', which is incorrect. See #4652.
+ # Both p5e.48xlarge and p5en.48xlarge have 8x H200 GPUs.
acc_name = 'H200'
acc_count = 8
- if (row['InstanceType'].startswith('g6f') or
- row['InstanceType'].startswith('gr6f')):
+ elif (row['InstanceType'].startswith('g6f') or
+ row['InstanceType'].startswith('gr6f')):
# These instance actually have only fractional GPUs, but the API
# returns Count: 1 or Count: 0 under GpuInfo. We need to
# directly check the GPU memory to get the actual fraction of
@@ -297,14 +337,18 @@ def get_additional_columns(row) -> pd.Series:
fraction = row['GpuInfo']['Gpus'][0]['MemoryInfo'][
'SizeInMiB'] / L4_GPU_MEMORY
acc_count = round(fraction, 3)
- if row['InstanceType'] == 'p5.4xlarge':
+ elif row['InstanceType'] == 'p5.4xlarge':
acc_count = 1
+ elif row['InstanceType'].startswith('g7e'):
+ # Change name from "RTX PRO Server 6000" to "RTXPRO6000" for consistency
+ acc_name = 'RTXPRO6000'
return pd.Series({
'AcceleratorName': acc_name,
'AcceleratorCount': acc_count,
'vCPUs': get_vcpus(row),
'MemoryGiB': get_memory_gib(row),
'Arch': get_arch(row),
+ **get_local_disk_info(row)
})
# The AWS API may not have all the instance types in the pricing table,
diff --git a/sky/catalog/kubernetes_catalog.py b/sky/catalog/kubernetes_catalog.py
index 1ed95b70a1a..4fa29face85 100644
--- a/sky/catalog/kubernetes_catalog.py
+++ b/sky/catalog/kubernetes_catalog.py
@@ -206,6 +206,13 @@ def _list_accelerators(
for node in nodes:
# Check if node is ready
node_is_ready = node.is_ready()
+ node_is_cordoned = node.is_cordoned()
+ node_taints = node.get_taints(
+ exclude_cordon=True,
+ exclude_not_ready=True,
+ exclude_effects=['PreferNoSchedule'],
+ exclude_keys=kubernetes_utils.get_handled_taint_keys())
+ node_is_tainted = len(node_taints) > 0
for key in keys:
if key in node.metadata.labels:
@@ -268,8 +275,9 @@ def _list_accelerators(
total_accelerators_available[accelerator_name] = (
total_accelerators_available.get(accelerator_name, 0))
- # Skip availability counting for not-ready nodes
- if not node_is_ready:
+ # Skip availability counting for not-ready, cordoned,
+ # or tainted nodes
+ if not node_is_ready or node_is_cordoned or node_is_tainted:
continue
if error_on_get_allocated_gpu_qty_by_node:
diff --git a/sky/catalog/yotta_catalog.py b/sky/catalog/yotta_catalog.py
new file mode 100644
index 00000000000..db317b221ec
--- /dev/null
+++ b/sky/catalog/yotta_catalog.py
@@ -0,0 +1,98 @@
+""" Yotta | Catalog
+This module loads the service catalog file and can be used to
+query instance types and pricing information for Yotta.
+"""
+
+import typing
+from typing import Dict, List, Optional, Tuple, Union
+
+from sky.catalog import common
+from sky.utils import ux_utils
+
+if typing.TYPE_CHECKING:
+ from sky.clouds import cloud
+
+_df = common.read_catalog('yotta/vms.csv')
+
+
+def instance_type_exists(instance_type: str) -> bool:
+ return common.instance_type_exists_impl(_df, instance_type)
+
+
+def validate_region_zone(
+ region: Optional[str],
+ zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
+ if zone is not None:
+ with ux_utils.print_exception_no_traceback():
+ raise ValueError('Yotta does not support zones.')
+ return common.validate_region_zone_impl('yotta', _df, region, zone)
+
+
+def get_hourly_cost(instance_type: str,
+ use_spot: bool = False,
+ region: Optional[str] = None,
+ zone: Optional[str] = None) -> float:
+ """Returns the cost, or the cheapest cost among all zones for spot."""
+ return common.get_hourly_cost_impl(_df, instance_type, use_spot, region,
+ zone)
+
+
+def get_vcpus_mem_from_instance_type(
+ instance_type: str) -> Tuple[Optional[float], Optional[float]]:
+ return common.get_vcpus_mem_from_instance_type_impl(_df, instance_type)
+
+
+def get_default_instance_type(cpus: Optional[str] = None,
+ memory: Optional[str] = None,
+ disk_tier: Optional[str] = None,
+ region: Optional[str] = None,
+ zone: Optional[str] = None) -> Optional[str]:
+ del disk_tier, region, zone # Unused.
+ # NOTE: After expanding catalog to multiple entries, you may
+ # want to specify a default instance type or family.
+ return common.get_instance_type_for_cpus_mem_impl(_df, cpus, memory)
+
+
+def get_accelerators_from_instance_type(
+ instance_type: str) -> Optional[Dict[str, Union[int, float]]]:
+ return common.get_accelerators_from_instance_type_impl(_df, instance_type)
+
+
+def get_instance_type_for_accelerator(
+ acc_name: str,
+ acc_count: int,
+ cpus: Optional[str] = None,
+ memory: Optional[str] = None,
+ use_spot: bool = False,
+ region: Optional[str] = None,
+ zone: Optional[str] = None) -> Tuple[Optional[List[str]], List[str]]:
+ """Returns a list of instance types that have the given accelerator."""
+ return common.get_instance_type_for_accelerator_impl(df=_df,
+ acc_name=acc_name,
+ acc_count=acc_count,
+ cpus=cpus,
+ memory=memory,
+ use_spot=use_spot,
+ region=region,
+ zone=zone)
+
+
+def get_region_zones_for_instance_type(instance_type: str,
+ use_spot: bool) -> List['cloud.Region']:
+ df = _df[_df['InstanceType'] == instance_type]
+ return common.get_region_zones(df, use_spot)
+
+
+def list_accelerators(
+ gpus_only: bool,
+ name_filter: Optional[str],
+ region_filter: Optional[str],
+ quantity_filter: Optional[int],
+ case_sensitive: bool = True,
+ all_regions: bool = False,
+ require_price: bool = True) -> Dict[str, List[common.InstanceTypeInfo]]:
+ """Returns all instance types in Yotta offering GPUs."""
+ del require_price # Unused.
+ return common.list_accelerators_impl('Yotta', _df, gpus_only, name_filter,
+ region_filter, quantity_filter,
+ case_sensitive, all_regions)
diff --git a/sky/client/cli/command.py b/sky/client/cli/command.py
index 0a81f2b50ed..312238f9b39 100644
--- a/sky/client/cli/command.py
+++ b/sky/client/cli/command.py
@@ -28,10 +28,12 @@
import fnmatch
import os
import pathlib
+import re
import shlex
import shutil
import subprocess
import sys
+import tempfile
import time
import traceback
import typing
@@ -68,6 +70,7 @@
from sky.schemas.api import responses
from sky.server import common as server_common
from sky.server import constants as server_constants
+from sky.server.requests import payloads
from sky.server.requests import requests
from sky.skylet import autostop_lib
from sky.skylet import constants
@@ -123,7 +126,8 @@
]
_DEFAULT_MANAGED_JOB_FIELDS_TO_GET = [
'job_id', 'task_id', 'workspace', 'job_name', 'task_name', 'resources',
- 'submitted_at', 'end_at', 'job_duration', 'recovery_count', 'status', 'pool'
+ 'submitted_at', 'end_at', 'job_duration', 'recovery_count', 'status',
+ 'pool', 'is_primary_in_job_group'
]
_VERBOSE_MANAGED_JOB_FIELDS_TO_GET = _DEFAULT_MANAGED_JOB_FIELDS_TO_GET + [
'current_cluster_name', 'job_id_on_pool_cluster', 'start_at', 'infra',
@@ -317,14 +321,21 @@ def _async_call_or_wait(request_id: server_common.RequestId[T],
f'{colorama.Style.RESET_ALL}\n')
-def _merge_env_vars(env_dict: Optional[Dict[str, str]],
- env_list: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
- """Merges all values from env_list into env_dict."""
- if not env_dict:
- return env_list
- for (key, value) in env_list:
- env_dict[key] = value
- return list(env_dict.items())
+def _merge_cli_and_file_vars(
+ env_dicts: List[Optional[Dict[str, str]]],
+ env_list: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
+ """Merges all values from env_list and env_dicts. Priority is
+ as follows: env_list has highest priority, and env_dict with
+ higher index has more priority than that of lower index."""
+ final_env_dict = {}
+ for env_dict in env_dicts:
+ if env_dict is None:
+ continue
+ for k, v in env_dict.items():
+ final_env_dict[k] = v
+ for k, v in env_list:
+ final_env_dict[k] = v
+ return list(final_env_dict.items())
def _complete_cluster_name(ctx: click.Context, param: click.Parameter,
@@ -709,6 +720,27 @@ def _check_yaml(entrypoint: str) -> Tuple[bool, Optional[Dict[str, Any]]]:
return is_yaml, result
+def _check_recipe_reference(entrypoint: str) -> Tuple[bool, Optional[str]]:
+ """Check if entrypoint is a recipe reference like 'recipes:my-recipe'.
+
+ Args:
+ entrypoint: The entrypoint string to check.
+
+ Returns:
+ Tuple of (is_recipe, recipe_name). If is_recipe is True, recipe_name
+ contains the name of the recipe to fetch from the Recipe Hub.
+ """
+ # Pattern matches 'recipes:'
+ # Recipe names must start with a letter, and can contain letters, numbers,
+ # and dashes, and must end with an alphanumeric character.
+ pattern = re.compile(r'^recipes:(' + constants.RECIPE_NAME_VALID_REGEX +
+ r')$')
+ match = pattern.match(entrypoint)
+ if match:
+ return True, match.group(1)
+ return False, None
+
+
def _pop_and_ignore_fields_in_override_params(
params: Dict[str, Any], field_to_ignore: List[str]) -> None:
"""Pops and ignores fields in override params.
@@ -728,6 +760,55 @@ def _pop_and_ignore_fields_in_override_params(
fg='yellow')
+def _get_recipe_yaml(entrypoint: str) -> Optional[str]:
+ """Checks if entrypoint is a recipe reference and returns the recipe YAML.
+
+ Fetches the recipe content from the API server.
+
+ Args:
+ entrypoint: The entrypoint string to check.
+
+ Returns:
+ The recipe YAML if entrypoint is a recipe reference. Otherwise, None.
+ """
+ is_recipe, recipe_name = _check_recipe_reference(entrypoint)
+ if is_recipe:
+ assert recipe_name is not None # For mypy
+ click.secho('Recipe to run: ', fg='cyan', nl=False)
+ click.secho(recipe_name)
+ try:
+ # Make API request to fetch recipe from server
+ body = payloads.RecipeGetBody(recipe_name=recipe_name)
+ response = server_common.make_authenticated_request(
+ 'POST', '/recipes/get', json=body.model_dump())
+ request_id: server_common.RequestId[Optional[Dict[
+ str, Any]]] = server_common.get_request_id(response)
+ recipe = sdk.get(request_id)
+ except requests_lib.exceptions.ConnectionError as e:
+ raise click.UsageError(
+ f'Failed to connect to API server to fetch recipe '
+ f'{recipe_name!r}: {e}') from e
+ except Exception as e:
+ # Handle errors from the API server (e.g., recipe not found)
+ raise click.UsageError(str(e)) from e
+
+ if recipe is None:
+ raise click.UsageError(f'Recipe not found: {recipe_name}')
+
+ content = recipe.get('content')
+ if content is None:
+ raise click.UsageError(f'Recipe {recipe_name!r} has no content')
+
+ # Write to temp file and treat as YAML
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml',
+ delete=False) as f:
+ f.write(content)
+ return f.name
+ else:
+ logger.debug(f'Not a recipe reference: {entrypoint}')
+ return None
+
+
def _make_task_or_dag_from_entrypoint_with_overrides(
entrypoint: Tuple[str, ...],
*,
@@ -766,7 +847,14 @@ def _make_task_or_dag_from_entrypoint_with_overrides(
raise click.UsageError('Cannot specify both --git-url and --workdir')
entrypoint = ' '.join(entrypoint)
+
+ # Check if entrypoint is a recipe reference (recipes:)
+ recipe_yaml = _get_recipe_yaml(entrypoint)
+ if recipe_yaml is not None:
+ entrypoint = recipe_yaml
+
is_yaml, _ = _check_yaml(entrypoint)
+
entrypoint: Optional[str]
if is_yaml:
# Treat entrypoint as a yaml.
@@ -801,6 +889,20 @@ def _make_task_or_dag_from_entrypoint_with_overrides(
if is_yaml:
assert entrypoint is not None
usage_lib.messages.usage.update_user_task_yaml(entrypoint)
+
+ # Check if this is a JobGroup YAML
+ if dag_utils.is_job_group_yaml(entrypoint):
+ click.secho('Detected JobGroup YAML', fg='cyan')
+ dag = dag_utils.load_job_group_from_yaml(entrypoint,
+ env_overrides=env,
+ secrets_overrides=secret)
+ if override_params:
+ click.secho(
+ f'WARNING: override params {override_params} are ignored '
+ 'for JobGroup YAML.',
+ fg='yellow')
+ return dag
+
dag = dag_utils.load_chain_dag_from_yaml(entrypoint,
env_overrides=env,
secret_overrides=secret)
@@ -1058,6 +1160,7 @@ def launch(
image_id: Optional[str],
env_file: Optional[Dict[str, str]],
env: List[Tuple[str, str]],
+ secret_file: Optional[Dict[str, str]],
secret: List[Tuple[str, str]],
disk_size: Optional[int],
disk_tier: Optional[str],
@@ -1093,7 +1196,8 @@ def launch(
# job can take up resources on the API server. When there are a lot of
# `launch` submitted asynchronously, the log tailing may overwhelm the API
# server, if the jobs are long running.
- env = _merge_env_vars(env_file, env)
+ env = _merge_cli_and_file_vars([env_file], env)
+ secret = _merge_cli_and_file_vars([env_file, secret_file], secret)
controller_utils.check_cluster_name_not_controller(
cluster, operation_str='Launching tasks on it')
if backend_name is None:
@@ -1247,6 +1351,7 @@ def exec(
image_id: Optional[str],
env_file: Optional[Dict[str, str]],
env: List[Tuple[str, str]],
+ secret_file: Optional[Dict[str, str]],
secret: List[Tuple[str, str]],
cpus: Optional[str],
memory: Optional[str],
@@ -1327,7 +1432,8 @@ def exec(
raise click.UsageError('Missing argument \'[ENTRYPOINT]...\'')
assert cluster is not None, (cluster, cluster_option, entrypoint)
- env = _merge_env_vars(env_file, env)
+ env = _merge_cli_and_file_vars([env_file], env)
+ secret = _merge_cli_and_file_vars([env_file, secret_file], secret)
controller_utils.check_cluster_name_not_controller(
cluster, operation_str='Executing task on it')
@@ -1568,12 +1674,15 @@ def _handle_services_request(
# print the original error.
pass
if not msg:
- msg = (f'Failed to fetch {noun} statuses due to connection issues. '
- 'Please try again later. Details: '
- f'{common_utils.format_exception(e, use_bracket=True)}')
- except Exception as e: # pylint: disable=broad-except
- msg = (f'Failed to fetch {noun} statuses: '
- f'{common_utils.format_exception(e, use_bracket=True)}')
+ # This is an actual error (connection issues), not a normal state.
+ # Format the error message and raise a new exception.
+ # Use 'from None' to suppress the exception chain and only show
+ # the formatted message.
+ error_msg = (
+ f'Failed to fetch {noun} statuses due to connection issues. '
+ 'Please try again later. Details: '
+ f'{common_utils.format_exception(e, use_bracket=True)}')
+ raise RuntimeError(error_msg) from None
else:
if show_endpoint:
if len(service_records) != 1:
@@ -1630,7 +1739,8 @@ def _show_endpoint(query_clusters: Optional[List[str]],
('endpoint port' if show_single_endpoint else 'endpoints')))
cluster_record = cluster_records[0]
- if cluster_record['status'] != status_lib.ClusterStatus.UP:
+ if cluster_record['status'] not in (status_lib.ClusterStatus.UP,
+ status_lib.ClusterStatus.AUTOSTOPPING):
with ux_utils.print_exception_no_traceback():
raise RuntimeError(f'Cluster {cluster_record["name"]!r} '
'is not in UP status.')
@@ -1807,6 +1917,7 @@ def status(verbose: bool, refresh: bool, ip: bool, endpoints: bool,
# Do not show job queue if user specifies clusters, and if user
# specifies --ip or --endpoint(s).
show_managed_jobs = show_managed_jobs and not any([clusters, ip, endpoints])
+ show_pools = show_pools and not any([clusters, ip, endpoints])
show_endpoints = endpoints or endpoint is not None
show_single_endpoint = endpoint is not None
show_services = show_services and not any([clusters, ip, endpoints])
@@ -2022,6 +2133,11 @@ def submit_enabled_clouds():
sdk.api_cancel(pool_status_request_id, silent=True)
num_pools = -1
msg = 'KeyboardInterrupt'
+ except Exception as e: # pylint: disable=broad-except
+ # For internal calls, handle exceptions gracefully by
+ # printing the error message instead of crashing.
+ num_pools = None
+ msg = str(e)
if num_pools is not None:
if num_pools > 0:
click.echo(f'\n{colorama.Fore.CYAN}{colorama.Style.BRIGHT}'
@@ -2050,6 +2166,11 @@ def submit_enabled_clouds():
sdk.api_cancel(service_status_request_id, silent=True)
num_services = -1
msg = 'KeyboardInterrupt'
+ except Exception as e: # pylint: disable=broad-except
+ # For internal calls, handle exceptions gracefully by
+ # printing the error message instead of crashing.
+ num_services = None
+ msg = str(e)
click.echo(msg)
if num_services is not None:
hints.append(
@@ -2247,6 +2368,10 @@ def _get_job_queue(cluster):
is_flag=True,
default=False,
help='Stream the cluster provisioning logs (provision.log).')
+@click.option('--autostop',
+ is_flag=True,
+ default=False,
+ help='Stream the autostop hook logs from the cluster.')
@click.option('--worker',
'-w',
default=None,
@@ -2290,6 +2415,7 @@ def logs(
cluster: str,
job_ids: Tuple[str, ...],
provision: bool,
+ autostop: bool, # pylint: disable=redefined-outer-name
worker: Optional[int],
sync_down: bool,
status: bool, # pylint: disable=redefined-outer-name
@@ -2319,6 +2445,9 @@ def logs(
4. If the job fails or fetching the logs fails, the command will exit with
a non-zero return code.
+
+ 5. If ``--autostop`` is specified, stream the autostop hook logs from the
+ cluster. This shows the output of the autostop hook script.
"""
if worker is not None:
if not provision:
@@ -2327,11 +2456,20 @@ def logs(
if worker < 1:
raise click.UsageError('--worker must be a positive integer.')
+ if provision and autostop:
+ raise click.UsageError(
+ '--provision and --autostop cannot be used together.')
+
if provision and (sync_down or status or job_ids):
raise click.UsageError(
'--provision cannot be combined with job log options '
'(--sync-down/--status/job IDs).')
+ if autostop and (sync_down or status or job_ids or worker is not None):
+ raise click.UsageError(
+ '--autostop cannot be combined with job log options '
+ '(--sync-down/--status/--worker/job IDs).')
+
if sync_down and status:
raise click.UsageError(
'Both --sync_down and --status are specified '
@@ -2352,6 +2490,13 @@ def logs(
follow=follow,
tail=tail))
+ if autostop:
+ # Stream autostop hook logs
+ sys.exit(
+ sdk.tail_autostop_logs(cluster_name=cluster,
+ follow=follow,
+ tail=tail))
+
if sync_down:
with rich_utils.client_status(
ux_utils.spinner_message('Downloading logs')):
@@ -2550,13 +2695,15 @@ def cancel(
@flags.all_option('Stop all existing clusters.')
@flags.all_users_option('Stop all existing clusters for all users.')
@flags.yes_option()
-@_add_click_options(flags.COMMON_OPTIONS)
+@_add_click_options(flags.GRACEFUL_OPTIONS + flags.COMMON_OPTIONS)
@usage_lib.entrypoint
def stop(
clusters: List[str],
all: bool, # pylint: disable=redefined-builtin
all_users: bool,
yes: bool,
+ graceful: bool,
+ graceful_timeout: Optional[int],
async_call: bool,
):
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
@@ -2593,6 +2740,8 @@ def stop(
all_users=all_users,
down=False,
no_confirm=yes,
+ graceful=graceful,
+ graceful_timeout=graceful_timeout,
async_call=async_call)
@@ -2970,7 +3119,7 @@ def start(
' in certain manual troubleshooting scenarios; with it set, it is the'
' user\'s responsibility to ensure there are no leaked instances and '
'related resources.'))
-@_add_click_options(flags.COMMON_OPTIONS)
+@_add_click_options(flags.GRACEFUL_OPTIONS + flags.COMMON_OPTIONS)
@usage_lib.entrypoint
def down(
clusters: List[str],
@@ -2978,6 +3127,8 @@ def down(
all_users: bool,
yes: bool,
purge: bool,
+ graceful: bool,
+ graceful_timeout: Optional[int],
async_call: bool,
):
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
@@ -3014,6 +3165,8 @@ def down(
down=True,
no_confirm=yes,
purge=purge,
+ graceful=graceful,
+ graceful_timeout=graceful_timeout,
async_call=async_call)
@@ -3175,6 +3328,8 @@ def _down_or_stop_clusters(
down: bool = False, # pylint: disable=redefined-outer-name
no_confirm: bool = True,
purge: bool = False,
+ graceful: bool = False,
+ graceful_timeout: Optional[int] = None,
idle_minutes_to_autostop: Optional[int] = None,
wait_for: Optional[autostop_lib.AutostopWaitFor] = None,
async_call: bool = False) -> None:
@@ -3192,6 +3347,10 @@ def _down_or_stop_clusters(
down: If True, tear down the clusters.
no_confirm: If True, skip the confirmation prompt.
purge: If True, forcefully remove the clusters from the cluster table.
+ graceful: If True, cancel the user task, but block until MOUNT_CACHE
+ finishes uploads.
+ graceful_timeout: If not None, sets a timeout for the graceful option
+ above (in seconds).
idle_minutes_to_autostop: The number of minutes to wait before
automatically stopping the cluster.
wait_for: Determines the condition for resetting the idleness timer.
@@ -3375,9 +3534,15 @@ def _down_or_stop(name: str):
else:
try:
if down:
- request_id = sdk.down(name, purge=purge)
+ request_id = sdk.down(name,
+ purge=purge,
+ graceful=graceful,
+ graceful_timeout=graceful_timeout)
else:
- request_id = sdk.stop(name, purge=purge)
+ request_id = sdk.stop(name,
+ purge=purge,
+ graceful=graceful,
+ graceful_timeout=graceful_timeout)
request_ids.append(request_id)
progress.stop()
_async_call_or_wait(
@@ -3734,7 +3899,10 @@ def _count_not_ready_gpus(
continue
node_is_ready = getattr(node_info, 'is_ready', True)
- if not node_is_ready:
+ node_is_cordoned = getattr(node_info, 'is_cordoned', False)
+ node_taints = getattr(node_info, 'taints', None) or []
+ node_is_tainted = len(node_taints) > 0
+ if not node_is_ready or node_is_cordoned or node_is_tainted:
not_ready_counts[accelerator_type] += accelerator_count
return not_ready_counts
@@ -3890,7 +4058,7 @@ def _format_kubernetes_node_info_combined(
context_title_str: str = 'CONTEXT') -> str:
node_table = log_utils.create_table([
context_title_str, 'NODE', 'vCPU', 'Memory (GB)', 'GPU',
- 'GPU UTILIZATION'
+ 'GPU UTILIZATION', 'NODE STATUS'
])
no_permissions_str = ''
@@ -3949,15 +4117,41 @@ def _format_kubernetes_node_info_combined(
utilization_str = (
f'{available} of '
f'{node_info.total["accelerator_count"]} free')
+
+ # Build node status string
+ status_info = []
# Check if node is ready (defaults to True for backward
# compatibility with older server versions)
node_is_ready = getattr(node_info, 'is_ready', True)
if not node_is_ready:
- utilization_str += ' (Node NotReady)'
-
+ status_info.append('NotReady')
+ node_is_cordoned = getattr(node_info, 'is_cordoned', False)
+ if node_is_cordoned:
+ status_info.append('Cordoned')
+ # Add taint info grouped by effect
+ taints = getattr(node_info, 'taints', None)
+ if taints:
+ # Group taints by effect: 'NoSchedule Taint [key1, key2],
+ # NoExecute Taint [key3]'
+ taints_by_effect: Dict[str, List[str]] = {}
+ for taint in taints:
+ effect = taint['effect']
+ key = taint['key']
+ if effect not in taints_by_effect:
+ taints_by_effect[effect] = []
+ taints_by_effect[effect].append(key)
+ taints_strs = []
+ for effect, keys in taints_by_effect.items():
+ taints_strs.append(
+ f'{effect} Taint [{", ".join(keys)}]')
+ if taints_strs:
+ status_info.append(', '.join(taints_strs))
+
+ status_str = ', '.join(
+ status_info) if status_info else 'Healthy'
node_table.add_row([
context_name, node_name, cpu_str, memory_str, acc_type,
- utilization_str
+ utilization_str, status_str
])
k8s_per_node_acc_message = (f'{cloud_str} per-node GPU availability')
@@ -4005,6 +4199,30 @@ def _format_slurm_node_info(slurm_cluster_names: List[str]) -> str:
f'{colorama.Style.RESET_ALL}\n'
f'{node_table.get_string()}')
+ def _get_labeled_zero_gpu_hint(
+ all_nodes_info: List[Tuple[str,
+ 'models.KubernetesNodesInfo']]) -> str:
+ """Returns a hint if any nodes have GPU labels but 0 GPU resources."""
+ # Collect nodes with GPU labels but 0 GPU resources
+ labeled_zero_gpu_nodes = [
+ (context, node_name)
+ for context, nodes_info in all_nodes_info
+ for node_name, node_info in nodes_info.node_info_dict.items()
+ if (node_info.accelerator_type is not None and
+ node_info.total.get('accelerator_count', 0) == 0)
+ ]
+
+ if not labeled_zero_gpu_nodes:
+ return ''
+
+ num_affected_nodes = len(labeled_zero_gpu_nodes)
+ node_list = ', '.join(
+ f'{ctx}/{name}' for ctx, name in labeled_zero_gpu_nodes[:3])
+ ellipsis = '...' if len(labeled_zero_gpu_nodes) > 3 else ''
+ return (f'Note: Some Kubernetes nodes have GPU labels but report 0 GPU '
+ f'resources. Please check the node labels and configuration. '
+ f'Affected {num_affected_nodes} node(s): {node_list}{ellipsis}')
+
def _format_kubernetes_realtime_gpu(
total_table: Optional['prettytable.PrettyTable'],
k8s_realtime_infos: List[Tuple[str, 'prettytable.PrettyTable']],
@@ -4070,6 +4288,11 @@ def _possibly_show_k8s_like_realtime(
show_node_info=True,
is_ssh=is_ssh)
+ # Check for nodes with GPU labels but 0 GPU resources
+ labeled_zero_hint = _get_labeled_zero_gpu_hint(all_nodes_info)
+ if labeled_zero_hint:
+ k8s_messages += labeled_zero_hint
+
if kubernetes_autoscaling:
k8s_messages += ('\n' +
kubernetes_utils.KUBERNETES_AUTOSCALER_NOTE)
@@ -4078,6 +4301,8 @@ def _possibly_show_k8s_like_realtime(
if not ssh_is_enabled:
yield ('SSH Node Pools are not enabled. To fix, run: '
'sky check ssh ')
+ if k8s_messages and print_section_titles:
+ yield '\n\n'
yield k8s_messages
return True, print_section_titles, ''
else:
@@ -4085,6 +4310,8 @@ def _possibly_show_k8s_like_realtime(
if not kubernetes_is_enabled:
yield ('Kubernetes is not enabled. To fix, run: '
'sky check kubernetes ')
+ if k8s_messages and print_section_titles:
+ yield '\n\n'
yield k8s_messages
return True, print_section_titles, ''
return False, print_section_titles, k8s_messages
@@ -4112,6 +4339,11 @@ def _possibly_show_k8s_like_realtime_for_acc(
all_nodes_info,
show_node_info=False,
is_ssh=is_ssh)
+
+ # Check for nodes with GPU labels but 0 GPU resources
+ labeled_zero_hint = _get_labeled_zero_gpu_hint(all_nodes_info)
+ if labeled_zero_hint:
+ k8s_messages += labeled_zero_hint
except ValueError as e:
# In the case of a specific accelerator, show the error message
# immediately (e.g., "Resources H100 not found ...")
@@ -4195,6 +4427,8 @@ def _output() -> Generator[str, None, None]:
stop_iter = stop_iter or stop_iter_one
print_section_titles = (print_section_titles or
print_section_titles_one)
+ if k8s_messages and k8s_messages_one:
+ k8s_messages += '\n'
k8s_messages += k8s_messages_one
prev_print_section_titles = print_section_titles_one
if stop_iter:
@@ -4381,11 +4615,8 @@ def _output() -> Generator[str, None, None]:
min_spot_price=('spot_price',
'min'))
df = df.merge(min_price_df, on='cloud')
- # Sort within each cloud by price.
- df = df.groupby('cloud', group_keys=False).apply(
- lambda x: x.sort_values(by=['price', 'spot_price']))
- # Sort across groups (clouds).
- df = df.sort_values(by=['min_price', 'min_spot_price'])
+ df = df.sort_values(
+ by=['min_price', 'min_spot_price', 'price', 'spot_price'])
df = df.drop(columns=['min_price', 'min_spot_price'])
sorted_dataclasses = [
catalog_common.InstanceTypeInfo(*row)
@@ -4644,6 +4875,12 @@ def volumes_apply(
volume_config_dict: Dict[str, Any] = {}
if entrypoint is not None and len(entrypoint) > 0:
entrypoint_str = ' '.join(entrypoint)
+
+ # Check if the entrypoint is a recipe reference
+ recipe_yaml = _get_recipe_yaml(entrypoint_str)
+ if recipe_yaml is not None:
+ entrypoint_str = recipe_yaml
+
is_yaml, yaml_config, yaml_file_provided, invalid_reason = (
_check_yaml_only(entrypoint_str))
if not is_yaml:
@@ -4717,10 +4954,18 @@ def _build_volume_override_config(
is_flag=True,
required=False,
help='Show all information in full.')
+@click.option('--refresh',
+ '-r',
+ default=False,
+ is_flag=True,
+ required=False,
+ help='Refresh volume state from cloud APIs before listing. '
+ 'Without this flag, cached data is returned which is updated '
+ 'periodically by the background daemon.')
@usage_lib.entrypoint
-def volumes_ls(verbose: bool):
+def volumes_ls(verbose: bool, refresh: bool):
"""List volumes managed by SkyPilot."""
- request_id = volumes_sdk.ls()
+ request_id = volumes_sdk.ls(refresh=refresh)
all_volumes = sdk.stream_and_get(request_id)
volume_table = table_utils.format_volume_table(all_volumes,
show_all=verbose)
@@ -4881,6 +5126,7 @@ def jobs_launch(
job_recovery: Optional[str],
env_file: Optional[Dict[str, str]],
env: List[Tuple[str, str]],
+ secret_file: Optional[Dict[str, str]],
secret: List[Tuple[str, str]],
disk_size: Optional[int],
disk_tier: Optional[str],
@@ -4917,7 +5163,8 @@ def jobs_launch(
raise click.UsageError('Cannot specify both --name and --cluster. '
'Use one of the flags as they are alias.')
name = cluster
- env = _merge_env_vars(env_file, env)
+ env = _merge_cli_and_file_vars([env_file], env)
+ secret = _merge_cli_and_file_vars([env_file, secret_file], secret)
cloud, region, zone = _handle_infra_cloud_region_zone_options(
infra, cloud, region, zone)
task_or_dag = _make_task_or_dag_from_entrypoint_with_overrides(
@@ -5291,10 +5538,31 @@ def jobs_cancel(
required=False,
help='Download logs for all jobs shown in the queue.')
@click.argument('job_id', required=False, type=int)
+@click.argument('task', required=False, type=str, default=None)
@usage_lib.entrypoint
def jobs_logs(name: Optional[str], job_id: Optional[int], follow: bool,
- controller: bool, refresh: bool, sync_down: bool):
- """Tail or sync down the log of a managed job."""
+ controller: bool, refresh: bool, sync_down: bool,
+ task: Optional[str]):
+ """Tail or sync down the log of a managed job.
+
+ TASK can be a task ID (integer) or task name. Numeric values are treated
+ as task IDs. If not specified, logs for all tasks are shown.
+
+
+ Examples:
+
+ \b
+ # View logs for job ID 1, task 0
+ sky jobs logs 1 0
+
+ \b
+ # View logs for job named 'my-job', task 'train'
+ sky jobs logs -n my-job train
+
+ \b
+ # View logs for job named 'my-job', task 'eval'
+ sky jobs logs -n my-job eval
+ """
try:
if sync_down:
with rich_utils.client_status(
@@ -5311,11 +5579,17 @@ def jobs_logs(name: Optional[str], job_id: Optional[int], follow: bool,
logger.info(f'{fore.CYAN}Job {job} logs{controller_str}: '
f'{log_local_path}{style.RESET_ALL}')
else:
+ # Parse task argument: if numeric, treat as task ID (int),
+ # otherwise treat as task name (str)
+ parsed_task: Optional[Union[str, int]] = None
+ if task is not None:
+ parsed_task = int(task) if task.isdigit() else task
returncode = managed_jobs.tail_logs(name=name,
job_id=job_id,
follow=follow,
controller=controller,
- refresh=refresh)
+ refresh=refresh,
+ task=parsed_task)
sys.exit(returncode)
except exceptions.ClusterNotUpError:
with ux_utils.print_exception_no_traceback():
@@ -5384,6 +5658,7 @@ def jobs_pool_apply(
image_id: Optional[str],
env_file: Optional[Dict[str, str]],
env: List[Tuple[str, str]],
+ secret_file: Optional[Dict[str, str]],
secret: List[Tuple[str, str]],
gpus: Optional[str],
instance_type: Optional[str],
@@ -5417,6 +5692,12 @@ def jobs_pool_apply(
'Cannot specify both --workers and POOL_YAML. Please use one of '
'them.')
+ if pool_yaml is not None and len(pool_yaml) > 0:
+ recipe_yaml = _get_recipe_yaml(pool_yaml[0])
+ if recipe_yaml is not None:
+ click.secho('Recipe to run: ', fg='cyan', nl=False)
+ pool_yaml = (recipe_yaml,)
+
if pool_yaml is None or len(pool_yaml) == 0:
if pool is None:
raise click.UsageError(
@@ -5444,6 +5725,7 @@ def jobs_pool_apply(
image_id=image_id,
env_file=env_file,
env=env,
+ secret_file=secret_file,
secret=secret,
disk_size=disk_size,
disk_tier=disk_tier,
@@ -5885,7 +6167,8 @@ def _generate_task_with_service(
image_id: Optional[str],
env_file: Optional[Dict[str, str]],
env: List[Tuple[str, str]],
- secret: Optional[List[Tuple[str, str]]],
+ secret_file: Optional[Dict[str, str]],
+ secret: List[Tuple[str, str]],
gpus: Optional[str],
instance_type: Optional[str],
ports: Optional[Tuple[str]],
@@ -5904,7 +6187,8 @@ def _generate_task_with_service(
yaml_name = 'SERVICE_YAML' if not pool else 'POOL_YAML'
if not is_yaml:
raise click.UsageError(f'{yaml_name} must be a valid YAML file.')
- env = _merge_env_vars(env_file, env)
+ env = _merge_cli_and_file_vars([env_file], env)
+ secret = _merge_cli_and_file_vars([env_file, secret_file], secret)
# We keep nargs=-1 in service_yaml argument to reuse this function.
task = _make_task_or_dag_from_entrypoint_with_overrides(
service_yaml_args,
@@ -6042,6 +6326,7 @@ def serve_up(
image_id: Optional[str],
env_file: Optional[Dict[str, str]],
env: List[Tuple[str, str]],
+ secret_file: Optional[Dict[str, str]],
secret: List[Tuple[str, str]],
gpus: Optional[str],
instance_type: Optional[str],
@@ -6105,6 +6390,7 @@ def serve_up(
image_id=image_id,
env_file=env_file,
env=env,
+ secret_file=secret_file,
secret=secret,
disk_size=disk_size,
disk_tier=disk_tier,
@@ -6156,12 +6442,12 @@ def serve_up(
@timeline.event
@usage_lib.entrypoint
def serve_update(
- service_name: str, service_yaml: Tuple[str,
- ...], workdir: Optional[str],
- infra: Optional[str], cloud: Optional[str], region: Optional[str],
- zone: Optional[str], num_nodes: Optional[int], use_spot: Optional[bool],
- image_id: Optional[str], env_file: Optional[Dict[str, str]],
- env: List[Tuple[str, str]], secret: List[Tuple[str, str]],
+ service_name: str, service_yaml: Tuple[str, ...],
+ workdir: Optional[str], infra: Optional[str], cloud: Optional[str],
+ region: Optional[str], zone: Optional[str], num_nodes: Optional[int],
+ use_spot: Optional[bool], image_id: Optional[str],
+ env_file: Optional[Dict[str, str]], env: List[Tuple[str, str]],
+ secret_file: Optional[Dict[str, str]], secret: List[Tuple[str, str]],
gpus: Optional[str], instance_type: Optional[str], ports: Tuple[str],
cpus: Optional[str], memory: Optional[str], disk_size: Optional[int],
disk_tier: Optional[str], network_tier: Optional[str], mode: str,
@@ -6215,6 +6501,7 @@ def serve_update(
image_id=image_id,
env_file=env_file,
env=env,
+ secret_file=secret_file,
secret=secret,
disk_size=disk_size,
disk_tier=disk_tier,
diff --git a/sky/client/cli/flags.py b/sky/client/cli/flags.py
index 531c47281d6..d2012af9cb5 100644
--- a/sky/client/cli/flags.py
+++ b/sky/client/cli/flags.py
@@ -52,6 +52,21 @@ def _parse_secret_var(secret_var: str) -> Tuple[str, str]:
help=('Run the command asynchronously.'))
]
+GRACEFUL_OPTIONS = [
+ click.option(
+ '--graceful',
+ is_flag=True,
+ default=False,
+ help=('Wait for MOUNT_CACHED uploads to complete before '
+ 'stopping/terminating. Will cancel current jobs first.')),
+ click.option('--graceful-timeout',
+ type=int,
+ default=None,
+ help=('Timeout in seconds for `--graceful` flag. When not '
+ 'set, will wait for MOUNT_CACHED uploads until they are '
+ 'finished.')),
+]
+
TASK_OPTIONS = [
click.option(
'--workdir',
@@ -155,7 +170,11 @@ def _parse_secret_var(secret_var: str) -> Tuple[str, str]:
node.
If any values from ``--env-file`` conflict with values set by
- ``--env``, the ``--env`` value will be preferred."""),
+ ``--env``, the ``--env`` value will be preferred.
+
+ Values from ``--env-file`` will also load to secrets with lower
+ preference compared to ``--secret`` or ``--secret-file``.
+ """),
click.option(
'--env',
required=False,
@@ -176,6 +195,16 @@ def _parse_secret_var(secret_var: str) -> Tuple[str, str]:
3. ``--env MY_ENV3``: set ``$MY_ENV3`` on the cluster to be the
same value of ``$MY_ENV3`` in the local environment.""",
),
+ click.option(
+ '--secret-file',
+ required=False,
+ type=dotenv.dotenv_values,
+ help="""\
+ Path to a dotenv file with secret variables to set on the remote node.
+
+ If any values from ``--secret-file`` conflict with values set by
+ ``--secret``, the ``--secret`` value will be preferred.""",
+ ),
click.option(
'--secret',
required=False,
diff --git a/sky/client/cli/table_utils.py b/sky/client/cli/table_utils.py
index dd9fa4876ee..fe07d7e907b 100644
--- a/sky/client/cli/table_utils.py
+++ b/sky/client/cli/table_utils.py
@@ -204,15 +204,23 @@ def format(self) -> str:
class PVCVolumeTable(VolumeTable):
"""The PVC volume table."""
+ def __init__(self,
+ volumes: List[responses.VolumeRecord],
+ show_all: bool = False):
+ # Check if any volume has an error before creating the table
+ self._has_errors = any(row.get('error_message') for row in volumes)
+ super().__init__(volumes, show_all)
+
def _create_table(self, show_all: bool = False) -> prettytable.PrettyTable:
"""Create the PVC volume table."""
# If show_all is False, show the table with the columns:
# NAME, TYPE, INFRA, SIZE, USER, WORKSPACE,
# AGE, STATUS, LAST_USE, USED_BY, IS_EPHEMERAL
+ # (+ MESSAGE if any volume is not ready)
# If show_all is True, show the table with the columns:
# NAME, TYPE, INFRA, SIZE, USER, WORKSPACE,
# AGE, STATUS, LAST_USE, USED_BY, IS_EPHEMERAL, NAME_ON_CLOUD
- # STORAGE_CLASS, ACCESS_MODE
+ # STORAGE_CLASS, ACCESS_MODE, MESSAGE
columns = _BASIC_COLUMNS + [
'IS_EPHEMERAL',
@@ -222,7 +230,11 @@ def _create_table(self, show_all: bool = False) -> prettytable.PrettyTable:
'NAME_ON_CLOUD',
'STORAGE_CLASS',
'ACCESS_MODE',
+ 'MESSAGE',
]
+ elif self._has_errors:
+ # Show MESSAGE column even without show_all if there are issues
+ columns = columns + ['MESSAGE']
table = log_utils.create_table(columns)
return table
@@ -239,6 +251,17 @@ def _add_rows(self,
table_row.append(
row.get('config', {}).get('storage_class_name', '-'))
table_row.append(row.get('config', {}).get('access_mode', ''))
+ # Add error message
+ error_msg = row.get('error_message', '')
+ table_row.append(error_msg if error_msg else '-')
+ elif self._has_errors:
+ # Show error message even without show_all if there are errors
+ error_msg = row.get('error_message', '')
+ # Truncate error message for display
+ if error_msg:
+ error_msg = common_utils.truncate_long_string(
+ error_msg, constants.ERROR_MESSAGE_TRUNC_LENGTH)
+ table_row.append(error_msg if error_msg else '-')
self.table.add_row(table_row)
diff --git a/sky/client/oauth.py b/sky/client/oauth.py
index 3afc1f2366e..da1d3006935 100644
--- a/sky/client/oauth.py
+++ b/sky/client/oauth.py
@@ -5,7 +5,7 @@
import time
from typing import Dict, Optional
-AUTH_TIMEOUT = 300 # 5 minutes
+from sky.server import constants as server_constants
class _AuthCallbackHandler(BaseHTTPRequestHandler):
@@ -44,10 +44,12 @@ def log_message(self, *args): # pylint: disable=unused-argument
pass
-def start_local_auth_server(port: int,
- token_store: Dict[str, Optional[str]],
- remote_endpoint: str,
- timeout: int = AUTH_TIMEOUT) -> HTTPServer:
+def start_local_auth_server(
+ port: int,
+ token_store: Dict[str, Optional[str]],
+ remote_endpoint: str,
+ timeout: int = server_constants.AUTH_SESSION_TIMEOUT_SECONDS
+) -> HTTPServer:
"""Start a local HTTP server to handle OAuth callback.
Args:
diff --git a/sky/client/sdk.py b/sky/client/sdk.py
index fe858966930..007cf847a14 100644
--- a/sky/client/sdk.py
+++ b/sky/client/sdk.py
@@ -36,6 +36,7 @@
from sky.jobs import utils as managed_job_utils
from sky.schemas.api import responses
from sky.server import common as server_common
+from sky.server import constants as server_constants
from sky.server import rest
from sky.server import versions
from sky.server.requests import payloads
@@ -65,6 +66,7 @@
import binascii
import io
import pathlib
+ import secrets
import time
import webbrowser
@@ -82,6 +84,8 @@
base64 = adaptors_common.LazyImport('base64')
binascii = adaptors_common.LazyImport('binascii')
pathlib = adaptors_common.LazyImport('pathlib')
+ requests = adaptors_common.LazyImport('requests')
+ secrets = adaptors_common.LazyImport('secrets')
time = adaptors_common.LazyImport('time')
# only used in dashboard() and api_login()
webbrowser = adaptors_common.LazyImport('webbrowser')
@@ -374,7 +378,7 @@ def optimize(
for a task.
exceptions.NoCloudAccessError: if no public clouds are enabled.
"""
- dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag)
+ dag_str = dag_utils.dump_dag_to_yaml_str(dag)
body = payloads.OptimizeBody(dag=dag_str,
minimize=minimize,
@@ -434,7 +438,7 @@ def validate(
task.expand_and_validate_workdir()
if not workdir_only:
task.expand_and_validate_file_mounts()
- dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag)
+ dag_str = dag_utils.dump_dag_to_yaml_str(dag)
body = payloads.ValidateBody(dag=dag_str,
request_options=admin_policy_request_options)
response = server_common.make_authenticated_request(
@@ -732,7 +736,7 @@ def _launch(
dag = client_common.upload_mounts_to_api_server(dag)
- dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag)
+ dag_str = dag_utils.dump_dag_to_yaml_str(dag)
body = payloads.LaunchBody(
task=dag_str,
@@ -823,7 +827,7 @@ def exec( # pylint: disable=redefined-builtin
dag = dag_utils.convert_entrypoint_to_dag(task)
validate(dag, workdir_only=True)
dag = client_common.upload_mounts_to_api_server(dag, workdir_only=True)
- dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag)
+ dag_str = dag_utils.dump_dag_to_yaml_str(dag)
body = payloads.ExecBody(
task=dag_str,
cluster_name=cluster_name,
@@ -1017,6 +1021,44 @@ def tail_provision_logs(cluster_name: str,
return 0
+@usage_lib.entrypoint
+@server_common.check_server_healthy_or_start
+@annotations.client_api
+def tail_autostop_logs(cluster_name: str,
+ follow: bool = True,
+ tail: int = 0) -> int:
+ """Tails the autostop hook logs (autostop_hook.log) for a cluster.
+
+ Args:
+ cluster_name: name of the cluster.
+ follow: whether to follow the logs.
+ tail: number of lines to display from the end of the log file.
+
+ Returns:
+ Exit code 0 on streaming success; non-zero on failure.
+
+ Request Raises:
+ ValueError: if arguments are invalid or the cluster is not supported.
+ sky.exceptions.ClusterDoesNotExist: if the cluster does not exist.
+ sky.exceptions.ClusterNotUpError: if the cluster is not UP.
+ sky.exceptions.NotSupportedError: if the cluster is not based on
+ CloudVmRayBackend.
+ sky.exceptions.ClusterOwnerIdentityMismatchError: if the current user is
+ not the same as the user who created the cluster.
+ sky.exceptions.CloudUserIdentityError: if we fail to get the current
+ user identity.
+ """
+ body = payloads.AutostopLogsBody(cluster_name=cluster_name,
+ follow=follow,
+ tail=tail)
+
+ response = server_common.make_authenticated_request(
+ 'POST', '/autostop_logs', json=json.loads(body.model_dump_json()))
+ request_id: server_common.RequestId[int] = server_common.get_request_id(
+ response)
+ return stream_and_get(request_id)
+
+
@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
@annotations.client_api
@@ -1153,8 +1195,12 @@ def start(
@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
@annotations.client_api
-def down(cluster_name: str,
- purge: bool = False) -> server_common.RequestId[None]:
+def down(
+ cluster_name: str,
+ purge: bool = False,
+ graceful: bool = False,
+ graceful_timeout: Optional[int] = None,
+) -> server_common.RequestId[None]:
"""Tears down a cluster.
Tearing down a cluster will delete all associated resources (all billing
@@ -1169,6 +1215,10 @@ def down(cluster_name: str,
troubleshooting scenarios; with it set, it is the user's
responsibility to ensure there are no leaked instances and related
resources.
+ graceful: Cancel the user's task but block until MOUNT_CACHED data is
+ fully uploaded. This helps with preserving user data integrity.
+ graceful_timeout: If not None, sets a timeout for the graceful option
+ above (in seconds).
Returns:
The request ID of the down request.
@@ -1184,9 +1234,15 @@ def down(cluster_name: str,
jobs controller.
"""
+ version = versions.get_remote_api_version()
+ if graceful and version is not None and version < 32:
+ logger.warning('`--graceful` is ignored because the server does '
+ 'not support it yet.')
body = payloads.StopOrDownBody(
cluster_name=cluster_name,
purge=purge,
+ graceful=graceful,
+ graceful_timeout=graceful_timeout,
)
response = server_common.make_authenticated_request(
'POST', '/down', json=json.loads(body.model_dump_json()), timeout=5)
@@ -1196,8 +1252,12 @@ def down(cluster_name: str,
@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
@annotations.client_api
-def stop(cluster_name: str,
- purge: bool = False) -> server_common.RequestId[None]:
+def stop(
+ cluster_name: str,
+ purge: bool = False,
+ graceful: bool = False,
+ graceful_timeout: Optional[int] = None,
+) -> server_common.RequestId[None]:
"""Stops a cluster.
Data on attached disks is not lost when a cluster is stopped. Billing for
@@ -1230,9 +1290,15 @@ def stop(cluster_name: str,
cluster, or a TPU VM Pod cluster, or the managed jobs controller.
"""
+ version = versions.get_remote_api_version()
+ if graceful and version is not None and version < 32:
+ logger.warning('`--graceful` is ignored because the server does '
+ 'not support it yet.')
body = payloads.StopOrDownBody(
cluster_name=cluster_name,
purge=purge,
+ graceful=graceful,
+ graceful_timeout=graceful_timeout,
)
response = server_common.make_authenticated_request(
'POST', '/stop', json=json.loads(body.model_dump_json()), timeout=5)
@@ -1243,10 +1309,12 @@ def stop(cluster_name: str,
@server_common.check_server_healthy_or_start
@annotations.client_api
def autostop(
- cluster_name: str,
- idle_minutes: int,
- wait_for: Optional[autostop_lib.AutostopWaitFor] = None,
- down: bool = False, # pylint: disable=redefined-outer-name
+ cluster_name: str,
+ idle_minutes: int,
+ wait_for: Optional[autostop_lib.AutostopWaitFor] = None,
+ down: bool = False, # pylint: disable=redefined-outer-name
+ hook: Optional[str] = None,
+ hook_timeout: Optional[int] = None,
) -> server_common.RequestId[None]:
"""Schedules an autostop/autodown for a cluster.
@@ -1287,6 +1355,13 @@ def autostop(
3. "none" - Wait for nothing; autostop right after ``idle_minutes``.
down: if true, use autodown (tear down the cluster; non-restartable),
rather than autostop (restartable).
+ hook: optional script to execute on the remote cluster before autostop.
+ The script runs before the cluster is stopped or torn down. If the
+ hook fails, autostop will still proceed but a warning will be
+ logged.
+ hook_timeout: timeout in seconds for hook execution. If None, uses
+ DEFAULT_AUTOSTOP_HOOK_TIMEOUT_SECONDS (3600 = 1 hour). The hook will
+ be terminated if it exceeds this timeout.
Returns:
The request ID of the autostop request.
@@ -1295,6 +1370,7 @@ def autostop(
None
Request Raises:
+ ValueError: if arguments are invalid.
sky.exceptions.ClusterDoesNotExist: if the cluster does not exist.
sky.exceptions.ClusterNotUpError: if the cluster is not UP.
sky.exceptions.NotSupportedError: if the cluster is not based on
@@ -1304,17 +1380,28 @@ def autostop(
sky.exceptions.CloudUserIdentityError: if we fail to get the current
user identity.
"""
+ if hook_timeout is not None and hook is None:
+ raise ValueError('hook_timeout can only be set if hook is set.')
+
remote_api_version = versions.get_remote_api_version()
if wait_for is not None and (remote_api_version is None or
remote_api_version < 13):
logger.warning('wait_for is not supported in your API server. '
'Please upgrade to a newer API server to use it.')
+ # Hook support requires API version 28 or higher
+ if hook is not None and (remote_api_version is None or
+ remote_api_version < 28):
+ logger.warning('Autostop hook is not supported in your API server. '
+ 'Please upgrade to a newer API server to use it.')
+
body = payloads.AutostopBody(
cluster_name=cluster_name,
idle_minutes=idle_minutes,
wait_for=wait_for,
down=down,
+ hook=hook,
+ hook_timeout=hook_timeout,
)
response = server_common.make_authenticated_request(
'POST', '/autostop', json=json.loads(body.model_dump_json()), timeout=5)
@@ -2479,6 +2566,117 @@ def _check_endpoint_in_env_var(is_login: bool) -> None:
'clear the environment variable.')
+def _try_polling_auth(endpoint: str) -> Optional[str]:
+ """Try the polling-based authentication flow."""
+ try:
+ # Generate code verifier (random secret) and challenge (hash)
+ code_verifier = common_utils.base64_url_encode(secrets.token_bytes(32))
+ code_challenge = common_utils.compute_code_challenge(code_verifier)
+
+ # Open browser to authorization page
+ auth_url = f'{endpoint}/auth/authorize?code_challenge={code_challenge}'
+ if not webbrowser.open(auth_url):
+ logger.debug('Failed to open browser.')
+ return None
+
+ click.echo(f'{colorama.Fore.GREEN}Browser opened at {auth_url}'
+ f'{colorama.Style.RESET_ALL}\n'
+ f'Please click "Authorize" to complete login.\n'
+ f'{colorama.Style.DIM}Press ctrl+c to fall back to legacy '
+ f'auth method.{colorama.Style.RESET_ALL}')
+
+ # Poll for token
+ start_time = time.time()
+ while time.time(
+ ) - start_time < server_constants.AUTH_SESSION_TIMEOUT_SECONDS:
+ time.sleep(1)
+ resp = requests.get(f'{endpoint}/api/v1/auth/token',
+ params={'code_verifier': code_verifier},
+ timeout=10)
+
+ if resp.status_code == 200:
+ data = resp.json()
+ if 'token' in data:
+ return data['token']
+ elif resp.status_code != 404:
+ # 404 means user hasn't clicked Authorize yet, keep polling
+ logger.debug(f'Poll failed: {resp.status_code}')
+ return None
+
+ click.echo(f'{colorama.Fore.YELLOW}Authentication timed out.'
+ f'{colorama.Style.RESET_ALL}')
+ return None
+
+ except KeyboardInterrupt:
+ click.echo(f'\n{colorama.Style.DIM}Interrupted.'
+ f'{colorama.Style.RESET_ALL}')
+ return None
+ except Exception as e: # pylint: disable=broad-except
+ logger.debug(f'Polling auth failed: {e}')
+ return None
+
+
+def _try_localhost_callback_auth(endpoint: str) -> Optional[str]:
+ """Try the localhost callback authentication flow (legacy)."""
+ server: Optional[oauth_lib.HTTPServer] = None
+ try:
+ callback_port = common_utils.find_free_port(8000)
+ token_container: Dict[str, Optional[str]] = {'token': None}
+ server = oauth_lib.start_local_auth_server(callback_port,
+ token_container, endpoint)
+
+ token_url = f'{endpoint}/token?local_port={callback_port}'
+ if not webbrowser.open(token_url):
+ return None
+
+ click.echo(f'{colorama.Fore.GREEN}Browser opened at {token_url}'
+ f'{colorama.Style.RESET_ALL}\n'
+ f'{colorama.Style.DIM}Press ctrl+c to enter token manually.'
+ f'{colorama.Style.RESET_ALL}')
+
+ start_time = time.time()
+ while (token_container['token'] is None and time.time() - start_time <
+ server_constants.AUTH_SESSION_TIMEOUT_SECONDS):
+ time.sleep(1)
+
+ if token_container['token'] is None:
+ click.echo(f'{colorama.Fore.YELLOW}Authentication timed out.'
+ f'{colorama.Style.RESET_ALL}')
+ return None
+ return token_container['token']
+
+ except KeyboardInterrupt:
+ click.echo(f'\n{colorama.Style.DIM}Interrupted.'
+ f'{colorama.Style.RESET_ALL}')
+ return None
+ except Exception as e: # pylint: disable=broad-except
+ logger.debug(f'Localhost callback failed: {e}')
+ return None
+ finally:
+ if server is not None:
+ try:
+ server.server_close()
+ except Exception: # pylint: disable=broad-except
+ pass
+
+
+def _try_manual_token_entry(endpoint: str) -> Optional[str]:
+ """Fall back to manual token entry."""
+ try:
+ token_url = f'{endpoint}/token'
+ click.echo(
+ f'Visit this URL to get the token:\n\n'
+ f'{colorama.Style.BRIGHT}{token_url}{colorama.Style.RESET_ALL}\n')
+ return click.prompt('Paste the token') or None
+ except (KeyboardInterrupt, click.Abort):
+ click.echo(
+ f'\n{colorama.Style.DIM}Cancelled.{colorama.Style.RESET_ALL}')
+ return None
+ except Exception as e: # pylint: disable=broad-except
+ logger.debug(f'Manual token entry failed: {e}')
+ return None
+
+
@usage_lib.entrypoint
@annotations.client_api
def api_login(endpoint: Optional[str] = None,
@@ -2581,59 +2779,26 @@ def _set_user_hash(user_hash: Optional[str]) -> None:
if server_status == server_common.ApiServerStatus.NEEDS_AUTH or relogin:
# We detected an auth proxy, so go through the auth proxy cookie flow.
token: Optional[str] = None
- server: Optional[oauth_lib.HTTPServer] = None
- try:
- callback_port = common_utils.find_free_port(8000)
-
- token_container: Dict[str, Optional[str]] = {'token': None}
- logger.debug('Starting local authentication server...')
- server = oauth_lib.start_local_auth_server(callback_port,
- token_container,
- endpoint)
-
- token_url = (f'{endpoint}/token?local_port={callback_port}')
- if webbrowser.open(token_url):
- click.echo(f'{colorama.Fore.GREEN}A web browser has been '
- f'opened at {token_url}. Please continue the login '
- f'in the web browser.{colorama.Style.RESET_ALL}\n'
- f'{colorama.Style.DIM}To manually copy the token, '
- f'press ctrl+c.{colorama.Style.RESET_ALL}')
- else:
- raise ValueError('Failed to open browser.')
- start_time = time.time()
+ # Try methods in order:
+ # 1. New polling-based flow - only on servers >= API v30
+ # 2. Old localhost callback flow
+ # 3. Manual token entry
+ remote_api_version = versions.get_remote_api_version()
+ if remote_api_version is not None and remote_api_version >= 30:
+ token = _try_polling_auth(endpoint)
- while (token_container['token'] is None and
- time.time() - start_time < oauth_lib.AUTH_TIMEOUT):
- time.sleep(1)
+ if token is None:
+ # Polling auth not available or failed, try localhost callback
+ token = _try_localhost_callback_auth(endpoint)
- if token_container['token'] is None:
- click.echo(f'{colorama.Fore.YELLOW}Authentication timed out '
- f'after {oauth_lib.AUTH_TIMEOUT} seconds.')
- else:
- token = token_container['token']
-
- except (Exception, KeyboardInterrupt) as e: # pylint: disable=broad-except
- logger.debug(f'Automatic authentication failed: {e}, '
- 'falling back to manual token entry.')
- if isinstance(e, KeyboardInterrupt):
- click.echo(f'\n{colorama.Style.DIM}Interrupted. Press ctrl+c '
- f'again to exit.{colorama.Style.RESET_ALL}')
- # Fall back to manual token entry
- token_url = f'{endpoint}/token'
- click.echo('Authentication is needed. Please visit this URL '
- f'to set up the token:{colorama.Style.BRIGHT}\n\n'
- f'{token_url}\n{colorama.Style.RESET_ALL}')
- token = click.prompt('Paste the token')
- finally:
- if server is not None:
- try:
- server.server_close()
- except Exception: # pylint: disable=broad-except
- pass
- if not token:
- with ux_utils.print_exception_no_traceback():
- raise ValueError('Authentication failed.')
+ if token is None:
+ # All automatic methods failed, fall back to manual entry
+ token = _try_manual_token_entry(endpoint)
+
+ if not token:
+ with ux_utils.print_exception_no_traceback():
+ raise ValueError('Authentication failed.')
# Parse the token.
# b64decode will ignore invalid characters, but does some length and
@@ -2642,7 +2807,6 @@ def _set_user_hash(user_hash: Optional[str]) -> None:
data = base64.b64decode(token)
except binascii.Error as e:
raise ValueError(f'Malformed token: {token}') from e
- logger.debug(f'Token data: {data!r}')
try:
json_data = json.loads(data)
except (json.JSONDecodeError, UnicodeDecodeError) as e:
diff --git a/sky/client/sdk_async.py b/sky/client/sdk_async.py
index abac1db20ff..bf962e68339 100644
--- a/sky/client/sdk_async.py
+++ b/sky/client/sdk_async.py
@@ -10,6 +10,7 @@
statuses = await sky.get(request_id)
"""
+import asyncio
import dataclasses
import logging
import typing
@@ -33,7 +34,6 @@
from sky.usage import usage_lib
from sky.utils import annotations
from sky.utils import common
-from sky.utils import context_utils
from sky.utils import env_options
from sky.utils import rich_utils
from sky.utils import ux_utils
@@ -280,8 +280,8 @@ async def check(
stream_logs: Optional[StreamConfig] = DEFAULT_STREAM_CONFIG
) -> Dict[str, List[str]]:
"""Async version of check() that checks the credentials to enable clouds."""
- request_id = await context_utils.to_thread(sdk.check, infra_list, verbose,
- workspace)
+ request_id = await asyncio.to_thread(sdk.check, infra_list, verbose,
+ workspace)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -296,8 +296,7 @@ async def enabled_clouds(
stream_logs: Optional[StreamConfig] = DEFAULT_STREAM_CONFIG
) -> List[str]:
"""Async version of enabled_clouds() that gets the enabled clouds."""
- request_id = await context_utils.to_thread(sdk.enabled_clouds, workspace,
- expand)
+ request_id = await asyncio.to_thread(sdk.enabled_clouds, workspace, expand)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -319,11 +318,10 @@ async def list_accelerators(
) -> Dict[str, List[catalog.common.InstanceTypeInfo]]:
"""Async version of list_accelerators() that lists the names of all
accelerators offered by Sky."""
- request_id = await context_utils.to_thread(sdk.list_accelerators, gpus_only,
- name_filter, region_filter,
- quantity_filter, clouds,
- all_regions, require_price,
- case_sensitive)
+ request_id = await asyncio.to_thread(sdk.list_accelerators, gpus_only,
+ name_filter, region_filter,
+ quantity_filter, clouds, all_regions,
+ require_price, case_sensitive)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -342,10 +340,9 @@ async def list_accelerator_counts(
) -> Dict[str, List[int]]:
"""Async version of list_accelerator_counts() that lists all accelerators
offered by Sky and available counts."""
- request_id = await context_utils.to_thread(sdk.list_accelerator_counts,
- gpus_only, name_filter,
- region_filter, quantity_filter,
- clouds)
+ request_id = await asyncio.to_thread(sdk.list_accelerator_counts, gpus_only,
+ name_filter, region_filter,
+ quantity_filter, clouds)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -363,8 +360,8 @@ async def optimize(
) -> 'sky.Dag':
"""Async version of optimize() that finds the best execution plan for the
given DAG."""
- request_id = await context_utils.to_thread(sdk.optimize, dag, minimize,
- admin_policy_request_options)
+ request_id = await asyncio.to_thread(sdk.optimize, dag, minimize,
+ admin_policy_request_options)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -377,7 +374,7 @@ async def workspaces(
stream_logs: Optional[StreamConfig] = DEFAULT_STREAM_CONFIG
) -> Dict[str, Any]:
"""Async version of workspaces() that gets the workspaces."""
- request_id = await context_utils.to_thread(sdk.workspaces)
+ request_id = await asyncio.to_thread(sdk.workspaces)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -408,7 +405,7 @@ async def launch(
stream_logs: Optional[StreamConfig] = DEFAULT_STREAM_CONFIG,
) -> Tuple[Optional[int], Optional['backends.ResourceHandle']]:
"""Async version of launch() that launches a cluster or task."""
- request_id = await context_utils.to_thread(
+ request_id = await asyncio.to_thread(
sdk.launch, task, cluster_name, retry_until_up,
idle_minutes_to_autostop, wait_for, dryrun, down, backend,
optimize_target, no_setup, clone_disk_from, fast, _need_confirmation,
@@ -431,8 +428,8 @@ async def exec( # pylint: disable=redefined-builtin
stream_logs: Optional[StreamConfig] = DEFAULT_STREAM_CONFIG,
) -> Tuple[Optional[int], Optional['backends.ResourceHandle']]:
"""Async version of exec() that executes a task on an existing cluster."""
- request_id = await context_utils.to_thread(sdk.exec, task, cluster_name,
- dryrun, down, backend)
+ request_id = await asyncio.to_thread(sdk.exec, task, cluster_name, dryrun,
+ down, backend)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -447,8 +444,8 @@ async def tail_logs(cluster_name: str,
tail: int = 0,
output_stream: Optional['io.TextIOBase'] = None) -> int:
"""Async version of tail_logs() that tails the logs of a job."""
- return await context_utils.to_thread(sdk.tail_logs, cluster_name, job_id,
- follow, tail, output_stream)
+ return await asyncio.to_thread(sdk.tail_logs, cluster_name, job_id, follow,
+ tail, output_stream)
@usage_lib.entrypoint
@@ -456,8 +453,7 @@ async def tail_logs(cluster_name: str,
async def download_logs(cluster_name: str,
job_ids: Optional[List[str]]) -> Dict[str, str]:
"""Async version of download_logs() that downloads the logs of jobs."""
- return await context_utils.to_thread(sdk.download_logs, cluster_name,
- job_ids)
+ return await asyncio.to_thread(sdk.download_logs, cluster_name, job_ids)
@usage_lib.entrypoint
@@ -472,10 +468,9 @@ async def start(
stream_logs: Optional[StreamConfig] = DEFAULT_STREAM_CONFIG,
) -> 'backends.CloudVmRayResourceHandle':
"""Async version of start() that restarts a cluster."""
- request_id = await context_utils.to_thread(sdk.start, cluster_name,
- idle_minutes_to_autostop,
- wait_for, retry_until_up, down,
- force)
+ request_id = await asyncio.to_thread(sdk.start, cluster_name,
+ idle_minutes_to_autostop, wait_for,
+ retry_until_up, down, force)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -487,9 +482,12 @@ async def start(
async def down(
cluster_name: str,
purge: bool = False,
+ graceful: bool = False,
+ graceful_timeout: Optional[int] = None,
stream_logs: Optional[StreamConfig] = DEFAULT_STREAM_CONFIG) -> None:
"""Async version of down() that tears down a cluster."""
- request_id = await context_utils.to_thread(sdk.down, cluster_name, purge)
+ request_id = await asyncio.to_thread(sdk.down, cluster_name, purge,
+ graceful, graceful_timeout)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -501,9 +499,12 @@ async def down(
async def stop(
cluster_name: str,
purge: bool = False,
+ graceful: bool = False,
+ graceful_timeout: Optional[int] = None,
stream_logs: Optional[StreamConfig] = DEFAULT_STREAM_CONFIG) -> None:
"""Async version of stop() that stops a cluster."""
- request_id = await context_utils.to_thread(sdk.stop, cluster_name, purge)
+ request_id = await asyncio.to_thread(sdk.stop, cluster_name, purge,
+ graceful, graceful_timeout)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -521,8 +522,8 @@ async def autostop(
) -> None:
"""Async version of autostop() that schedules an autostop/autodown for a
cluster."""
- request_id = await context_utils.to_thread(sdk.autostop, cluster_name,
- idle_minutes, wait_for, down)
+ request_id = await asyncio.to_thread(sdk.autostop, cluster_name,
+ idle_minutes, wait_for, down)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -538,8 +539,8 @@ async def queue(
stream_logs: Optional[StreamConfig] = DEFAULT_STREAM_CONFIG
) -> List[responses.ClusterJobRecord]:
"""Async version of queue() that gets the job queue of a cluster."""
- request_id = await context_utils.to_thread(sdk.queue, cluster_name,
- skip_finished, all_users)
+ request_id = await asyncio.to_thread(sdk.queue, cluster_name, skip_finished,
+ all_users)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -555,8 +556,7 @@ async def job_status(
) -> Dict[Optional[int], Optional['job_lib.JobStatus']]:
"""Async version of job_status() that gets the status of jobs on a
cluster."""
- request_id = await context_utils.to_thread(sdk.job_status, cluster_name,
- job_ids)
+ request_id = await asyncio.to_thread(sdk.job_status, cluster_name, job_ids)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -574,9 +574,9 @@ async def cancel(
_try_cancel_if_cluster_is_init: bool = False,
stream_logs: Optional[StreamConfig] = DEFAULT_STREAM_CONFIG) -> None:
"""Async version of cancel() that cancels jobs on a cluster."""
- request_id = await context_utils.to_thread(sdk.cancel, cluster_name, all,
- all_users, job_ids,
- _try_cancel_if_cluster_is_init)
+ request_id = await asyncio.to_thread(sdk.cancel, cluster_name, all,
+ all_users, job_ids,
+ _try_cancel_if_cluster_is_init)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -594,7 +594,7 @@ async def status(
_include_credentials: bool = False,
) -> List[Dict[str, Any]]:
"""Async version of status() that gets cluster statuses."""
- request_id = await context_utils.to_thread(
+ request_id = await asyncio.to_thread(
sdk.status,
cluster_names,
refresh,
@@ -615,7 +615,7 @@ async def endpoints(
) -> Dict[int, str]:
"""Async version of endpoints() that gets the endpoint for a given cluster
and port number."""
- request_id = await context_utils.to_thread(sdk.endpoints, cluster, port)
+ request_id = await asyncio.to_thread(sdk.endpoints, cluster, port)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -628,7 +628,7 @@ async def cost_report(
stream_logs: Optional[StreamConfig] = DEFAULT_STREAM_CONFIG
) -> List[Dict[str, Any]]:
"""Async version of cost_report() that gets all cluster cost reports."""
- request_id = await context_utils.to_thread(sdk.cost_report)
+ request_id = await asyncio.to_thread(sdk.cost_report)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -641,7 +641,7 @@ async def storage_ls(
stream_logs: Optional[StreamConfig] = DEFAULT_STREAM_CONFIG
) -> List[Dict[str, Any]]:
"""Async version of storage_ls() that gets the storages."""
- request_id = await context_utils.to_thread(sdk.storage_ls)
+ request_id = await asyncio.to_thread(sdk.storage_ls)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -654,7 +654,7 @@ async def storage_delete(
name: str,
stream_logs: Optional[StreamConfig] = DEFAULT_STREAM_CONFIG) -> None:
"""Async version of storage_delete() that deletes a storage."""
- request_id = await context_utils.to_thread(sdk.storage_delete, name)
+ request_id = await asyncio.to_thread(sdk.storage_delete, name)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -670,8 +670,7 @@ async def local_up(
stream_logs: Optional[StreamConfig] = DEFAULT_STREAM_CONFIG) -> None:
"""Async version of local_up() that launches a Kubernetes cluster on
local machines."""
- request_id = await context_utils.to_thread(sdk.local_up, gpus, name,
- port_start)
+ request_id = await asyncio.to_thread(sdk.local_up, gpus, name, port_start)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -685,7 +684,7 @@ async def local_down(
stream_logs: Optional[StreamConfig] = DEFAULT_STREAM_CONFIG) -> None:
"""Async version of local_down() that tears down the Kubernetes cluster
started by local_up."""
- request_id = await context_utils.to_thread(sdk.local_down, name)
+ request_id = await asyncio.to_thread(sdk.local_down, name)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -699,7 +698,7 @@ async def ssh_up(
stream_logs: Optional[StreamConfig] = DEFAULT_STREAM_CONFIG) -> None:
"""Async version of ssh_up() that deploys the SSH Node Pools defined in
~/.sky/ssh_targets.yaml."""
- request_id = await context_utils.to_thread(sdk.ssh_up, infra)
+ request_id = await asyncio.to_thread(sdk.ssh_up, infra)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -713,7 +712,7 @@ async def ssh_down(
stream_logs: Optional[StreamConfig] = DEFAULT_STREAM_CONFIG) -> None:
"""Async version of ssh_down() that tears down a Kubernetes cluster on SSH
targets."""
- request_id = await context_utils.to_thread(sdk.ssh_down, infra)
+ request_id = await asyncio.to_thread(sdk.ssh_down, infra)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -731,7 +730,7 @@ async def realtime_kubernetes_gpu_availability(
) -> List[Tuple[str, List['models.RealtimeGpuAvailability']]]:
"""Async version of realtime_kubernetes_gpu_availability() that gets the
real-time Kubernetes GPU availability."""
- request_id = await context_utils.to_thread(
+ request_id = await asyncio.to_thread(
sdk.realtime_kubernetes_gpu_availability, context, name_filter,
quantity_filter, is_ssh)
if stream_logs is not None:
@@ -748,8 +747,7 @@ async def kubernetes_node_info(
) -> 'models.KubernetesNodesInfo':
"""Async version of kubernetes_node_info() that gets the resource
information for all the nodes in the cluster."""
- request_id = await context_utils.to_thread(sdk.kubernetes_node_info,
- context)
+ request_id = await asyncio.to_thread(sdk.kubernetes_node_info, context)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -765,7 +763,7 @@ async def status_kubernetes(
List[Dict[str, Any]], Optional[str]]:
"""Async version of status_kubernetes() that gets all SkyPilot clusters
and jobs in the Kubernetes cluster."""
- request_id = await context_utils.to_thread(sdk.status_kubernetes)
+ request_id = await asyncio.to_thread(sdk.status_kubernetes)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -781,8 +779,8 @@ async def api_cancel(
stream_logs: Optional[StreamConfig] = DEFAULT_STREAM_CONFIG
) -> List[str]:
"""Async version of api_cancel() that aborts a request or all requests."""
- request_id = await context_utils.to_thread(sdk.api_cancel, request_ids,
- all_users, silent)
+ request_id = await asyncio.to_thread(sdk.api_cancel, request_ids, all_users,
+ silent)
if stream_logs is not None:
return await _stream_and_get(request_id, stream_logs)
else:
@@ -794,15 +792,14 @@ async def api_cancel(
async def api_status(request_ids: Optional[List[str]] = None,
all_status: bool = False) -> List[payloads.RequestPayload]:
"""Async version of api_status() that lists all requests."""
- return await context_utils.to_thread(sdk.api_status, request_ids,
- all_status)
+ return await asyncio.to_thread(sdk.api_status, request_ids, all_status)
@usage_lib.entrypoint
@annotations.client_api
async def dashboard(starting_page: Optional[str] = None) -> None:
"""Async version of dashboard() that starts the dashboard for SkyPilot."""
- return await context_utils.to_thread(sdk.dashboard, starting_page)
+ return await asyncio.to_thread(sdk.dashboard, starting_page)
@usage_lib.entrypoint
@@ -810,14 +807,14 @@ async def dashboard(starting_page: Optional[str] = None) -> None:
async def api_info() -> responses.APIHealthResponse:
"""Async version of api_info() that gets the server's status, commit and
version."""
- return await context_utils.to_thread(sdk.api_info)
+ return await asyncio.to_thread(sdk.api_info)
@usage_lib.entrypoint
@annotations.client_api
async def api_stop() -> None:
"""Async version of api_stop() that stops the API server."""
- return await context_utils.to_thread(sdk.api_stop)
+ return await asyncio.to_thread(sdk.api_stop)
@usage_lib.entrypoint
@@ -825,7 +822,7 @@ async def api_stop() -> None:
async def api_server_logs(follow: bool = True,
tail: Optional[int] = None) -> None:
"""Async version of api_server_logs() that streams the API server logs."""
- return await context_utils.to_thread(sdk.api_server_logs, follow, tail)
+ return await asyncio.to_thread(sdk.api_server_logs, follow, tail)
@usage_lib.entrypoint
@@ -833,4 +830,4 @@ async def api_server_logs(follow: bool = True,
async def api_login(endpoint: Optional[str] = None,
get_token: bool = False) -> None:
"""Async version of api_login() that logs into a SkyPilot API server."""
- return await context_utils.to_thread(sdk.api_login, endpoint, get_token)
+ return await asyncio.to_thread(sdk.api_login, endpoint, get_token)
diff --git a/sky/clouds/__init__.py b/sky/clouds/__init__.py
index 7efdd8bae7f..a9a3229459b 100644
--- a/sky/clouds/__init__.py
+++ b/sky/clouds/__init__.py
@@ -35,6 +35,7 @@
from sky.clouds.ssh import SSH
from sky.clouds.vast import Vast
from sky.clouds.vsphere import Vsphere
+from sky.clouds.yotta import Yotta
__all__ = [
'IBM',
@@ -66,6 +67,7 @@
'Nebius',
'Hyperbolic',
'Seeweb',
+ 'Yotta',
# Utility functions
'cloud_in_iterable',
]
diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py
index 4323ed8ac64..af24f1537a1 100644
--- a/sky/clouds/aws.py
+++ b/sky/clouds/aws.py
@@ -9,8 +9,8 @@
import subprocess
import time
import typing
-from typing import (Any, Callable, Dict, Iterator, List, Literal, Optional, Set,
- Tuple, TypeVar, Union)
+from typing import (Any, Callable, Dict, Iterable, Iterator, List, Literal,
+ Optional, Set, Tuple, TypeVar, Union)
import colorama
from typing_extensions import ParamSpec
@@ -85,6 +85,7 @@
'p5e.',
'p5en.',
'p6-b200.',
+ 'p6-b300.',
]
# Docker run options for EFA.
@@ -1654,3 +1655,18 @@ def is_label_valid(cls, label_key: str,
if not key_valid or not value_valid:
return False, error_msg
return True, None
+
+ @classmethod
+ def yield_cloud_specific_failover_overrides(cls,
+ region: Optional[str] = None
+ ) -> Iterable[Dict[str, Any]]:
+ vpc_names = skypilot_config.get_effective_region_config(
+ cloud='aws', region=region, keys=('vpc_names',), default_value=None)
+ if vpc_names:
+ if isinstance(vpc_names, str):
+ vpc_names = [vpc_names]
+ for vpc_name in vpc_names:
+ yield {'vpc_name': vpc_name}
+ else:
+ yield {}
+ return
diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py
index 8fbca890789..83dd8acd64c 100644
--- a/sky/clouds/cloud.py
+++ b/sky/clouds/cloud.py
@@ -978,6 +978,21 @@ def display_name(cls) -> str:
"""Name of the cloud used in messages displayed to the user."""
return cls.canonical_name()
+ # === Misc Failovers ===
+
+ @classmethod
+ def yield_cloud_specific_failover_overrides(cls,
+ region: Optional[str] = None
+ ) -> Iterable[Dict[str, Any]]:
+ """Some clouds may have configurations that require them to have
+ non-region/zone failovers. This method yields override keys for the
+ cluster config. Refer to the implementation for AWS for an example."""
+ del region # unused
+ yield {}
+ return
+
+ # === End of Misc Failovers ===
+
def __repr__(self):
return self._REPR
diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py
index 648c9b8ea4b..1b5a62a1017 100644
--- a/sky/clouds/gcp.py
+++ b/sky/clouds/gcp.py
@@ -542,6 +542,8 @@ def make_deploy_resources_variables(
'runtime_version']
resources_vars['tpu_node_name'] = r.accelerator_args.get(
'tpu_name')
+ resources_vars['gcp_queued_resource'] = r.accelerator_args.get(
+ 'gcp_queued_resource')
# TPU VMs require privileged mode for docker containers to
# access TPU devices.
resources_vars['docker_run_options'] = ['--privileged']
diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py
index 23345800557..fd40ecfc864 100644
--- a/sky/clouds/kubernetes.py
+++ b/sky/clouds/kubernetes.py
@@ -1,10 +1,11 @@
"""Kubernetes."""
import concurrent.futures
+import math
import os
import re
import subprocess
import tempfile
-from typing import Dict, Iterator, List, Optional, Set, Tuple, Union
+from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union
import colorama
@@ -44,6 +45,8 @@
# addons/fuse-proxy/README.md for more details.
_FUSERMOUNT_SHARED_DIR = '/var/run/fusermount'
+AWS_EFA_RESOURCE_KEY = 'vpc.amazonaws.com/efa'
+
@registry.CLOUD_REGISTRY.register(aliases=['k8s'])
class Kubernetes(clouds.Cloud):
@@ -604,7 +607,8 @@ def _get_image_id(resources: 'resources_lib.Resources') -> str:
cloud='kubernetes',
region=context,
keys=('remote_identity',),
- default_value=schemas.get_default_remote_identity('kubernetes'))
+ default_value=schemas.get_default_remote_identity('kubernetes'),
+ override_configs=resources.cluster_config_overrides)
if isinstance(remote_identity, dict):
# If remote_identity is a dict, use the service account for the
@@ -620,13 +624,16 @@ def _get_image_id(resources: 'resources_lib.Resources') -> str:
lc = schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value
sa = schemas.RemoteIdentityOptions.SERVICE_ACCOUNT.value
+ no_upload = schemas.RemoteIdentityOptions.NO_UPLOAD.value
- if k8s_service_account_name == lc or k8s_service_account_name == sa:
+ if k8s_service_account_name in (lc, sa, no_upload):
# Use the default service account if remote identity is not set.
# For LOCAL_CREDENTIALS, this is for in-cluster authentication
# which needs a serviceaccount (specifically for SSH node pools
# which uses in-cluster authentication internally, and we would
# like to support exec-auth when the user is also using SSH infra)
+ # For NO_UPLOAD, we don't upload credentials but still need a
+ # service account for pod creation.
k8s_service_account_name = (
kubernetes_utils.DEFAULT_SERVICE_ACCOUNT_NAME)
@@ -637,8 +644,18 @@ def _get_image_id(resources: 'resources_lib.Resources') -> str:
if resources.use_spot:
spot_label_key, spot_label_value = kubernetes_utils.get_spot_label()
- network_type, machine_type = self._detect_network_type(
- context, resources.network_tier)
+ network_type, metadata = self._detect_network_type(
+ context, resources.network_tier, k8s_acc_label_key,
+ k8s_resource_key, acc_count)
+
+ k8s_efa_count = None
+ if network_type == KubernetesHighPerformanceNetworkType.AWS_EFA:
+ if metadata and 'efa_count' in metadata:
+ k8s_efa_count = metadata['efa_count']
+ else:
+ logger.warning(
+ f'No EFA interfaces detected on AWS nodes with '
+ f'accelerator {k8s_acc_label_key}, skipping enabling EFA.')
# Check if this cluster supports high performance networking and
# configure appropriate settings for different cluster types
@@ -673,6 +690,19 @@ def _get_image_id(resources: 'resources_lib.Resources') -> str:
keys=('high_availability', 'storage_class_name'),
default_value=None))
+ # Get the config for setting pod CPU/memory limits relative to requests.
+ # This is useful for clusters that require limits to be set (e.g., for
+ # LimitRange enforcement or resource quotas).
+ # Can be: False (default, no limits), True (limits = requests),
+ # or a number (limits = requests * multiplier).
+ set_pod_resource_limits_config = (
+ skypilot_config.get_effective_workspace_region_config(
+ cloud='kubernetes',
+ region=context,
+ keys=('set_pod_resource_limits',),
+ default_value=False,
+ override_configs=resources.cluster_config_overrides))
+
k8s_kueue_local_queue_name = (
skypilot_config.get_effective_workspace_region_config(
# TODO(kyuds): Support SSH node pools as well.
@@ -732,6 +762,8 @@ def _get_image_id(resources: 'resources_lib.Resources') -> str:
'memory': str(mem),
'accelerator_count': str(acc_count),
'timeout': str(timeout),
+ 'k8s_efa_count': str(k8s_efa_count)
+ if k8s_efa_count is not None else None,
'k8s_port_mode': port_mode.value,
'k8s_acc_label_key': k8s_acc_label_key,
'k8s_acc_label_values': k8s_acc_label_values,
@@ -781,6 +813,17 @@ def _get_image_id(resources: 'resources_lib.Resources') -> str:
'k8s_network_type': network_type.value,
}
+ # Calculate CPU/memory limits if set_pod_resource_limits is configured.
+ # Convert config: False -> no limits, True -> multiplier 1.0,
+ # number -> that multiplier
+ if set_pod_resource_limits_config is not False:
+ if set_pod_resource_limits_config is True:
+ multiplier = 1.0
+ else:
+ multiplier = float(set_pod_resource_limits_config)
+ deploy_vars['k8s_cpu_limit'] = round(cpus * multiplier, 3)
+ deploy_vars['k8s_memory_limit'] = round(mem * multiplier, 3)
+
# Add kubecontext if it is set. It may be None if SkyPilot is running
# inside a pod with in-cluster auth.
if context is not None:
@@ -797,7 +840,8 @@ def _get_image_id(resources: 'resources_lib.Resources') -> str:
rdma_enabled = (network_type ==
KubernetesHighPerformanceNetworkType.GCP_GPUDIRECT_RDMA)
deploy_vars['k8s_enable_gpudirect_rdma'] = rdma_enabled
- if rdma_enabled and machine_type.startswith('a4'):
+ if (rdma_enabled and metadata and 'instance_type' in metadata and
+ metadata['instance_type'].startswith('a4')):
deploy_vars['k8s_enable_gpudirect_rdma_a4'] = True
else:
deploy_vars['k8s_enable_gpudirect_rdma_a4'] = False
@@ -1151,22 +1195,31 @@ def expand_infras(cls) -> List[str]:
def _detect_network_type(
cls,
context: str,
- network_tier: Optional['resources_utils.NetworkTier'] = None
- ) -> Tuple[KubernetesHighPerformanceNetworkType, str]:
+ network_tier: Optional['resources_utils.NetworkTier'] = None,
+ k8s_acc_label_key: Optional[str] = None,
+ k8s_resource_key: Optional[str] = None,
+ acc_count: Optional[int] = None,
+ ) -> Tuple[KubernetesHighPerformanceNetworkType, Optional[Dict[str, Any]]]:
"""Detect the type of Kubernetes network based on node labels.
Args:
context: The Kubernetes context to check.
network_tier: The network tier requested. If None or not BEST,
returns NONE (no high-performance networking).
+ k8s_acc_label_key: The key of the Kubernetes accelerator label.
+ k8s_resource_key: The key of the Kubernetes resource.
+ acc_count: The number of accelerators requested.
Returns:
- A tuple of the detected network type and the instance type.
+ A tuple of (network_type, metadata).
+ - network_type: The detected high-performance network type
+ - metadata: Optional dict with cloud-specific info
+ (e.g., {'instance_type': str, 'efa_count': int})
"""
# If network_tier is None or not BEST, return NONE
if (network_tier is None or
network_tier != resources_utils.NetworkTier.BEST):
- return KubernetesHighPerformanceNetworkType.NONE, ''
+ return KubernetesHighPerformanceNetworkType.NONE, None
try:
nodes = kubernetes_utils.get_kubernetes_nodes(context=context)
@@ -1176,11 +1229,49 @@ def _detect_network_type(
for label_key, _ in node.metadata.labels.items():
if label_key.startswith('nebius.com/'):
return (KubernetesHighPerformanceNetworkType.NEBIUS,
- '')
+ None)
if label_key.startswith('ib.coreweave.cloud/'):
return (
KubernetesHighPerformanceNetworkType.COREWEAVE,
- '')
+ None)
+ if label_key.startswith('node-role.together.ai/'):
+ return (
+ KubernetesHighPerformanceNetworkType.TOGETHER,
+ None)
+ if label_key.startswith('k8s.io/cloud-provider-aws'):
+ network_type = (
+ KubernetesHighPerformanceNetworkType.AWS_EFA)
+ metadata: Optional[Dict[str, Any]] = None
+ # Only check for AWS EFA count if GPU is specified
+ if (not k8s_acc_label_key or not k8s_resource_key or
+ not acc_count):
+ return (network_type, metadata)
+ if (k8s_acc_label_key not in node.metadata.labels or
+ k8s_resource_key
+ not in node.status.allocatable or
+ int(node.status.
+ allocatable[k8s_resource_key]) <
+ acc_count):
+ continue
+ # Calculate EFA count proportionally
+ if AWS_EFA_RESOURCE_KEY in node.status.allocatable:
+ node_gpu_count = int(
+ node.status.allocatable[k8s_resource_key])
+ node_efa_count = int(
+ node.status.
+ allocatable[AWS_EFA_RESOURCE_KEY])
+ if node_efa_count > 0:
+ # Proportional allocation:
+ # user_gpu / node_gpu * node_efa
+ calculated_efa = math.floor(acc_count /
+ node_gpu_count *
+ node_efa_count)
+ efa_count = max(
+ 1, min(calculated_efa, node_efa_count))
+ metadata = {'efa_count': efa_count}
+ return (network_type, metadata)
+ # No EFA available, but it's an AWS node
+ return (network_type, metadata)
# Check for GKE clusters with specific GPUDirect variants
machine_family = node.metadata.labels.get(
@@ -1196,26 +1287,36 @@ def _detect_network_type(
# variant
if 'a3-highgpu-8g' in instance_type:
return (
- KubernetesHighPerformanceNetworkType.GCP_TCPX,
- 'a3-highgpu-8g')
+ KubernetesHighPerformanceNetworkType.GCP_TCPX, {
+ 'instance_type': 'a3-highgpu-8g'
+ })
elif 'a3-edgegpu-8g' in instance_type:
return (
- KubernetesHighPerformanceNetworkType.GCP_TCPX,
- 'a3-edgegpu-8g')
+ KubernetesHighPerformanceNetworkType.GCP_TCPX, {
+ 'instance_type': 'a3-edgegpu-8g'
+ })
elif 'a3-megagpu-8g' in instance_type:
return (
KubernetesHighPerformanceNetworkType.GCP_TCPXO,
- 'a3-megagpu-8g')
+ {
+ 'instance_type': 'a3-megagpu-8g'
+ })
elif 'a4-highgpu-8g' in instance_type:
return (KubernetesHighPerformanceNetworkType.
- GCP_GPUDIRECT_RDMA, 'a4-highgpu-8g')
+ GCP_GPUDIRECT_RDMA, {
+ 'instance_type': 'a4-highgpu-8g'
+ })
elif 'a3-ultragpu-8g' in instance_type:
return (KubernetesHighPerformanceNetworkType.
- GCP_GPUDIRECT_RDMA, 'a3-ultragpu-8g')
+ GCP_GPUDIRECT_RDMA, {
+ 'instance_type': 'a3-ultragpu-8g'
+ })
# Generic A3/A4 detection as fallback
elif machine_family == 'a4':
return (KubernetesHighPerformanceNetworkType.
- GCP_GPUDIRECT_RDMA, 'a4')
+ GCP_GPUDIRECT_RDMA, {
+ 'instance_type': 'a4'
+ })
# Fallback: Check for GPU Direct TCPX capable instance
# types with high-perf GPUs
@@ -1229,8 +1330,9 @@ def _detect_network_type(
if is_gpu_direct_tcpx_instance and has_high_perf_gpu:
# Default to TCPX if we can't determine the specific
# variant
- return (KubernetesHighPerformanceNetworkType.GCP_TCPX,
- instance_type)
+ return (KubernetesHighPerformanceNetworkType.GCP_TCPX, {
+ 'instance_type': instance_type
+ })
except exceptions.KubeAPIUnreachableError:
# If we can't reach the cluster, assume no high perf networking
@@ -1246,26 +1348,31 @@ def _detect_network_type(
default_value=None)
if (autoscaler_type !=
kubernetes_enums.KubernetesAutoscalerType.GKE.value):
- return KubernetesHighPerformanceNetworkType.NONE, ''
+ return KubernetesHighPerformanceNetworkType.NONE, None
autoscaler = kubernetes_utils.get_autoscaler(
kubernetes_enums.KubernetesAutoscalerType(autoscaler_type))
logger.debug(f'{context} has autoscaler of type: {autoscaler_type}')
machine_types = autoscaler.get_available_machine_types(context)
# Check if any machine type supports high perf networking for GKE.
if 'a3-highgpu-8g' in machine_types:
- return (KubernetesHighPerformanceNetworkType.GCP_TCPX,
- 'a3-highgpu-8g')
+ return (KubernetesHighPerformanceNetworkType.GCP_TCPX, {
+ 'instance_type': 'a3-highgpu-8g'
+ })
elif 'a3-edgegpu-8g' in machine_types:
- return (KubernetesHighPerformanceNetworkType.GCP_TCPX,
- 'a3-edgegpu-8g')
+ return (KubernetesHighPerformanceNetworkType.GCP_TCPX, {
+ 'instance_type': 'a3-edgegpu-8g'
+ })
elif 'a3-megagpu-8g' in machine_types:
- return (KubernetesHighPerformanceNetworkType.GCP_TCPXO,
- 'a3-megagpu-8g')
+ return (KubernetesHighPerformanceNetworkType.GCP_TCPXO, {
+ 'instance_type': 'a3-megagpu-8g'
+ })
elif 'a4-highgpu-8g' in machine_types:
- return (KubernetesHighPerformanceNetworkType.GCP_GPUDIRECT_RDMA,
- 'a4-highgpu-8g')
+ return (KubernetesHighPerformanceNetworkType.GCP_GPUDIRECT_RDMA, {
+ 'instance_type': 'a4-highgpu-8g'
+ })
elif 'a3-ultragpu-8g' in machine_types:
- return (KubernetesHighPerformanceNetworkType.GCP_GPUDIRECT_RDMA,
- 'a3-ultragpu-8g')
+ return (KubernetesHighPerformanceNetworkType.GCP_GPUDIRECT_RDMA, {
+ 'instance_type': 'a3-ultragpu-8g'
+ })
- return KubernetesHighPerformanceNetworkType.NONE, ''
+ return KubernetesHighPerformanceNetworkType.NONE, None
diff --git a/sky/clouds/slurm.py b/sky/clouds/slurm.py
index 7af4398655d..cab850d8a08 100644
--- a/sky/clouds/slurm.py
+++ b/sky/clouds/slurm.py
@@ -47,10 +47,6 @@ class Slurm(clouds.Cloud):
'controllers is not '
'well tested with '
'Slurm.',
- clouds.CloudImplementationFeatures.IMAGE_ID: 'Specifying image ID is '
- 'not supported in Slurm.',
- clouds.CloudImplementationFeatures.DOCKER_IMAGE: 'Docker image is not '
- 'supported in Slurm.',
}
_MAX_CLUSTER_NAME_LEN_LIMIT = 120
_regions: List[clouds.Region] = []
@@ -65,7 +61,6 @@ class Slurm(clouds.Cloud):
STATUS_VERSION = clouds.StatusVersion.SKYPILOT
_SSH_CONFIG_KEY_MAPPING = {
- 'identityfile': 'IdentityFile',
'user': 'User',
'hostname': 'HostName',
}
@@ -366,6 +361,8 @@ def make_deploy_resources_variables(
if acc_type:
acc_type = slurm_utils.get_gres_gpu_type(cluster, acc_type)
+ image_id = resources.extract_docker_image()
+
deploy_vars = {
'instance_type': resources.instance_type,
'custom_resources': custom_resources,
@@ -383,11 +380,12 @@ def make_deploy_resources_variables(
'slurm_proxy_jump': ssh_config_dict.get('proxyjump', None),
# TODO(jwj): Solve naming collision with 'ssh_private_key'.
# Please refer to slurm-ray.yml.j2 'ssh' and 'auth' sections.
- 'slurm_private_key': ssh_config_dict['identityfile'][0],
+ 'slurm_private_key': slurm_utils.get_identity_file(ssh_config_dict),
'slurm_sshd_host_key_filename':
(slurm_utils.SLURM_SSHD_HOST_KEY_FILENAME),
'slurm_cluster_name_env_var':
(constants.SKY_CLUSTER_NAME_ENV_VAR_KEY),
+ 'image_id': image_id,
}
return deploy_vars
@@ -509,7 +507,7 @@ def _check_compute_credentials(
ssh_config_dict['hostname'],
int(ssh_config_dict.get('port', 22)),
ssh_config_dict['user'],
- ssh_config_dict['identityfile'][0],
+ slurm_utils.get_identity_file(ssh_config_dict),
ssh_proxy_command=ssh_config_dict.get('proxycommand', None),
ssh_proxy_jump=ssh_config_dict.get('proxyjump', None))
info = client.info()
diff --git a/sky/clouds/vast.py b/sky/clouds/vast.py
index abf4b542c5c..cbd5ca9585d 100644
--- a/sky/clouds/vast.py
+++ b/sky/clouds/vast.py
@@ -309,7 +309,7 @@ def _check_compute_credentials(
' $ pip install vastai\n'
' $ mkdir -p ~/.config/vastai\n'
f' $ echo [key] > {_CREDENTIAL_PATH}\n'
- ' For more information, see https://skypilot.readthedocs.io/en/latest/getting-started/installation.html#vast' # pylint: disable=line-too-long
+ ' For more information, see https://docs.skypilot.co/en/latest/getting-started/installation.html#vast' # pylint: disable=line-too-long
)
return True, None
diff --git a/sky/clouds/yotta.py b/sky/clouds/yotta.py
new file mode 100644
index 00000000000..ef87ab111da
--- /dev/null
+++ b/sky/clouds/yotta.py
@@ -0,0 +1,327 @@
+""" Yotta Cloud. """
+
+import os
+import typing
+from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
+
+from sky import catalog
+from sky import clouds
+from sky.provision.yotta.yotta_utils import CREDENTIAL_FILE
+from sky.provision.yotta.yotta_utils import yotta_client
+from sky.utils import registry
+from sky.utils import resources_utils
+
+if typing.TYPE_CHECKING:
+ from sky import resources as resources_lib
+ from sky.utils import volume as volume_lib
+
+_CLOUD = 'yotta'
+_BASE_IMAGE = (
+ 'yottalabsai/pytorch:2.9.0-py3.11-cuda12.8.1-cudnn-devel-ubuntu22.04')
+
+
+@registry.CLOUD_REGISTRY.register
+class Yotta(clouds.Cloud):
+ """ Yotta GPU Cloud
+
+ _REPR | The string representation for the Yotta GPU cloud object.
+ """
+ _REPR = 'Yotta'
+ _CLOUD_UNSUPPORTED_FEATURES = {
+ clouds.CloudImplementationFeatures.STOP: 'Stopping not supported.',
+ clouds.CloudImplementationFeatures.MULTI_NODE:
+ ('Multi-node not supported yet, as the interconnection among nodes '
+ 'are non-trivial on Yotta.'),
+ clouds.CloudImplementationFeatures.CLONE_DISK_FROM_CLUSTER:
+ ('Disk cloning not supported yet on Yotta.'),
+ clouds.CloudImplementationFeatures.SPOT_INSTANCE:
+ ('Spot instances not supported yet on Yotta.'),
+ clouds.CloudImplementationFeatures.CUSTOM_DISK_TIER:
+ ('Customizing disk tier is not supported yet on Yotta.'),
+ clouds.CloudImplementationFeatures.CUSTOM_NETWORK_TIER:
+ ('Custom network tier is not supported yet on Yotta.'),
+ clouds.CloudImplementationFeatures.STORAGE_MOUNTING:
+ ('Mounting object stores is not supported on Yotta. To read data '
+ 'from object stores on Yotta, use `mode: COPY` to copy the data '
+ 'to local disk.'),
+ clouds.CloudImplementationFeatures.HOST_CONTROLLERS:
+ ('Host controllers not supported yet on Yotta.'),
+ clouds.CloudImplementationFeatures.HIGH_AVAILABILITY_CONTROLLERS:
+ ('High availability controllers are not supported yet on Yotta.'),
+ clouds.CloudImplementationFeatures.AUTO_TERMINATE:
+ ('Auto-termination not supported yet on Yotta.'),
+ clouds.CloudImplementationFeatures.AUTOSTOP:
+ ('Auto-stop not supported yet on Yotta.'),
+ clouds.CloudImplementationFeatures.AUTODOWN:
+ ('Auto-down not supported yet on Yotta.'),
+ clouds.CloudImplementationFeatures.CUSTOM_MULTI_NETWORK:
+ ('Customized multiple network interfaces are not supported yet on '
+ 'Yotta.'),
+ }
+
+ _MAX_CLUSTER_NAME_LEN_LIMIT = 255
+ _regions: List[clouds.Region] = []
+
+ PROVISIONER_VERSION = clouds.ProvisionerVersion.SKYPILOT
+ STATUS_VERSION = clouds.StatusVersion.SKYPILOT
+ OPEN_PORTS_VERSION = clouds.OpenPortsVersion.LAUNCH_ONLY
+
+ @classmethod
+ def _unsupported_features_for_resources(
+ cls,
+ resources: 'resources_lib.Resources',
+ region: Optional[str] = None,
+ ) -> Dict[clouds.CloudImplementationFeatures, str]:
+ """The features not supported based on the resources provided.
+
+ This method is used by check_features_are_supported() to check if the
+ cloud implementation supports all the requested features.
+
+ Returns:
+ A dict of {feature: reason} for the features not supported by the
+ cloud implementation.
+ """
+ del resources # unused
+ return cls._CLOUD_UNSUPPORTED_FEATURES
+
+ @classmethod
+ def _max_cluster_name_length(cls) -> Optional[int]:
+ return cls._MAX_CLUSTER_NAME_LEN_LIMIT
+
+ @classmethod
+ def regions_with_offering(
+ cls,
+ instance_type: str,
+ accelerators: Optional[Dict[str, int]],
+ use_spot: bool,
+ region: Optional[str],
+ zone: Optional[str],
+ resources: Optional['resources_lib.Resources'] = None,
+ ) -> List[clouds.Region]:
+ del accelerators # unused
+ regions = catalog.get_region_zones_for_instance_type(
+ instance_type, use_spot, _CLOUD)
+
+ if region is not None:
+ regions = [r for r in regions if r.name == region]
+
+ if zone is not None:
+ for r in regions:
+ assert r.zones is not None, r
+ r.set_zones([z for z in r.zones if z.name == zone])
+ regions = [r for r in regions if r.zones]
+ return regions
+
+ @classmethod
+ def get_vcpus_mem_from_instance_type(
+ cls,
+ instance_type: str,
+ ) -> Tuple[Optional[float], Optional[float]]:
+ return catalog.get_vcpus_mem_from_instance_type(instance_type,
+ clouds=_CLOUD)
+
+ @classmethod
+ def zones_provision_loop(
+ cls,
+ *,
+ region: str,
+ num_nodes: int,
+ instance_type: str,
+ accelerators: Optional[Dict[str, int]] = None,
+ use_spot: bool = False,
+ ) -> Iterator[Optional[List['clouds.Zone']]]:
+ del num_nodes # unused
+ regions = cls.regions_with_offering(instance_type,
+ accelerators,
+ use_spot,
+ region=region,
+ zone=None)
+ for r in regions:
+ assert r
+ yield r.zones
+
+ def instance_type_to_hourly_cost(self,
+ instance_type: str,
+ use_spot: bool,
+ region: Optional[str] = None,
+ zone: Optional[str] = None) -> float:
+ return catalog.get_hourly_cost(instance_type,
+ use_spot=use_spot,
+ region=region,
+ zone=zone,
+ clouds=_CLOUD)
+
+ def accelerators_to_hourly_cost(self,
+ accelerators: Dict[str, int],
+ use_spot: bool,
+ region: Optional[str] = None,
+ zone: Optional[str] = None) -> float:
+ """Returns the hourly cost of the accelerators, in dollars/hour."""
+ del accelerators, use_spot, region, zone # unused
+ return 0.0 # Yotta includes accelerators in the hourly cost.
+
+ def get_egress_cost(self, num_gigabytes: float) -> float:
+ return 0.0
+
+ @classmethod
+ def get_default_instance_type(cls,
+ cpus: Optional[str] = None,
+ memory: Optional[str] = None,
+ disk_tier: Optional[
+ resources_utils.DiskTier] = None,
+ region: Optional[str] = None,
+ zone: Optional[str] = None) -> Optional[str]:
+ """Returns the default instance type for Yotta."""
+ return catalog.get_default_instance_type(cpus=cpus,
+ memory=memory,
+ disk_tier=disk_tier,
+ region=region,
+ zone=zone,
+ clouds=_CLOUD)
+
+ @classmethod
+ def get_accelerators_from_instance_type(
+ cls, instance_type: str) -> Optional[Dict[str, Union[int, float]]]:
+ return catalog.get_accelerators_from_instance_type(instance_type,
+ clouds=_CLOUD)
+
+ @classmethod
+ def get_zone_shell_cmd(cls) -> Optional[str]:
+ return None
+
+ def make_deploy_resources_variables(
+ self,
+ resources: 'resources_lib.Resources',
+ cluster_name: resources_utils.ClusterName,
+ region: 'clouds.Region',
+ zones: Optional[List['clouds.Zone']],
+ num_nodes: int,
+ dryrun: bool = False,
+ volume_mounts: Optional[List['volume_lib.VolumeMount']] = None,
+ ) -> Dict[str, Any]:
+ del dryrun, cluster_name, zones, num_nodes # unused
+ resources = resources.assert_launchable()
+ acc_dict = self.get_accelerators_from_instance_type(
+ resources.instance_type)
+ custom_resources = resources_utils.make_ray_custom_resources_str(
+ acc_dict)
+
+ if resources.image_id is None:
+ image_id: Optional[str] = _BASE_IMAGE
+ elif resources.extract_docker_image() is not None:
+ image_id = resources.extract_docker_image()
+ else:
+ image_id = resources.image_id[resources.region]
+
+ instance_type = resources.instance_type
+ use_spot = resources.use_spot
+ hourly_cost = self.instance_type_to_hourly_cost(
+ instance_type=instance_type, use_spot=use_spot)
+
+ return {
+ 'instance_type': instance_type,
+ 'custom_resources': custom_resources,
+ 'region': region.name,
+ 'image_id': image_id,
+ 'use_spot': use_spot,
+ 'bid_per_gpu': str(hourly_cost),
+ 'docker_login_config': resources.docker_login_config,
+ }
+
+ def _get_feasible_launchable_resources(
+ self, resources: 'resources_lib.Resources'
+ ) -> 'resources_utils.FeasibleResources':
+ """Returns a list of feasible resources for the given resources."""
+ if resources.instance_type is not None:
+ assert resources.is_launchable(), resources
+ resources = resources.copy(accelerators=None)
+ return resources_utils.FeasibleResources([resources], [], None)
+
+ def _make(instance_list):
+ resource_list = []
+ for instance_type in instance_list:
+ r = resources.copy(
+ cloud=Yotta(),
+ instance_type=instance_type,
+ accelerators=None,
+ cpus=None,
+ )
+ resource_list.append(r)
+ return resource_list
+
+ # Currently, handle a filter on accelerators only.
+ accelerators = resources.accelerators
+ if accelerators is None:
+ # Return a default instance type
+ default_instance_type = Yotta.get_default_instance_type(
+ cpus=resources.cpus,
+ memory=resources.memory,
+ disk_tier=resources.disk_tier,
+ region=resources.region,
+ zone=resources.zone)
+ if default_instance_type is None:
+ # TODO: Add hints to all return values in this method to help
+ # users understand why the resources are not launchable.
+ return resources_utils.FeasibleResources([], [], None)
+ else:
+ return resources_utils.FeasibleResources(
+ _make([default_instance_type]), [], None)
+
+ assert len(accelerators) == 1, resources
+ acc, acc_count = list(accelerators.items())[0]
+ (instance_list,
+ fuzzy_candidate_list) = catalog.get_instance_type_for_accelerator(
+ acc,
+ acc_count,
+ use_spot=resources.use_spot,
+ cpus=resources.cpus,
+ region=resources.region,
+ zone=resources.zone,
+ clouds=_CLOUD)
+ if instance_list is None:
+ return resources_utils.FeasibleResources([], fuzzy_candidate_list,
+ None)
+ return resources_utils.FeasibleResources(_make(instance_list),
+ fuzzy_candidate_list, None)
+
+ @classmethod
+ def _check_compute_credentials(cls) -> Tuple[bool, Optional[str]]:
+ """Checks if the user has access credentials to
+ Yotta's compute service."""
+ msg = ('Failed to access Yotta Cloud with credentials. '
+ 'To configure credentials, go to:\n '
+ ' https://console.yottalabs.ai \n '
+ 'to obtain an API key, then add save the contents '
+ f'to {CREDENTIAL_FILE} \n')
+ if not os.path.exists(os.path.expanduser(CREDENTIAL_FILE)):
+ return False, msg
+
+ try:
+ valid = yotta_client.check_api_key()
+ if not valid:
+ return False, msg
+ return True, None
+ except Exception as e: # pylint: disable=broad-except
+ return False, str(e)
+
+ def get_credential_file_mounts(self) -> Dict[str, str]:
+ return {CREDENTIAL_FILE: CREDENTIAL_FILE}
+
+ @classmethod
+ def get_user_identities(cls) -> Optional[List[List[str]]]:
+ # NOTE: used for very advanced SkyPilot functionality
+ # Can implement later if desired
+ return None
+
+ def instance_type_exists(self, instance_type: str) -> bool:
+ return catalog.instance_type_exists(instance_type, _CLOUD)
+
+ def validate_region_zone(self, region: Optional[str], zone: Optional[str]):
+ return catalog.validate_region_zone(region, zone, clouds=_CLOUD)
+
+ @classmethod
+ def get_image_size(cls, image_id: str, region: Optional[str]) -> float:
+ # TODO: use 0.0 for now to allow all images. We should change this to
+ # return the docker image size.
+ del image_id, region # unused
+ return 0.0
diff --git a/sky/core.py b/sky/core.py
index 58b1772cfd8..730664d6017 100644
--- a/sky/core.py
+++ b/sky/core.py
@@ -1,4 +1,5 @@
"""SDK functions for cluster/job management."""
+import shlex
import typing
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
@@ -20,6 +21,7 @@
from sky.adaptors import common as adaptors_common
from sky.backends import backend_utils
from sky.backends import cloud_vm_ray_backend
+from sky.backends import task_codegen
from sky.clouds import cloud as sky_cloud
from sky.jobs.server import core as managed_jobs_core
from sky.provision.kubernetes import constants as kubernetes_constants
@@ -89,6 +91,14 @@ def optimize(
request_name=request_names.AdminPolicyRequestName.OPTIMIZE,
request_options=request_options) as dag:
dag.resolve_and_validate_volumes()
+ # Use job group optimizer for job groups to properly handle
+ # co-location constraints and show the combined optimizer table
+ if dag.is_job_group():
+ return optimizer.Optimizer.optimize_job_group(
+ dag=dag,
+ minimize=minimize,
+ blocked_resources=blocked_resources,
+ quiet=quiet)
return optimizer.Optimizer.optimize(dag=dag,
minimize=minimize,
blocked_resources=blocked_resources,
@@ -540,6 +550,8 @@ def _start(
f'Starting cluster {cluster_name!r} with backend {backend.NAME} '
'is not supported.')
+ hook: Optional[str] = None
+ hook_timeout: Optional[int] = None
controller = controller_utils.Controllers.from_name(cluster_name)
if controller is not None:
if down or idle_minutes_to_autostop:
@@ -568,6 +580,9 @@ def _start(
controller_autostop_config.enabled):
idle_minutes_to_autostop = controller_autostop_config.idle_minutes
down = controller_autostop_config.down
+ wait_for = controller_autostop_config.wait_for
+ hook = controller_autostop_config.hook
+ hook_timeout = controller_autostop_config.hook_timeout
else:
# For non-controller clusters, restore autostop configuration from
# database if not explicitly provided.
@@ -613,7 +628,15 @@ def _start(
all_file_mounts=None,
storage_mounts=storage_mounts)
if idle_minutes_to_autostop is not None:
- backend.set_autostop(handle, idle_minutes_to_autostop, wait_for, down)
+ # For controller clusters, hook comes from controller_autostop_config.
+ # For regular clusters, hook is None so it will be inherited from the
+ # existing config on the remote cluster.
+ backend.set_autostop(handle,
+ idle_minutes_to_autostop,
+ wait_for,
+ down,
+ hook=hook,
+ hook_timeout=hook_timeout)
return handle
@@ -695,8 +718,84 @@ def _stop_not_supported_message(resources: 'resources_lib.Resources') -> str:
return message
+def _graceful_job_cancel(handle: backends.ResourceHandle,
+ backend: backends.Backend,
+ cluster_name: str,
+ timeout: Optional[int] = None,
+ terminate: bool = True) -> None:
+ """Stop jobs and flush rclone uploads on all nodes in parallel."""
+ op = 'shutdown' if terminate else 'stop'
+ if (not isinstance(handle, backends.CloudVmRayResourceHandle) or
+ not isinstance(backend, backends.CloudVmRayBackend)):
+ logger.warning(f'Graceful {op} only available for '
+ 'CloudVmRayBackend. Skipping...')
+ return
+
+ # Kill all running jobs on the cluster
+ logger.info(f'Graceful {op} enabled. Terminating user jobs on '
+ f'{cluster_name}...')
+ try:
+ backend.cancel_jobs(handle, jobs=None, cancel_all=True)
+ except Exception as e: # pylint: disable=broad-except
+ logger.warning(f'Failed to cancel jobs: {e}')
+
+ # Flush rclone uploads on all nodes in parallel
+ logger.info('Flushing MOUNT_CACHED uploads on all nodes of '
+ f'{cluster_name!r}...')
+
+ # Get the flush script
+ flush_script = task_codegen.TaskCodeGen.get_rclone_flush_script()
+
+ # Wrap with timeout if specified
+ if timeout:
+ flush_script = f'timeout {timeout} bash -c {shlex.quote(flush_script)}'
+
+ runners = handle.get_command_runners()
+ node_args = [(i, runner) for i, runner in enumerate(runners)]
+ errors = []
+ logger.debug(f'Waiting for uploads on {len(runners)} node(s)...')
+
+ def run_flush_on_node(args):
+ """Run flush script on a single node."""
+ node_id, runner = args
+ try:
+ returncode, stdout, stderr = runner.run(
+ flush_script,
+ stream_logs=False,
+ require_outputs=True,
+ )
+ return (node_id, returncode, stdout, stderr)
+ except Exception as e: # pylint: disable=broad-except
+ return (node_id, -1, '', str(e))
+
+ parallel_results = subprocess_utils.run_in_parallel(
+ run_flush_on_node,
+ node_args,
+ num_threads=len(runners),
+ )
+
+ for node_id, returncode, _, stderr in parallel_results:
+ if returncode == 0:
+ logger.debug(f'Node {node_id}: uploads flushed successfully')
+ elif returncode == 124: # timeout exit code
+ logger.warning(f'Node {node_id}: flush timed out after {timeout}s')
+ errors.append(f'Node {node_id}: timeout')
+ else:
+ logger.warning(
+ f'Node {node_id}: flush failed (rc={returncode}): {stderr}')
+ errors.append(f'Node {node_id}: {stderr}')
+
+ if errors:
+ logger.warning(f'Some nodes had flush errors: {errors}')
+ else:
+ logger.debug(f'All MOUNT_CACHED uploads completed on {cluster_name!r}')
+
+
@usage_lib.entrypoint
-def down(cluster_name: str, purge: bool = False) -> None:
+def down(cluster_name: str,
+ purge: bool = False,
+ graceful: bool = False,
+ graceful_timeout: Optional[int] = None) -> None:
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
"""Tears down a cluster.
@@ -712,6 +811,10 @@ def down(cluster_name: str, purge: bool = False) -> None:
troubleshooting scenarios; with it set, it is the user's
responsibility to ensure there are no leaked instances and related
resources.
+ graceful: Cancel the user's task but block until MOUNT_CACHED data is
+ fully uploaded. This helps with preserving user data integrity.
+ graceful_timeout: If not None, sets a timeout for the graceful option
+ above (in seconds).
Raises:
sky.exceptions.ClusterDoesNotExist: the specified cluster does not
@@ -724,14 +827,24 @@ def down(cluster_name: str, purge: bool = False) -> None:
if handle is None:
raise exceptions.ClusterDoesNotExist(
f'Cluster {cluster_name!r} does not exist.')
+ backend = backend_utils.get_backend_from_handle(handle)
+
+ if graceful:
+ _graceful_job_cancel(handle,
+ backend,
+ cluster_name,
+ graceful_timeout,
+ terminate=True)
usage_lib.record_cluster_name_for_current_operation(cluster_name)
- backend = backend_utils.get_backend_from_handle(handle)
backend.teardown(handle, terminate=True, purge=purge)
@usage_lib.entrypoint
-def stop(cluster_name: str, purge: bool = False) -> None:
+def stop(cluster_name: str,
+ purge: bool = False,
+ graceful: bool = False,
+ graceful_timeout: Optional[int] = None) -> None:
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
"""Stops a cluster.
@@ -750,6 +863,10 @@ def stop(cluster_name: str, purge: bool = False) -> None:
certain manual troubleshooting scenarios; with it set, it is the
user's responsibility to ensure there are no leaked instances and
related resources.
+ graceful: Cancel the user's task but block until MOUNT_CACHED data is
+ fully uploaded. This helps with preserving user data integrity.
+ graceful_timeout: If not None, sets a timeout for the graceful option
+ above (in seconds).
Raises:
sky.exceptions.ClusterDoesNotExist: the specified cluster does not
@@ -791,17 +908,26 @@ def stop(cluster_name: str, purge: bool = False) -> None:
' To terminate the cluster instead, run: '
f'{colorama.Style.BRIGHT}sky down {cluster_name}') from e
+ if graceful:
+ _graceful_job_cancel(handle,
+ backend,
+ cluster_name,
+ graceful_timeout,
+ terminate=False)
+
usage_lib.record_cluster_name_for_current_operation(cluster_name)
backend.teardown(handle, terminate=False, purge=purge)
@usage_lib.entrypoint
def autostop(
- cluster_name: str,
- idle_minutes: int,
- wait_for: Optional[autostop_lib.AutostopWaitFor] = autostop_lib.
- DEFAULT_AUTOSTOP_WAIT_FOR,
- down: bool = False, # pylint: disable=redefined-outer-name
+ cluster_name: str,
+ idle_minutes: int,
+ wait_for: Optional[
+ autostop_lib.AutostopWaitFor] = autostop_lib.DEFAULT_AUTOSTOP_WAIT_FOR,
+ down: bool = False, # pylint: disable=redefined-outer-name
+ hook: Optional[str] = None,
+ hook_timeout: Optional[int] = None,
) -> None:
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
"""Schedules an autostop/autodown for a cluster.
@@ -835,6 +961,12 @@ def autostop(
to a negative number cancels any autostop/autodown setting.
down: if true, use autodown (tear down the cluster; non-restartable),
rather than autostop (restartable).
+ hook: optional script to execute on the remote cluster before autostop.
+ The script runs before the cluster is stopped or torn down. If the
+ hook fails, autostop will still proceed but a warning will be logged.
+ hook_timeout: timeout in seconds for hook execution. If None, uses
+ DEFAULT_AUTOSTOP_HOOK_TIMEOUT_SECONDS (3600 = 1 hour). The hook will
+ be terminated if it exceeds this timeout.
Raises:
sky.exceptions.ClusterDoesNotExist: if the cluster does not exist.
@@ -890,7 +1022,12 @@ def autostop(
f'see reason above.') from e
usage_lib.record_cluster_name_for_current_operation(cluster_name)
- backend.set_autostop(handle, idle_minutes, wait_for, down)
+ backend.set_autostop(handle,
+ idle_minutes,
+ wait_for,
+ down,
+ hook=hook,
+ hook_timeout=hook_timeout)
# ==================
@@ -1132,6 +1269,43 @@ def tail_logs(cluster_name: str,
return returnval
+@usage_lib.entrypoint
+def tail_autostop_logs(cluster_name: str,
+ follow: bool = True,
+ tail: int = 0) -> int:
+ """Tails the autostop hook logs of a cluster.
+
+ Args:
+ cluster_name: name of the cluster.
+ follow: whether to follow the logs.
+ tail: number of lines to display from the end of the log file.
+
+ Raises:
+ ValueError: if arguments are invalid or the cluster is not supported.
+ sky.exceptions.ClusterDoesNotExist: if the cluster does not exist.
+ sky.exceptions.ClusterNotUpError: if the cluster is not UP.
+ sky.exceptions.NotSupportedError: if the cluster is not based on
+ CloudVmRayBackend.
+ sky.exceptions.ClusterOwnerIdentityMismatchError: if the current user is
+ not the same as the user who created the cluster.
+ sky.exceptions.CloudUserIdentityError: if we fail to get the current
+ user identity.
+
+ Returns:
+ Return code 0 on success, non-zero on failure.
+ """
+ # Check the status of the cluster.
+ handle = backend_utils.check_cluster_available(
+ cluster_name,
+ operation='tailing autostop logs',
+ )
+ backend = backend_utils.get_backend_from_handle(handle)
+
+ usage_lib.record_cluster_name_for_current_operation(cluster_name)
+ returnval = backend.tail_autostop_logs(handle, follow=follow, tail=tail)
+ return returnval
+
+
@usage_lib.entrypoint
def download_logs(
cluster_name: str,
@@ -1342,11 +1516,13 @@ def _realtime_kubernetes_gpu_availability_single(
region_filter=context,
quantity_filter=quantity_filter,
case_sensitive=False)
- assert (set(counts.keys()) == set(capacity.keys()) == set(
- available.keys())), (f'Keys of counts ({list(counts.keys())}), '
- f'capacity ({list(capacity.keys())}), '
- f'and available ({list(available.keys())}) '
- 'must be the same.')
+
+ all_keys = set(counts.keys()) | set(capacity.keys()) | set(
+ available.keys())
+ counts = {key: counts.get(key, []) for key in all_keys}
+ capacity = {key: capacity.get(key, 0) for key in all_keys}
+ available = {key: available.get(key, 0) for key in all_keys}
+
realtime_gpu_availability_list: List[
models.RealtimeGpuAvailability] = []
diff --git a/sky/dag.py b/sky/dag.py
index 349f3ef1b04..d736ea667f0 100644
--- a/sky/dag.py
+++ b/sky/dag.py
@@ -1,13 +1,28 @@
"""DAGs: user applications to be run."""
+import enum
import pprint
import threading
import typing
-from typing import List, Optional
+from typing import Dict, List, Optional, Union
if typing.TYPE_CHECKING:
from sky import task
+class DagExecution(enum.Enum):
+ """Execution mode for DAGs with multiple tasks.
+
+ This controls how tasks in a multi-task DAG are executed.
+ """
+ SERIAL = 'serial' # Tasks execute sequentially (pipeline)
+ PARALLEL = 'parallel' # All tasks start in parallel (job group)
+
+
+# Default execution mode for jobs without an explicit execution mode set.
+# Used for single jobs and as a fallback for pipelines.
+DEFAULT_EXECUTION = DagExecution.SERIAL
+
+
class Dag:
"""Dag: a user application, represented as a DAG of Tasks.
@@ -15,6 +30,11 @@ class Dag:
>>> import sky
>>> with sky.Dag() as dag:
>>> task = sky.Task(...)
+
+ For JobGroups (heterogeneous parallel workloads):
+ >>> dag = dag_utils.load_job_group_from_yaml('job_group.yaml')
+ >>> # dag.is_job_group() returns True
+ >>> # dag.tasks contains jobs to run in parallel
"""
def __init__(self) -> None:
@@ -26,6 +46,18 @@ def __init__(self) -> None:
self.policy_applied: bool = False
self.pool: Optional[str] = None
+ # Execution mode for multi-task DAGs
+ self.execution: Optional[DagExecution] = None
+
+ # Primary/auxiliary task support for job groups
+ # If set, only the named tasks are "primary"; others are "auxiliary".
+ # When all primary tasks complete, auxiliary tasks are terminated.
+ self.primary_tasks: Optional[List[str]] = None
+ # Termination delay for auxiliary tasks when primary tasks complete.
+ # Can be a string like "30s" (applies to all auxiliary tasks) or
+ # a dict like {"default": "30s", "replay-buffer": "1m"}.
+ self.termination_delay: Optional[Union[str, Dict[str, str]]] = None
+
def add(self, task: 'task.Task') -> None:
self.graph.add_node(task)
self.tasks.append(task)
@@ -56,6 +88,74 @@ def __repr__(self) -> str:
def get_graph(self):
return self.graph
+ def is_job_group(self) -> bool:
+ """Check if this DAG represents a JobGroup.
+
+ A DAG is a JobGroup if it has parallel execution mode. This is the
+ defining characteristic that distinguishes job groups from pipelines.
+ """
+ return self.execution == DagExecution.PARALLEL
+
+ def set_execution(self, execution: 'DagExecution') -> None:
+ """Configure this DAG with the given execution mode."""
+ self.execution = execution
+
+ def get_termination_delay_secs(self, task_name: str) -> int:
+ """Get termination delay in seconds for a specific task.
+
+ Args:
+ task_name: The name of the task to get the delay for.
+
+ Returns:
+ Termination delay in seconds. Returns 0 if not configured.
+ """
+ if self.termination_delay is None:
+ return 0
+
+ # Import here to avoid circular imports
+ # pylint: disable=import-outside-toplevel
+ from sky.utils import resources_utils
+
+ # Get the delay string based on format (str or dict)
+ if isinstance(self.termination_delay, str):
+ delay_str = self.termination_delay
+ else:
+ delay_str = self.termination_delay.get(
+ task_name, self.termination_delay.get('default', '0s'))
+
+ return resources_utils.parse_time_seconds(delay_str)
+
+ def is_primary_task(self, task_name: str) -> bool:
+ """Check if a task is a primary task.
+
+ Args:
+ task_name: The name of the task to check.
+
+ Returns:
+ True if the task is primary. When primary_tasks is None or empty,
+ all tasks are considered primary.
+ """
+ if self.primary_tasks is None or len(self.primary_tasks) == 0:
+ return True
+ # pylint: disable=unsupported-membership-test
+ return task_name in self.primary_tasks
+
+ def get_auxiliary_task_names(self) -> typing.List[str]:
+ """Get the names of all auxiliary (non-primary) tasks.
+
+ Returns:
+ List of auxiliary task names. Returns empty list if all tasks
+ are primary (when primary_tasks is None or empty).
+ """
+ if self.primary_tasks is None or len(self.primary_tasks) == 0:
+ return []
+ # pylint: disable=unsupported-membership-test
+ return [
+ t.name
+ for t in self.tasks
+ if t.name is not None and t.name not in self.primary_tasks
+ ]
+
def is_chain(self) -> bool:
"""Check if the DAG is a linear chain of tasks."""
diff --git a/sky/dashboard/package-lock.json b/sky/dashboard/package-lock.json
index e41998e426f..db81f1f3f14 100644
--- a/sky/dashboard/package-lock.json
+++ b/sky/dashboard/package-lock.json
@@ -8,6 +8,7 @@
"name": "dashboard",
"version": "0.1.0",
"dependencies": {
+ "@codemirror/lang-yaml": "^6.1.2",
"@emotion/react": "^11.13.0",
"@emotion/styled": "^11.13.0",
"@mui/material": "^5.16.7",
@@ -18,6 +19,7 @@
"@radix-ui/react-label": "^2.1.0",
"@radix-ui/react-select": "^2.1.1",
"@radix-ui/react-slot": "^1.1.0",
+ "@uiw/react-codemirror": "^4.25.4",
"chart.js": "^4.4.3",
"class-variance-authority": "^0.7.0",
"clsx": "^2.1.1",
@@ -1920,6 +1922,114 @@
"dev": true,
"license": "MIT"
},
+ "node_modules/@codemirror/autocomplete": {
+ "version": "6.20.0",
+ "resolved": "https://registry.npmjs.org/@codemirror/autocomplete/-/autocomplete-6.20.0.tgz",
+ "integrity": "sha512-bOwvTOIJcG5FVo5gUUupiwYh8MioPLQ4UcqbcRf7UQ98X90tCa9E1kZ3Z7tqwpZxYyOvh1YTYbmZE9RTfTp5hg==",
+ "license": "MIT",
+ "dependencies": {
+ "@codemirror/language": "^6.0.0",
+ "@codemirror/state": "^6.0.0",
+ "@codemirror/view": "^6.17.0",
+ "@lezer/common": "^1.0.0"
+ }
+ },
+ "node_modules/@codemirror/commands": {
+ "version": "6.10.1",
+ "resolved": "https://registry.npmjs.org/@codemirror/commands/-/commands-6.10.1.tgz",
+ "integrity": "sha512-uWDWFypNdQmz2y1LaNJzK7fL7TYKLeUAU0npEC685OKTF3KcQ2Vu3klIM78D7I6wGhktme0lh3CuQLv0ZCrD9Q==",
+ "license": "MIT",
+ "dependencies": {
+ "@codemirror/language": "^6.0.0",
+ "@codemirror/state": "^6.4.0",
+ "@codemirror/view": "^6.27.0",
+ "@lezer/common": "^1.1.0"
+ }
+ },
+ "node_modules/@codemirror/lang-yaml": {
+ "version": "6.1.2",
+ "resolved": "https://registry.npmjs.org/@codemirror/lang-yaml/-/lang-yaml-6.1.2.tgz",
+ "integrity": "sha512-dxrfG8w5Ce/QbT7YID7mWZFKhdhsaTNOYjOkSIMt1qmC4VQnXSDSYVHHHn8k6kJUfIhtLo8t1JJgltlxWdsITw==",
+ "license": "MIT",
+ "dependencies": {
+ "@codemirror/autocomplete": "^6.0.0",
+ "@codemirror/language": "^6.0.0",
+ "@codemirror/state": "^6.0.0",
+ "@lezer/common": "^1.2.0",
+ "@lezer/highlight": "^1.2.0",
+ "@lezer/lr": "^1.0.0",
+ "@lezer/yaml": "^1.0.0"
+ }
+ },
+ "node_modules/@codemirror/language": {
+ "version": "6.12.1",
+ "resolved": "https://registry.npmjs.org/@codemirror/language/-/language-6.12.1.tgz",
+ "integrity": "sha512-Fa6xkSiuGKc8XC8Cn96T+TQHYj4ZZ7RdFmXA3i9xe/3hLHfwPZdM+dqfX0Cp0zQklBKhVD8Yzc8LS45rkqcwpQ==",
+ "license": "MIT",
+ "dependencies": {
+ "@codemirror/state": "^6.0.0",
+ "@codemirror/view": "^6.23.0",
+ "@lezer/common": "^1.5.0",
+ "@lezer/highlight": "^1.0.0",
+ "@lezer/lr": "^1.0.0",
+ "style-mod": "^4.0.0"
+ }
+ },
+ "node_modules/@codemirror/lint": {
+ "version": "6.9.3",
+ "resolved": "https://registry.npmjs.org/@codemirror/lint/-/lint-6.9.3.tgz",
+ "integrity": "sha512-y3YkYhdnhjDBAe0VIA0c4wVoFOvnp8CnAvfLqi0TqotIv92wIlAAP7HELOpLBsKwjAX6W92rSflA6an/2zBvXw==",
+ "license": "MIT",
+ "dependencies": {
+ "@codemirror/state": "^6.0.0",
+ "@codemirror/view": "^6.35.0",
+ "crelt": "^1.0.5"
+ }
+ },
+ "node_modules/@codemirror/search": {
+ "version": "6.6.0",
+ "resolved": "https://registry.npmjs.org/@codemirror/search/-/search-6.6.0.tgz",
+ "integrity": "sha512-koFuNXcDvyyotWcgOnZGmY7LZqEOXZaaxD/j6n18TCLx2/9HieZJ5H6hs1g8FiRxBD0DNfs0nXn17g872RmYdw==",
+ "license": "MIT",
+ "dependencies": {
+ "@codemirror/state": "^6.0.0",
+ "@codemirror/view": "^6.37.0",
+ "crelt": "^1.0.5"
+ }
+ },
+ "node_modules/@codemirror/state": {
+ "version": "6.5.4",
+ "resolved": "https://registry.npmjs.org/@codemirror/state/-/state-6.5.4.tgz",
+ "integrity": "sha512-8y7xqG/hpB53l25CIoit9/ngxdfoG+fx+V3SHBrinnhOtLvKHRyAJJuHzkWrR4YXXLX8eXBsejgAAxHUOdW1yw==",
+ "license": "MIT",
+ "dependencies": {
+ "@marijn/find-cluster-break": "^1.0.0"
+ }
+ },
+ "node_modules/@codemirror/theme-one-dark": {
+ "version": "6.1.3",
+ "resolved": "https://registry.npmjs.org/@codemirror/theme-one-dark/-/theme-one-dark-6.1.3.tgz",
+ "integrity": "sha512-NzBdIvEJmx6fjeremiGp3t/okrLPYT0d9orIc7AFun8oZcRk58aejkqhv6spnz4MLAevrKNPMQYXEWMg4s+sKA==",
+ "license": "MIT",
+ "dependencies": {
+ "@codemirror/language": "^6.0.0",
+ "@codemirror/state": "^6.0.0",
+ "@codemirror/view": "^6.0.0",
+ "@lezer/highlight": "^1.0.0"
+ }
+ },
+ "node_modules/@codemirror/view": {
+ "version": "6.39.11",
+ "resolved": "https://registry.npmjs.org/@codemirror/view/-/view-6.39.11.tgz",
+ "integrity": "sha512-bWdeR8gWM87l4DB/kYSF9A+dVackzDb/V56Tq7QVrQ7rn86W0rgZFtlL3g3pem6AeGcb9NQNoy3ao4WpW4h5tQ==",
+ "license": "MIT",
+ "dependencies": {
+ "@codemirror/state": "^6.5.0",
+ "crelt": "^1.0.6",
+ "style-mod": "^4.1.0",
+ "w3c-keyname": "^2.2.4"
+ }
+ },
"node_modules/@emnapi/core": {
"version": "1.4.3",
"resolved": "https://registry.npmjs.org/@emnapi/core/-/core-1.4.3.tgz",
@@ -2893,6 +3003,47 @@
"integrity": "sha512-M5UknZPHRu3DEDWoipU6sE8PdkZ6Z/S+v4dD+Ke8IaNlpdSQah50lz1KtcFBa2vsdOnwbbnxJwVM4wty6udA5w==",
"license": "MIT"
},
+ "node_modules/@lezer/common": {
+ "version": "1.5.0",
+ "resolved": "https://registry.npmjs.org/@lezer/common/-/common-1.5.0.tgz",
+ "integrity": "sha512-PNGcolp9hr4PJdXR4ix7XtixDrClScvtSCYW3rQG106oVMOOI+jFb+0+J3mbeL/53g1Zd6s0kJzaw6Ri68GmAA==",
+ "license": "MIT"
+ },
+ "node_modules/@lezer/highlight": {
+ "version": "1.2.3",
+ "resolved": "https://registry.npmjs.org/@lezer/highlight/-/highlight-1.2.3.tgz",
+ "integrity": "sha512-qXdH7UqTvGfdVBINrgKhDsVTJTxactNNxLk7+UMwZhU13lMHaOBlJe9Vqp907ya56Y3+ed2tlqzys7jDkTmW0g==",
+ "license": "MIT",
+ "dependencies": {
+ "@lezer/common": "^1.3.0"
+ }
+ },
+ "node_modules/@lezer/lr": {
+ "version": "1.4.8",
+ "resolved": "https://registry.npmjs.org/@lezer/lr/-/lr-1.4.8.tgz",
+ "integrity": "sha512-bPWa0Pgx69ylNlMlPvBPryqeLYQjyJjqPx+Aupm5zydLIF3NE+6MMLT8Yi23Bd9cif9VS00aUebn+6fDIGBcDA==",
+ "license": "MIT",
+ "dependencies": {
+ "@lezer/common": "^1.0.0"
+ }
+ },
+ "node_modules/@lezer/yaml": {
+ "version": "1.0.3",
+ "resolved": "https://registry.npmjs.org/@lezer/yaml/-/yaml-1.0.3.tgz",
+ "integrity": "sha512-GuBLekbw9jDBDhGur82nuwkxKQ+a3W5H0GfaAthDXcAu+XdpS43VlnxA9E9hllkpSP5ellRDKjLLj7Lu9Wr6xA==",
+ "license": "MIT",
+ "dependencies": {
+ "@lezer/common": "^1.2.0",
+ "@lezer/highlight": "^1.0.0",
+ "@lezer/lr": "^1.4.0"
+ }
+ },
+ "node_modules/@marijn/find-cluster-break": {
+ "version": "1.0.2",
+ "resolved": "https://registry.npmjs.org/@marijn/find-cluster-break/-/find-cluster-break-1.0.2.tgz",
+ "integrity": "sha512-l0h88YhZFyKdXIFNfSWpyjStDjGHwZ/U7iobcK1cQQD8sejsONdQtTVU+1wVN1PBw40PiiHB1vA5S7VTfQiP9g==",
+ "license": "MIT"
+ },
"node_modules/@mui/core-downloads-tracker": {
"version": "5.17.1",
"resolved": "https://registry.npmjs.org/@mui/core-downloads-tracker/-/core-downloads-tracker-5.17.1.tgz",
@@ -5300,6 +5451,59 @@
"url": "https://opencollective.com/eslint"
}
},
+ "node_modules/@uiw/codemirror-extensions-basic-setup": {
+ "version": "4.25.4",
+ "resolved": "https://registry.npmjs.org/@uiw/codemirror-extensions-basic-setup/-/codemirror-extensions-basic-setup-4.25.4.tgz",
+ "integrity": "sha512-YzNwkm0AbPv1EXhCHYR5v0nqfemG2jEB0Z3Att4rBYqKrlG7AA9Rhjc3IyBaOzsBu18wtrp9/+uhTyu7TXSRng==",
+ "license": "MIT",
+ "dependencies": {
+ "@codemirror/autocomplete": "^6.0.0",
+ "@codemirror/commands": "^6.0.0",
+ "@codemirror/language": "^6.0.0",
+ "@codemirror/lint": "^6.0.0",
+ "@codemirror/search": "^6.0.0",
+ "@codemirror/state": "^6.0.0",
+ "@codemirror/view": "^6.0.0"
+ },
+ "funding": {
+ "url": "https://jaywcjlove.github.io/#/sponsor"
+ },
+ "peerDependencies": {
+ "@codemirror/autocomplete": ">=6.0.0",
+ "@codemirror/commands": ">=6.0.0",
+ "@codemirror/language": ">=6.0.0",
+ "@codemirror/lint": ">=6.0.0",
+ "@codemirror/search": ">=6.0.0",
+ "@codemirror/state": ">=6.0.0",
+ "@codemirror/view": ">=6.0.0"
+ }
+ },
+ "node_modules/@uiw/react-codemirror": {
+ "version": "4.25.4",
+ "resolved": "https://registry.npmjs.org/@uiw/react-codemirror/-/react-codemirror-4.25.4.tgz",
+ "integrity": "sha512-ipO067oyfUw+DVaXhQCxkB0ZD9b7RnY+ByrprSYSKCHaULvJ3sqWYC/Zen6zVQ8/XC4o5EPBfatGiX20kC7XGA==",
+ "license": "MIT",
+ "dependencies": {
+ "@babel/runtime": "^7.18.6",
+ "@codemirror/commands": "^6.1.0",
+ "@codemirror/state": "^6.1.1",
+ "@codemirror/theme-one-dark": "^6.0.0",
+ "@uiw/codemirror-extensions-basic-setup": "4.25.4",
+ "codemirror": "^6.0.0"
+ },
+ "funding": {
+ "url": "https://jaywcjlove.github.io/#/sponsor"
+ },
+ "peerDependencies": {
+ "@babel/runtime": ">=7.11.0",
+ "@codemirror/state": ">=6.0.0",
+ "@codemirror/theme-one-dark": ">=6.0.0",
+ "@codemirror/view": ">=6.0.0",
+ "codemirror": ">=6.0.0",
+ "react": ">=17.0.0",
+ "react-dom": ">=17.0.0"
+ }
+ },
"node_modules/@ungap/structured-clone": {
"version": "1.3.0",
"resolved": "https://registry.npmjs.org/@ungap/structured-clone/-/structured-clone-1.3.0.tgz",
@@ -6680,6 +6884,21 @@
"node": ">= 0.12.0"
}
},
+ "node_modules/codemirror": {
+ "version": "6.0.2",
+ "resolved": "https://registry.npmjs.org/codemirror/-/codemirror-6.0.2.tgz",
+ "integrity": "sha512-VhydHotNW5w1UGK0Qj96BwSk/Zqbp9WbnyK2W/eVMv4QyF41INRGpjUhFJY7/uDNuudSc33a/PKr4iDqRduvHw==",
+ "license": "MIT",
+ "dependencies": {
+ "@codemirror/autocomplete": "^6.0.0",
+ "@codemirror/commands": "^6.0.0",
+ "@codemirror/language": "^6.0.0",
+ "@codemirror/lint": "^6.0.0",
+ "@codemirror/search": "^6.0.0",
+ "@codemirror/state": "^6.0.0",
+ "@codemirror/view": "^6.0.0"
+ }
+ },
"node_modules/collect-v8-coverage": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/collect-v8-coverage/-/collect-v8-coverage-1.0.2.tgz",
@@ -6871,6 +7090,12 @@
"node": "^14.15.0 || ^16.10.0 || >=18.0.0"
}
},
+ "node_modules/crelt": {
+ "version": "1.0.6",
+ "resolved": "https://registry.npmjs.org/crelt/-/crelt-1.0.6.tgz",
+ "integrity": "sha512-VQ2MBenTq1fWZUH9DJNGti7kKv6EeAuYr3cLwxUWhIu1baTaXh4Ib5W2CqHVqib4/MqbYGJqiL3Zb8GJZr3l4g==",
+ "license": "MIT"
+ },
"node_modules/cross-spawn": {
"version": "7.0.6",
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz",
@@ -13873,6 +14098,12 @@
"url": "https://github.com/sponsors/sindresorhus"
}
},
+ "node_modules/style-mod": {
+ "version": "4.1.3",
+ "resolved": "https://registry.npmjs.org/style-mod/-/style-mod-4.1.3.tgz",
+ "integrity": "sha512-i/n8VsZydrugj3Iuzll8+x/00GH2vnYsk1eomD8QiRrSAeW6ItbCQDtfXCeJHd0iwiNagqjQkvpvREEPtW3IoQ==",
+ "license": "MIT"
+ },
"node_modules/styled-jsx": {
"version": "5.1.1",
"resolved": "https://registry.npmjs.org/styled-jsx/-/styled-jsx-5.1.1.tgz",
@@ -14755,6 +14986,12 @@
"node": ">= 0.8"
}
},
+ "node_modules/w3c-keyname": {
+ "version": "2.2.8",
+ "resolved": "https://registry.npmjs.org/w3c-keyname/-/w3c-keyname-2.2.8.tgz",
+ "integrity": "sha512-dpojBhNsCNN7T82Tm7k26A6G9ML3NkhDsnw9n/eoxSRlVBB4CEtIQ/KTCLI2Fwf3ataSXRhYFkQi3SlnFwPvPQ==",
+ "license": "MIT"
+ },
"node_modules/w3c-xmlserializer": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/w3c-xmlserializer/-/w3c-xmlserializer-4.0.0.tgz",
diff --git a/sky/dashboard/package.json b/sky/dashboard/package.json
index 31dbbf5bfab..05f5f7b6bfe 100644
--- a/sky/dashboard/package.json
+++ b/sky/dashboard/package.json
@@ -6,7 +6,7 @@
"dev": "node server.js",
"build": "next build",
"start": "NODE_ENV=production node server.js",
- "lint": "next lint",
+ "lint": "next lint --max-warnings 0",
"lint:fix": "next lint --fix",
"format": "prettier --write .",
"format:check": "prettier --check .",
@@ -15,6 +15,7 @@
"test:watch": "jest --watch"
},
"dependencies": {
+ "@codemirror/lang-yaml": "^6.1.2",
"@emotion/react": "^11.13.0",
"@emotion/styled": "^11.13.0",
"@mui/material": "^5.16.7",
@@ -25,6 +26,7 @@
"@radix-ui/react-label": "^2.1.0",
"@radix-ui/react-select": "^2.1.1",
"@radix-ui/react-slot": "^1.1.0",
+ "@uiw/react-codemirror": "^4.25.4",
"chart.js": "^4.4.3",
"class-variance-authority": "^0.7.0",
"clsx": "^2.1.1",
diff --git a/sky/dashboard/src/app/globals.css b/sky/dashboard/src/app/globals.css
index fbee99dc207..7822ecd45cc 100644
--- a/sky/dashboard/src/app/globals.css
+++ b/sky/dashboard/src/app/globals.css
@@ -505,3 +505,73 @@
text-align: left !important;
white-space: nowrap !important;
}
+
+/* ===== Infra Page Glassy Loading Styles ===== */
+
+@keyframes infra-shimmer {
+ 0% {
+ background-position: -200% 0;
+ }
+ 100% {
+ background-position: 200% 0;
+ }
+}
+
+/* Skeleton text placeholder - replaces CircularProgress in cells */
+.infra-skeleton-text {
+ display: inline-block;
+ height: 0.75rem;
+ background: linear-gradient(
+ 90deg,
+ hsl(var(--border)) 0%,
+ hsl(var(--secondary)) 50%,
+ hsl(var(--border)) 100%
+ );
+ background-size: 200% 100%;
+ animation: infra-shimmer 1.5s infinite;
+ border-radius: 3px;
+ min-width: 24px;
+}
+
+/* Row shimmer effect for progressive loading */
+.infra-loading-row {
+ position: relative;
+ overflow: hidden;
+}
+
+.infra-loading-row::after {
+ content: '';
+ position: absolute;
+ top: 0;
+ left: 0;
+ right: 0;
+ bottom: 0;
+ background: linear-gradient(
+ 90deg,
+ rgba(255, 255, 255, 0) 0%,
+ rgba(255, 255, 255, 0.5) 50%,
+ rgba(255, 255, 255, 0) 100%
+ );
+ background-size: 200% 100%;
+ animation: infra-shimmer 1.5s infinite;
+ pointer-events: none;
+}
+
+/* Glass overlay for table during refresh */
+.infra-table-refreshing {
+ position: relative;
+}
+
+.infra-table-refreshing::before {
+ content: '';
+ position: absolute;
+ top: 0;
+ left: 0;
+ right: 0;
+ bottom: 0;
+ background: rgba(255, 255, 255, 0.5);
+ backdrop-filter: blur(1px);
+ z-index: 5;
+ border-radius: inherit;
+ pointer-events: none;
+}
diff --git a/sky/dashboard/src/components/GPUMetricsSection.jsx b/sky/dashboard/src/components/GPUMetricsSection.jsx
new file mode 100644
index 00000000000..28c565d078d
--- /dev/null
+++ b/sky/dashboard/src/components/GPUMetricsSection.jsx
@@ -0,0 +1,212 @@
+import React, { useState } from 'react';
+import {
+ ChevronDownIcon,
+ ChevronRightIcon,
+ ExternalLinkIcon,
+} from 'lucide-react';
+import { CustomTooltip as Tooltip } from '@/components/utils';
+import { getGrafanaUrl, buildGrafanaUrl } from '@/utils/grafana';
+
+// Grafana configuration constants
+const GRAFANA_DASHBOARD_SLUG = 'skypilot-dcgm-gpu/skypilot-dcgm-gpu-metrics';
+const GRAFANA_ORG_ID = '1';
+
+// Time range presets for GPU metrics
+const TIME_RANGE_PRESETS = [
+ { label: '15m', value: '15m' },
+ { label: '1h', value: '1h' },
+ { label: '6h', value: '6h' },
+ { label: '24h', value: '24h' },
+ { label: '7d', value: '7d' },
+];
+
+// GPU panels configuration
+const GPU_PANELS = [
+ { id: '1', title: 'GPU Utilization', keyPrefix: 'gpu-util' },
+ { id: '2', title: 'GPU Memory Utilization', keyPrefix: 'gpu-memory' },
+ { id: '3', title: 'GPU Temperature', keyPrefix: 'gpu-temp' },
+ { id: '4', title: 'GPU Power Usage', keyPrefix: 'gpu-power' },
+];
+
+/**
+ * Build Grafana panel URL with filters
+ */
+const buildGrafanaMetricsUrl = (panelId, clusterNameOnCloud, timeRange) => {
+ const grafanaUrl = getGrafanaUrl();
+ const params = new URLSearchParams({
+ orgId: GRAFANA_ORG_ID,
+ from: timeRange.from,
+ to: timeRange.to,
+ timezone: 'browser',
+ 'var-cluster': clusterNameOnCloud,
+ 'var-node': '$__all',
+ 'var-gpu': '$__all',
+ theme: 'light',
+ panelId: panelId,
+ });
+ return `${grafanaUrl}/d-solo/${GRAFANA_DASHBOARD_SLUG}?${params.toString()}&__feature.dashboardSceneSolo`;
+};
+
+/**
+ * Reusable GPU Metrics Section component
+ *
+ * @param {Object} props
+ * @param {string} props.clusterNameOnCloud - The cluster name for filtering metrics
+ * @param {string} props.displayName - The name to show in the "Showing:" text
+ * @param {number} props.refreshTrigger - Increment to trigger iframe refresh
+ * @param {string} props.storageKey - LocalStorage key for expanded state
+ * @param {React.ReactNode} props.headerExtra - Optional extra content for header (e.g., task selector)
+ * @param {string} props.noMetricsMessage - Custom message when no metrics available
+ */
+export function GPUMetricsSection({
+ clusterNameOnCloud,
+ displayName,
+ refreshTrigger = 0,
+ storageKey = 'skypilot-gpu-metrics-expanded',
+ headerExtra = null,
+ noMetricsMessage = 'No GPU metrics available.',
+}) {
+ const [timeRange, setTimeRange] = useState({ from: 'now-1h', to: 'now' });
+ const [isExpanded, setIsExpanded] = useState(() => {
+ if (typeof window !== 'undefined') {
+ const saved = localStorage.getItem(storageKey);
+ return saved === 'true';
+ }
+ return false;
+ });
+
+ const handleTimeRangePreset = (preset) => {
+ setTimeRange({
+ from: `now-${preset}`,
+ to: 'now',
+ });
+ };
+
+ const toggleExpanded = () => {
+ const newValue = !isExpanded;
+ setIsExpanded(newValue);
+ if (typeof window !== 'undefined') {
+ localStorage.setItem(storageKey, String(newValue));
+ }
+ };
+
+ const openInGrafana = () => {
+ const queryParams = new URLSearchParams({
+ orgId: GRAFANA_ORG_ID,
+ from: timeRange.from,
+ to: timeRange.to,
+ timezone: 'browser',
+ 'var-cluster': clusterNameOnCloud,
+ 'var-node': '$__all',
+ 'var-gpu': '$__all',
+ });
+ window.open(
+ buildGrafanaUrl(`/d/${GRAFANA_DASHBOARD_SLUG}?${queryParams.toString()}`),
+ '_blank'
+ );
+ };
+
+ return (
+
+
+ );
+
+ // Should not show any plugin names
+ expect(container.textContent).not.toContain('HiddenPlugin1');
+ expect(container.textContent).not.toContain('HiddenPlugin2');
+
+ // Should still show commit info
+ expect(container.textContent).toContain('Core commit');
+ expect(container.textContent).toContain('core123');
+ });
+ });
+});
diff --git a/sky/dashboard/src/components/infra.jsx b/sky/dashboard/src/components/infra.jsx
index 022e43bff14..5e728b405dc 100755
--- a/sky/dashboard/src/components/infra.jsx
+++ b/sky/dashboard/src/components/infra.jsx
@@ -7,6 +7,7 @@ import React, { useState, useEffect, useCallback } from 'react';
import { CircularProgress } from '@mui/material';
import { Layout } from '@/components/elements/layout';
import {
+ AlertTriangleIcon,
RotateCwIcon,
SearchIcon,
XIcon,
@@ -151,6 +152,16 @@ const GpuUtilizationBar = ({
);
};
+// Skeleton badge for loading cells - replaces CircularProgress size={12}
+const SkeletonBadge = () => (
+
+
+
+);
+
// Reusable component for infrastructure sections (SSH Node Pool or Kubernetes)
export function InfrastructureSection({
title,
@@ -216,6 +227,16 @@ export function InfrastructureSection({
);
}
+ // Determine if table should show refreshing state
+ // For K8s: show during loading or when contexts haven't all loaded yet
+ // For SSH/Slurm: only show during loading
+ const isTableRefreshing =
+ !isInitialLoad &&
+ (isLoading ||
+ (!(isSlurm || isSSH) &&
+ safeContexts.length > 0 &&
+ !safeContexts.every((c) => loadedContexts.has(c))));
+
// Show table if we have contexts to display, even if some data is still loading
if (safeContexts.length > 0) {
return (
@@ -243,7 +264,11 @@ export function InfrastructureSection({
-
-
- {nodesInContext.map((node, index) => {
- // Format CPU display: "X of Y free" or just "Y" if free is unknown
- let cpuDisplay = '-';
+ {nodesInContext.length > 0 && (
+
+
+
+
+
+ Node
+
+ {!isSlurm && (
+ <>
+
+ IP Address
+
+
+ vCPU
+
+
+ Memory (GB)
+
+ >
+ )}
+
+ GPU
+
+
+ GPU Utilization
+
+
+ Node Status
+
+
+
+
+ {nodesInContext.map((node, index) => {
+ // Format CPU display: "X of Y free" or just "Y" if free is unknown
+ let cpuDisplay = '-';
+ if (
+ node.cpu_count !== null &&
+ node.cpu_count !== undefined
+ ) {
+ const cpuTotal = formatCpu(node.cpu_count);
if (
- node.cpu_count !== null &&
- node.cpu_count !== undefined
+ node.cpu_free !== null &&
+ node.cpu_free !== undefined
) {
- const cpuTotal = formatCpu(node.cpu_count);
- if (
- node.cpu_free !== null &&
- node.cpu_free !== undefined
- ) {
- const cpuFree = formatCpu(node.cpu_free);
- cpuDisplay = `${cpuFree} of ${cpuTotal} free`;
- } else {
- cpuDisplay = cpuTotal;
- }
+ const cpuFree = formatCpu(node.cpu_free);
+ cpuDisplay = `${cpuFree} of ${cpuTotal} free`;
+ } else {
+ cpuDisplay = cpuTotal;
}
-
- // Format memory display: "X of Y free" or just "Y" if free is unknown
- // (GB is in column header, so don't include it in values)
- let memoryDisplay = '-';
+ }
+
+ // Format memory display: "X of Y free" or just "Y" if free is unknown
+ // (GB is in column header, so don't include it in values)
+ let memoryDisplay = '-';
+ if (
+ node.memory_gb !== null &&
+ node.memory_gb !== undefined
+ ) {
+ const memoryTotal = node.memory_gb.toFixed(1);
if (
- node.memory_gb !== null &&
- node.memory_gb !== undefined
+ node.memory_free_gb !== null &&
+ node.memory_free_gb !== undefined
) {
- const memoryTotal = node.memory_gb.toFixed(1);
- if (
- node.memory_free_gb !== null &&
- node.memory_free_gb !== undefined
- ) {
- const memoryFree = node.memory_free_gb.toFixed(1);
- memoryDisplay = `${memoryFree} of ${memoryTotal} free`;
- } else {
- memoryDisplay = memoryTotal;
+ const memoryFree = node.memory_free_gb.toFixed(1);
+ memoryDisplay = `${memoryFree} of ${memoryTotal} free`;
+ } else {
+ memoryDisplay = memoryTotal;
+ }
+ }
+
+ // Build utilization string
+ const utilizationStr = `${node.gpu_free} of ${node.gpu_total} free`;
+
+ // Build node status string
+ const statusInfo = [];
+
+ // Add not ready info
+ if (node.is_ready === false) {
+ statusInfo.push('NotReady');
+ }
+
+ // Add cordoned info
+ if (node.is_cordoned === true) {
+ statusInfo.push('Cordoned');
+ }
+
+ // Build taint info separately
+ const taints = node.taints || [];
+ let taintInfo = null;
+ if (taints.length > 0) {
+ const taintsByEffect = {};
+ for (const taint of taints) {
+ const effect = taint.effect;
+ const key = taint.key;
+ if (!taintsByEffect[effect]) {
+ taintsByEffect[effect] = [];
}
+ taintsByEffect[effect].push(key);
}
+ const taintStrs = Object.entries(taintsByEffect).map(
+ ([effect, keys]) =>
+ `${effect} Taint [${keys.join(', ')}]`
+ );
+ if (taintStrs.length > 0) {
+ taintInfo = taintStrs.join(', ');
+ }
+ }
- const utilizationStr =
- node.is_ready === false
- ? `0 of ${node.gpu_total} free (Node NotReady)`
- : `${node.gpu_free} of ${node.gpu_total} free`;
+ const nodeStatusStr =
+ statusInfo.length > 0 || taintInfo
+ ? statusInfo.join(', ')
+ : 'Healthy';
+ const isNodeHealthy = statusInfo.length === 0 && !taintInfo;
- return (
-
)}
- {/* GPU Metrics Section - only show for k8s contexts, not SSH node pools */}
+ {/* GPU Metrics Section - only show for k8s contexts, not SSH node pools or Slurm */}
{isGrafanaAvailable &&
gpusInContext &&
gpusInContext.length > 0 &&
- !isSSHContext && (
+ !isSSHContext &&
+ !isSlurm && (
<>
GPU Metrics
@@ -2049,42 +2144,52 @@ export function GPUs() {
const gpuDataPromise = forceRefresh
? getContextGPUData(context)
: dashboardCache.get(getContextGPUData, [context]);
- gpuDataPromise.then((gpuData) => {
- // Mark this context as loaded (even if it has no GPUs)
- setLoadedContexts((prev) => new Set([...prev, context]));
-
- // Update perContextGPUs - merge in data for this context
- setPerContextGPUs((prev) => {
- // Remove any existing entries for this context, then add new ones
- const filtered = prev.filter((gpu) => gpu.context !== context);
- return [...filtered, ...gpuData.perContextGPUs];
- });
+ gpuDataPromise
+ .then((gpuData) => {
+ // Mark this context as loaded (even if it has no GPUs)
+ setLoadedContexts((prev) => new Set([...prev, context]));
+
+ // Update perContextGPUs - merge in data for this context
+ setPerContextGPUs((prev) => {
+ // Remove any existing entries for this context, then add new ones
+ const filtered = prev.filter((gpu) => gpu.context !== context);
+ return [...filtered, ...gpuData.perContextGPUs];
+ });
- // Update perNodeGPUs - merge in data for this context
- setPerNodeGPUs((prev) => {
- const filtered = prev.filter((node) => node.context !== context);
- return [...filtered, ...gpuData.perNodeGPUs];
- });
+ // Update perNodeGPUs - merge in data for this context
+ setPerNodeGPUs((prev) => {
+ const filtered = prev.filter((node) => node.context !== context);
+ return [...filtered, ...gpuData.perNodeGPUs];
+ });
- // Note: allGPUs is computed via useEffect when perContextGPUs changes
+ // Note: allGPUs is computed via useEffect when perContextGPUs changes
- // Update context errors if there was an error
- if (gpuData.error) {
+ // Update context errors if there was an error
+ if (gpuData.error) {
+ setContextErrors((prev) => ({
+ ...prev,
+ [context]: gpuData.error,
+ }));
+ }
+ })
+ .catch((error) => {
+ // Mark context as loaded even on error to prevent infinite spinner
+ setLoadedContexts((prev) => new Set([...prev, context]));
setContextErrors((prev) => ({
...prev,
- [context]: gpuData.error,
+ [context]: error.message || 'Failed to load GPU data',
}));
- }
-
- // Decrement pending count and check if ALL fetches are complete
- pendingContextCountRef.current--;
- if (
- pendingContextCountRef.current === 0 &&
- mainFetchDoneRef.current
- ) {
- setIsFetching(false); // Everything done, stop spinner
- }
- });
+ })
+ .finally(() => {
+ // Decrement pending count and check if ALL fetches are complete
+ pendingContextCountRef.current--;
+ if (
+ pendingContextCountRef.current === 0 &&
+ mainFetchDoneRef.current
+ ) {
+ setIsFetching(false); // Everything done, stop spinner
+ }
+ });
});
} catch (error) {
console.error('Error in fetchKubernetesData:', error);
@@ -2320,6 +2425,7 @@ export function GPUs() {
};
initializeData();
+ // eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
// Effect for interval refresh.
@@ -2691,6 +2797,7 @@ export function GPUs() {
gpusInContext={gpusInContext}
nodesInContext={nodesInContext}
gpuMetricsRefreshTrigger={gpuMetricsRefreshTrigger}
+ isSlurm={isSlurmCluster}
/>
);
};
@@ -2732,7 +2839,14 @@ export function GPUs() {
: `No enabled clouds for workspace "${selectedWorkspace}".`}
) : (
-
+
@@ -2747,7 +2861,7 @@ export function GPUs() {
-
+
{filteredCloudInfraData.map((cloud) => {
// Use separate loading states for progressive loading
// Clusters and jobs load independently (clusters often ready first)
@@ -2756,15 +2870,13 @@ export function GPUs() {
const jobCount = cloudJobCounts[cloud.name] ?? cloud.jobs;
return (
-