Skip to content

[Test] Refine backward compat test on API version change #5276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 72 additions & 60 deletions tests/smoke_tests/smoke_tests_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ class Test(NamedTuple):
timeout: int = DEFAULT_CMD_TIMEOUT
# Environment variables to set for each command.
env: Optional[Dict[str, str]] = None
# Skip this test if the command condition met (returns 0).
skip_if_command_met: Optional[str] = None

def echo(self, message: str):
# pytest's xdist plugin captures stdout; print to stderr so that the
Expand Down Expand Up @@ -391,7 +393,7 @@ def run_one_test(test: Test) -> None:
# Update the environment variable to use the temporary file
env_dict[skypilot_config.ENV_VAR_SKYPILOT_CONFIG] = temp_config.name

for command in test.commands:
def run_one_command(command: str, raise_on_timeout: bool = False):
write(f'+ {command}\n')
flush()
proc = subprocess.Popen(
Expand All @@ -412,17 +414,30 @@ def run_one_test(test: Test) -> None:
flush()
# Kill the current process.
proc.terminate()
if raise_on_timeout:
raise e
proc.returncode = 1 # None if we don't set it.
break
return proc.returncode

if test.skip_if_command_met is not None:
write('Checking the precondition of this test...\n')
flush()
if run_one_command(test.skip_if_command_met,
raise_on_timeout=True) == 0:
test.echo('Skipping this test...')
return

if proc.returncode:
returncode = 0
for command in test.commands:
returncode = run_one_command(command)
if returncode:
break

style = colorama.Style
fore = colorama.Fore
outcome = (f'{fore.RED}Failed{style.RESET_ALL} (returned {proc.returncode})'
if proc.returncode else f'{fore.GREEN}Passed{style.RESET_ALL}')
reason = f'\nReason: {command}' if proc.returncode else ''
outcome = (f'{fore.RED}Failed{style.RESET_ALL} (returned {returncode})'
if returncode else f'{fore.GREEN}Passed{style.RESET_ALL}')
reason = f'\nReason: {command}' if returncode else ''
msg = (f'{outcome}.'
f'{reason}')
if log_to_stdout:
Expand All @@ -432,7 +447,7 @@ def run_one_test(test: Test) -> None:
test.echo(msg)
write(msg)

if (proc.returncode == 0 or
if (returncode == 0 or
pytest.terminate_on_failure) and test.teardown is not None:
subprocess_utils.run(
test.teardown,
Expand All @@ -443,13 +458,62 @@ def run_one_test(test: Test) -> None:
env=env_dict,
)

if proc.returncode:
if returncode:
if log_to_stdout:
raise Exception(f'test failed')
else:
raise Exception(f'test failed: less -r {log_file.name}')


def get_api_server_url() -> str:
"""Get the API server URL in the test environment."""
if 'PYTEST_SKYPILOT_REMOTE_SERVER_TEST' in os.environ:
return docker_utils.get_api_server_endpoint_inside_docker()
return server_common.get_server_url()


def get_dashboard_cluster_status_request_id() -> str:
"""Get the status of the cluster from the dashboard."""
body = payloads.StatusBody(all_users=True,)
response = requests.post(
f'{get_api_server_url()}/internal/dashboard/status',
json=json.loads(body.model_dump_json()))
return server_common.get_request_id(response)


def get_dashboard_jobs_queue_request_id() -> str:
"""Get the jobs queue from the dashboard."""
body = payloads.JobsQueueBody(all_users=True,)
response = requests.post(
f'{get_api_server_url()}/internal/dashboard/jobs/queue',
json=json.loads(body.model_dump_json()))
return server_common.get_request_id(response)


def get_response_from_request_id(request_id: str) -> Any:
"""Waits for and gets the result of a request.
Args:
request_id: The request ID of the request to get.
Returns:
The ``Request Returns`` of the specified request. See the documentation
of the specific requests above for more details.
Raises:
Exception: It raises the same exceptions as the specific requests,
see ``Request Raises`` in the documentation of the specific requests
above.
"""
response = requests.get(
f'{get_api_server_url()}/internal/dashboard/api/get?request_id={request_id}',
timeout=15)
request_task = None
if response.status_code == 200:
request_task = requests_lib.Request.decode(
requests_lib.RequestPayload(**response.json()))
return request_task.get_return_value()
raise RuntimeError(f'Failed to get request {request_id}: '
f'{response.status_code} {response.text}')


def get_aws_region_for_quota_failover() -> Optional[str]:
candidate_regions = AWS.regions_with_offering(instance_type='p3.16xlarge',
accelerators=None,
Expand Down Expand Up @@ -647,55 +711,3 @@ def _context_func(original_cmd: str, factor: float = 2):
finally:
for file in files:
os.unlink(file)


def get_api_server_url() -> str:
"""Get the API server URL in the test environment."""
if 'PYTEST_SKYPILOT_REMOTE_SERVER_TEST' in os.environ:
return docker_utils.get_api_server_endpoint_inside_docker()
return server_common.get_server_url()


def get_dashboard_cluster_status_request_id() -> str:
"""Get the status of the cluster from the dashboard."""
body = payloads.StatusBody(all_users=True,)
response = requests.post(
f'{get_api_server_url()}/internal/dashboard/status',
json=json.loads(body.model_dump_json()))
return server_common.get_request_id(response)


def get_dashboard_jobs_queue_request_id() -> str:
"""Get the jobs queue from the dashboard."""
body = payloads.JobsQueueBody(all_users=True,)
response = requests.post(
f'{get_api_server_url()}/internal/dashboard/jobs/queue',
json=json.loads(body.model_dump_json()))
return server_common.get_request_id(response)


def get_response_from_request_id(request_id: str) -> Any:
"""Waits for and gets the result of a request.

Args:
request_id: The request ID of the request to get.

Returns:
The ``Request Returns`` of the specified request. See the documentation
of the specific requests above for more details.

Raises:
Exception: It raises the same exceptions as the specific requests,
see ``Request Raises`` in the documentation of the specific requests
above.
"""
response = requests.get(
f'{get_api_server_url()}/internal/dashboard/api/get?request_id={request_id}',
timeout=15)
request_task = None
if response.status_code == 200:
request_task = requests_lib.Request.decode(
requests_lib.RequestPayload(**response.json()))
return request_task.get_return_value()
raise RuntimeError(f'Failed to get request {request_id}: '
f'{response.status_code} {response.text}')
26 changes: 17 additions & 9 deletions tests/smoke_tests/test_backward_compat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pathlib
import subprocess
from typing import Sequence
from typing import Optional, Sequence

import pytest
from smoke_tests import smoke_tests_utils
Expand Down Expand Up @@ -109,15 +109,19 @@ def session_setup(self, request):
yield # Optional teardown logic
self._run_cmd(f'{self.ACTIVATE_CURRENT} && sky api stop',)

def run_compatibility_test(self, test_name: str, commands: list,
teardown: str):
def run_compatibility_test(self,
test_name: str,
commands: list,
teardown: str,
skip_if: Optional[str] = None):
"""Helper method to create and run tests with proper cleanup"""
test = smoke_tests_utils.Test(
test_name,
commands,
teardown=teardown,
timeout=self.TEST_TIMEOUT,
env=smoke_tests_utils.LOW_CONTROLLER_RESOURCE_ENV,
skip_if=skip_if,
)
smoke_tests_utils.run_one_test(test)

Expand Down Expand Up @@ -333,13 +337,14 @@ def test_client_server_compatibility(self, generic_cloud: str):
"""Test client server compatibility across versions"""
cluster_name = smoke_tests_utils.get_cluster_name()
job_name = f"{cluster_name}-job"
commands = [
# Check API version compatibility
# If API version is bumped, the in-compatibility is expected
# and we just skip the test.
skip_if = (
Copy link
Collaborator

@zpoint zpoint Apr 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if we actually need to introduce the skip_if parameter. It's not a straightforward boolean parameter, and you'd need to read the source code to understand how skip_if behaves.

An alternative could be

# pesudo code
BASE_SERVER_VERSION = subprocess.popen('self.ACTIVATE_BASE && python -c "print(sky.__server_version__)"')
CURRENT_CLIENT_VERSION=subprocess.popen('{self.ACTIVATE_CURRENT} && python -c "print(sky.__client_version__)"'')
if mismatch(BASE_SERVER_VERSION, CURRENT_CLIENT_VERSION):
    pytest.skip()

...
self.run_compatibility_test(cluster_name,
                                    commands,
                                    teardown)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about make skip_if a callable that returns boolean? I was intended to use a sky command to check compatibility so that the implementation details of compatibility can be kept in a blackbox for smoke test

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that sounds better that way.

f'{self.ACTIVATE_BASE} && {self.SKY_API_RESTART} && '
f'{self.ACTIVATE_CURRENT} && result="$(sky status 2>&1)" || true; '
'if echo "$result" | grep -q "SkyPilot API server is too old"; then '
' echo "$result" && exit 1; '
'fi',
'echo "$result" | grep -q -e "version mismatch" -e "too old" && '
'echo "API version bumped, skip compatibility test"')
commands = [
# managed job test
f'{self.ACTIVATE_BASE} && {self.SKY_API_RESTART} && '
f'sky jobs launch -d --cloud {generic_cloud} -y {smoke_tests_utils.LOW_RESOURCE_ARG} -n {job_name} "echo hello world; sleep 60"',
Expand Down Expand Up @@ -373,4 +378,7 @@ def test_client_server_compatibility(self, generic_cloud: str):
]

teardown = f'{self.ACTIVATE_BASE} && sky down {cluster_name} -y && sky serve down {cluster_name}* -y'
self.run_compatibility_test(cluster_name, commands, teardown)
self.run_compatibility_test(cluster_name,
commands,
teardown,
skip_if=skip_if)