Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion src/bernstein/cli/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@

Provides a Click Group subclass that resolves short aliases to
full command names, plus a registry of built-in aliases.
User-defined aliases can be loaded from ``~/.bernstein/aliases.yaml``.
"""

from __future__ import annotations

import logging
from pathlib import Path

import click
import yaml

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Alias registry
Expand All @@ -25,6 +32,11 @@
"i": "overture", # init (hidden name: overture)
}

# Track which aliases are user-defined (populated at load time)
_USER_ALIASES: dict[str, str] = {}

_USER_ALIASES_PATH = Path.home() / ".bernstein" / "aliases.yaml"


def get_alias(name: str) -> str | None:
"""Return the full command name for an alias, or None.
Expand All @@ -43,6 +55,31 @@ def get_all_aliases() -> dict[str, str]:
return dict(ALIASES)


def _load_user_aliases() -> dict[str, str]:
"""Load user-defined aliases from ~/.bernstein/aliases.yaml."""
if not _USER_ALIASES_PATH.is_file():
return {}
try:
with open(_USER_ALIASES_PATH) as f:
raw: object = yaml.safe_load(f) or {}
data: dict[str, object] = raw if isinstance(raw, dict) else {}
return {str(k): str(v) for k, v in data.items() if isinstance(k, str) and isinstance(v, str)}
except Exception:
logger.debug("Failed to load user aliases from %s", _USER_ALIASES_PATH, exc_info=True)
return {}


def _merge_aliases() -> None:
"""Merge user aliases into the global registry (user overrides built-in)."""
global _USER_ALIASES
_USER_ALIASES = _load_user_aliases()
ALIASES.update(_USER_ALIASES)


# Call at module load time
_merge_aliases()


class AliasGroup(click.Group):
"""Click Group that resolves short aliases to full command names.

Expand Down Expand Up @@ -92,6 +129,7 @@ def aliases_cmd() -> None:
table = Table(title="Command Aliases", show_header=True, header_style="bold cyan")
table.add_column("Alias", style="green", width=10)
table.add_column("Command", style="white", width=20)
table.add_column("Source", style="dim", width=10)
table.add_column("Description", style="dim")

_descriptions: dict[str, str] = {
Expand All @@ -107,7 +145,8 @@ def aliases_cmd() -> None:

for alias, command in sorted(ALIASES.items()):
desc = _descriptions.get(alias, "")
table.add_row(alias, command, desc)
source = "[cyan]user[/cyan]" if alias in _USER_ALIASES else "[dim]built-in[/dim]"
table.add_row(alias, command, source, desc)

console.print(table)
console.print("\n[dim]Usage: bernstein <alias> [options][/dim]")
Loading