diff --git a/src/ssh/azext_ssh/_file_permission_utils.py b/src/ssh/azext_ssh/_file_permission_utils.py new file mode 100644 index 00000000000..a6d9a0623de --- /dev/null +++ b/src/ssh/azext_ssh/_file_permission_utils.py @@ -0,0 +1,32 @@ +import win32security +import ntsecuritycon as con +import platform +import os + +def set_certificate_permissions(path): + if platform.system() != 'Windows': + os.chmod(path, 0o644) + return + + sd = win32security.GetFileSecurity( + path, + win32security.DACL_SECURITY_INFORMATION + ) + + admSid = win32security.LookupAccountName("", "Administrators")[0] + systemSid = win32security.LookupAccountName("", "SYSTEM")[0] + ownerSid = _get_owner_string(path) + + dacl = win32security.ACL() + + dacl.AddAccessAllowedAce(win32security.ACL_REVISION, con.FILE_ALL_ACCESS, admSid) + dacl.AddAccessAllowedAce(win32security.ACL_REVISION, con.FILE_ALL_ACCESS, systemSid) + dacl.AddAccessAllowedAce(win32security.ACL_REVISION, con.FILE_ALL_ACCESS, ownerSid) + + sd.SetSecurityDescriptorDacl(1, dacl, 0) + win32security.SetFileSecurity(path, win32security.DACL_SECURITY_INFORMATION, sd) + + +def _get_owner_string(path): + sd = win32security.GetFileSecurity(path, win32security.OWNER_SECURITY_INFORMATION) + return sd.GetSecurityDescriptorOwner() \ No newline at end of file diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index 34054ebd3fc..b384a645ae4 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -9,7 +9,6 @@ import tempfile import time import platform -import oschmod from knack import log from azure.cli.core import azclierror @@ -26,6 +25,7 @@ from . import constants as const from . import resource_type_utils from . import target_os_utils +from . import _file_permission_utils logger = log.get_logger(__name__) @@ -345,7 +345,7 @@ def _check_or_create_public_private_files(public_key_file, private_key_file, cre def _write_cert_file(certificate_contents, cert_file): with open(cert_file, 'w', encoding='utf-8') as f: f.write(f"ssh-rsa-cert-v01@openssh.com {certificate_contents}") - oschmod.set_mode(cert_file, 0o644) + _file_permission_utils.set_certificate_permissions(cert_file) return cert_file diff --git a/src/ssh/azext_ssh/ssh_info.py b/src/ssh/azext_ssh/ssh_info.py index ba4b986768d..c84487eb562 100644 --- a/src/ssh/azext_ssh/ssh_info.py +++ b/src/ssh/azext_ssh/ssh_info.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------------------------- import os import datetime -import oschmod from azure.cli.core.style import Style, print_styled_text @@ -12,6 +11,7 @@ from knack import log from . import file_utils from . import connectivity_utils +from . import _file_permission_utils logger = log.get_logger(__name__) @@ -192,7 +192,9 @@ def _create_relay_info_file(self): file_utils.delete_file(relay_info_path, f"{relay_info_path} already exists, and couldn't be overwritten.") file_utils.write_to_file(relay_info_path, 'w', connectivity_utils.format_relay_info_string(self.relay_info), f"Couldn't write relay information to file {relay_info_path}.", 'utf-8') - oschmod.set_mode(relay_info_path, 0o644) + #oschmod.set_mode(relay_info_path, 0o644) + _file_permission_utils.set_certificate_permissions(relay_info_path) + # pylint: disable=broad-except try: # pylint: disable=unsubscriptable-object diff --git a/src/ssh/azext_ssh/tests/latest/test_custom.py b/src/ssh/azext_ssh/tests/latest/test_custom.py index 3f1951c5bfb..5529154dece 100644 --- a/src/ssh/azext_ssh/tests/latest/test_custom.py +++ b/src/ssh/azext_ssh/tests/latest/test_custom.py @@ -301,14 +301,14 @@ def test_check_or_create_public_private_files_no_private(self, mock_join, mock_i @mock.patch('builtins.open') - @mock.patch('oschmod.set_mode') + @mock.patch('azext_ssh._file_permission_utils.set_certificate_permissions') def test_write_cert_file(self, mock_mode, mock_open): mock_file = mock.Mock() mock_open.return_value.__enter__.return_value = mock_file custom._write_cert_file("cert", "publickey-aadcert.pub") - mock_mode.assert_called_once_with("publickey-aadcert.pub", 0o644) + mock_mode.assert_called_once_with("publickey-aadcert.pub") mock_open.assert_called_once_with("publickey-aadcert.pub", 'w', encoding='utf-8') mock_file.write.assert_called_once_with("ssh-rsa-cert-v01@openssh.com cert")