Skip to content
116 changes: 88 additions & 28 deletions src/azure-cli/azure/cli/command_modules/vm/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -1419,12 +1419,12 @@ def get_vm_to_update_by_aaz(cmd, resource_group_name, vm_name):
from .operations.vm import VMShow

vm = VMShow(cli_ctx=cmd.cli_ctx)(command_args={
'resource_group': resource_group_name,
"resource_group": resource_group_name,
"vm_name": vm_name
})

# To avoid unnecessary permission check of image
storage_profile = vm.get('storageProfile', {})
storage_profile = vm.get("storageProfile", {})
storage_profile["imageReference"] = None

return vm
Expand Down Expand Up @@ -1739,6 +1739,40 @@ def set_vm(cmd, instance, lro_operation=None, no_wait=False):
return LongRunningOperation(cmd.cli_ctx)(poller)


# Notes: vm format is in snake_case
def set_vm_by_aaz(cmd, vm, no_wait=False):
from .aaz.latest.vm import Create as _VMCreate

parsed_id = _parse_rg_name(vm["id"])
vm["resource_group"] = parsed_id[0]
vm["vm_name"] = parsed_id[1]
vm["no_wait"] = no_wait

class SetVM(_VMCreate):
def _output(self, *args, **kwargs):
from azure.cli.core.aaz import AAZUndefined, has_value

# Resolve flatten conflict
# When the type field conflicts, the type in inner layer is ignored and the outer layer is applied
if has_value(self.ctx.vars.instance.resources):
for resource in self.ctx.vars.instance.resources:
if has_value(resource.type):
resource.type = AAZUndefined

result = self.deserialize_output(self.ctx.vars.instance, client_flatten=True)
if result.get('osProfile', {}).get('secrets', []):
for secret in result['osProfile']['secrets']:
for cert in secret.get('vaultCertificates', []):
if not cert.get('certificateStore'):
cert['certificateStore'] = None
return result

vm = LongRunningOperation(cmd.cli_ctx)(
SetVM(cli_ctx=cmd.cli_ctx)(command_args=vm))

return vm


def patch_vm(cmd, resource_group_name, vm_name, vm):
client = _compute_client_factory(cmd.cli_ctx)
poller = client.virtual_machines.begin_update(resource_group_name, vm_name, vm)
Expand Down Expand Up @@ -3288,51 +3322,75 @@ def get_vm_format_secret(cmd, secrets, certificate_store=None, keyvault=None, re
def add_vm_secret(cmd, resource_group_name, vm_name, keyvault, certificate, certificate_store=None):
from azure.mgmt.core.tools import parse_resource_id
from ._vm_utils import create_data_plane_keyvault_certificate_client, get_key_vault_base_url
VaultSecretGroup, SubResource, VaultCertificate = cmd.get_models(
'VaultSecretGroup', 'SubResource', 'VaultCertificate')
vm = get_vm_to_update(cmd, resource_group_name, vm_name)
from .operations.vm import convert_show_result_to_snake_case
vm = get_vm_to_update_by_aaz(cmd, resource_group_name, vm_name)
vm = convert_show_result_to_snake_case(vm)

if '://' not in certificate: # has a cert name rather a full url?
keyvault_client = create_data_plane_keyvault_certificate_client(
cmd.cli_ctx, get_key_vault_base_url(cmd.cli_ctx, parse_resource_id(keyvault)['name']))
cert_info = keyvault_client.get_certificate(certificate)
certificate = cert_info.secret_id

if not _is_linux_os(vm):
if not _is_linux_os_by_aaz(vm):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the _is_linux_os_by_aaz function is based on the snake case, so we might need to convert the vm to snake case first

certificate_store = certificate_store or 'My'
elif certificate_store:
raise CLIError('Usage error: --certificate-store is only applicable on Windows VM')
vault_cert = VaultCertificate(certificate_url=certificate, certificate_store=certificate_store)
vault_secret_group = next((x for x in vm.os_profile.secrets
if x.source_vault and x.source_vault.id.lower() == keyvault.lower()), None)
vault_cert = {
'certificate_store': certificate_store,
'certificate_url': certificate
}
vault_secret_group = next((x for x in vm.get('os_profile', {}).get('secrets', [])
if x.get('source_vault', {}).get('id', '').lower() == keyvault.lower()), None)
if vault_secret_group:
vault_secret_group.vault_certificates.append(vault_cert)
certs = vault_secret_group.get('vault_certificates', [])
certs.append(vault_cert)
vault_secret_group['vault_certificates'] = certs
else:
vault_secret_group = VaultSecretGroup(source_vault=SubResource(id=keyvault), vault_certificates=[vault_cert])
vm.os_profile.secrets.append(vault_secret_group)
vm = set_vm(cmd, vm)
return vm.os_profile.secrets
vault_secret_group = {
'source_vault': {
'id': keyvault
},
'vault_certificates': [vault_cert]
}

if not vm.get('os_profile'):
vm['os_profile'] = {'secret': []}

if not vm.get('os_profile').get('secrets'):
vm['os_profile']['secrets'] = []

vm['os_profile']['secrets'].append(vault_secret_group)

vm = set_vm_by_aaz(cmd, vm)
return vm.get('osProfile', {}).get('secrets', [])


def list_vm_secrets(cmd, resource_group_name, vm_name):
vm = get_vm(cmd, resource_group_name, vm_name)
if vm.os_profile:
return vm.os_profile.secrets
return []
vm = get_vm_by_aaz(cmd, resource_group_name, vm_name)

if vm.get('osProfile', {}).get('secrets', []):
for secret in vm['osProfile']['secrets']:
for cert in secret.get('vaultCertificates', []):
if not cert.get('certificateStore'):
cert['certificateStore'] = None

return vm.get('osProfile', {}).get('secrets', [])


def remove_vm_secret(cmd, resource_group_name, vm_name, keyvault, certificate=None):
vm = get_vm_to_update(cmd, resource_group_name, vm_name)
from .operations.vm import convert_show_result_to_snake_case
vm = get_vm_to_update_by_aaz(cmd, resource_group_name, vm_name)

# support 2 kinds of filter:
# a. if only keyvault is supplied, we delete its whole vault group.
# b. if both keyvault and certificate are supplied, we only delete the specific cert entry.

to_keep = vm.os_profile.secrets
to_keep = vm.get('osProfile', {}).get('secrets', [])
keyvault_matched = []
if keyvault:
keyvault = keyvault.lower()
keyvault_matched = [x for x in to_keep if x.source_vault and x.source_vault.id.lower() == keyvault]
keyvault_matched = [x for x in to_keep if x.get('sourceVault', {}).get('id', '').lower() == keyvault]

if keyvault and not certificate:
to_keep = [x for x in to_keep if x not in keyvault_matched]
Expand All @@ -3342,13 +3400,15 @@ def remove_vm_secret(cmd, resource_group_name, vm_name, keyvault, certificate=No
if '://' not in cert_url_pattern: # just a cert name?
cert_url_pattern = '/' + cert_url_pattern + '/'
for x in temp:
x.vault_certificates = ([v for v in x.vault_certificates
if not (v.certificate_url and cert_url_pattern in v.certificate_url.lower())])
to_keep = [x for x in to_keep if x.vault_certificates] # purge all groups w/o any cert entries

vm.os_profile.secrets = to_keep
vm = set_vm(cmd, vm)
return vm.os_profile.secrets
x['vaultCertificates'] = [v for v in x.get('vaultCertificates')
if not (v.get('certificateUrl') and
cert_url_pattern in v.get('certificateUrl', '').lower())]
to_keep = [x for x in to_keep if x.get('vaultCertificates')] # purge all groups w/o any cert entries

vm['osProfile']['secrets'] = to_keep
vm = convert_show_result_to_snake_case(vm)
vm = set_vm_by_aaz(cmd, vm)
return vm.get('osProfile', {}).get('secrets', [])
# endregion


Expand Down
2 changes: 2 additions & 0 deletions src/azure-cli/azure/cli/command_modules/vm/operations/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ def __call__(self, *args, **kwargs):

def convert_show_result_to_snake_case(result):
new_result = {}
if "id" in result:
new_result["id"] = result["id"]
if "extendedLocation" in result:
new_result["extended_location"] = result["extendedLocation"]
if "identity" in result:
Expand Down
Loading