Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion sky/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,23 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:


# ---------------------------------- RunPod ---------------------------------- #
def _runpod_key_label() -> str:
"""Per-user identifier to stamp as the registered SSH key's comment.

RunPod derives a key's display name from its comment, reading only the
first whitespace-separated token after the key material, so the label must
contain no spaces.
"""
user_hash = common_utils.get_user_hash()
try:
username = common_utils.get_cleaned_username()
except Exception: # pylint: disable=broad-except
username = ''
if username:
return f'skypilot-{username}-{user_hash[:8]}'
return f'skypilot-{user_hash[:8]}'


def setup_runpod_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
"""Sets up SSH authentication for RunPod.
- Generates a new SSH key pair if one does not exist.
Expand All @@ -347,7 +364,10 @@ def setup_runpod_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
_, public_key_path = auth_utils.get_or_generate_keys()
with open(public_key_path, 'r', encoding='UTF-8') as pub_key_file:
public_key = pub_key_file.read().strip()
runpod.runpod.cli.groups.ssh.functions.add_ssh_key(public_key)
# Add a label to the public key so that it can be identified in the RunPod
# dashboard.
labeled_key = ' '.join(public_key.split()[:2] + [_runpod_key_label()])
runpod.runpod.cli.groups.ssh.functions.add_ssh_key(labeled_key)

return configure_ssh_info(config)

Expand Down
62 changes: 62 additions & 0 deletions tests/unit_tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import os
from unittest.mock import MagicMock
from unittest.mock import patch

from google.auth import exceptions as google_exceptions
Expand Down Expand Up @@ -110,3 +111,64 @@ def test_gcp_project_metadata_parsing_malformed():
# and should return 'False' (default)
result = auth.parse_gcp_project_oslogin(malformed_project)
assert result == 'False'


def test_runpod_key_label_with_username():
with patch('sky.utils.common_utils.get_user_hash',
return_value='abcdef1234567890'), \
patch('sky.utils.common_utils.get_cleaned_username',
return_value='alice'):
assert auth._runpod_key_label() == 'skypilot-alice-abcdef12'


def test_runpod_key_label_falls_back_when_username_unresolved():
with patch('sky.utils.common_utils.get_user_hash',
return_value='abcdef1234567890'), \
patch('sky.utils.common_utils.get_cleaned_username',
side_effect=RuntimeError('no current user')):
assert auth._runpod_key_label() == 'skypilot-abcdef12'


def test_runpod_key_label_falls_back_when_username_empty():
with patch('sky.utils.common_utils.get_user_hash',
return_value='abcdef1234567890'), \
patch('sky.utils.common_utils.get_cleaned_username', return_value=''):
assert auth._runpod_key_label() == 'skypilot-abcdef12'


def _register_runpod_key(tmp_path, key_content):
"""Runs setup_runpod_authentication and returns the key string that would
be registered with RunPod."""
pub_key_path = tmp_path / 'sky-key.pub'
pub_key_path.write_text(key_content)
mock_runpod = MagicMock()
with patch('sky.authentication.auth_utils.get_or_generate_keys',
return_value=('priv', str(pub_key_path))), \
patch('sky.authentication.runpod', mock_runpod), \
patch('sky.authentication.configure_ssh_info',
side_effect=lambda config: config), \
patch('sky.utils.common_utils.get_user_hash',
return_value='abcdef1234567890'), \
patch('sky.utils.common_utils.get_cleaned_username',
return_value='alice'):
auth.setup_runpod_authentication({'auth': {}})
add_ssh_key = mock_runpod.runpod.cli.groups.ssh.functions.add_ssh_key
add_ssh_key.assert_called_once()
return add_ssh_key.call_args[0][0]


def test_setup_runpod_authentication_labels_bare_key(tmp_path):
registered = _register_runpod_key(tmp_path, 'ssh-rsa AAAAB3NzaC1yc2E\n')
parts = registered.split(' ')
assert parts[:2] == ['ssh-rsa', 'AAAAB3NzaC1yc2E'] # material untouched
assert parts[2] == 'skypilot-alice-abcdef12'
assert len(parts) == 3 # single-token comment for RunPod's parser


def test_setup_runpod_authentication_replaces_existing_comment(tmp_path):
registered = _register_runpod_key(
tmp_path, 'ssh-rsa AAAAB3NzaC1yc2E old@host trailing words\n')
parts = registered.split(' ')
assert parts[:2] == ['ssh-rsa', 'AAAAB3NzaC1yc2E']
assert parts[2] == 'skypilot-alice-abcdef12'
assert len(parts) == 3
Loading