Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions src/anaconda_auth/_conda/auth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

"""

from fnmatch import fnmatch
from functools import lru_cache
from typing import Any
from typing import NamedTuple
Expand All @@ -12,7 +13,11 @@
from urllib.parse import urlparse

from conda import CondaError
from conda.base.context import context as conda_context
from conda.plugins.types import ChannelAuthBase
from pydantic import BaseModel
from pydantic import Field
from pydantic import ValidationError
from requests import PreparedRequest
from requests import Response

Expand Down Expand Up @@ -40,7 +45,73 @@ class AnacondaAuthError(CondaError):
"""


def _load_channel_settings(channel_name: str) -> dict[str, Any]:
"""Find the correct channel settings from conda's configuration."""
# TODO(mattkram): Open conda issue to see if we can pass this into the AuthHandler
# as part of the plugin protocol.

# Since the conda logic uses a url, we derive a url from the channel name
# this will not work for multi_channels like "defaults", but we restrict the
# extra fields we need to URL-based channel_settings, which should be sufficient.
url = channel_name
if not url.endswith("/"):
url += "/"

parsed_url = urlparse(url)
if not parsed_url.scheme or not parsed_url.netloc:
return {}

# The following implementation has mostly been copied from conda, with one noted exception.
# Ideally, we can receive the settings in the plugin instantiation.
# See: https://github.com/conda/conda/blob/2af8e0f7255e1d06ea0bfcb6076c7427d101feee/conda/gateways/connection/session.py#L91-L112

# We ensure here if there are duplicates defined, we choose the last one
channel_settings = {}
for settings in conda_context.channel_settings:
channel = settings.get("channel", "")
if channel == channel_name:
# First we check for exact match
channel_settings = settings
continue

parsed_setting = urlparse(channel)

# We require that the schemes must be identical to prevent downgrade attacks.
# This includes the case of a scheme-less pattern like "*", which is not allowed.
if parsed_setting.scheme != parsed_url.scheme:
continue

url_without_schema = parsed_url.netloc + parsed_url.path
pattern = parsed_setting.netloc + parsed_setting.path
if fnmatch(url_without_schema, pattern):
channel_settings = settings

return channel_settings


class AnacondaAuthHandlerExtraSettings(BaseModel):
override_auth_domain: Optional[str] = Field(default=None, alias="auth_domain")
override_credential_type: Optional[CredentialType] = Field(
default=None, alias="credential_type"
)

@classmethod
def from_channel_name(cls, channel_name: str) -> "AnacondaAuthHandlerExtraSettings":
"""Load extra settings for a channel, with validation."""
settings = _load_channel_settings(channel_name)
try:
return cls(**settings)
except ValidationError as e:
raise AnacondaAuthError(
f"""Error when loading anaconda-auth extra configuration from your condarc.\n\n{e}"""
)


class AnacondaAuthHandler(ChannelAuthBase):
def __init__(self, channel_name: str, *args: Any, **kwargs: Any):
super().__init__(channel_name, *args, **kwargs)
self._extras = AnacondaAuthHandlerExtraSettings.from_channel_name(channel_name)

def _load_token_domain(self, parsed_url: ParseResult) -> tuple[str, CredentialType]:
"""Select the appropriate domain for token lookup based on a parsed URL.

Expand All @@ -64,6 +135,12 @@ def _load_token_domain(self, parsed_url: ParseResult) -> tuple[str, CredentialTy
if config.use_unified_repo_api_key:
credential_type = CredentialType.API_KEY

if self._extras.override_auth_domain:
token_domain = self._extras.override_auth_domain

if self._extras.override_credential_type:
credential_type = self._extras.override_credential_type

return token_domain, credential_type

def _load_token_from_keyring(self, url: str) -> Optional[AccessCredential]:
Expand Down
29 changes: 20 additions & 9 deletions src/anaconda_auth/_conda/condarc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ruamel.yaml import YAML
from ruamel.yaml import YAMLError

from anaconda_auth._conda.config import CredentialType
from anaconda_cli_base import console

DEFAULT_CONDARC_PATH = Path("~/.condarc").expanduser()
Expand Down Expand Up @@ -46,19 +47,29 @@ def load(self, path: Path | None = None) -> None:
raise CondaRCError(f"Could not parse condarc: {exc}")

def update_channel_settings(
self, channel: str, auth_type: str, username: str | None = None
self,
channel: str,
auth_type: str,
username: str | None = None,
*,
auth_domain: str | None = None,
credential_type: CredentialType | None = None,
) -> None:
"""
Update the condarc file's "channel_settings" section
"""
if username is None:
updated_settings = {"channel": channel, "auth": auth_type}
else:
updated_settings = {
"channel": channel,
"auth": auth_type,
"username": username,
}
updated_settings = {
"channel": channel,
"auth": auth_type,
"username": username,
"auth_domain": auth_domain,
"credential_type": credential_type.value if credential_type else None,
}

# Filter out any None values
updated_settings = {
key: value for key, value in updated_settings.items() if value is not None
}

channel_settings = self._loaded_yaml.get("channel_settings", []) or []

Expand Down
65 changes: 65 additions & 0 deletions tests/test_conda_plugins.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from urllib.parse import urlparse

import pytest
Expand Down Expand Up @@ -361,3 +362,67 @@ def test_load_token_domain_user_provided_default(conda_search_path):

assert token_domain == "some-domain.com"
assert credential_type == CredentialType.API_KEY


def test_load_token_domain_user_provided_auth_domain_override(conda_search_path):
channel_url = "https://some-domain.com/repo/some-channel"
condarc = CondaRC()
condarc.update_channel_settings(
channel=channel_url + "/*",
auth_type="anaconda-auth",
auth_domain="auth.some-domain.com",
)
condarc.save()
context.reset_context()

handler = AnacondaAuthHandler(channel_name=channel_url)
url = channel_url + "/noarch/repodata.json"
token_domain, credential_type = handler._load_token_domain(parsed_url=urlparse(url))

assert token_domain == "auth.some-domain.com"
assert credential_type == CredentialType.API_KEY


@pytest.mark.parametrize(
"override_credential_type", [CredentialType.API_KEY, CredentialType.REPO_TOKEN]
)
def test_load_token_domain_user_provided_credential_type_override(
override_credential_type, conda_search_path
):
channel_url = "https://some-domain.com/repo/some-channel"
condarc = CondaRC()
condarc.update_channel_settings(
channel=channel_url + "/*",
auth_type="anaconda-auth",
credential_type=override_credential_type,
)
condarc.save()
context.reset_context()

handler = AnacondaAuthHandler(channel_name=channel_url)
url = channel_url + "/noarch/repodata.json"
token_domain, credential_type = handler._load_token_domain(parsed_url=urlparse(url))

assert token_domain == "some-domain.com"
assert credential_type == override_credential_type


def test_load_token_domain_user_provided_credential_type_override_invalid(
conda_search_path,
):
channel_url = "https://some-domain.com/repo/some-channel"
condarc = CondaRC()

class BadCredentialType(Enum):
INVALID = "invalid"

condarc.update_channel_settings(
channel=channel_url + "/*",
auth_type="anaconda-auth",
credential_type=BadCredentialType.INVALID,
)
condarc.save()
context.reset_context()

with pytest.raises(AnacondaAuthError):
AnacondaAuthHandler(channel_name=channel_url)
Loading