Skip to content

Support wss:// websocket URLs for api-managed SSH proxying [WIP] #5116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _get_cluster_records_and_set_ssh_config(
# updating skypilot.
f'\'{sys.executable} {sky.__root_dir__}/templates/'
f'websocket_proxy.py '
f'{server_common.get_server_url().split("://")[1]} '
f'{server_common.get_server_url().replace("http://", "ws://").replace("https://", "wss://").split("://")[1]} '
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems after splitting, the URL passed to the executable will not be affected by the replacement. Should we remove the split, and make the replacement in sky/templates/websocket_proxy.py instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean, I left the split there, it shouldn't be there! What the best way to set this up to test locally on MacOS? I tried building the Dockerfile, but it just pulls the latest sky cli nightly with pip...

My thinking was to make this backwards-compatible, so that a user who still has a "non-protocoled" hostname in their config can still use sky/templates/websocket_proxy.py as-is (there's a check for :// in the hostname), but that any new configs created after this PR is merged will have either http:// or https:// in the config, which will better inform sky/templates/websocket_proxy.py re which websockets prefix to use.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! For backward compatibility, I think we can deal with it in the websocket_proxy.py, where we can check if the protocol is included in the URL and default to http if not.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For testing it, a simple test might be mock the is_api_server_local to be False all the time?

@annotations.lru_cache(scope='global')
def is_api_server_local():
return get_server_url() in AVAILABLE_LOCAL_API_SERVER_URLS

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I'm after in terms of advice on testing is more about how to build the CLI locally? I was hoping to do it in a dockerfile to keep dependencies etc clean, but the Dockerfile seems to just pull the latest nightly with pip. Is there a "best practice" way to test this locally?

f'{handle.cluster_name}\'')
credentials['ssh_proxy_command'] = proxy_command
cluster_utils.SSHConfigHelper.add_cluster(
Expand Down Expand Up @@ -1796,7 +1796,6 @@ def status(verbose: bool, refresh: bool, ip: bool, endpoints: bool,
skip_finished=True,
all_users=all_users)
show_endpoints = endpoints or endpoint is not None
show_single_endpoint = endpoint is not None
show_services = show_services and not any([clusters, ip, endpoints])
if show_services:
# Run the sky serve service query in parallel to speed up the
Expand Down Expand Up @@ -1835,7 +1834,7 @@ def status(verbose: bool, refresh: bool, ip: bool, endpoints: bool,
property='IP address' if ip else 'endpoint(s)',
flag='ip' if ip else
('endpoint port'
if show_single_endpoint else 'endpoints')))
if endpoint is not None else 'endpoints')))
else:
click.echo(f'{colorama.Fore.CYAN}{colorama.Style.BRIGHT}Clusters'
f'{colorama.Style.RESET_ALL}')
Expand Down Expand Up @@ -2075,6 +2074,7 @@ def _get_job_queue(cluster):
fg='yellow')


@cli.command()
@cli.command()
@click.option(
'--sync-down',
Expand Down
8 changes: 6 additions & 2 deletions sky/templates/websocket_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ async def websocket_to_stdout(websocket):

if __name__ == '__main__':
server_url = sys.argv[1].strip('/')
websocket_url = (f'ws://{server_url}/kubernetes-pod-ssh-proxy'
f'?cluster_name={sys.argv[2]}')
if '://' in server_url:
websocket_url = (f'{server_url}/kubernetes-pod-ssh-proxy'
f'?cluster_name={sys.argv[2]}')
else:
websocket_url = (f'ws://{server_url}/kubernetes-pod-ssh-proxy'
f'?cluster_name={sys.argv[2]}')
asyncio.run(main(websocket_url))
Loading