Skip to content

Commit 04a692a

Browse files
[Extension] Support automatically installing an extension if the extension of a command is not installed (#14478)
* dynamic extension install poc * fix cli_ctx * continue run with subprocess * add config * donwload remote index * check close matches first * change config name * add option to turn off dynamic install * add option to turn off dynamic install * fix when cli_ctx is None * add no prompt msg * modify message * default to no * refactor error msg * fix style * resolve UX comments * make changes backward compatible * add test * fix shell * fix style
1 parent cec9b7d commit 04a692a

8 files changed

Lines changed: 207 additions & 53 deletions

File tree

src/azure-cli-core/azure/cli/core/__init__.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -373,18 +373,6 @@ def _get_extension_suppressions(mod_loaders):
373373
res.append(sup)
374374
return res
375375

376-
def _roughly_parse_command(args):
377-
# Roughly parse the command part: <az vm create> --name vm1
378-
# Similar to knack.invocation.CommandInvoker._rudimentary_get_command, but we don't need to bother with
379-
# positional args
380-
nouns = []
381-
for arg in args:
382-
if arg and arg[0] != '-':
383-
nouns.append(arg)
384-
else:
385-
break
386-
return ' '.join(nouns).lower()
387-
388376
# Clear the tables to make this method idempotent
389377
self.command_group_table.clear()
390378
self.command_table.clear()
@@ -404,8 +392,9 @@ def _roughly_parse_command(args):
404392
_update_command_table_from_extensions([], index_extensions)
405393

406394
logger.debug("Loaded %d groups, %d commands.", len(self.command_group_table), len(self.command_table))
395+
from azure.cli.core.util import roughly_parse_command
407396
# The index may be outdated. Make sure the command appears in the loaded command table
408-
command_str = _roughly_parse_command(args)
397+
command_str = roughly_parse_command(args)
409398
if command_str in self.command_table:
410399
logger.debug("Found a match in the command table for '%s'", command_str)
411400
return self.command_table

src/azure-cli-core/azure/cli/core/_session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,6 @@ def __len__(self):
113113
# it could be lagged behind and can be used to check whether
114114
# an upgrade of azure-cli happens
115115
VERSIONS = Session()
116+
117+
# EXT_CMD_TREE provides command to extension name mapping
118+
EXT_CMD_TREE = Session()

src/azure-cli-core/azure/cli/core/extension/operations.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def _validate_whl_extension(ext_file):
8585
check_version_compatibility(azext_metadata)
8686

8787

88-
def _add_whl_ext(cmd, source, ext_sha256=None, pip_extra_index_urls=None, pip_proxy=None, system=None): # pylint: disable=too-many-statements
89-
cmd.cli_ctx.get_progress_controller().add(message='Analyzing')
88+
def _add_whl_ext(cli_ctx, source, ext_sha256=None, pip_extra_index_urls=None, pip_proxy=None, system=None): # pylint: disable=too-many-statements
89+
cli_ctx.get_progress_controller().add(message='Analyzing')
9090
if not source.endswith('.whl'):
9191
raise ValueError('Unknown extension type. Only Python wheels are supported.')
9292
url_parse_result = urlparse(source)
@@ -108,7 +108,7 @@ def _add_whl_ext(cmd, source, ext_sha256=None, pip_extra_index_urls=None, pip_pr
108108
logger.debug('Downloading %s to %s', source, ext_file)
109109
import requests
110110
try:
111-
cmd.cli_ctx.get_progress_controller().add(message='Downloading')
111+
cli_ctx.get_progress_controller().add(message='Downloading')
112112
_whl_download_from_url(url_parse_result, ext_file)
113113
except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError) as err:
114114
raise CLIError('Please ensure you have network connection. Error detail: {}'.format(str(err)))
@@ -130,7 +130,7 @@ def _add_whl_ext(cmd, source, ext_sha256=None, pip_extra_index_urls=None, pip_pr
130130
raise CLIError("The checksum of the extension does not match the expected value. "
131131
"Use --debug for more information.")
132132
try:
133-
cmd.cli_ctx.get_progress_controller().add(message='Validating')
133+
cli_ctx.get_progress_controller().add(message='Validating')
134134
_validate_whl_extension(ext_file)
135135
except AssertionError:
136136
logger.debug(traceback.format_exc())
@@ -140,7 +140,7 @@ def _add_whl_ext(cmd, source, ext_sha256=None, pip_extra_index_urls=None, pip_pr
140140
logger.debug('Validation successful on %s', ext_file)
141141
# Check for distro consistency
142142
check_distro_consistency()
143-
cmd.cli_ctx.get_progress_controller().add(message='Installing')
143+
cli_ctx.get_progress_controller().add(message='Installing')
144144
# Install with pip
145145
extension_path = build_extension_path(extension_name, system)
146146
pip_args = ['install', '--target', extension_path, ext_file]
@@ -206,15 +206,15 @@ def check_version_compatibility(azext_metadata):
206206
raise CLIError(min_max_msg_fmt)
207207

208208

209-
def add_extension(cmd, source=None, extension_name=None, index_url=None, yes=None, # pylint: disable=unused-argument
209+
def add_extension(cmd=None, source=None, extension_name=None, index_url=None, yes=None, # pylint: disable=unused-argument
210210
pip_extra_index_urls=None, pip_proxy=None, system=None,
211-
version=None):
211+
version=None, cli_ctx=None):
212212
ext_sha256 = None
213213

214214
version = None if version == 'latest' else version
215-
215+
cmd_cli_ctx = cli_ctx or cmd.cli_ctx
216216
if extension_name:
217-
cmd.cli_ctx.get_progress_controller().add(message='Searching')
217+
cmd_cli_ctx.get_progress_controller().add(message='Searching')
218218
ext = None
219219
try:
220220
ext = get_extension(extension_name)
@@ -236,7 +236,7 @@ def add_extension(cmd, source=None, extension_name=None, index_url=None, yes=Non
236236
err = "No matching extensions for '{}'. Use --debug for more information.".format(extension_name)
237237
raise CLIError(err)
238238

239-
extension_name = _add_whl_ext(cmd=cmd, source=source, ext_sha256=ext_sha256,
239+
extension_name = _add_whl_ext(cli_ctx=cmd_cli_ctx, source=source, ext_sha256=ext_sha256,
240240
pip_extra_index_urls=pip_extra_index_urls, pip_proxy=pip_proxy, system=system)
241241
try:
242242
ext = get_extension(extension_name)
@@ -289,8 +289,9 @@ def show_extension(extension_name):
289289
raise CLIError(e)
290290

291291

292-
def update_extension(cmd, extension_name, index_url=None, pip_extra_index_urls=None, pip_proxy=None):
292+
def update_extension(cmd=None, extension_name=None, index_url=None, pip_extra_index_urls=None, pip_proxy=None, cli_ctx=None):
293293
try:
294+
cmd_cli_ctx = cli_ctx or cmd.cli_ctx
294295
ext = get_extension(extension_name, ext_type=WheelExtension)
295296
cur_version = ext.get_version()
296297
try:
@@ -307,7 +308,7 @@ def update_extension(cmd, extension_name, index_url=None, pip_extra_index_urls=N
307308
shutil.rmtree(extension_path)
308309
# Install newer version
309310
try:
310-
_add_whl_ext(cmd=cmd, source=download_url, ext_sha256=ext_sha256,
311+
_add_whl_ext(cli_ctx=cmd_cli_ctx, source=download_url, ext_sha256=ext_sha256,
311312
pip_extra_index_urls=pip_extra_index_urls, pip_proxy=pip_proxy)
312313
logger.debug('Deleting backup of old extension at %s', backup_dir)
313314
shutil.rmtree(backup_dir)

src/azure-cli-core/azure/cli/core/parser.py

Lines changed: 152 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -280,37 +280,167 @@ def parse_known_args(self, args=None, namespace=None):
280280
self._namespace, self._raw_arguments = super().parse_known_args(args=args, namespace=namespace)
281281
return self._namespace, self._raw_arguments
282282

283-
def _check_value(self, action, value):
283+
def _get_extension_command_tree(self):
284+
from azure.cli.core._session import EXT_CMD_TREE
285+
import os
286+
VALID_SECOND = 3600 * 24 * 10
287+
# self.cli_ctx is None when self.prog is beyond 'az', such as 'az iot'.
288+
# use cli_ctx from cli_help which is not lost.
289+
cli_ctx = self.cli_ctx or (self.cli_help.cli_ctx if self.cli_help else None)
290+
if not cli_ctx:
291+
return None
292+
EXT_CMD_TREE.load(os.path.join(cli_ctx.config.config_dir, 'extensionCommandTree.json'), VALID_SECOND)
293+
if not EXT_CMD_TREE.data:
294+
import requests
295+
from azure.cli.core.util import should_disable_connection_verify
296+
try:
297+
response = requests.get(
298+
'https://azurecliextensionsync.blob.core.windows.net/cmd-index/extensionCommandTree.json',
299+
verify=(not should_disable_connection_verify()),
300+
timeout=300)
301+
except Exception as ex: # pylint: disable=broad-except
302+
logger.info("Request failed for extension command tree: %s", str(ex))
303+
return None
304+
if response.status_code == 200:
305+
EXT_CMD_TREE.data = response.json()
306+
EXT_CMD_TREE.save_with_retry()
307+
else:
308+
logger.info("Error when retrieving extension command tree. Response code: %s", response.status_code)
309+
return None
310+
return EXT_CMD_TREE
311+
312+
def _search_in_extension_commands(self, command_str):
313+
"""Search the command in an extension commands dict which mimics a prefix tree.
314+
If the value of the dict item is a string, then the key represents the end of a complete command
315+
and the value is the name of the extension that the command belongs to.
316+
An example of the dict read from extensionCommandTree.json:
317+
{
318+
"aks": {
319+
"create": "aks-preview",
320+
"update": "aks-preview",
321+
"app": {
322+
"up": "deploy-to-azure"
323+
},
324+
"use-dev-spaces": "dev-spaces"
325+
},
326+
...
327+
}
328+
"""
329+
330+
cmd_chain = self._get_extension_command_tree()
331+
for part in command_str.split():
332+
try:
333+
if isinstance(cmd_chain[part], str):
334+
return cmd_chain[part]
335+
cmd_chain = cmd_chain[part]
336+
except KeyError:
337+
return None
338+
return None
339+
340+
def _get_extension_use_dynamic_install_config(self):
341+
cli_ctx = self.cli_ctx or (self.cli_help.cli_ctx if self.cli_help else None)
342+
use_dynamic_install = cli_ctx.config.get(
343+
'extension', 'use_dynamic_install', 'no').lower() if cli_ctx else 'no'
344+
if use_dynamic_install not in ['no', 'yes_prompt', 'yes_without_prompt']:
345+
use_dynamic_install = 'no'
346+
return use_dynamic_install
347+
348+
def _check_value(self, action, value): # pylint: disable=too-many-statements, too-many-locals
284349
# Override to customize the error message when a argument is not among the available choices
285350
# converted value must be one of the choices (if specified)
286-
if action.choices is not None and value not in action.choices:
351+
if action.choices is not None and value not in action.choices: # pylint: disable=too-many-nested-blocks
352+
caused_by_extension_not_installed = False
287353
if not self.command_source:
288-
# parser has no `command_source`, value is part of command itself
289-
extensions_link = 'https://docs.microsoft.com/en-us/cli/azure/azure-cli-extensions-overview'
290-
error_msg = ("{prog}: '{value}' is not in the '{prog}' command group. See '{prog} --help'. "
291-
"If the command is from an extension, "
292-
"please make sure the corresponding extension is installed. "
293-
"To learn more about extensions, please visit "
294-
"{extensions_link}").format(prog=self.prog, value=value, extensions_link=extensions_link)
354+
candidates = difflib.get_close_matches(value, action.choices, cutoff=0.7)
355+
error_msg = None
356+
# self.cli_ctx is None when self.prog is beyond 'az', such as 'az iot'.
357+
# use cli_ctx from cli_help which is not lost.
358+
cli_ctx = self.cli_ctx or (self.cli_help.cli_ctx if self.cli_help else None)
359+
use_dynamic_install = self._get_extension_use_dynamic_install_config()
360+
if use_dynamic_install != 'no' and not candidates:
361+
# Check if the command is from an extension
362+
from azure.cli.core.util import roughly_parse_command
363+
cmd_list = self.prog.split() + self._raw_arguments
364+
command_str = roughly_parse_command(cmd_list[1:])
365+
ext_name = self._search_in_extension_commands(command_str)
366+
if ext_name:
367+
caused_by_extension_not_installed = True
368+
telemetry.set_command_details(command_str,
369+
parameters=AzCliCommandInvoker._extract_parameter_names(cmd_list), # pylint: disable=protected-access
370+
extension_name=ext_name)
371+
run_after_extension_installed = cli_ctx.config.getboolean('extension',
372+
'run_after_dynamic_install',
373+
False) if cli_ctx else False
374+
if use_dynamic_install == 'yes_without_prompt':
375+
logger.warning('The command requires the extension %s. '
376+
'It will be installed first.', ext_name)
377+
go_on = True
378+
else:
379+
from knack.prompting import prompt_y_n, NoTTYException
380+
prompt_msg = 'The command requires the extension {}. ' \
381+
'Do you want to install it now?'.format(ext_name)
382+
if run_after_extension_installed:
383+
prompt_msg = '{} The command will continue to run after the extension is installed.' \
384+
.format(prompt_msg)
385+
NO_PROMPT_CONFIG_MSG = "Run 'az config set extension.use_dynamic_install=" \
386+
"yes_without_prompt' to allow installing extensions without prompt."
387+
try:
388+
go_on = prompt_y_n(prompt_msg, default='y')
389+
if go_on:
390+
logger.warning(NO_PROMPT_CONFIG_MSG)
391+
except NoTTYException:
392+
logger.warning("The command requires the extension %s.\n "
393+
"Unable to prompt for extension install confirmation as no tty "
394+
"available. %s", ext_name, NO_PROMPT_CONFIG_MSG)
395+
go_on = False
396+
if go_on:
397+
from azure.cli.core.extension.operations import add_extension
398+
add_extension(cli_ctx=cli_ctx, extension_name=ext_name)
399+
if run_after_extension_installed:
400+
import subprocess
401+
import platform
402+
exit_code = subprocess.call(cmd_list, shell=platform.system() == 'Windows')
403+
telemetry.set_user_fault("Extension {} dynamically installed and commands will be "
404+
"rerun automatically.".format(ext_name))
405+
self.exit(exit_code)
406+
else:
407+
error_msg = 'Extension {} installed. Please rerun your command.'.format(ext_name)
408+
else:
409+
error_msg = "The command requires the extension {ext_name}. " \
410+
"To install, run 'az extension add -n {ext_name}'.".format(ext_name=ext_name)
411+
if not error_msg:
412+
# parser has no `command_source`, value is part of command itself
413+
error_msg = "{prog}: '{value}' is not in the '{prog}' command group. See '{prog} --help'." \
414+
.format(prog=self.prog, value=value)
415+
if use_dynamic_install.lower() == 'no':
416+
extensions_link = 'https://docs.microsoft.com/en-us/cli/azure/azure-cli-extensions-overview'
417+
error_msg = ("{msg} "
418+
"If the command is from an extension, "
419+
"please make sure the corresponding extension is installed. "
420+
"To learn more about extensions, please visit "
421+
"{extensions_link}").format(msg=error_msg, extensions_link=extensions_link)
295422
else:
296423
# `command_source` indicates command values have been parsed, value is an argument
297424
parameter = action.option_strings[0] if action.option_strings else action.dest
298425
error_msg = "{prog}: '{value}' is not a valid value for '{param}'. See '{prog} --help'.".format(
299426
prog=self.prog, value=value, param=parameter)
427+
candidates = difflib.get_close_matches(value, action.choices, cutoff=0.7)
428+
300429
telemetry.set_user_fault(error_msg)
301430
with CommandLoggerContext(logger):
302431
logger.error(error_msg)
303-
candidates = difflib.get_close_matches(value, action.choices, cutoff=0.7)
304-
if candidates:
305-
print_args = {
306-
's': 's' if len(candidates) > 1 else '',
307-
'verb': 'are' if len(candidates) > 1 else 'is',
308-
'value': value
309-
}
310-
self._suggestion_msg.append("\nThe most similar choice{s} to '{value}' {verb}:".format(**print_args))
311-
self._suggestion_msg.append('\n'.join(['\t' + candidate for candidate in candidates]))
312-
313-
failure_recovery_recommendations = self._get_failure_recovery_recommendations(action)
314-
self._suggestion_msg.extend(failure_recovery_recommendations)
315-
self._print_suggestion_msg(sys.stderr)
432+
if not caused_by_extension_not_installed:
433+
if candidates:
434+
print_args = {
435+
's': 's' if len(candidates) > 1 else '',
436+
'verb': 'are' if len(candidates) > 1 else 'is',
437+
'value': value
438+
}
439+
self._suggestion_msg.append("\nThe most similar choice{s} to '{value}' {verb}:"
440+
.format(**print_args))
441+
self._suggestion_msg.append('\n'.join(['\t' + candidate for candidate in candidates]))
442+
443+
failure_recovery_recommendations = self._get_failure_recovery_recommendations(action)
444+
self._suggestion_msg.extend(failure_recovery_recommendations)
445+
self._print_suggestion_msg(sys.stderr)
316446
self.exit(2)

0 commit comments

Comments
 (0)