diff --git a/src/crewai/cli/provider.py b/src/crewai/cli/provider.py index 529ca5e26c..f98c40c9ef 100644 --- a/src/crewai/cli/provider.py +++ b/src/crewai/cli/provider.py @@ -1,4 +1,6 @@ import json +import logging +import os import time from collections import defaultdict from pathlib import Path @@ -8,6 +10,8 @@ from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS +logger = logging.getLogger(__name__) + def select_choice(prompt_message, choices): """ @@ -157,22 +161,74 @@ def fetch_provider_data(cache_file): """ Fetches provider data from a specified URL and caches it to a file. + Warning: This function includes a fallback that disables SSL verification. + This should only be used in development environments or when absolutely necessary. + Production deployments should resolve SSL certificate issues properly. + Args: - cache_file (Path): The path to the cache file. Returns: - dict or None: The fetched provider data or None if the operation fails. """ + allow_insecure = os.getenv("CREW_ALLOW_INSECURE_SSL", "false").lower() == "true" + try: - response = requests.get(JSON_URL, stream=True, timeout=60) + verify = not allow_insecure + if not verify: + logger.warning( + "SSL verification disabled via CREW_ALLOW_INSECURE_SSL environment variable. " + "This is less secure and should only be used in development environments." + ) + click.secho( + "SSL verification disabled via environment variable. " + "This is less secure and should only be used in development environments.", + fg="yellow", + ) + + response = requests.get(JSON_URL, stream=True, timeout=60, verify=verify) response.raise_for_status() data = download_data(response) with open(cache_file, "w") as f: json.dump(data, f) return data + except requests.exceptions.SSLError: + if not allow_insecure: + logger.warning( + "SSL certificate verification failed. Retrying with verification disabled. " + "This is less secure but may be necessary on some systems." + ) + click.secho( + "SSL certificate verification failed. Retrying with verification disabled. " + "This is less secure but may be necessary on some systems.", + fg="yellow", + ) + try: + os.environ["CREW_TEMP_ALLOW_INSECURE"] = "true" + response = requests.get( + JSON_URL, + stream=True, + timeout=60, + verify=False, # nosec B501 + ) + os.environ.pop("CREW_TEMP_ALLOW_INSECURE", None) + + response.raise_for_status() + data = download_data(response) + with open(cache_file, "w") as f: + json.dump(data, f) + return data + except requests.RequestException as e: + logger.error(f"Error fetching provider data: {e}") + click.secho(f"Error fetching provider data: {e}", fg="red") + return None + finally: + os.environ.pop("CREW_TEMP_ALLOW_INSECURE", None) except requests.RequestException as e: + logger.error(f"Error fetching provider data: {e}") click.secho(f"Error fetching provider data: {e}", fg="red") except json.JSONDecodeError: + logger.error("Error parsing provider data. Invalid JSON format.") click.secho("Error parsing provider data. Invalid JSON format.", fg="red") return None diff --git a/tests/cli/provider_test.py b/tests/cli/provider_test.py new file mode 100644 index 0000000000..1957f5c980 --- /dev/null +++ b/tests/cli/provider_test.py @@ -0,0 +1,109 @@ +import json +import os +import tempfile +from pathlib import Path +from unittest import mock + +import pytest +import requests +from requests.exceptions import SSLError + +from crewai.cli.provider import fetch_provider_data, get_provider_data + + +class TestProviderFunctions: + @mock.patch("crewai.cli.provider.requests.get") + def test_fetch_provider_data_success(self, mock_get): + mock_response = mock.MagicMock() + mock_response.headers.get.return_value = "100" + mock_response.iter_content.return_value = [b'{"test": "data"}'] + mock_get.return_value = mock_response + + with tempfile.NamedTemporaryFile() as temp_file: + cache_file = Path(temp_file.name) + result = fetch_provider_data(cache_file) + + assert result == {"test": "data"} + mock_get.assert_called_once() + + @mock.patch("crewai.cli.provider.requests.get") + @mock.patch("crewai.cli.provider.click.secho") + def test_fetch_provider_data_ssl_error_fallback(self, mock_secho, mock_get): + mock_response = mock.MagicMock() + mock_response.headers.get.return_value = "100" + mock_response.iter_content.return_value = [b'{"test": "data"}'] + + mock_get.side_effect = [ + SSLError("certificate verify failed: unable to get local issuer certificate"), + mock_response + ] + + with tempfile.NamedTemporaryFile() as temp_file: + cache_file = Path(temp_file.name) + result = fetch_provider_data(cache_file) + + assert result == {"test": "data"} + assert mock_get.call_count == 2 + + assert mock_get.call_args_list[1][1]["verify"] is False + + mock_secho.assert_any_call( + "SSL certificate verification failed. Retrying with verification disabled. " + "This is less secure but may be necessary on some systems.", + fg="yellow" + ) + + @mock.patch("crewai.cli.provider.requests.get") + @mock.patch("crewai.cli.provider.click.secho") + @mock.patch.dict(os.environ, {"CREW_ALLOW_INSECURE_SSL": "true"}) + def test_fetch_provider_data_with_insecure_env_var(self, mock_secho, mock_get): + mock_response = mock.MagicMock() + mock_response.headers.get.return_value = "100" + mock_response.iter_content.return_value = [b'{"test": "data"}'] + mock_get.return_value = mock_response + + with tempfile.NamedTemporaryFile() as temp_file: + cache_file = Path(temp_file.name) + result = fetch_provider_data(cache_file) + + assert result == {"test": "data"} + mock_get.assert_called_once() + + assert mock_get.call_args[1]["verify"] is False + + mock_secho.assert_any_call( + "SSL verification disabled via environment variable. " + "This is less secure and should only be used in development environments.", + fg="yellow" + ) + + @mock.patch("crewai.cli.provider.requests.get") + def test_fetch_provider_data_with_empty_response(self, mock_get): + mock_response = mock.MagicMock() + mock_response.headers.get.return_value = "0" + mock_response.iter_content.return_value = [b'{}'] + mock_get.return_value = mock_response + + with tempfile.NamedTemporaryFile() as temp_file: + cache_file = Path(temp_file.name) + result = fetch_provider_data(cache_file) + + assert result == {} + mock_get.assert_called_once() + + @mock.patch("crewai.cli.provider.requests.get") + @mock.patch("crewai.cli.provider.click.secho") + def test_fetch_provider_data_request_exception(self, mock_secho, mock_get): + mock_get.side_effect = requests.RequestException("Connection error") + + with tempfile.NamedTemporaryFile() as temp_file: + cache_file = Path(temp_file.name) + result = fetch_provider_data(cache_file) + + assert result is None + mock_get.assert_called_once() + + mock_secho.assert_any_call( + "Error fetching provider data: Connection error", + fg="red" + )