diff --git a/.bandit b/.bandit new file mode 100644 index 00000000..3a8026e6 --- /dev/null +++ b/.bandit @@ -0,0 +1,13 @@ +[bandit] +# Bandit configuration file +exclude_dirs = ['tests', 'build', 'dist', '.git', '.tox', '.venv', '__pycache__'] + +# Skip certain tests that may be too strict for this project +skips = ['B101', 'B601'] + +# B101: Test for use of assert +# B601: Test for shell injection within Paramiko + +[bandit.assert_used] +# Allow assert statements in test files +skips = ['*test*.py', '*tests*.py'] diff --git a/.bumpversion.cfg b/.bumpversion.cfg deleted file mode 100644 index 947d6fcc..00000000 --- a/.bumpversion.cfg +++ /dev/null @@ -1,15 +0,0 @@ -[bumpversion] -current_version = 0.17.0 -commit = True -tag = True -tag_name = {new_version} -message = [RELEASE] - Release version {new_version} -parse = (?P\d+)\.(?P\d+)\.(?P\d+) -serialize = - {major}.{minor}.{patch} - -[bumpversion:file:README.md] - -[bumpversion:file:setup.py] - -[bumpversion:file:himl/main.py] diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..a2fcf0dc --- /dev/null +++ b/.flake8 @@ -0,0 +1,13 @@ +[flake8] +max-line-length = 120 +extend-ignore = E203, W503 +exclude = + .git, + __pycache__, + build, + dist, + .eggs, + *.egg-info, + .venv, + .tox, + himl/_version.py diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 00000000..a455837e --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,54 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '[BUG] ' +labels: bug +assignees: '' +--- + +## Bug Description + +A clear and concise description of what the bug is. + +## To Reproduce + +Steps to reproduce the behavior: +1. Go to '...' +2. Click on '....' +3. Scroll down to '....' +4. See error + +## Expected Behavior + +A clear and concise description of what you expected to happen. + +## Actual Behavior + +A clear and concise description of what actually happened. + +## Environment + +- OS: [e.g. Ubuntu 20.04, macOS 12.0, Windows 10] +- Python version: [e.g. 3.9.7] +- himl version: [e.g. 0.17.0] +- Installation method: [e.g. pip, conda, from source] + +## Configuration Files + +If applicable, add sample configuration files that reproduce the issue. + +```yaml +# Example config that causes the issue +``` + +## Error Messages + +If applicable, add the full error message and stack trace. + +``` +Paste error message here +``` + +## Additional Context + +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 00000000..0e7c1080 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,38 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '[FEATURE] ' +labels: enhancement +assignees: '' +--- + +## Feature Description + +A clear and concise description of what you want to happen. + +## Problem Statement + +Is your feature request related to a problem? Please describe. +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +## Proposed Solution + +Describe the solution you'd like. +A clear and concise description of what you want to happen. + +## Alternatives Considered + +Describe alternatives you've considered. +A clear and concise description of any alternative solutions or features you've considered. + +## Use Case + +Describe your use case and how this feature would help. + +## Implementation Ideas + +If you have ideas about how this could be implemented, please share them here. + +## Additional Context + +Add any other context or screenshots about the feature request here. diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 00000000..08b11548 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,33 @@ +## Description + +Brief description of the changes in this PR. + +## Type of Change + +- [ ] Bug fix (non-breaking change which fixes an issue) +- [ ] New feature (non-breaking change which adds functionality) +- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) +- [ ] Documentation update +- [ ] Code refactoring +- [ ] Performance improvement +- [ ] Test improvement + +## Testing + +- [ ] Tests pass locally with my changes +- [ ] I have added tests that prove my fix is effective or that my feature works +- [ ] New and existing unit tests pass locally with my changes +- [ ] I have tested the changes manually + +## Checklist + +- [ ] My code follows the style guidelines of this project +- [ ] I have performed a self-review of my own code +- [ ] I have commented my code, particularly in hard-to-understand areas +- [ ] I have made corresponding changes to the documentation +- [ ] My changes generate no new warnings +- [ ] Any dependent changes have been merged and published in downstream modules + +## Additional Notes + +Add any additional notes, concerns, or context about the changes here. diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 00000000..cb1350b4 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,187 @@ +name: CI + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + workflow_dispatch: + +env: + FORCE_COLOR: 1 + +jobs: + test: + name: Test Python ${{ matrix.python-version }} on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + python-version: ['3.9', '3.10', '3.11', '3.12', '3.13', '3.14'] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Get pip cache dir + id: pip-cache + run: | + echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + + - name: Cache pip dependencies + uses: actions/cache@v4 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml', '**/tests/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip-${{ matrix.python-version }}- + ${{ runner.os }}-pip- + + - name: Install dependencies + run: | + python -m pip install --upgrade pip setuptools wheel + pip install -e .[dev] + + - name: Lint with flake8 + run: | + # Stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # Exit-zero treats all errors as warnings + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=120 --statistics + + - name: Type check with mypy + run: | + mypy himl/ --ignore-missing-imports + + - name: Run tests with pytest + run: | + python -m pytest tests/ -v --tb=short --cov=himl --cov-report=xml --cov-report=term-missing --cov-fail-under=80 + + - name: Upload coverage to Codecov + if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.14' + uses: codecov/codecov-action@v4 + with: + file: ./coverage.xml + flags: unittests + name: codecov-umbrella + fail_ci_if_error: false + + security: + name: Security checks + runs-on: ubuntu-latest + # Note: Using Python 3.13 for security checks until bandit supports Python 3.14 + # See: https://github.com/PyCQA/bandit/issues with ast.Num deprecation + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.13' # Use 3.13 until bandit supports 3.14 + + - name: Install security tools + run: | + python -m pip install --upgrade pip + pip install bandit[toml] safety + + - name: Run security checks with bandit + run: | + # Generate JSON report (allow failures for reporting) + bandit -r himl/ -f json -o bandit-report.json || echo "Bandit JSON report generation completed with issues" + # Run bandit with medium severity (fail on medium+ issues) + bandit -r himl/ --severity-level medium + + + - name: Upload security reports + if: always() + uses: actions/upload-artifact@v4 + with: + name: security-reports + path: | + bandit-report.json + + build: + name: Build package + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 # Needed for setuptools_scm + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.14' + + - name: Install build dependencies + run: | + python -m pip install --upgrade pip + pip install build twine + + - name: Build package + run: | + python -m build + + - name: Check package + run: | + twine check dist/* + + - name: Test package installation + run: | + pip install dist/*.whl + himl --help + himl-config-merger --help + + - name: Upload build artifacts + uses: actions/upload-artifact@v4 + with: + name: dist-${{ github.sha }} + path: dist/ + retention-days: 30 + + integration: + name: Integration tests + runs-on: ubuntu-latest + needs: [test, build] + if: github.event_name == 'pull_request' || github.ref == 'refs/heads/main' + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.14' + + - name: Download build artifacts + uses: actions/download-artifact@v4 + with: + name: dist-${{ github.sha }} + path: dist/ + + - name: Install package from wheel + run: | + python -m pip install --upgrade pip + pip install dist/*.whl + + - name: Run integration tests + run: | + # Test CLI tools work + himl --help + himl-config-merger --help + + # Test basic functionality with examples + if [ -d "examples" ]; then + cd examples + if [ -d "simple" ]; then + himl simple/production --format yaml + fi + fi diff --git a/.gitignore b/.gitignore index 9572cbc1..4655e85a 100644 --- a/.gitignore +++ b/.gitignore @@ -104,3 +104,6 @@ venv.bak/ # mypy .mypy_cache/ + +# setuptools_scm generated version file +himl/_version.py diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..975a39ae --- /dev/null +++ b/Makefile @@ -0,0 +1,61 @@ +# Use Python 3.13 where the packages are installed +PYTHON := /opt/homebrew/bin/python3.13 + +.PHONY: help install test lint format clean build release bump-patch bump-minor bump-major + +help: ## Show this help message + @echo "Available commands:" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}' + +install: ## Install the package in development mode + pip install -e .[dev] + +test: ## Run tests + $(PYTHON) -m pytest tests/ -v + +lint: ## Run linting + # Stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # Exit-zero treats all errors as warnings + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=120 --statistics + mypy himl/ --ignore-missing-imports + +format: ## Format code + black himl tests + +clean: ## Clean build artifacts + rm -rf build/ + rm -rf dist/ + rm -rf *.egg-info/ + find . -type d -name __pycache__ -exec rm -rf {} + + find . -type f -name "*.pyc" -delete + +build: clean ## Build the package + $(PYTHON) -m build + +# Version bumping commands +bump-patch: ## Bump patch version (0.16.4 -> 0.16.5) + bump-my-version bump patch + +bump-minor: ## Bump minor version (0.16.4 -> 0.17.0) + bump-my-version bump minor + +bump-major: ## Bump major version (0.16.4 -> 1.0.0) + bump-my-version bump major + +# Show what version bump would do +show-bump: ## Show what version bumps would result in + bump-my-version show-bump + +# Dry run version bumps +dry-bump-patch: ## Dry run patch version bump + bump-my-version bump --dry-run --allow-dirty patch + +dry-bump-minor: ## Dry run minor version bump + bump-my-version bump --dry-run --allow-dirty minor + +dry-bump-major: ## Dry run major version bump + bump-my-version bump --dry-run --allow-dirty major + +release: build ## Build and upload to PyPI (requires proper credentials) + $(PYTHON) -m twine upload dist/* diff --git a/README.md b/README.md index b5286a38..9f766f23 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,20 @@ # himl + +[![CI](https://github.com/adobe/himl/workflows/CI/badge.svg)](https://github.com/adobe/himl/actions) +[![codecov](https://codecov.io/gh/adobe/himl/branch/main/graph/badge.svg)](https://codecov.io/gh/adobe/himl) +[![PyPI version](https://badge.fury.io/py/himl.svg)](https://badge.fury.io/py/himl) +[![Python versions](https://img.shields.io/pypi/pyversions/himl.svg)](https://pypi.org/project/himl/) + A hierarchical config using yaml in Python. Latest version is: 0.17.0 +> **⚠️ Breaking Changes in v0.18.0** +> This version includes breaking changes. See [MIGRATION_GUIDE.md](MIGRATION_GUIDE.md) for migration instructions. +> - Default list merge strategy changed from `append` to `append_unique` +> - Enhanced path validation (raises `FileNotFoundError` for non-existent paths) +> - Empty directory validation (raises `Exception` for directories without YAML files) + ## Description A python module which allows you to merge hierarchical config files using YAML syntax. It offers deep merge, variable interpolation and secrets retrieval from secrets managers. @@ -442,3 +454,85 @@ config_processor.process(path=path, filters=filters, exclude_keys=exclude_keys, type_strategies= [(list, [strategy_merge_override,'append']), (dict, ["merge"])] )) ``` + +## Development + +### Setting up development environment + +1. Clone the repository: +```bash +git clone https://github.com/adobe/himl.git +cd himl +``` + +2. Install in development mode with all dependencies: +```bash +pip install -e .[dev] +``` + +3. Run tests: +```bash +pytest tests/ +``` + +4. Run tests with coverage: +```bash +pytest tests/ --cov=himl --cov-report=term-missing +``` + +5. Run linting: +```bash +flake8 . +``` + +### CI/CD + +This project uses GitHub Actions for continuous integration. The CI pipeline: + +- **Tests**: Runs on Python 3.8-3.14 across Ubuntu and macOS +- **Security**: Runs bandit security checks and safety dependency checks +- **Build**: Builds and validates the package +- **Integration**: Tests CLI tools and package installation +- **Coverage**: Reports test coverage to Codecov + +All tests must pass and maintain good test coverage before merging PRs. + +### Version Management + +This project uses [bump-my-version](https://github.com/callowayproject/bump-my-version) for version management. The version is automatically determined from Git tags using setuptools_scm. + +#### Bumping Versions + +Use the provided Makefile commands: + +```bash +# Show what version bumps would result in +make show-bump + +# Dry run version bumps (to see what would change) +make dry-bump-patch # 0.16.4 -> 0.16.5 +make dry-bump-minor # 0.16.4 -> 0.17.0 +make dry-bump-major # 0.16.4 -> 1.0.0 + +# Actually bump the version +make bump-patch # For bug fixes +make bump-minor # For new features +make bump-major # For breaking changes +``` + +The version bump will: +1. Update the version in `pyproject.toml` +2. Update the version in `README.md` +3. Update the version in `himl/main.py` (CLI help) +4. Create a Git commit with the changes +5. Create a Git tag with the new version + +### Contributing + +1. Fork the repository +2. Create a feature branch (`git checkout -b feature/amazing-feature`) +3. Make your changes and add tests +4. Ensure all tests pass (`pytest tests/`) +5. Commit your changes (`git commit -m 'Add amazing feature'`) +6. Push to the branch (`git push origin feature/amazing-feature`) +7. Open a Pull Request diff --git a/RELEASE.md b/RELEASE.md index 310cdf4c..e4d80e71 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,9 +1,78 @@ -# Release - -1. Install: `pip install bump2version` -2. Bump version: `bump2version minor` -3. Push the release commit: `git push --follow-tags` -4. Wait for Github Actions to finish - 1. A new version is published at https://pypi.org/project/himl/#history - 2. Docker image is published at https://github.com/adobe/himl/pkgs/container/himl -5. Create a new Github release at https://github.com/adobe/himl/releases +# Release Process + +This project uses [bump-my-version](https://github.com/callowayproject/bump-my-version) for automated version management and releases. + +## Prerequisites + +1. Ensure you have the development dependencies installed: + ```bash + pip install -e .[dev] + ``` + +2. Ensure your working directory is clean (all changes committed) + +3. Ensure you're on the `main` branch and up to date: + ```bash + git checkout main + git pull origin main + ``` + +## Release Steps + +### 1. Choose Version Type + +Determine the type of release based on the changes: +- **Patch** (`0.16.4 → 0.16.5`): Bug fixes, documentation updates +- **Minor** (`0.16.4 → 0.17.0`): New features, backward-compatible changes +- **Major** (`0.16.4 → 1.0.0`): Breaking changes + +### 2. Preview the Release + +Check what the version bump would do: +```bash +# See all possible version bumps +make show-bump + +# Dry run the specific bump you want to make +make dry-bump-patch # For patch releases +make dry-bump-minor # For minor releases +make dry-bump-major # For major releases +``` + +### 3. Execute the Version Bump + +Run the appropriate bump command: +```bash +make bump-patch # For patch releases +make bump-minor # For minor releases +make bump-major # For major releases +``` + +This will automatically: +- Update the version in `pyproject.toml` +- Update the version in `README.md` +- Update the version in `himl/main.py` (CLI help) +- Create a Git commit with message `[RELEASE] - Release version X.Y.Z` +- Create a Git tag `X.Y.Z` + +### 4. Push the Release + +Push the commit and tags to trigger the release: +```bash +git push --follow-tags +``` + +### 5. Monitor Automated Release + +Wait for GitHub Actions to complete: +1. **PyPI Release**: A new version will be published at https://pypi.org/project/himl/#history +2. **Docker Image**: Published at https://github.com/adobe/himl/pkgs/container/himl +3. **GitHub Release**: Automatically created at https://github.com/adobe/himl/releases + +### 6. Verify Release + +1. Check that the new version appears on PyPI +2. Verify the Docker image is available +3. Test install the new version: `pip install himl==X.Y.Z` +4. Update the GitHub release notes if needed + diff --git a/himl/__init__.py b/himl/__init__.py index 80553726..6cd69b05 100644 --- a/himl/__init__.py +++ b/himl/__init__.py @@ -10,3 +10,22 @@ from .config_generator import ConfigGenerator, ConfigProcessor from .main import ConfigRunner + +# Make imports available at package level +__all__ = ['ConfigGenerator', 'ConfigProcessor', 'ConfigRunner', '__version__'] + +try: + from ._version import version as __version__ +except ImportError: + # Fallback for development installs + __version__ = "unknown" + try: + from importlib.metadata import version + __version__ = version("himl") + except ImportError: + try: + # Fallback for Python < 3.8 + from importlib_metadata import version as fallback_version + __version__ = fallback_version("himl") + except ImportError: + pass diff --git a/himl/config_generator.py b/himl/config_generator.py index c08814e1..ae39c933 100755 --- a/himl/config_generator.py +++ b/himl/config_generator.py @@ -20,7 +20,6 @@ from .interpolation import InterpolationResolver, EscapingResolver, InterpolationValidator, SecretResolver, \ DictIterator, replace_parent_working_directory, EnvVarResolver from .python_compat import iteritems, primitive_types, PY3 -from .remote_state import S3TerraformRemoteStateRetriever logging.basicConfig() logging.root.setLevel(logging.INFO) @@ -47,71 +46,109 @@ def process(self, cwd=None, fallback_strategies=["override"], type_conflict_strategies=["override"]): + # Prepare parameters and create generator path = self.get_relative_path(path) + skip_interpolation_validation = self._should_skip_interpolation_validation( + skip_interpolations, skip_secrets, skip_interpolation_validation) + cwd = cwd or os.getcwd() - if skip_interpolations or skip_secrets: - skip_interpolation_validation = True + generator = self._create_and_initialize_generator( + cwd, path, multi_line_string, type_strategies, fallback_strategies, type_conflict_strategies) - if cwd is None: - cwd = os.getcwd() + # Process data exclusions and interpolations + self._process_exclusions(generator, exclude_keys) + self._process_interpolations(generator, skip_interpolations, skip_secrets) + self._process_filters_and_validation(generator, filters, skip_interpolation_validation) + # Handle enclosing key operations and get final data + data = self._handle_enclosing_key_operations(generator, enclosing_key, remove_enclosing_key) + generator.clean_escape_characters() + + # Handle output operations + self._handle_output_operations(generator, data, output_format, print_data, output_file) + + return data + + def _should_skip_interpolation_validation(self, skip_interpolations, skip_secrets, skip_interpolation_validation): + """Determine if interpolation validation should be skipped.""" + return skip_interpolation_validation or skip_interpolations or skip_secrets + + def _create_and_initialize_generator(self, cwd, path, multi_line_string, type_strategies, + fallback_strategies, type_conflict_strategies): + """Create and initialize the ConfigGenerator.""" generator = ConfigGenerator(cwd, path, multi_line_string, type_strategies, fallback_strategies, type_conflict_strategies) generator.generate_hierarchy() generator.process_hierarchy() + return generator - # Exclude data before interpolations + def _process_exclusions(self, generator, exclude_keys): + """Process key exclusions before interpolations.""" if len(exclude_keys) > 0: generator.exclude_keys(exclude_keys) - # Resolve multiple levels of interpolations: + def _process_interpolations(self, generator, skip_interpolations, skip_secrets): + """Process all interpolation steps.""" if not skip_interpolations: - # TODO: reduce the number of calls to resolve_interpolations - generator.resolve_interpolations() - - # Resolve nested interpolations: + self._resolve_basic_interpolations(generator) + self._resolve_dynamic_interpolations(generator) + self._resolve_env_interpolations(generator) + self._resolve_secret_interpolations(generator, skip_secrets) + + def _resolve_basic_interpolations(self, generator): + """Resolve basic and nested interpolations.""" + # TODO: reduce the number of calls to resolve_interpolations + generator.resolve_interpolations() + + # Resolve nested interpolations: + # Example: + # map1: + # key1: value1 + # map2: "{{map1.key1}}" + # value: "something-{{map2.key1}} <--- this will be resolved at this step + generator.resolve_interpolations() + + def _resolve_dynamic_interpolations(self, generator): + """Add dynamic data and resolve interpolations using dynamic data.""" + generator.add_dynamic_data() + generator.resolve_interpolations() + + def _resolve_env_interpolations(self, generator): + """Add env vars and resolve interpolations using env vars.""" + generator.resolve_env() + generator.resolve_interpolations() + + def _resolve_secret_interpolations(self, generator, skip_secrets): + """Add secrets and resolve interpolations using secrets.""" + if not skip_secrets: + default_aws_profile = self.get_default_aws_profile(generator.generated_data) + generator.resolve_secrets(default_aws_profile) + # Perform resolving in case some secrets are used in nested interpolations. # Example: - # map1: - # key1: value1 - # map2: "{{map1.key1}}" - # value: "something-{{map2.key1}} <--- this will be resolved at this step + # value1: "{{ssm.mysecret}}" + # value2: "something-{{value1}} <--- this will be resolved at this step generator.resolve_interpolations() - # Add dynamic data and resolve interpolations using dynamic data: - generator.add_dynamic_data() - generator.resolve_interpolations() - - # Add env vars and resolve interpolations using env vars: - generator.resolve_env() - generator.resolve_interpolations() - - # Add secrets and resolve interpolations using secrets: - if not skip_secrets: - default_aws_profile = self.get_default_aws_profile(generator.generated_data) - generator.resolve_secrets(default_aws_profile) - # Perform resolving in case some secrets are used in nested interpolations. - # Example: - # value1: "{{ssm.mysecret}}" - # value2: "something-{{value1}} <--- this will be resolved at this step - generator.resolve_interpolations() - - # Filter data before interpolation validation + def _process_filters_and_validation(self, generator, filters, skip_interpolation_validation): + """Process data filtering and interpolation validation.""" if len(filters) > 0: generator.filter_data(filters) if not skip_interpolation_validation: generator.validate_interpolations() + def _handle_enclosing_key_operations(self, generator, enclosing_key, remove_enclosing_key): + """Handle enclosing key addition or removal operations.""" if enclosing_key: logger.info("Adding enclosing key {}".format(enclosing_key)) - data = generator.add_enclosing_key(enclosing_key) + return generator.add_enclosing_key(enclosing_key) elif remove_enclosing_key: logger.info("Removing enclosing key {}".format(remove_enclosing_key)) - data = generator.remove_enclosing_key(remove_enclosing_key) + return generator.remove_enclosing_key(remove_enclosing_key) else: - data = generator.generated_data - - generator.clean_escape_characters() + return generator.generated_data + def _handle_output_operations(self, generator, data, output_format, print_data, output_file): + """Handle printing and file output operations.""" if print_data or output_file: formatted_data = generator.output_data(data, output_format) @@ -122,8 +159,6 @@ def process(self, cwd=None, with open(output_file, 'w') as f: f.write(formatted_data) - return data - @staticmethod def get_default_aws_profile(data): return data['aws']['profile'] if 'aws' in data and 'profile' in data['aws'] else None @@ -139,7 +174,8 @@ def get_relative_path(path): class ConfigGenerator(object): """ this class is used to create a config generator object which will be used to generate cluster definition files - from the hierarchy of directories. The class implements methods that performs deep merging on dicts so the end result + from the hierarchy of directories. The class implements methods that performs deep merging on dicts so the end + result will contain merged data on each layer. """ @@ -153,15 +189,14 @@ def __init__(self, cwd, path, multi_line_string, type_strategies, fallback_strat self.fallback_strategies = fallback_strategies self.type_conflict_strategies = type_conflict_strategies if multi_line_string is True: - yaml.representer.BaseRepresenter.represent_scalar = ConfigGenerator.custom_represent_scalar + yaml.representer.BaseRepresenter.represent_scalar = ConfigGenerator.custom_represent_scalar # type: ignore @staticmethod def yaml_dumper(): try: from yaml import CLoader as Loader, CDumper as Dumper except ImportError: - from yaml import Loader, Dumper - from yaml.representer import SafeRepresenter + from yaml import Loader, Dumper # type: ignore _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG def dict_representer(dumper, data): @@ -184,7 +219,8 @@ def unicode_representer_pipestyle(dumper, data): style = u'|' if u'\n' in data else None return dumper.represent_scalar(u'tag:yaml.org,2002:str', data, style=style) - Dumper.add_representer(unicode, unicode_representer_pipestyle) + # Python 3 doesn't have unicode type, use str instead + Dumper.add_representer(str, unicode_representer_pipestyle) return Dumper @@ -193,7 +229,7 @@ def get_yaml_from_path(working_directory, path): yaml_files = [] for yaml_file in os.listdir(path): if yaml_file.endswith(".yaml"): - yaml_files.append(os.path.join(working_directory, yaml_file)) + yaml_files.append(os.path.join(path, yaml_file)) return sorted(yaml_files) @staticmethod @@ -214,7 +250,7 @@ def merge_value(reference, new_value, type_strategies, fallback_strategies, type @staticmethod def merge_yamls(values, yaml_content, type_strategies, fallback_strategies, type_conflict_strategies): for key, value in iteritems(yaml_content): - if key in values and type(values[key]) != type(value): + if key in values and not isinstance(values[key], type(value)) and not isinstance(value, type(values[key])): raise Exception("Failed to merge key '{}', because of mismatch in type: {} vs {}" .format(key, type(values[key]), type(value))) if key in values and not isinstance(value, primitive_types): @@ -239,34 +275,70 @@ def generate_hierarchy(self): :return: returns a list of directories in a priority order (from less specific to more specific) """ hierarchy = [] + + # Start from the current working directory + current_dir = self.cwd + full_target_path = os.path.join(current_dir, self.path) + + # If path is a file, just process the current directory + if os.path.isfile(full_target_path): + hierarchy.append(self.get_yaml_from_path(".", current_dir)) + return hierarchy + + # If path is a directory, build hierarchy by traversing path components full_path = pathlib2.Path(self.path) - for path in full_path.parts: - os.chdir(path) - new_path = os.path.relpath(os.getcwd(), self.cwd) - hierarchy.append(self.get_yaml_from_path(new_path, os.getcwd())) - os.chdir(self.cwd) + accumulated_path = "" + + # First, add the base directory (cwd) + hierarchy.append(self.get_yaml_from_path(".", current_dir)) + + # Then traverse each directory component in the path + for path_part in full_path.parts: + if accumulated_path: + accumulated_path = os.path.join(accumulated_path, path_part) + else: + accumulated_path = path_part + + full_dir_path = os.path.join(current_dir, accumulated_path) + if os.path.isdir(full_dir_path): + hierarchy.append(self.get_yaml_from_path(accumulated_path, full_dir_path)) + return hierarchy def process_hierarchy(self): - merged_values = OrderedDict() + # Check if the target path exists before processing + full_target_path = os.path.join(self.cwd, self.path) + if not os.path.exists(full_target_path): + raise FileNotFoundError(f"Path does not exist: {full_target_path}") + + merged_values: OrderedDict = OrderedDict() + total_files_processed = 0 + for yaml_files in self.hierarchy: for yaml_file in yaml_files: yaml_content = self.yaml_get_content(yaml_file) self.merge_yamls(merged_values, yaml_content, self.type_strategies, self.fallback_strategies, self.type_conflict_strategies) self.resolve_simple_interpolations(merged_values, yaml_file) + total_files_processed += 1 + + if total_files_processed == 0: + raise Exception("No YAML files found to process in the hierarchy") + self.generated_data = merged_values def get_values_from_dir_path(self): values = {} full_path = pathlib2.Path(self.path) - for path in full_path.parts[1:]: - split_value = path.split('=') - values[split_value[0]] = split_value[1] + for path in full_path.parts: + if '=' in path: + split_value = path.split('=') + values[split_value[0]] = split_value[1] return values def output_yaml_data(self, data): - return yaml.dump(data, Dumper=ConfigGenerator.yaml_dumper(), default_flow_style=False, width=200, sort_keys=False) + return yaml.dump(data, Dumper=ConfigGenerator.yaml_dumper(), default_flow_style=False, width=200, + sort_keys=False) def yaml_to_json(self, yaml_data): return json.dumps(yaml.load(yaml_data, Loader=yaml.SafeLoader), indent=4) @@ -286,7 +358,7 @@ def remove_enclosing_key(self, key): return self.generated_data[key] def filter_data(self, keys): - self.generated_data = {key: self.generated_data[key] for key in keys if key in self.generated_data} + self.generated_data = OrderedDict({key: self.generated_data[key] for key in keys if key in self.generated_data}) def exclude_keys(self, keys): for key in keys: @@ -294,7 +366,7 @@ def exclude_keys(self, keys): try: logger.info("Excluding key %s", key) del self.generated_data[key] - except KeyNotFound: + except KeyError: logger.info("Excluded key %s not found or already removed", key) def add_dynamic_data(self): diff --git a/himl/config_merger.py b/himl/config_merger.py index 91c09e9f..b389305e 100755 --- a/himl/config_merger.py +++ b/himl/config_merger.py @@ -23,6 +23,7 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + class Loader(yaml.SafeLoader): """ Overloading the default YAML Loader by adding the custom include tag @@ -39,15 +40,25 @@ def include(self, node): """ Method implementing the custom include tag logic that will grab a value from a yaml key in a specified folder at a specified path - :param node: String containing the path and key variables + :param node: String containing the path and optionally key variables :return: String values representing the extracted value for the specified path, key combination under node """ - path, key = self.construct_yaml_str(node).split(" ") - filename = os.path.join(self.ROOT_DIR, path) - - with open(filename, 'r') as f: - yaml_structure = yaml.load(f, Loader) - return self.__traverse_path(path=key, yaml_dict=yaml_structure) + node_value = self.construct_yaml_str(node) + parts = node_value.split(" ", 1) # Split into at most 2 parts + + if len(parts) == 1: + # Only path provided, return entire file content + path = parts[0] + filename = os.path.join(self.ROOT_DIR, path) + with open(filename, 'r') as f: + return yaml.load(f, Loader=Loader) # nosec B506 # Custom Loader inherits from SafeLoader + else: + # Both path and key provided + path, key = parts + filename = os.path.join(self.ROOT_DIR, path) + with open(filename, 'r') as f: + yaml_structure = yaml.load(f, Loader=Loader) # nosec B506 # Custom Loader inherits from SafeLoader + return self.__traverse_path(path=key, yaml_dict=yaml_structure) def __traverse_path(self, path: str, yaml_dict: dict): """ @@ -74,6 +85,10 @@ def __traverse_path(self, path: str, yaml_dict: dict): current_key, yaml_dict)) +# Register the include constructor +Loader.add_constructor('!include', Loader.include) + + def merge_configs(directories, levels, output_dir, enable_parallel, filter_rules): """ Method for running the merge configuration logic under different formats @@ -111,7 +126,7 @@ def merge_logic(process_params): Loader.add_constructor('!include', Loader.include) # override the Yaml SafeLoader with our custom loader - yaml.SafeLoader = Loader + yaml.SafeLoader = Loader # type: ignore # for path in directories: # use the HIML deep merge functionality @@ -121,7 +136,9 @@ def merge_logic(process_params): level_values = [output.get(level) for level in levels] # create the publish path and all level_values except the last one - publish_path = os.path.join(output_dir, '') + '/'.join(level_values[:-1]) + # Filter out None values and convert to strings + valid_level_values = [str(val) for val in level_values[:-1] if val is not None] + publish_path = os.path.join(output_dir, '') + '/'.join(valid_level_values) if not os.path.exists(publish_path): os.makedirs(publish_path) @@ -143,11 +160,13 @@ def is_leaf_directory(dir, leaf_directories): return any(dir.startswith(leaf) for leaf in leaf_directories) -def get_leaf_directories(src, leaf_directories): +def get_leaf_directories(src, leaf_directories, exit_on_empty=True): """ Method for doing a deep search of directories matching either the desired leaf directories. :param src: the source path to start looking from + :param leaf_directories: list of leaf directory patterns to match + :param exit_on_empty: whether to exit when no directories are found (default: True) :return: the list of absolute paths """ directories = [] @@ -163,13 +182,14 @@ def get_leaf_directories(src, leaf_directories): else: continue - if len(directories) == 0: + if len(directories) == 0 and exit_on_empty: sys.exit("No leaf directories found") return directories -def parser_options(args): +def get_parser(): + """Create and return the argument parser""" parser = argparse.ArgumentParser() parser.add_argument('path', type=str, help='The configs directory') @@ -183,6 +203,11 @@ def parser_options(args): action='store_true', help='Process config using multiprocessing') parser.add_argument('--filter-rules-key', dest='filter_rules', default=None, type=str, help='keep these keys from the generated data, based on the configured filter key') + return parser + + +def parser_options(args): + parser = get_parser() return parser.parse_args(args) diff --git a/himl/filter_rules.py b/himl/filter_rules.py index 6776ea73..077f140c 100644 --- a/himl/filter_rules.py +++ b/himl/filter_rules.py @@ -13,7 +13,7 @@ def run(self, output): for filter in self.rules: selector = filter.get("selector", {}) - if type(selector) != dict: + if not isinstance(selector, dict): raise Exception("Filter selector must be a dictionary") if not self.match(output, selector): diff --git a/himl/inject_env.py b/himl/inject_env.py index 27cf5ee8..cb8a4053 100644 --- a/himl/inject_env.py +++ b/himl/inject_env.py @@ -26,22 +26,43 @@ def is_interpolation(self, value): def inject_env_var(self, line): """ Check if value is an interpolation and try to resolve it. + Handles both full interpolations and partial interpolations. """ - if not self.is_interpolation(line): + if not isinstance(line, str): return line - # remove {{ and }} - updated_line = line[2:-2] + # Handle full interpolations (entire string is an interpolation) + if self.is_interpolation(line): + # remove {{ and }} + updated_line = line[2:-2] - # check supported function to ensure the proper format is used - if not self.is_env_interpolation(updated_line): - return line + # check supported function to ensure the proper format is used + if not self.is_env_interpolation(updated_line): + return line + + # remove env( and ) to extract the env Variable + updated_line = updated_line[4:-1] + + # If env variable is missing or not set, the output will be None + return getenv(updated_line) + + # Handle partial interpolations (interpolations within a string) + import re + pattern = r'\{\{env\([^)]+\)\}\}' - # remove env( and ) to extract the env Variable - updated_line = updated_line[4:-1] + def replace_env_var(match): + env_interpolation = match.group(0) + # Extract the variable name from {{env(VAR_NAME)}} + var_name = env_interpolation[6:-3] # Remove {{env( and )}} + if len(var_name.strip()) > 0: + return getenv(var_name) or '' + return env_interpolation - # If env variable is missing or not set, the output will be None - return getenv(updated_line) + return re.sub(pattern, replace_env_var, line) def is_env_interpolation(self, value): - return value.startswith('env(') and value.endswith(')') + if not (value.startswith('env(') and value.endswith(')')): + return False + # Extract the variable name and check it's not empty + var_name = value[4:-1] # Remove 'env(' and ')' + return len(var_name.strip()) > 0 diff --git a/himl/inject_secrets.py b/himl/inject_secrets.py index 029c2f51..f507171d 100644 --- a/himl/inject_secrets.py +++ b/himl/inject_secrets.py @@ -14,7 +14,7 @@ try: from functools import lru_cache except ImportError: - from backports.functools_lru_cache import lru_cache + from backports.functools_lru_cache import lru_cache # type: ignore class SecretInjector(object): @@ -50,14 +50,22 @@ def inject_secret(self, line): secret_type = parts[0] - secret_params = {} + secret_params: dict = {} for part in parts: if '(' not in part: secret_params[part] = None else: - key = part.split('(')[0] - value = part.split('(')[1].split(')')[0] - secret_params[key] = value + try: + key = part.split('(')[0] + value_part = part.split('(')[1] + if ')' not in value_part: + # Malformed interpolation - missing closing parenthesis + return line + value = value_part.split(')')[0] + secret_params[key] = value + except (IndexError, ValueError): + # Malformed interpolation + return line if self.resolver.supports(secret_type): return self.resolver.resolve(secret_type, secret_params) diff --git a/himl/interpolation.py b/himl/interpolation.py index 2a7b9793..ab3c5698 100644 --- a/himl/interpolation.py +++ b/himl/interpolation.py @@ -131,8 +131,7 @@ class EnvVarResolver(object): def resolve_env_vars(self, data): injector = EnvVarInjector() env_resolver = EnvVarInterpolationsResolver(injector) - env_resolver.resolve_interpolations(data) - return data + return env_resolver.resolve_interpolations(data) class DictIterator(object): @@ -259,10 +258,9 @@ def resolve(self, line, data): if data_id not in self._parse_cache: self._parse_cache[data_id] = {} self._parse_leaves_cached(data, "", self._parse_cache[data_id]) - + # Use cached results instead of rebuilding cached_results = self._parse_cache[data_id] - for key, value in iteritems(cached_results): placeholder = "{{" + key + "}}" if placeholder not in line: diff --git a/himl/main.py b/himl/main.py index cb4f7751..4d0a17bc 100644 --- a/himl/main.py +++ b/himl/main.py @@ -13,15 +13,17 @@ from .config_generator import ConfigProcessor from enum import Enum + class ListMergeStrategy(Enum): append = 'append' override = 'override' prepend = 'prepend' - append_unique = 'append_unique' #WARNING: currently this strategy does not support list of dicts, only list of str + append_unique = 'append_unique' # WARNING: currently this strategy does not support list of dicts, only list of str def __str__(self): return self.value + class ConfigRunner(object): def run(self, args): @@ -38,10 +40,11 @@ def do_run(self, opts): config_processor = ConfigProcessor() - config_processor.process(cwd, opts.path, filters, excluded_keys, opts.enclosing_key, opts.remove_enclosing_key, - opts.output_format, opts.print_data, opts.output_file, opts.skip_interpolation_resolving, - opts.skip_interpolation_validation, opts.skip_secrets, opts.multi_line_string, - type_strategies= [(list, [opts.merge_list_strategy.value]), (dict, ["merge"])] ) + config_processor.process(cwd, opts.path, filters, excluded_keys, opts.enclosing_key, + opts.remove_enclosing_key, opts.output_format, opts.print_data, opts.output_file, + opts.skip_interpolation_resolving, opts.skip_interpolation_validation, + opts.skip_secrets, opts.multi_line_string, + type_strategies=[(list, [opts.merge_list_strategy.value]), (dict, ["merge"])]) @staticmethod def get_parser(parser=None): @@ -73,12 +76,13 @@ def get_parser(parser=None): help='the working directory') parser.add_argument('--multi-line-string', action='store_true', help='will overwrite the global yaml dumper to use block style') - parser.add_argument('--list-merge-strategy', dest='merge_list_strategy', type=ListMergeStrategy, choices=list(ListMergeStrategy), - default='append', + parser.add_argument('--list-merge-strategy', dest='merge_list_strategy', type=ListMergeStrategy, + choices=list(ListMergeStrategy), default='append_unique', help='override default merge strategy for list') parser.add_argument('--version', action='version', version='%(prog)s v{version}'.format(version="0.17.0"), help='print himl version') return parser + def run(args=None): ConfigRunner().run(args) diff --git a/himl/python_compat.py b/himl/python_compat.py index 0b1dd083..fa0ce444 100644 --- a/himl/python_compat.py +++ b/himl/python_compat.py @@ -8,17 +8,14 @@ # OF ANY KIND, either express or implied. See the License for the specific language # governing permissions and limitations under the License. -import sys +# Python 3.8+ only (Python 2 is EOL) +PY3 = True -PY3 = sys.version_info >= (3, 0) -if PY3: - iteritems = lambda d: iter(d.items()) - integer_types = (int,) - string_types = (str,) - primitive_types = (str, int, float, bool) -else: - iteritems = lambda d: d.iteritems() - integer_types = (int, long) - string_types = (str, unicode) - primitive_types = (str, unicode, int, long, float, bool) +def iteritems(d): + return iter(d.items()) + + +integer_types = (int,) +string_types = (str,) +primitive_types = (str, int, float, bool) diff --git a/himl/remote_state.py b/himl/remote_state.py index e1bd5e22..2c7e243e 100644 --- a/himl/remote_state.py +++ b/himl/remote_state.py @@ -10,14 +10,21 @@ import json +# boto3 is imported only when needed to avoid requiring it for basic himl functionality try: import boto3 -except ImportError as e: - raise Exception('Error while trying to read remote state, package "boto3" is required and cannot be imported: %s' % e) +except ImportError: + # boto3 will be None, and we'll raise an error only when S3 functionality is used + boto3 = None + class S3TerraformRemoteStateRetriever: @staticmethod def get_s3_client(bucket_name, bucket_key, boto_profile): + if boto3 is None: + raise ImportError('boto3 package is required for S3 remote state functionality. ' + 'Install with: pip install himl[s3]') + session = boto3.session.Session(profile_name=boto_profile) client = session.client('s3') try: @@ -27,7 +34,7 @@ def get_s3_client(bucket_name, bucket_key, boto_profile): return [] def get_dynamic_data(self, remote_states): - generated_data = {"outputs": {}} + generated_data: dict = {"outputs": {}} for state in remote_states: bucket_object = self.get_s3_client(state["s3_bucket"], state["s3_key"], state["aws_profile"]) if "outputs" in bucket_object: diff --git a/himl/secret_resolvers.py b/himl/secret_resolvers.py index b7a1086a..e27ba570 100644 --- a/himl/secret_resolvers.py +++ b/himl/secret_resolvers.py @@ -8,15 +8,16 @@ # OF ANY KIND, either express or implied. See the License for the specific language # governing permissions and limitations under the License. -import logging import os +import sys + class SecretResolver: def supports(self, secret_type): - return False + raise NotImplementedError("Subclasses must implement supports method") def resolve(self, secret_type, secret_params): - return None + raise NotImplementedError("Subclasses must implement resolve method") def get_param_or_exception(self, key, params): if key not in params: @@ -29,7 +30,7 @@ def __init__(self, default_aws_profile=None): self.default_aws_profile = default_aws_profile def supports(self, secret_type): - return "boto3" in sys.modules && secret_type == "ssm" + return "boto3" in sys.modules and secret_type == "ssm" def resolve(self, secret_type, secret_params): aws_profile = secret_params.get("aws_profile", self.default_aws_profile) @@ -49,7 +50,7 @@ def __init__(self, default_aws_profile=None): self.default_aws_profile = default_aws_profile def supports(self, secret_type): - return "boto3" in sys.modules && secret_type == "s3" + return "boto3" in sys.modules and secret_type == "s3" def resolve(self, secret_type, secret_params): aws_profile = secret_params.get("aws_profile", self.default_aws_profile) @@ -66,19 +67,21 @@ def resolve(self, secret_type, secret_params): s3 = SimpleS3(aws_profile, region_name) return s3.get(bucket, path, base64Encode) + class SopsSecretResolver(SecretResolver): def supports(self, secret_type): return secret_type == "sops" - + def resolve(self, secret_type, secret_params): from .simplesops import SimpleSops file = self.get_param_or_exception("secret_file", secret_params) sops = SimpleSops() return sops.get(secret_file=file, secret_key=secret_params.get("secret_key")) + class VaultSecretResolver(SecretResolver): def supports(self, secret_type): - return "hvac" in sys.modules && secret_type == "vault" + return "hvac" in sys.modules and secret_type == "vault" def resolve(self, secret_type, secret_params): from .simplevault import SimpleVault @@ -114,4 +117,5 @@ def resolve(self, secret_type, secret_params): if resolver.supports(secret_type): return resolver.resolve(secret_type, secret_params) - raise Exception("Could not resolve secret type '{}' with params {}. Check if you installed the required 3rd party modules.".format(secret_type, secret_params)) + raise Exception("Could not resolve secret type '{}' with params {}. " + "Check if you installed the required 3rd party modules.".format(secret_type, secret_params)) diff --git a/himl/simples3.py b/himl/simples3.py index b420a64d..2ad37971 100644 --- a/himl/simples3.py +++ b/himl/simples3.py @@ -10,12 +10,12 @@ import boto3 import logging -import os from botocore.exceptions import ClientError logger = logging.getLogger(__name__) + class SimpleS3(object): def __init__(self, aws_profile, region_name): self.aws_profile = aws_profile @@ -23,14 +23,14 @@ def __init__(self, aws_profile, region_name): def get(self, bucket_name, bucket_key, base64Encode=False): try: - logger.info("Resolving S3 object for bucket %s, key '%s' on profile %s in region %s", - bucket_name, bucket_key, self.aws_profile, self.region_name) + logger.info("Resolving S3 object for bucket %s, key '%s' on profile %s in region %s", + bucket_name, bucket_key, self.aws_profile, self.region_name) client = self.get_s3_client() bucket_object = client.get_object(Bucket=bucket_name, Key=bucket_key)["Body"].read() return self.parse_data(bucket_object, base64Encode) except ClientError as e: raise Exception( - 'Error while trying to read S3 value for bucket_name %s, bucket_key: %s - %s' + 'Error while trying to read S3 value for bucket_name %s, bucket_key: %s - %s' % (bucket_name, bucket_key, e.response['Error']['Code'])) def parse_data(self, bucket_object, base64Encode): diff --git a/himl/simplesops.py b/himl/simplesops.py index 8fdf90fc..741af4ec 100644 --- a/himl/simplesops.py +++ b/himl/simplesops.py @@ -5,7 +5,9 @@ from __future__ import absolute_import, division, print_function from functools import lru_cache -import os, logging, yaml +import os +import logging +import yaml from subprocess import Popen, PIPE @@ -46,16 +48,21 @@ class SopsError(Exception): """Extend Exception class with sops specific information""" def __init__(self, filename, exit_code, message, decryption=True): + self.filename = filename + self.exit_code = exit_code + self.stderr = message + self.decryption = decryption + if exit_code in SOPS_ERROR_CODES: exception_name = SOPS_ERROR_CODES[exit_code] - message = "error with file %s: %s exited with code %d: %s" % ( + formatted_message = "error with file %s: %s exited with code %d: %s" % ( filename, exception_name, exit_code, message, ) else: - message = ( + formatted_message = ( "could not %s file %s; Unknown sops error code: %s; message: %s" % ( "decrypt" if decryption else "encrypt", @@ -64,17 +71,18 @@ def __init__(self, filename, exit_code, message, decryption=True): message, ) ) - super(SopsError, self).__init__(message) + super(SopsError, self).__init__(formatted_message) class Sops: """Utility class to perform sops CLI actions""" + @staticmethod @lru_cache(maxsize=2048) def decrypt( - encrypted_file, - decode_output=True, - rstrip=True, + encrypted_file: str, + decode_output: bool = True, + rstrip: bool = True, ): command = ["sops"] env = os.environ.copy() @@ -87,20 +95,22 @@ def decrypt( stderr=PIPE, env=env, ) - (output, err) = process.communicate() + (output_bytes, err) = process.communicate() exit_code = process.returncode if decode_output: # output is binary, we want UTF-8 string - output = output.decode("utf-8", errors="surrogate_or_strict") + output: str = output_bytes.decode("utf-8", errors="surrogate_or_strict") # the process output is the decrypted secret; be cautious + else: + output = output_bytes.decode("utf-8", errors="surrogate_or_strict") if exit_code != 0: raise SopsError(encrypted_file, exit_code, err, decryption=True) if rstrip: output = output.rstrip() return yaml.full_load(output) - + def get_keys(self, secret_file, secret_key): result = Sops.decrypt(secret_file) secret_key_path = secret_key.strip("[]") @@ -108,8 +118,9 @@ def get_keys(self, secret_file, secret_key): try: for key in keys: result = result[key] - except KeyError as e: - raise SopsError(secret_file, 128, "Encountered KeyError parsing yaml for key: %s" % secret_key, decryption=True) + except KeyError: + raise SopsError(secret_file, 128, "Encountered KeyError parsing yaml for key: %s" % secret_key, + decryption=True) return result diff --git a/himl/simplessm.py b/himl/simplessm.py index d5bb3e0d..d1fac291 100644 --- a/himl/simplessm.py +++ b/himl/simplessm.py @@ -25,7 +25,8 @@ def __init__(self, aws_profile, region_name): def get(self, key): client = self.get_ssm_client() try: - logger.info("Resolving SSM secret for key '%s' on profile %s in region %s", key, self.aws_profile, self.region_name) + logger.info("Resolving SSM secret for key '%s' on profile %s in region %s", key, self.aws_profile, + self.region_name) return client.get_parameter(Name=key, WithDecryption=True).get("Parameter").get("Value") except ClientError as e: raise Exception( @@ -39,6 +40,7 @@ def get_ssm_client(self): def release_ssm_client(self): if self.initial_aws_profile is None: - del os.environ['AWS_PROFILE'] + if 'AWS_PROFILE' in os.environ: + del os.environ['AWS_PROFILE'] else: os.environ['AWS_PROFILE'] = self.initial_aws_profile diff --git a/himl/simplevault.py b/himl/simplevault.py index 3b465cd6..63303825 100644 --- a/himl/simplevault.py +++ b/himl/simplevault.py @@ -10,10 +10,26 @@ import logging import os -from distutils.util import strtobool import hvac + +def strtobool(val): + """Convert a string representation of truth to true (1) or false (0). + + True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values + are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if + 'val' is anything else. + """ + val = val.lower() + if val in ('y', 'yes', 't', 'true', 'on', '1'): + return 1 + elif val in ('n', 'no', 'f', 'false', 'off', '0'): + return 0 + else: + raise ValueError("invalid truth value %r" % (val,)) + + logger = logging.getLogger(__name__) @@ -50,7 +66,7 @@ def get_vault_client(self): ) assert client.is_authenticated() logger.info("Vault LDAP authenticated") - except Exception as e: + except Exception: raise Exception("Error authenticating Vault over LDAP") return client diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..3ba2f2d0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,163 @@ +[build-system] +requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.2"] +build-backend = "setuptools.build_meta" + +[project] +name = "himl" +dynamic = ["version"] +description = "A hierarchical config using YAML in Python" +readme = "README.md" +license = "Apache-2.0" +authors = [ + {name = "Adobe", email = "noreply@adobe.com"} +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: System :: Systems Administration", + "Topic :: Utilities", +] +requires-python = ">=3.9" +dependencies = [ + "PyYAML==6.0.1", + "deepmerge==2.0", + "pathlib2==2.3.7.post1", +] + +[project.optional-dependencies] +s3 = ["boto3==1.34.6"] +aws = ["boto3==1.34.6"] # Alias for s3 for backward compatibility +vault = ["hvac==1.2.1"] +all = ["boto3==1.34.6", "hvac==1.2.1"] +dev = [ + "pytest==8.4.2", + "pytest-cov==7.0.0", + "pytest-mock==3.15.1", + "pytest-xdist==3.8.0", + "coverage[toml]==7.10.7", + "black==25.9.0", + "flake8==7.3.0", + "mypy==1.18.2", + "types-PyYAML==6.0.12.20250915", + "boto3==1.34.6", + "hvac==1.2.1", + "bump-my-version==1.2.4", + "build==1.3.0", + "twine==6.2.0", +] + +[project.scripts] +himl = "himl.main:run" +himl-config-merger = "himl.config_merger:run" + +[project.urls] +Homepage = "https://github.com/adobe/himl" +Repository = "https://github.com/adobe/himl" +Documentation = "https://github.com/adobe/himl#readme" +"Bug Tracker" = "https://github.com/adobe/himl/issues" + +[tool.setuptools_scm] +write_to = "himl/_version.py" + +[tool.bumpversion] +current_version = "0.17.0" +parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)" +serialize = ["{major}.{minor}.{patch}"] +search = "{current_version}" +replace = "{new_version}" +regex = false +ignore_missing_version = false +tag = true +sign_tags = false +tag_name = "{new_version}" +tag_message = "Bump version: {current_version} → {new_version}" +allow_dirty = false +commit = true +message = "[RELEASE] - Release version {new_version}" +commit_args = "" + +[[tool.bumpversion.files]] +filename = "README.md" +search = "Latest version is: {current_version}" +replace = "Latest version is: {new_version}" + +[[tool.bumpversion.files]] +filename = "himl/main.py" +search = "version='%(prog)s v{{version}}'.format(version=\"{current_version}\")" +replace = "version='%(prog)s v{{version}}'.format(version=\"{new_version}\")" + +[tool.coverage.run] +source = ["himl"] +omit = [ + "*/tests/*", + "*/test_*", + "himl/_version.py", +] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "if self.debug:", + "if settings.DEBUG", + "raise AssertionError", + "raise NotImplementedError", + "if 0:", + "if __name__ == .__main__.:", + "class .*\\bProtocol\\):", + "@(abc\\.)?abstractmethod", +] + +[tool.black] +line-length = 120 +target-version = ['py38'] +include = '\.pyi?$' +extend-exclude = ''' +/( + # directories + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | build + | dist +)/ +''' + + + +[tool.mypy] +python_version = "3.13" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = false +disallow_incomplete_defs = false +check_untyped_defs = true +disallow_untyped_decorators = false +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = true +warn_unreachable = true +strict_equality = true + +[[tool.mypy.overrides]] +module = [ + "boto3.*", + "botocore.*", + "hvac.*", + "deepmerge.*", + "pathlib2.*", +] +ignore_missing_imports = true diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..e94f1b3a --- /dev/null +++ b/pytest.ini @@ -0,0 +1,21 @@ +[tool:pytest] +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* +addopts = + -v + --tb=short + --strict-markers + --disable-warnings + --color=yes +markers = + unit: Unit tests + integration: Integration tests + slow: Slow running tests + aws: Tests requiring AWS credentials + vault: Tests requiring Vault setup + sops: Tests requiring SOPS setup +filterwarnings = + ignore::DeprecationWarning + ignore::PendingDeprecationWarning diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index c882422d..00000000 --- a/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -setuptools~=80.9.0 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 04d45c4a..00000000 --- a/setup.cfg +++ /dev/null @@ -1,5 +0,0 @@ -[bdist_wheel] -universal=1 - -[metadata] -description_file = README.md diff --git a/setup.py b/setup.py deleted file mode 100644 index 2e11f750..00000000 --- a/setup.py +++ /dev/null @@ -1,68 +0,0 @@ -try: - from setuptools import setup -except ImportError: - from distutils.core import setup - -with open('README.md', encoding="utf-8") as f: - _readme = f.read() - -_install_requires = [ - 'deepmerge==1.1.1', - 'lru_cache==0.2.3', - 'backports.functools_lru_cache==2.0.0', - 'pathlib2==2.3.7', - 'pyyaml==6.0.2', -] - -_extras_require = { - 's3': [ - 'boto3==1.34.6', - ], - 'vault': [ - 'hvac==1.2.1', - ], -} -_extras_require['all'] = [dep for deps in _extras_require.values() for dep in deps] - -setup( - name='himl', - version="0.17.0", - description='A hierarchical config using yaml', - long_description=_readme + '\n\n', - long_description_content_type='text/markdown', - url='https://github.com/adobe/himl', - author='Adobe', - author_email='noreply@adobe.com', - python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*", - license='Apache2', - classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Environment :: Web Environment', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3.12', - 'Programming Language :: Python :: Implementation :: CPython', - 'Programming Language :: Python :: Implementation :: PyPy', - 'Topic :: Internet :: WWW/HTTP :: Dynamic Content', - 'Topic :: Software Development :: Libraries :: Python Modules', - 'Topic :: Text Processing :: Markup :: HTML' - ], - packages=['himl'], - include_package_data=True, - install_requires=_install_requires, - extras_require=_extras_require, - entry_points={ - 'console_scripts': [ - 'himl = himl.main:run', - 'himl-config-merger = himl.config_merger:run' - ] - } -) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..5cf85796 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,277 @@ +# Copyright 2019 Adobe. All rights reserved. +# This file is licensed to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy +# of the License at http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS +# OF ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +import os +import tempfile +import shutil +import pytest +import yaml +from unittest.mock import patch + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for tests""" + temp_dir = tempfile.mkdtemp() + yield temp_dir + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + + +@pytest.fixture +def yaml_creator(temp_dir): + """Factory fixture for creating YAML files in temp directory""" + def create_yaml(path, content): + full_path = os.path.join(temp_dir, path) + os.makedirs(os.path.dirname(full_path), exist_ok=True) + with open(full_path, 'w') as f: + yaml.dump(content, f) + return full_path + return create_yaml + + +@pytest.fixture +def sample_config(): + """Sample configuration data for testing""" + return { + 'env': 'test', + 'database': { + 'host': 'localhost', + 'port': 5432, + 'name': 'testdb' + }, + 'features': ['feature1', 'feature2'], + 'debug': True + } + + +@pytest.fixture +def hierarchical_config(yaml_creator): + """Create a hierarchical config structure for testing""" + # Default config + default_config = { + 'env': 'default', + 'database': {'host': 'localhost', 'port': 5432}, + 'features': ['default_feature'], + 'timeout': 30 + } + yaml_creator('default.yaml', default_config) + + # Environment-specific config + env_config = { + 'env': 'production', + 'database': {'host': 'prod-db.example.com'}, + 'features': ['prod_feature'], + 'ssl_enabled': True + } + yaml_creator('production/env.yaml', env_config) + + # Region-specific config + region_config = { + 'region': 'us-east-1', + 'database': {'region': 'us-east-1'}, + 'cdn_endpoint': 'https://cdn-us-east-1.example.com' + } + yaml_creator('production/us-east-1/region.yaml', region_config) + + return { + 'default': default_config, + 'env': env_config, + 'region': region_config + } + + +@pytest.fixture +def mock_aws_credentials(): + """Mock AWS credentials for testing""" + with patch.dict(os.environ, { + 'AWS_ACCESS_KEY_ID': 'test_access_key', + 'AWS_SECRET_ACCESS_KEY': 'test_secret_key', + 'AWS_DEFAULT_REGION': 'us-east-1' + }): + yield + + +@pytest.fixture +def mock_vault_env(): + """Mock Vault environment variables for testing""" + with patch.dict(os.environ, { + 'VAULT_ADDR': 'https://vault.example.com', + 'VAULT_TOKEN': 'test_token', + 'VAULT_USERNAME': 'test_user', + 'VAULT_PASSWORD': 'test_password', + 'VAULT_ROLE': 'test_role', + 'VAULT_MOUNT_POINT': 'kv' + }): + yield + + +@pytest.fixture +def interpolation_config(): + """Configuration with interpolations for testing""" + return { + 'env': 'production', + 'region': 'us-east-1', + 'app_name': 'myapp', + 'database_url': 'db-{{env}}.example.com', + 'full_name': '{{app_name}}-{{env}}-{{region}}', + 'config': { + 'environment': '{{env}}', + 'nested_interpolation': 'Environment is {{env}}' + }, + 'reference_config': '{{config}}' + } + + +@pytest.fixture +def secret_config(): + """Configuration with secret interpolations for testing""" + return { + 'database': { + 'password': '{{ssm.path(/app/db/password).aws_profile(prod)}}', + 'api_key': '{{vault.path(/secret/api).key(key)}}' + }, + 's3_config': { + 'credentials': '{{s3.bucket(secrets).path(creds.json).aws_profile(prod)}}' + }, + 'sops_secret': '{{sops.secret_file(/path/secrets.yaml).secret_key(db_password)}}' + } + + +@pytest.fixture +def filter_config(): + """Configuration with filter rules for testing""" + return { + 'env': 'dev', + 'cluster': 'cluster1', + 'region': 'us-east-1', + 'keep_this': 'should_remain', + 'remove_this': 'should_be_filtered', + 'keep_pattern_match': 'should_remain', + 'tags': { + 'cost_center': '123', + 'team': 'backend' + }, + '_filters': [ + { + 'selector': {'env': 'dev'}, + 'keys': { + 'values': ['keep_this', 'tags'], + 'regex': 'keep_pattern_.*' + } + } + ] + } + + +@pytest.fixture(autouse=True) +def clean_environment(): + """Clean environment variables before each test""" + # Store original environment + original_env = os.environ.copy() + + # Clean up test-related environment variables + test_vars = [ + 'AWS_PROFILE', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', + 'VAULT_ADDR', 'VAULT_TOKEN', 'VAULT_USERNAME', 'VAULT_PASSWORD', + 'VAULT_ROLE', 'VAULT_MOUNT_POINT' + ] + + for var in test_vars: + if var in os.environ: + del os.environ[var] + + yield + + # Restore original environment + os.environ.clear() + os.environ.update(original_env) + + +@pytest.fixture +def mock_terraform_state(): + """Mock Terraform state for testing""" + return { + 'version': 4, + 'terraform_version': '1.0.0', + 'outputs': { + 'vpc_id': { + 'value': 'vpc-12345', + 'type': 'string' + }, + 'subnet_ids': { + 'value': ['subnet-1', 'subnet-2'], + 'type': ['list', 'string'] + }, + 'database_config': { + 'value': { + 'endpoint': 'db.example.com', + 'port': 5432, + 'database_name': 'myapp' + }, + 'type': ['object', { + 'endpoint': 'string', + 'port': 'number', + 'database_name': 'string' + }] + } + } + } + + +# Pytest markers for categorizing tests +def pytest_configure(config): + """Configure pytest with custom markers""" + config.addinivalue_line( + "markers", "unit: mark test as a unit test" + ) + config.addinivalue_line( + "markers", "integration: mark test as an integration test" + ) + config.addinivalue_line( + "markers", "slow: mark test as slow running" + ) + config.addinivalue_line( + "markers", "aws: mark test as requiring AWS credentials" + ) + config.addinivalue_line( + "markers", "vault: mark test as requiring Vault setup" + ) + config.addinivalue_line( + "markers", "sops: mark test as requiring SOPS setup" + ) + + +# Skip tests that require external dependencies if not available +def pytest_collection_modifyitems(config, items): + """Modify test collection to skip tests based on available dependencies""" + import sys + + # Skip AWS tests if boto3 is not available + if 'boto3' not in sys.modules: + skip_aws = pytest.mark.skip(reason="boto3 not available") + for item in items: + if "aws" in item.keywords: + item.add_marker(skip_aws) + + # Skip Vault tests if hvac is not available + if 'hvac' not in sys.modules: + skip_vault = pytest.mark.skip(reason="hvac not available") + for item in items: + if "vault" in item.keywords: + item.add_marker(skip_vault) + + # Skip SOPS tests if sops binary is not available + import shutil + if not shutil.which('sops'): + skip_sops = pytest.mark.skip(reason="sops binary not available") + for item in items: + if "sops" in item.keywords: + item.add_marker(skip_sops) diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 00000000..474648d3 --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1,16 @@ +# Test dependencies for himl +pytest>=6.0.0 +pytest-cov>=2.10.0 +pytest-mock>=3.0.0 +pytest-xdist>=2.0.0 +coverage>=5.0.0 + +# Optional dependencies for testing specific features +boto3>=1.34.6 +hvac>=1.2.1 + +# Development dependencies +black>=21.0.0 +flake8>=3.8.0 +mypy>=0.800 +types-PyYAML>=6.0.0 diff --git a/tests/test_config_generator.py b/tests/test_config_generator.py index 8ba8b982..91bb3fad 100644 --- a/tests/test_config_generator.py +++ b/tests/test_config_generator.py @@ -1,13 +1,363 @@ -#Copyright 2019 Adobe. All rights reserved. -#This file is licensed to you under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. You may obtain a copy -#of the License at http://www.apache.org/licenses/LICENSE-2.0 +# Copyright 2025 Adobe. All rights reserved. +# This file is licensed to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy +# of the License at http://www.apache.org/licenses/LICENSE-2.0 -#Unless required by applicable law or agreed to in writing, software distributed under -#the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS -#OF ANY KIND, either express or implied. See the License for the specific language -#governing permissions and limitations under the License. +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS +# OF ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. +import os +import tempfile +import shutil +import pytest +import yaml +from collections import OrderedDict -def test_something(): - assert 1 == 1 + +from himl import ConfigProcessor, ConfigGenerator + + +class TestConfigProcessor: + """Test cases for ConfigProcessor class""" + + def setup_method(self): + """Set up test fixtures""" + self.temp_dir = tempfile.mkdtemp() + self.config_processor = ConfigProcessor() + + def teardown_method(self): + """Clean up test fixtures""" + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + def create_test_yaml(self, path, content): + """Helper to create test YAML files""" + full_path = os.path.join(self.temp_dir, path) + os.makedirs(os.path.dirname(full_path), exist_ok=True) + with open(full_path, 'w') as f: + yaml.dump(content, f) + return full_path + + def test_simple_config_processing(self): + """Test basic config processing with single file""" + # Create a simple config file + config_data = {'env': 'test', 'debug': True, 'port': 8080} + self.create_test_yaml('config.yaml', config_data) + + # Process the config + result = self.config_processor.process( + cwd=self.temp_dir, + path='config.yaml', + print_data=False + ) + + assert result == config_data + + def test_hierarchical_config_merging(self): + """Test hierarchical config merging""" + # Create default config + default_config = { + 'env': 'default', + 'database': {'host': 'localhost', 'port': 5432}, + 'features': ['feature1', 'feature2'] + } + self.create_test_yaml('default.yaml', default_config) + + # Create environment-specific config + env_config = { + 'env': 'production', + 'database': {'host': 'prod-db.example.com'}, + 'features': ['feature3'] + } + os.makedirs(os.path.join(self.temp_dir, 'production'), exist_ok=True) + self.create_test_yaml('production/env.yaml', env_config) + + # Process the hierarchical config + result = self.config_processor.process( + cwd=self.temp_dir, + path='production', + print_data=False + ) + + # Verify deep merge occurred + assert result['env'] == 'production' + assert result['database']['host'] == 'prod-db.example.com' + assert result['database']['port'] == 5432 # From default + assert 'feature1' in result['features'] + assert 'feature2' in result['features'] + assert 'feature3' in result['features'] + + def test_config_filtering(self): + """Test config filtering functionality""" + config_data = { + 'env': 'test', + 'database': {'host': 'localhost'}, + 'secret_key': 'should_be_filtered', + 'public_key': 'should_remain' + } + self.create_test_yaml('config.yaml', config_data) + + result = self.config_processor.process( + cwd=self.temp_dir, + path='config.yaml', + filters=['env', 'database', 'public_key'], + print_data=False + ) + + assert 'env' in result + assert 'database' in result + assert 'public_key' in result + assert 'secret_key' not in result + + def test_config_exclusion(self): + """Test config key exclusion functionality""" + config_data = { + 'env': 'test', + 'database': {'host': 'localhost'}, + 'secret_key': 'should_be_excluded', + 'public_key': 'should_remain' + } + self.create_test_yaml('config.yaml', config_data) + + result = self.config_processor.process( + cwd=self.temp_dir, + path='config.yaml', + exclude_keys=['secret_key'], + print_data=False + ) + + assert 'env' in result + assert 'database' in result + assert 'public_key' in result + assert 'secret_key' not in result + + def test_enclosing_key_addition(self): + """Test adding enclosing key to config""" + config_data = {'env': 'test', 'debug': True} + self.create_test_yaml('config.yaml', config_data) + + result = self.config_processor.process( + cwd=self.temp_dir, + path='config.yaml', + enclosing_key='application', + print_data=False + ) + + assert 'application' in result + assert result['application'] == config_data + + def test_enclosing_key_removal(self): + """Test removing enclosing key from config""" + config_data = { + 'application': { + 'env': 'test', + 'debug': True + }, + 'other_key': 'value' + } + self.create_test_yaml('config.yaml', config_data) + + result = self.config_processor.process( + cwd=self.temp_dir, + path='config.yaml', + remove_enclosing_key='application', + print_data=False + ) + + assert result == config_data['application'] + + def test_output_formats(self): + """Test different output formats""" + config_data = {'env': 'test', 'debug': True} + self.create_test_yaml('config.yaml', config_data) + + # Test YAML output + yaml_result = self.config_processor.process( + cwd=self.temp_dir, + path='config.yaml', + output_format='yaml', + print_data=False + ) + assert yaml_result == config_data + + # Test JSON output + json_result = self.config_processor.process( + cwd=self.temp_dir, + path='config.yaml', + output_format='json', + print_data=False + ) + assert json_result == config_data + + +class TestConfigGenerator: + """Test cases for ConfigGenerator class""" + + def setup_method(self): + """Set up test fixtures""" + self.temp_dir = tempfile.mkdtemp() + + def teardown_method(self): + """Clean up test fixtures""" + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + def create_test_yaml(self, path, content): + """Helper to create test YAML files""" + full_path = os.path.join(self.temp_dir, path) + os.makedirs(os.path.dirname(full_path), exist_ok=True) + with open(full_path, 'w') as f: + yaml.dump(content, f) + return full_path + + def test_config_generator_initialization(self): + """Test ConfigGenerator initialization""" + generator = ConfigGenerator( + cwd=self.temp_dir, + path='test', + multi_line_string=False, + type_strategies=[(list, ["append_unique"]), (dict, ["merge"])], + fallback_strategies=["override"], + type_conflict_strategies=["override"] + ) + + assert generator.cwd == self.temp_dir + assert generator.path == 'test' + assert isinstance(generator.generated_data, OrderedDict) + + def test_hierarchy_generation(self): + """Test hierarchy generation from directory structure""" + # Create a hierarchy structure + self.create_test_yaml('default.yaml', {'env': 'default'}) + os.makedirs(os.path.join(self.temp_dir, 'production'), exist_ok=True) + self.create_test_yaml('production/env.yaml', {'env': 'production'}) + + generator = ConfigGenerator( + cwd=self.temp_dir, + path='production', + multi_line_string=False, + type_strategies=[(list, ["append_unique"]), (dict, ["merge"])], + fallback_strategies=["override"], + type_conflict_strategies=["override"] + ) + + hierarchy = generator.generate_hierarchy() + assert len(hierarchy) >= 1 + assert any('default.yaml' in str(files) for files in hierarchy) + + def test_yaml_content_loading(self): + """Test YAML content loading""" + config_data = {'test_key': 'test_value', 'number': 42} + yaml_file = self.create_test_yaml('test.yaml', config_data) + + generator = ConfigGenerator( + cwd=self.temp_dir, + path='test', + multi_line_string=False, + type_strategies=[(list, ["append_unique"]), (dict, ["merge"])], + fallback_strategies=["override"], + type_conflict_strategies=["override"] + ) + + content = generator.yaml_get_content(yaml_file) + assert content == config_data + + def test_yaml_merging(self): + """Test YAML merging functionality""" + generator = ConfigGenerator( + cwd=self.temp_dir, + path='test', + multi_line_string=False, + type_strategies=[(list, ["append_unique"]), (dict, ["merge"])], + fallback_strategies=["override"], + type_conflict_strategies=["override"] + ) + + base_config = OrderedDict([('env', 'base'), ('features', ['f1', 'f2'])]) + new_config = {'env': 'new', 'features': ['f3'], 'new_key': 'value'} + + generator.merge_yamls( + base_config, + new_config, + [(list, ["append_unique"]), (dict, ["merge"])], + ["override"], + ["override"] + ) + + assert base_config['env'] == 'new' + assert 'f1' in base_config['features'] + assert 'f2' in base_config['features'] + assert 'f3' in base_config['features'] + assert base_config['new_key'] == 'value' + + def test_output_data_yaml(self): + """Test YAML output formatting""" + generator = ConfigGenerator( + cwd=self.temp_dir, + path='test', + multi_line_string=False, + type_strategies=[(list, ["append_unique"]), (dict, ["merge"])], + fallback_strategies=["override"], + type_conflict_strategies=["override"] + ) + + test_data = {'env': 'test', 'debug': True, 'port': 8080} + yaml_output = generator.output_data(test_data, 'yaml') + + # Parse the YAML output back to verify it's valid + parsed_data = yaml.safe_load(yaml_output) + assert parsed_data == test_data + + def test_output_data_json(self): + """Test JSON output formatting""" + generator = ConfigGenerator( + cwd=self.temp_dir, + path='test', + multi_line_string=False, + type_strategies=[(list, ["append_unique"]), (dict, ["merge"])], + fallback_strategies=["override"], + type_conflict_strategies=["override"] + ) + + test_data = {'env': 'test', 'debug': True, 'port': 8080} + json_output = generator.output_data(test_data, 'json') + + # Parse the JSON output back to verify it's valid + import json + parsed_data = json.loads(json_output) + assert parsed_data == test_data + + def test_invalid_output_format(self): + """Test handling of invalid output format""" + generator = ConfigGenerator( + cwd=self.temp_dir, + path='test', + multi_line_string=False, + type_strategies=[(list, ["append_unique"]), (dict, ["merge"])], + fallback_strategies=["override"], + type_conflict_strategies=["override"] + ) + + test_data = {'env': 'test'} + + with pytest.raises(Exception) as exc_info: + generator.output_data(test_data, 'invalid_format') + + assert "Unknown output format" in str(exc_info.value) + + def test_values_from_dir_path(self): + """Test extracting values from directory path""" + generator = ConfigGenerator( + cwd=self.temp_dir, + path='env=production/region=us-east-1/cluster=web', + multi_line_string=False, + type_strategies=[(list, ["append_unique"]), (dict, ["merge"])], + fallback_strategies=["override"], + type_conflict_strategies=["override"] + ) + + values = generator.get_values_from_dir_path() + expected = {'env': 'production', 'region': 'us-east-1', 'cluster': 'web'} + assert values == expected diff --git a/tests/test_config_merger.py b/tests/test_config_merger.py new file mode 100644 index 00000000..f9369122 --- /dev/null +++ b/tests/test_config_merger.py @@ -0,0 +1,357 @@ +# Copyright 2019 Adobe. All rights reserved. +# This file is licensed to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy +# of the License at http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS +# OF ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +import os +import tempfile +import shutil +import pytest +import yaml +from unittest.mock import patch, MagicMock + +from himl.config_merger import ( + Loader, merge_configs, merge_logic, get_leaf_directories, + get_parser, run +) + + +class TestLoader: + """Test custom YAML Loader with include functionality""" + + def test_loader_initialization(self): + """Test Loader initialization""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write('test: value') + f.flush() + + try: + with open(f.name, 'r') as stream: + loader = Loader(stream) + assert loader._root == os.path.dirname(f.name) + finally: + os.unlink(f.name) + + def test_include_constructor(self): + """Test include constructor functionality""" + # Create a temporary file to include + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as include_file: + include_content = {'included_key': 'included_value', 'nested': {'key': 'value'}} + yaml.dump(include_content, include_file) + include_file.flush() + + try: + # Create main YAML content with include + main_content = f""" +test_key: test_value +included_data: !include {include_file.name} included_key +nested_data: !include {include_file.name} nested.key +full_data: !include {include_file.name} +""" + + # Test loading with custom loader + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as main_file: + main_file.write(main_content) + main_file.flush() + + try: + with open(main_file.name, 'r') as stream: + data = yaml.load(stream, Loader=Loader) + + assert data['test_key'] == 'test_value' + assert data['included_data'] == 'included_value' + assert data['nested_data'] == 'value' + assert data['full_data'] == include_content + finally: + os.unlink(main_file.name) + finally: + os.unlink(include_file.name) + + +class TestConfigMergerFunctions: + """Test config merger utility functions""" + + def setup_method(self): + """Set up test fixtures""" + self.temp_dir = tempfile.mkdtemp() + + def teardown_method(self): + """Clean up test fixtures""" + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + def create_directory_structure(self): + """Create a test directory structure""" + # Create directory structure: env=dev/region=us-east-1/cluster=web + structure = { + 'default.yaml': {'env': 'default', 'region': 'default', 'cluster': 'default'}, + 'env=dev/env.yaml': {'env': 'dev'}, + 'env=dev/region=us-east-1/region.yaml': {'region': 'us-east-1'}, + 'env=dev/region=us-east-1/cluster=web/cluster.yaml': {'cluster': 'web', 'app': 'webapp'} + } + + for path, content in structure.items(): + full_path = os.path.join(self.temp_dir, path) + os.makedirs(os.path.dirname(full_path), exist_ok=True) + with open(full_path, 'w') as f: + yaml.dump(content, f) + + def test_get_leaf_directories(self): + """Test finding leaf directories""" + self.create_directory_structure() + + leaf_dirs = get_leaf_directories(self.temp_dir, ['cluster']) + + assert len(leaf_dirs) == 1 + assert leaf_dirs[0].endswith('env=dev/region=us-east-1/cluster=web') + + def test_get_leaf_directories_multiple_leaves(self): + """Test finding multiple leaf directories""" + self.create_directory_structure() + + # Create another cluster + cluster2_path = os.path.join(self.temp_dir, 'env=dev/region=us-east-1/cluster=api') + os.makedirs(cluster2_path, exist_ok=True) + with open(os.path.join(cluster2_path, 'cluster.yaml'), 'w') as f: + yaml.dump({'cluster': 'api', 'app': 'api-service'}, f) + + leaf_dirs = get_leaf_directories(self.temp_dir, ['cluster']) + + assert len(leaf_dirs) == 2 + cluster_names = [os.path.basename(d).split('=')[1] for d in leaf_dirs] + assert 'web' in cluster_names + assert 'api' in cluster_names + + def test_get_leaf_directories_no_matches(self): + """Test finding leaf directories with no matches""" + self.create_directory_structure() + + leaf_dirs = get_leaf_directories(self.temp_dir, ['nonexistent'], exit_on_empty=False) + + assert len(leaf_dirs) == 0 + + @patch('himl.config_merger.ConfigProcessor') + def test_merge_logic(self, mock_config_processor): + """Test merge logic function""" + self.create_directory_structure() + + # Mock ConfigProcessor + mock_processor = MagicMock() + mock_processor.process.return_value = { + 'env': 'dev', + 'region': 'us-east-1', + 'cluster': 'web', + 'app': 'webapp' + } + + # Create output directory + output_dir = os.path.join(self.temp_dir, 'output') + os.makedirs(output_dir, exist_ok=True) + + config_tuple = ( + mock_processor, + os.path.join(self.temp_dir, 'env=dev/region=us-east-1/cluster=web'), + ['env', 'region', 'cluster'], + output_dir, + None # No filter rules + ) + + merge_logic(config_tuple) + + # Verify output file was created + expected_output = os.path.join(output_dir, 'dev/us-east-1/web.yaml') + assert os.path.exists(expected_output) + + # Verify content + with open(expected_output, 'r') as f: + content = yaml.safe_load(f) + assert content['env'] == 'dev' + assert content['region'] == 'us-east-1' + assert content['cluster'] == 'web' + + @patch('himl.config_merger.ConfigProcessor') + @patch('himl.config_merger.FilterRules') + def test_merge_logic_with_filters(self, mock_filter_rules, mock_config_processor): + """Test merge logic with filter rules""" + self.create_directory_structure() + + # Mock ConfigProcessor + mock_processor = MagicMock() + mock_processor.process.return_value = { + 'env': 'dev', + 'region': 'us-east-1', + 'cluster': 'web', + 'app': 'webapp', + 'remove_me': 'should_be_filtered', + '_filters': [{'selector': {'env': 'dev'}, 'keys': {'values': ['app']}}] + } + + # Mock FilterRules + mock_filter_instance = MagicMock() + mock_filter_rules.return_value = mock_filter_instance + + output_dir = os.path.join(self.temp_dir, 'output') + os.makedirs(output_dir, exist_ok=True) + + config_tuple = ( + mock_processor, + os.path.join(self.temp_dir, 'env=dev/region=us-east-1/cluster=web'), + ['env', 'region', 'cluster'], + output_dir, + '_filters' + ) + + merge_logic(config_tuple) + + # Verify filter was applied + mock_filter_rules.assert_called_once() + mock_filter_instance.run.assert_called_once() + + @patch('himl.config_merger.Pool') + @patch('himl.config_merger.cpu_count') + def test_merge_configs_parallel(self, mock_cpu_count, mock_pool): + """Test merge configs with parallel processing""" + mock_cpu_count.return_value = 4 + mock_pool_instance = MagicMock() + mock_pool.return_value.__enter__.return_value = mock_pool_instance + + directories = ['dir1', 'dir2'] + levels = ['env', 'region'] + output_dir = '/output' + + merge_configs(directories, levels, output_dir, enable_parallel=True, filter_rules=None) + + mock_pool.assert_called_once_with(4) + mock_pool_instance.map.assert_called_once() + + @patch('himl.config_merger.merge_logic') + def test_merge_configs_sequential(self, mock_merge_logic): + """Test merge configs with sequential processing""" + directories = ['dir1', 'dir2'] + levels = ['env', 'region'] + output_dir = '/output' + + merge_configs(directories, levels, output_dir, enable_parallel=False, filter_rules=None) + + assert mock_merge_logic.call_count == 2 + + def test_get_parser(self): + """Test argument parser creation""" + parser = get_parser() + + # Test basic arguments + args = parser.parse_args(['input_dir', '--output-dir', 'output', '--levels', 'env', 'region', + '--leaf-directories', 'cluster']) + assert args.path == 'input_dir' + assert args.output_dir == 'output' + assert args.hierarchy_levels == ['env', 'region'] + + # Test optional arguments + args = parser.parse_args([ + 'input_dir', + '--output-dir', 'output', + '--levels', 'env', 'region', 'cluster', + '--leaf-directories', 'cluster', + '--filter-rules-key', '_filters', + '--enable-parallel' + ]) + + assert args.leaf_directories == ['cluster'] + assert args.filter_rules == '_filters' + assert args.enable_parallel is True + + @patch('himl.config_merger.merge_configs') + @patch('himl.config_merger.get_leaf_directories') + def test_run_function(self, mock_get_leaf_directories, mock_merge_configs): + """Test main run function""" + mock_get_leaf_directories.return_value = ['dir1', 'dir2'] + + args = [ + 'input_dir', + '--output-dir', 'output', + '--levels', 'env', 'region', + '--leaf-directories', 'cluster' + ] + + with patch('sys.argv', ['himl-config-merger'] + args): + run() + + mock_get_leaf_directories.assert_called_once_with('input_dir', ['cluster']) + mock_merge_configs.assert_called_once_with( + ['dir1', 'dir2'], + ['env', 'region'], + 'output', + False, # enable_parallel default + None # filter_rules_key default + ) + + def test_parser_default_values(self): + """Test parser default values""" + parser = get_parser() + + args = parser.parse_args(['input_dir', '--output-dir', 'output', '--levels', 'env', + '--leaf-directories', 'cluster']) + + assert args.leaf_directories == ['cluster'] + assert args.filter_rules is None + assert args.enable_parallel is False + + def test_parser_multiple_levels(self): + """Test parser with multiple levels""" + parser = get_parser() + + args = parser.parse_args([ + 'input_dir', + '--output-dir', 'output', + '--levels', 'env', 'region', 'cluster', 'app', + '--leaf-directories', 'cluster' + ]) + + assert args.hierarchy_levels == ['env', 'region', 'cluster', 'app'] + + def test_parser_multiple_leaf_directories(self): + """Test parser with multiple leaf directories""" + parser = get_parser() + + args = parser.parse_args([ + 'input_dir', + '--output-dir', 'output', + '--levels', 'env', 'region', + '--leaf-directories', 'cluster', 'service' + ]) + + assert args.leaf_directories == ['cluster', 'service'] + + @patch('himl.config_merger.ConfigProcessor') + def test_merge_logic_missing_filter_key(self, mock_config_processor): + """Test merge logic when filter key is missing""" + self.create_directory_structure() + + mock_processor = MagicMock() + mock_processor.process.return_value = { + 'env': 'dev', + 'region': 'us-east-1', + 'cluster': 'web' + # No _filters key + } + + output_dir = os.path.join(self.temp_dir, 'output') + os.makedirs(output_dir, exist_ok=True) + + config_tuple = ( + mock_processor, + os.path.join(self.temp_dir, 'env=dev/region=us-east-1/cluster=web'), + ['env', 'region', 'cluster'], + output_dir, + '_filters' # Filter key that doesn't exist + ) + + with pytest.raises(Exception) as exc_info: + merge_logic(config_tuple) + + assert "Filter rule key '_filters' not found in config" in str(exc_info.value) diff --git a/tests/test_edge_cases.py b/tests/test_edge_cases.py new file mode 100644 index 00000000..f430b7bf --- /dev/null +++ b/tests/test_edge_cases.py @@ -0,0 +1,436 @@ +# Copyright 2019 Adobe. All rights reserved. +# This file is licensed to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy +# of the License at http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS +# OF ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +import os +import tempfile +import shutil +import pytest +import yaml +from himl import ConfigProcessor +from himl.interpolation import InterpolationValidator +from himl.python_compat import iteritems, primitive_types, PY3 + + +class TestEdgeCases: + """Test edge cases and error conditions""" + + def setup_method(self): + """Set up test fixtures""" + self.temp_dir = tempfile.mkdtemp() + self.config_processor = ConfigProcessor() + + def teardown_method(self): + """Clean up test fixtures""" + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + def create_test_yaml(self, path, content): + """Helper to create test YAML files""" + full_path = os.path.join(self.temp_dir, path) + os.makedirs(os.path.dirname(full_path), exist_ok=True) + with open(full_path, 'w') as f: + yaml.dump(content, f) + return full_path + + def test_empty_yaml_file(self): + """Test processing empty YAML file""" + self.create_test_yaml('empty.yaml', {}) + + result = self.config_processor.process( + cwd=self.temp_dir, + path='empty.yaml', + print_data=False + ) + + assert result == {} + + def test_yaml_file_with_null_values(self): + """Test processing YAML with null values""" + config_data = { + 'key1': None, + 'key2': 'value2', + 'nested': { + 'null_key': None, + 'valid_key': 'valid_value' + } + } + self.create_test_yaml('null_values.yaml', config_data) + + result = self.config_processor.process( + cwd=self.temp_dir, + path='null_values.yaml', + print_data=False + ) + + assert result['key1'] is None + assert result['key2'] == 'value2' + assert result['nested']['null_key'] is None + assert result['nested']['valid_key'] == 'valid_value' + + def test_yaml_file_with_special_characters(self): + """Test processing YAML with special characters""" + config_data = { + 'unicode_key': 'value with émojis 🚀', + 'special_chars': 'value with !@#$%^&*()', + 'quotes': 'value with "quotes" and \'apostrophes\'', + 'newlines': 'value with\nnewlines\nand\ttabs' + } + self.create_test_yaml('special_chars.yaml', config_data) + + result = self.config_processor.process( + cwd=self.temp_dir, + path='special_chars.yaml', + print_data=False + ) + + assert result == config_data + + def test_deeply_nested_structure(self): + """Test processing deeply nested YAML structure""" + config_data = { + 'level1': { + 'level2': { + 'level3': { + 'level4': { + 'level5': { + 'deep_value': 'found_it' + } + } + } + } + } + } + self.create_test_yaml('deep_nested.yaml', config_data) + + result = self.config_processor.process( + cwd=self.temp_dir, + path='deep_nested.yaml', + print_data=False + ) + + assert result['level1']['level2']['level3']['level4']['level5']['deep_value'] == 'found_it' + + def test_large_list_processing(self): + """Test processing large lists""" + large_list = [f'item_{i}' for i in range(1000)] + config_data = { + 'large_list': large_list, + 'other_key': 'other_value' + } + self.create_test_yaml('large_list.yaml', config_data) + + result = self.config_processor.process( + cwd=self.temp_dir, + path='large_list.yaml', + print_data=False + ) + + assert len(result['large_list']) == 1000 + assert result['large_list'][0] == 'item_0' + assert result['large_list'][999] == 'item_999' + + def test_circular_interpolation_detection(self): + """Test detection of circular interpolations""" + config_data = { + 'key1': '{{key2}}', + 'key2': '{{key1}}' # Circular reference + } + self.create_test_yaml('circular.yaml', config_data) + + # This should not cause infinite recursion + result = self.config_processor.process( + cwd=self.temp_dir, + path='circular.yaml', + print_data=False, + skip_interpolation_validation=True # Skip validation to avoid exception + ) + + # The interpolations should remain unresolved + assert '{{' in str(result['key1']) or '{{' in str(result['key2']) + + def test_malformed_interpolation_syntax(self): + """Test handling of malformed interpolation syntax""" + config_data = { + 'malformed1': '{{incomplete', + 'malformed2': 'incomplete}}', + 'malformed3': '{{}}', + 'malformed4': '{{.invalid.syntax}}', + 'valid': 'normal_value' + } + self.create_test_yaml('malformed.yaml', config_data) + + result = self.config_processor.process( + cwd=self.temp_dir, + path='malformed.yaml', + print_data=False, + skip_interpolation_validation=True + ) + + # Malformed interpolations should remain unchanged + assert result['malformed1'] == '{{incomplete' + assert result['malformed2'] == 'incomplete}}' + assert result['malformed3'] == '{{}}' + assert result['valid'] == 'normal_value' + + def test_nonexistent_path(self): + """Test processing nonexistent path""" + with pytest.raises(Exception): + self.config_processor.process( + cwd=self.temp_dir, + path='nonexistent/path', + print_data=False + ) + + def test_invalid_yaml_syntax(self): + """Test handling of invalid YAML syntax""" + invalid_yaml_path = os.path.join(self.temp_dir, 'invalid.yaml') + with open(invalid_yaml_path, 'w') as f: + f.write('invalid: yaml: content: [unclosed') + + with pytest.raises(yaml.YAMLError): + self.config_processor.process( + cwd=self.temp_dir, + path='invalid.yaml', + print_data=False + ) + + def test_mixed_data_types(self): + """Test processing mixed data types""" + config_data = { + 'string': 'text', + 'integer': 42, + 'float': 3.14, + 'boolean': True, + 'list': [1, 'two', 3.0, False], + 'dict': {'nested': 'value'}, + 'null': None + } + self.create_test_yaml('mixed_types.yaml', config_data) + + result = self.config_processor.process( + cwd=self.temp_dir, + path='mixed_types.yaml', + print_data=False + ) + + assert result['string'] == 'text' + assert result['integer'] == 42 + assert result['float'] == 3.14 + assert result['boolean'] is True + assert result['list'] == [1, 'two', 3.0, False] + assert result['dict']['nested'] == 'value' + assert result['null'] is None + + def test_unicode_handling(self): + """Test Unicode character handling""" + config_data = { + 'chinese': '你好世界', + 'arabic': 'مرحبا بالعالم', + 'emoji': '🌍🚀⭐', + 'mixed': 'Hello 世界 🌍' + } + self.create_test_yaml('unicode.yaml', config_data) + + result = self.config_processor.process( + cwd=self.temp_dir, + path='unicode.yaml', + print_data=False + ) + + assert result == config_data + + def test_very_long_strings(self): + """Test handling of very long strings""" + long_string = 'x' * 10000 + config_data = { + 'long_string': long_string, + 'normal_key': 'normal_value' + } + self.create_test_yaml('long_strings.yaml', config_data) + + result = self.config_processor.process( + cwd=self.temp_dir, + path='long_strings.yaml', + print_data=False + ) + + assert len(result['long_string']) == 10000 + assert result['normal_key'] == 'normal_value' + + def test_empty_directory_hierarchy(self): + """Test processing empty directory hierarchy""" + empty_dir = os.path.join(self.temp_dir, 'empty_dir') + os.makedirs(empty_dir, exist_ok=True) + + with pytest.raises(Exception): + self.config_processor.process( + cwd=self.temp_dir, + path='empty_dir', + print_data=False + ) + + def test_interpolation_with_missing_keys(self): + """Test interpolation with missing keys""" + config_data = { + 'existing_key': 'existing_value', + 'interpolation': '{{missing.key}}', + 'partial_interpolation': 'prefix-{{missing.key}}-suffix' + } + self.create_test_yaml('missing_keys.yaml', config_data) + + result = self.config_processor.process( + cwd=self.temp_dir, + path='missing_keys.yaml', + print_data=False, + skip_interpolation_validation=True + ) + + # Missing interpolations should remain unresolved + assert result['existing_key'] == 'existing_value' + assert '{{missing.key}}' in result['interpolation'] + + def test_filter_with_nonexistent_keys(self): + """Test filtering with nonexistent keys""" + config_data = { + 'key1': 'value1', + 'key2': 'value2' + } + self.create_test_yaml('filter_test.yaml', config_data) + + result = self.config_processor.process( + cwd=self.temp_dir, + path='filter_test.yaml', + filters=['key1', 'nonexistent_key'], + print_data=False + ) + + # Should only include existing filtered keys + assert 'key1' in result + assert 'key2' not in result + assert 'nonexistent_key' not in result + + def test_exclude_all_keys(self): + """Test excluding all keys""" + config_data = { + 'key1': 'value1', + 'key2': 'value2', + 'key3': 'value3' + } + self.create_test_yaml('exclude_all.yaml', config_data) + + result = self.config_processor.process( + cwd=self.temp_dir, + path='exclude_all.yaml', + exclude_keys=['key1', 'key2', 'key3'], + print_data=False + ) + + assert result == {} + + def test_complex_merge_strategies(self): + """Test complex merge strategies""" + # Create base config + base_config = { + 'list_append': ['item1'], + 'list_override': ['base1', 'base2'], + 'dict_merge': {'key1': 'base_value1', 'key2': 'base_value2'} + } + self.create_test_yaml('default.yaml', base_config) + + # Create override config + override_config = { + 'list_append': ['item2'], + 'list_override': ['override1'], + 'dict_merge': {'key2': 'override_value2', 'key3': 'new_value3'} + } + os.makedirs(os.path.join(self.temp_dir, 'env'), exist_ok=True) + self.create_test_yaml('env/config.yaml', override_config) + + result = self.config_processor.process( + cwd=self.temp_dir, + path='env', + print_data=False + ) + + # Verify merge behavior + assert 'item1' in result['list_append'] + assert 'item2' in result['list_append'] + assert result['dict_merge']['key1'] == 'base_value1' + assert result['dict_merge']['key2'] == 'override_value2' + assert result['dict_merge']['key3'] == 'new_value3' + + +class TestPythonCompatibility: + """Test Python compatibility utilities""" + + def test_iteritems_function(self): + """Test iteritems compatibility function""" + test_dict = {'key1': 'value1', 'key2': 'value2'} + + items = list(iteritems(test_dict)) + + assert len(items) == 2 + assert ('key1', 'value1') in items + assert ('key2', 'value2') in items + + def test_primitive_types(self): + """Test primitive types detection""" + assert isinstance('string', primitive_types) + assert isinstance(42, primitive_types) + assert isinstance(3.14, primitive_types) + assert isinstance(True, primitive_types) + assert not isinstance([], primitive_types) + assert not isinstance({}, primitive_types) + + def test_py3_flag(self): + """Test Python 3 detection flag""" + import sys + expected_py3 = sys.version_info[0] >= 3 + assert PY3 == expected_py3 + + +class TestInterpolationValidatorEdgeCases: + """Test InterpolationValidator edge cases""" + + def test_validator_with_nested_unresolved(self): + """Test validator with nested unresolved interpolations""" + validator = InterpolationValidator() + + data = { + 'level1': { + 'level2': { + 'unresolved': '{{missing.key}}' + } + } + } + + with pytest.raises(Exception) as exc_info: + validator.check_all_interpolations_resolved(data) + + assert 'Interpolation could not be resolved' in str(exc_info.value) + assert '{{missing.key}}' in str(exc_info.value) + + def test_validator_with_list_unresolved(self): + """Test validator with unresolved interpolations in lists""" + validator = InterpolationValidator() + + data = { + 'list_with_interpolation': [ + 'resolved_value', + '{{unresolved.key}}', + 'another_resolved_value' + ] + } + + with pytest.raises(Exception) as exc_info: + validator.check_all_interpolations_resolved(data) + + assert 'Interpolation could not be resolved' in str(exc_info.value) + assert '{{unresolved.key}}' in str(exc_info.value) diff --git a/tests/test_filter_rules.py b/tests/test_filter_rules.py new file mode 100644 index 00000000..efb98edb --- /dev/null +++ b/tests/test_filter_rules.py @@ -0,0 +1,361 @@ +# Copyright 2019 Adobe. All rights reserved. +# This file is licensed to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy +# of the License at http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS +# OF ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +import pytest +from himl.filter_rules import FilterRules + + +class TestFilterRules: + """Test FilterRules class""" + + def test_initialization(self): + """Test FilterRules initialization""" + rules = [ + { + 'selector': {'env': 'dev'}, + 'keys': {'values': ['key1', 'key2']} + } + ] + levels = ['env', 'region', 'cluster'] + + filter_rules = FilterRules(rules, levels) + + assert filter_rules.rules == rules + assert filter_rules.levels == levels + + def test_simple_value_filter(self): + """Test filtering with simple value selector""" + rules = [ + { + 'selector': {'env': 'dev'}, + 'keys': {'values': ['keep_me', 'also_keep']} + } + ] + levels = ['env', 'region'] + filter_rules = FilterRules(rules, levels) + + output = { + 'env': 'dev', + 'region': 'us-east-1', + 'keep_me': 'value1', + 'also_keep': 'value2', + 'remove_me': 'value3', + 'remove_this_too': 'value4' + } + + filter_rules.run(output) + + # Should keep level keys and specified keys + assert 'env' in output + assert 'region' in output + assert 'keep_me' in output + assert 'also_keep' in output + assert 'remove_me' not in output + assert 'remove_this_too' not in output + + def test_regex_filter(self): + """Test filtering with regex selector""" + rules = [ + { + 'selector': {'env': 'dev'}, + 'keys': {'regex': 'keep_.*'} + } + ] + levels = ['env', 'region'] + filter_rules = FilterRules(rules, levels) + + output = { + 'env': 'dev', + 'region': 'us-east-1', + 'keep_me': 'value1', + 'keep_this': 'value2', + 'remove_me': 'value3', + 'also_remove': 'value4' + } + + filter_rules.run(output) + + assert 'env' in output + assert 'region' in output + assert 'keep_me' in output + assert 'keep_this' in output + assert 'remove_me' not in output + assert 'also_remove' not in output + + def test_combined_values_and_regex_filter(self): + """Test filtering with both values and regex""" + rules = [ + { + 'selector': {'env': 'dev'}, + 'keys': { + 'values': ['explicit_keep'], + 'regex': 'pattern_.*' + } + } + ] + levels = ['env'] + filter_rules = FilterRules(rules, levels) + + output = { + 'env': 'dev', + 'explicit_keep': 'value1', + 'pattern_match': 'value2', + 'pattern_another': 'value3', + 'remove_me': 'value4' + } + + filter_rules.run(output) + + assert 'env' in output + assert 'explicit_keep' in output + assert 'pattern_match' in output + assert 'pattern_another' in output + assert 'remove_me' not in output + + def test_regex_selector_match(self): + """Test regex matching in selector""" + rules = [ + { + 'selector': {'cluster': 'cluster.*'}, + 'keys': {'values': ['keep_me']} + } + ] + levels = ['env', 'cluster'] + filter_rules = FilterRules(rules, levels) + + output = { + 'env': 'dev', + 'cluster': 'cluster1', + 'keep_me': 'value1', + 'remove_me': 'value2' + } + + filter_rules.run(output) + + assert 'env' in output + assert 'cluster' in output + assert 'keep_me' in output + assert 'remove_me' not in output + + def test_multiple_selector_conditions(self): + """Test multiple conditions in selector""" + rules = [ + { + 'selector': { + 'env': 'dev', + 'region': 'us-.*' + }, + 'keys': {'values': ['keep_me']} + } + ] + levels = ['env', 'region'] + filter_rules = FilterRules(rules, levels) + + # Should match - both conditions satisfied + output = { + 'env': 'dev', + 'region': 'us-east-1', + 'keep_me': 'value1', + 'remove_me': 'value2' + } + + filter_rules.run(output) + + assert 'keep_me' in output + assert 'remove_me' not in output + + def test_selector_no_match(self): + """Test when selector doesn't match""" + rules = [ + { + 'selector': {'env': 'prod'}, + 'keys': {'values': ['keep_me']} + } + ] + levels = ['env'] + filter_rules = FilterRules(rules, levels) + + output = { + 'env': 'dev', # Doesn't match selector + 'keep_me': 'value1', + 'remove_me': 'value2' + } + + filter_rules.run(output) + + # Since selector doesn't match, all non-level keys should be removed + assert 'env' in output + assert 'keep_me' not in output + assert 'remove_me' not in output + + def test_multiple_rules(self): + """Test multiple filter rules""" + rules = [ + { + 'selector': {'env': 'dev'}, + 'keys': {'values': ['dev_specific']} + }, + { + 'selector': {'cluster': 'cluster1'}, + 'keys': {'values': ['cluster_specific']} + } + ] + levels = ['env', 'cluster'] + filter_rules = FilterRules(rules, levels) + + output = { + 'env': 'dev', + 'cluster': 'cluster1', + 'dev_specific': 'value1', + 'cluster_specific': 'value2', + 'remove_me': 'value3' + } + + filter_rules.run(output) + + assert 'env' in output + assert 'cluster' in output + assert 'dev_specific' in output + assert 'cluster_specific' in output + assert 'remove_me' not in output + + def test_missing_selector_key(self): + """Test selector with missing key in output""" + rules = [ + { + 'selector': {'missing_key': 'value'}, + 'keys': {'values': ['keep_me']} + } + ] + levels = ['env'] + filter_rules = FilterRules(rules, levels) + + output = { + 'env': 'dev', + 'keep_me': 'value1', + 'remove_me': 'value2' + } + + filter_rules.run(output) + + # Selector should not match due to missing key + assert 'env' in output + assert 'keep_me' not in output + assert 'remove_me' not in output + + def test_invalid_selector_type(self): + """Test invalid selector type""" + rules = [ + { + 'selector': 'invalid_selector', # Should be dict + 'keys': {'values': ['keep_me']} + } + ] + levels = ['env'] + filter_rules = FilterRules(rules, levels) + + output = {'env': 'dev', 'keep_me': 'value1'} + + with pytest.raises(Exception) as exc_info: + filter_rules.run(output) + + assert "Filter selector must be a dictionary" in str(exc_info.value) + + def test_empty_rules(self): + """Test with empty rules list""" + rules = [] + levels = ['env'] + filter_rules = FilterRules(rules, levels) + + output = { + 'env': 'dev', + 'remove_me': 'value1', + 'also_remove': 'value2' + } + + filter_rules.run(output) + + # With no rules, all non-level keys should be removed + assert 'env' in output + assert 'remove_me' not in output + assert 'also_remove' not in output + + def test_match_method_direct(self): + """Test match method directly""" + filter_rules = FilterRules([], []) + + output = {'env': 'dev', 'region': 'us-east-1'} + + # Exact match + assert filter_rules.match(output, {'env': 'dev'}) is True + + # Regex match + assert filter_rules.match(output, {'region': 'us-.*'}) is True + + # No match + assert filter_rules.match(output, {'env': 'prod'}) is False + + # Missing key + assert filter_rules.match(output, {'missing': 'value'}) is False + + def test_complex_regex_patterns(self): + """Test complex regex patterns""" + rules = [ + { + 'selector': {'env': 'dev'}, + 'keys': {'regex': '^(keep|save)_.*$'} + } + ] + levels = ['env'] + filter_rules = FilterRules(rules, levels) + + output = { + 'env': 'dev', + 'keep_this': 'value1', + 'save_that': 'value2', + 'keep_me_too': 'value3', + 'remove_this': 'value4', + 'also_remove': 'value5' + } + + filter_rules.run(output) + + assert 'env' in output + assert 'keep_this' in output + assert 'save_that' in output + assert 'keep_me_too' in output + assert 'remove_this' not in output + assert 'also_remove' not in output + + def test_preserve_level_keys(self): + """Test that level keys are always preserved""" + rules = [ + { + 'selector': {'env': 'dev'}, + 'keys': {'values': []} # Empty values list + } + ] + levels = ['env', 'region', 'cluster'] + filter_rules = FilterRules(rules, levels) + + output = { + 'env': 'dev', + 'region': 'us-east-1', + 'cluster': 'web', + 'remove_me': 'value1' + } + + filter_rules.run(output) + + # Level keys should always be preserved + assert 'env' in output + assert 'region' in output + assert 'cluster' in output + assert 'remove_me' not in output diff --git a/tests/test_inject_env.py b/tests/test_inject_env.py new file mode 100644 index 00000000..f72a1b67 --- /dev/null +++ b/tests/test_inject_env.py @@ -0,0 +1,245 @@ +# Copyright 2019 Adobe. All rights reserved. +# This file is licensed to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy +# of the License at http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS +# OF ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +import os + +from unittest.mock import patch + +from himl.inject_env import EnvVarInjector +from himl.interpolation import EnvVarResolver, EnvVarInterpolationsResolver + + +class TestEnvVarInjector: + """Test EnvVarInjector class""" + + def setup_method(self): + """Set up test fixtures""" + self.injector = EnvVarInjector() + + def test_is_interpolation(self): + """Test interpolation detection""" + assert self.injector.is_interpolation('{{env(HOME)}}') + assert self.injector.is_interpolation('{{env(PATH)}}') + assert not self.injector.is_interpolation('not an interpolation') + assert not self.injector.is_interpolation('{{incomplete') + assert not self.injector.is_interpolation('incomplete}}') + + def test_is_env_interpolation(self): + """Test environment variable interpolation detection""" + assert self.injector.is_env_interpolation('env(HOME)') + assert self.injector.is_env_interpolation('env(PATH)') + assert self.injector.is_env_interpolation('env(MY_VAR)') + assert not self.injector.is_env_interpolation('ssm(path)') + assert not self.injector.is_env_interpolation('vault(path)') + assert not self.injector.is_env_interpolation('env') + assert not self.injector.is_env_interpolation('env()') + + def test_inject_env_var_non_interpolation(self): + """Test that non-interpolations are returned unchanged""" + result = self.injector.inject_env_var('normal string') + assert result == 'normal string' + + result = self.injector.inject_env_var('{{incomplete') + assert result == '{{incomplete' + + def test_inject_env_var_non_env_interpolation(self): + """Test that non-env interpolations are returned unchanged""" + result = self.injector.inject_env_var('{{ssm.path(/secret)}}') + assert result == '{{ssm.path(/secret)}}' + + result = self.injector.inject_env_var('{{vault.path(/secret)}}') + assert result == '{{vault.path(/secret)}}' + + @patch.dict(os.environ, {'TEST_VAR': 'test_value'}) + def test_inject_env_var_existing_variable(self): + """Test injection of existing environment variable""" + result = self.injector.inject_env_var('{{env(TEST_VAR)}}') + assert result == 'test_value' + + @patch.dict(os.environ, {}, clear=True) + def test_inject_env_var_missing_variable(self): + """Test injection of missing environment variable""" + result = self.injector.inject_env_var('{{env(MISSING_VAR)}}') + assert result is None + + @patch.dict(os.environ, {'HOME': '/home/user', 'USER': 'testuser'}) + def test_inject_env_var_common_variables(self): + """Test injection of common environment variables""" + result = self.injector.inject_env_var('{{env(HOME)}}') + assert result == '/home/user' + + result = self.injector.inject_env_var('{{env(USER)}}') + assert result == 'testuser' + + @patch.dict(os.environ, {'EMPTY_VAR': ''}) + def test_inject_env_var_empty_variable(self): + """Test injection of empty environment variable""" + result = self.injector.inject_env_var('{{env(EMPTY_VAR)}}') + assert result == '' + + @patch.dict(os.environ, {'NUMERIC_VAR': '12345'}) + def test_inject_env_var_numeric_value(self): + """Test injection of numeric environment variable""" + result = self.injector.inject_env_var('{{env(NUMERIC_VAR)}}') + assert result == '12345' + assert isinstance(result, str) + + @patch.dict(os.environ, {'SPECIAL_CHARS': 'value with spaces and !@#$%'}) + def test_inject_env_var_special_characters(self): + """Test injection of environment variable with special characters""" + result = self.injector.inject_env_var('{{env(SPECIAL_CHARS)}}') + assert result == 'value with spaces and !@#$%' + + def test_inject_env_var_malformed_interpolation(self): + """Test handling of malformed interpolations""" + result = self.injector.inject_env_var('{{env(}}') + assert result == '{{env(}}' + + result = self.injector.inject_env_var('{{env)}}') + assert result == '{{env)}}' + + result = self.injector.inject_env_var('{{env()}}') + assert result == '{{env()}}' + + @patch.dict(os.environ, {'VAR_WITH_UNDERSCORES': 'underscore_value'}) + def test_inject_env_var_with_underscores(self): + """Test injection of environment variable with underscores""" + result = self.injector.inject_env_var('{{env(VAR_WITH_UNDERSCORES)}}') + assert result == 'underscore_value' + + @patch.dict(os.environ, {'VAR123': 'alphanumeric_value'}) + def test_inject_env_var_alphanumeric(self): + """Test injection of alphanumeric environment variable""" + result = self.injector.inject_env_var('{{env(VAR123)}}') + assert result == 'alphanumeric_value' + + +class TestEnvVarResolver: + """Test EnvVarResolver class""" + + def setup_method(self): + """Set up test fixtures""" + self.resolver = EnvVarResolver() + + @patch.dict(os.environ, {'TEST_ENV': 'test_value', 'ANOTHER_ENV': 'another_value'}) + def test_resolve_env_vars_simple(self): + """Test simple environment variable resolution""" + data = { + 'env_var': '{{env(TEST_ENV)}}', + 'another_var': '{{env(ANOTHER_ENV)}}', + 'normal_var': 'normal_value' + } + + result = self.resolver.resolve_env_vars(data) + + assert result['env_var'] == 'test_value' + assert result['another_var'] == 'another_value' + assert result['normal_var'] == 'normal_value' + + @patch.dict(os.environ, {'HOME': '/home/user'}) + def test_resolve_env_vars_nested(self): + """Test environment variable resolution in nested structures""" + data = { + 'config': { + 'home_dir': '{{env(HOME)}}', + 'nested': { + 'path': '{{env(HOME)}}/config' + } + }, + 'list_with_env': [ + '{{env(HOME)}}/file1', + '{{env(HOME)}}/file2' + ] + } + + result = self.resolver.resolve_env_vars(data) + + assert result['config']['home_dir'] == '/home/user' + assert result['config']['nested']['path'] == '/home/user/config' + assert result['list_with_env'][0] == '/home/user/file1' + assert result['list_with_env'][1] == '/home/user/file2' + + @patch.dict(os.environ, {}, clear=True) + def test_resolve_env_vars_missing(self): + """Test resolution with missing environment variables""" + data = { + 'missing_var': '{{env(MISSING_VAR)}}', + 'normal_var': 'normal_value' + } + + result = self.resolver.resolve_env_vars(data) + + assert result['missing_var'] is None + assert result['normal_var'] == 'normal_value' + + @patch.dict(os.environ, {'MIXED_VAR': 'mixed_value'}) + def test_resolve_env_vars_mixed_content(self): + """Test resolution with mixed content""" + data = { + 'mixed': 'prefix-{{env(MIXED_VAR)}}-suffix', + 'pure_env': '{{env(MIXED_VAR)}}', + 'no_env': 'no environment variables here' + } + + result = self.resolver.resolve_env_vars(data) + + # Note: The actual behavior depends on the implementation + # This test assumes the resolver handles partial interpolations + assert 'mixed_value' in str(result['mixed']) or result['mixed'] == 'prefix-mixed_value-suffix' + assert result['pure_env'] == 'mixed_value' + assert result['no_env'] == 'no environment variables here' + + +class TestEnvVarInterpolationsResolver: + """Test EnvVarInterpolationsResolver class""" + + def setup_method(self): + """Set up test fixtures""" + from himl.inject_env import EnvVarInjector + self.injector = EnvVarInjector() + self.resolver = EnvVarInterpolationsResolver(self.injector) + + @patch.dict(os.environ, {'TEST_VAR': 'test_value'}) + def test_resolve_interpolations(self): + """Test interpolation resolution""" + data = { + 'env_var': '{{env(TEST_VAR)}}', + 'normal_var': 'normal_value', + 'nested': { + 'env_nested': '{{env(TEST_VAR)}}' + } + } + + self.resolver.resolve_interpolations(data) + + assert data['env_var'] == 'test_value' + assert data['normal_var'] == 'normal_value' + assert data['nested']['env_nested'] == 'test_value' + + @patch.dict(os.environ, {'PATH_VAR': '/usr/bin'}) + def test_do_resolve_interpolation(self): + """Test individual interpolation resolution""" + result = self.resolver.do_resolve_interpolation('{{env(PATH_VAR)}}') + assert result == '/usr/bin' + + result = self.resolver.do_resolve_interpolation('not an interpolation') + assert result == 'not an interpolation' + + @patch.dict(os.environ, {}, clear=True) + def test_resolve_missing_env_var(self): + """Test resolution of missing environment variable""" + result = self.resolver.do_resolve_interpolation('{{env(MISSING)}}') + assert result is None + + @patch.dict(os.environ, {'COMPLEX_VAR': 'complex/path/value'}) + def test_resolve_complex_env_var(self): + """Test resolution of complex environment variable""" + result = self.resolver.do_resolve_interpolation('{{env(COMPLEX_VAR)}}') + assert result == 'complex/path/value' diff --git a/tests/test_inject_secrets.py b/tests/test_inject_secrets.py new file mode 100644 index 00000000..e428db8b --- /dev/null +++ b/tests/test_inject_secrets.py @@ -0,0 +1,225 @@ +# Copyright 2019 Adobe. All rights reserved. +# This file is licensed to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy +# of the License at http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS +# OF ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +import pytest +from unittest.mock import patch + +from himl.inject_secrets import SecretInjector + + +class TestSecretInjector: + """Test SecretInjector class""" + + def setup_method(self): + """Set up test fixtures""" + self.injector = SecretInjector(default_aws_profile='test-profile') + + def test_is_interpolation(self): + """Test interpolation detection""" + assert self.injector.is_interpolation('{{ssm.path(/my/secret)}}') + assert self.injector.is_interpolation('{{vault.path(/secret)}}') + assert not self.injector.is_interpolation('not an interpolation') + assert not self.injector.is_interpolation('{{incomplete') + assert not self.injector.is_interpolation('incomplete}}') + + def test_split_dot_not_within_parentheses(self): + """Test splitting on dots outside parentheses""" + # Test simple case + result = self.injector.split_dot_not_within_parentheses('ssm.path(/my/secret)') + assert result == ['ssm', 'path(/my/secret)'] + + # Test complex case with multiple parameters + result = self.injector.split_dot_not_within_parentheses('ssm.path(/my/secret).aws_profile(test)') + assert result == ['ssm', 'path(/my/secret)', 'aws_profile(test)'] + + # Test with dots inside parentheses (should not split) + result = self.injector.split_dot_not_within_parentheses('ssm.path(/my.secret.path)') + assert result == ['ssm', 'path(/my.secret.path)'] + + def test_inject_secret_non_interpolation(self): + """Test that non-interpolations are returned unchanged""" + result = self.injector.inject_secret('normal string') + assert result == 'normal string' + + result = self.injector.inject_secret('{{incomplete') + assert result == '{{incomplete' + + @patch.object(SecretInjector, 'split_dot_not_within_parentheses') + def test_inject_secret_insufficient_parts(self, mock_split): + """Test handling of insufficient parts after splitting""" + mock_split.return_value = ['ssm'] # Only one part + + result = self.injector.inject_secret('{{ssm}}') + assert result == '{{ssm}}' + + def test_inject_secret_ssm_format(self): + """Test SSM secret format parsing""" + with patch.object(self.injector.resolver, 'supports', return_value=True), \ + patch.object(self.injector.resolver, 'resolve', return_value='secret_value'): + + result = self.injector.inject_secret('{{ssm.path(/my/secret).aws_profile(test)}}') + + assert result == 'secret_value' + self.injector.resolver.resolve.assert_called_once_with( + 'ssm', + { + 'ssm': None, + 'path': '/my/secret', + 'aws_profile': 'test' + } + ) + + def test_inject_secret_vault_format(self): + """Test Vault secret format parsing""" + with patch.object(self.injector.resolver, 'supports', return_value=True), \ + patch.object(self.injector.resolver, 'resolve', return_value={'key': 'value'}): + + result = self.injector.inject_secret('{{vault.path(/secret/path)}}') + + assert result == {'key': 'value'} + self.injector.resolver.resolve.assert_called_once_with( + 'vault', + { + 'vault': None, + 'path': '/secret/path' + } + ) + + def test_inject_secret_s3_format(self): + """Test S3 secret format parsing""" + with patch.object(self.injector.resolver, 'supports', return_value=True), \ + patch.object(self.injector.resolver, 'resolve', return_value='file_content'): + + result = self.injector.inject_secret('{{s3.bucket(my-bucket).path(file.txt).base64encode(true)}}') + + assert result == 'file_content' + self.injector.resolver.resolve.assert_called_once_with( + 's3', + { + 's3': None, + 'bucket': 'my-bucket', + 'path': 'file.txt', + 'base64encode': 'true' + } + ) + + def test_inject_secret_sops_format(self): + """Test SOPS secret format parsing""" + with patch.object(self.injector.resolver, 'supports', return_value=True), \ + patch.object(self.injector.resolver, 'resolve', return_value='decrypted_value'): + + result = self.injector.inject_secret('{{sops.secret_file(/path/to/secrets.yaml).secret_key(my_key)}}') + + assert result == 'decrypted_value' + self.injector.resolver.resolve.assert_called_once_with( + 'sops', + { + 'sops': None, + 'secret_file': '/path/to/secrets.yaml', + 'secret_key': 'my_key' + } + ) + + def test_inject_secret_unsupported_type(self): + """Test handling of unsupported secret types""" + with patch.object(self.injector.resolver, 'supports', return_value=False): + + result = self.injector.inject_secret('{{unsupported.path(/secret)}}') + + assert result == '{{unsupported.path(/secret)}}' + + def test_inject_secret_parameter_without_parentheses(self): + """Test parsing parameters without parentheses""" + with patch.object(self.injector.resolver, 'supports', return_value=True), \ + patch.object(self.injector.resolver, 'resolve', return_value='secret_value'): + + result = self.injector.inject_secret('{{ssm.path(/my/secret).decrypt}}') + + assert result == 'secret_value' + self.injector.resolver.resolve.assert_called_once_with( + 'ssm', + { + 'ssm': None, + 'path': '/my/secret', + 'decrypt': None + } + ) + + def test_inject_secret_caching(self): + """Test that secret injection uses caching""" + with patch.object(self.injector.resolver, 'supports', return_value=True), \ + patch.object(self.injector.resolver, 'resolve', return_value='secret_value') as mock_resolve: + + # Call the same secret twice + secret_interpolation = '{{ssm.path(/my/secret)}}' + result1 = self.injector.inject_secret(secret_interpolation) + result2 = self.injector.inject_secret(secret_interpolation) + + assert result1 == 'secret_value' + assert result2 == 'secret_value' + + # Due to LRU cache, resolve should only be called once + assert mock_resolve.call_count == 1 + + def test_inject_secret_complex_path(self): + """Test injection with complex paths containing special characters""" + with patch.object(self.injector.resolver, 'supports', return_value=True), \ + patch.object(self.injector.resolver, 'resolve', return_value='secret_value'): + + result = self.injector.inject_secret('{{ssm.path(/app/env-prod/db.password)}}') + + assert result == 'secret_value' + self.injector.resolver.resolve.assert_called_once_with( + 'ssm', + { + 'ssm': None, + 'path': '/app/env-prod/db.password' + } + ) + + def test_inject_secret_multiple_parameters(self): + """Test injection with multiple parameters""" + with patch.object(self.injector.resolver, 'supports', return_value=True), \ + patch.object(self.injector.resolver, 'resolve', return_value='secret_value'): + + result = self.injector.inject_secret( + '{{ssm.path(/my/secret).aws_profile(prod).region_name(us-west-2)}}' + ) + + assert result == 'secret_value' + self.injector.resolver.resolve.assert_called_once_with( + 'ssm', + { + 'ssm': None, + 'path': '/my/secret', + 'aws_profile': 'prod', + 'region_name': 'us-west-2' + } + ) + + def test_inject_secret_resolver_exception(self): + """Test handling of resolver exceptions""" + with patch.object(self.injector.resolver, 'supports', return_value=True), \ + patch.object(self.injector.resolver, 'resolve', side_effect=Exception('Resolver error')): + + with pytest.raises(Exception) as exc_info: + self.injector.inject_secret('{{ssm.path(/my/secret)}}') + + assert 'Resolver error' in str(exc_info.value) + + def test_inject_secret_empty_interpolation(self): + """Test handling of empty interpolation""" + result = self.injector.inject_secret('{{}}') + assert result == '{{}}' # Should remain unchanged + + def test_inject_secret_malformed_interpolation(self): + """Test handling of malformed interpolations""" + result = self.injector.inject_secret('{{ssm.path(}}') + assert result == '{{ssm.path(}}' # Should remain unchanged due to malformed parentheses diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py new file mode 100644 index 00000000..17493f6e --- /dev/null +++ b/tests/test_interpolation.py @@ -0,0 +1,295 @@ +# Copyright 2025 Adobe. All rights reserved. +# This file is licensed to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy +# of the License at http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS +# OF ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +import pytest +from himl.interpolation import ( + InterpolationResolver, EscapingResolver, InterpolationValidator, + FromDictInjector, FullBlobInjector, + is_interpolation, is_escaped_interpolation, is_full_interpolation, + is_fully_escaped_interpolation, remove_white_spaces, + replace_parent_working_directory +) + + +class TestInterpolationUtilities: + """Test utility functions for interpolation""" + + def test_is_interpolation(self): + """Test interpolation detection""" + assert is_interpolation("{{env.name}}") + assert is_interpolation("prefix {{env.name}} suffix") + assert not is_interpolation("no interpolation") + assert not is_interpolation("{{`escaped`}}") + assert not is_interpolation(123) + assert not is_interpolation(None) + + def test_is_escaped_interpolation(self): + """Test escaped interpolation detection""" + assert is_escaped_interpolation("{{`escaped`}}") + assert is_escaped_interpolation("prefix {{`escaped`}} suffix") + assert not is_escaped_interpolation("{{normal}}") + assert not is_escaped_interpolation("no interpolation") + + def test_is_full_interpolation(self): + """Test full interpolation detection""" + assert is_full_interpolation("{{env.name}}") + assert not is_full_interpolation("prefix {{env.name}} suffix") + assert not is_full_interpolation("{{`escaped`}}") + assert not is_full_interpolation("no interpolation") + + def test_is_fully_escaped_interpolation(self): + """Test fully escaped interpolation detection""" + assert is_fully_escaped_interpolation("{{`escaped`}}") + assert not is_fully_escaped_interpolation("prefix {{`escaped`}} suffix") + assert not is_fully_escaped_interpolation("{{normal}}") + + def test_remove_white_spaces(self): + """Test whitespace removal""" + assert remove_white_spaces(" hello world ") == "helloworld" + assert remove_white_spaces("no spaces") == "nospaces" + assert remove_white_spaces("") == "" + + def test_replace_parent_working_directory(self): + """Test CWD replacement""" + result = replace_parent_working_directory("{{cwd}}/config", "/home/user") + assert result == "/home/user/config" + + result = replace_parent_working_directory("no cwd here", "/home/user") + assert result == "no cwd here" + + +class TestFromDictInjector: + """Test FromDictInjector class""" + + def test_simple_interpolation_resolve(self): + """Test simple value interpolation""" + injector = FromDictInjector() + data = {'env': {'name': 'production'}, 'port': 8080} + + result = injector.resolve("{{env.name}}", data) + assert result == "production" + + result = injector.resolve("Environment: {{env.name}}", data) + assert result == "Environment: production" + + def test_numeric_interpolation_resolve(self): + """Test numeric value interpolation""" + injector = FromDictInjector() + data = {'config': {'port': 8080, 'enabled': True}} + + result = injector.resolve("{{config.port}}", data) + assert result == 8080 + + result = injector.resolve("{{config.enabled}}", data) + assert result is True + + def test_nested_interpolation_resolve(self): + """Test deeply nested interpolation""" + injector = FromDictInjector() + data = { + 'app': { + 'database': { + 'connection': { + 'host': 'db.example.com' + } + } + } + } + + result = injector.resolve("{{app.database.connection.host}}", data) + assert result == "db.example.com" + + def test_missing_key_interpolation(self): + """Test interpolation with missing keys""" + injector = FromDictInjector() + data = {'env': {'name': 'production'}} + + result = injector.resolve("{{missing.key}}", data) + assert result == "{{missing.key}}" # Should remain unchanged + + def test_multiple_interpolations(self): + """Test multiple interpolations in one string""" + injector = FromDictInjector() + data = {'env': 'prod', 'region': 'us-east-1'} + + result = injector.resolve("{{env}}-{{region}}", data) + assert result == "prod-us-east-1" + + def test_parse_leaves(self): + """Test parse_leaves method""" + injector = FromDictInjector() + data = { + 'level1': { + 'level2': { + 'value': 'test' + }, + 'simple': 'value' + }, + 'root': 'root_value' + } + + injector.parse_leaves(data, "") + + assert 'level1.level2.value' in injector.results + assert injector.results['level1.level2.value'] == 'test' + assert 'level1.simple' in injector.results + assert injector.results['level1.simple'] == 'value' + assert 'root' in injector.results + assert injector.results['root'] == 'root_value' + + +class TestFullBlobInjector: + """Test FullBlobInjector class""" + + def test_full_blob_injection(self): + """Test full blob injection""" + injector = FullBlobInjector() + data = { + 'database': { + 'host': 'localhost', + 'port': 5432 + } + } + + result = injector.resolve("{{database}}", data) + assert result == data['database'] + + def test_partial_interpolation_unchanged(self): + """Test that partial interpolations are unchanged""" + injector = FullBlobInjector() + data = {'env': 'production'} + + result = injector.resolve("Environment: {{env}}", data) + assert result == "Environment: {{env}}" + + def test_missing_key_unchanged(self): + """Test that missing keys remain unchanged""" + injector = FullBlobInjector() + data = {'env': 'production'} + + result = injector.resolve("{{missing}}", data) + assert result == "{{missing}}" + + def test_nested_blob_injection(self): + """Test nested blob injection""" + injector = FullBlobInjector() + data = { + 'app': { + 'config': { + 'database': { + 'host': 'localhost', + 'port': 5432 + } + } + } + } + + result = injector.resolve("{{app.config.database}}", data) + assert result == data['app']['config']['database'] + + +class TestInterpolationResolver: + """Test InterpolationResolver class""" + + def test_resolve_interpolations(self): + """Test interpolation resolution""" + resolver = InterpolationResolver() + data = { + 'env': 'production', + 'database_url': 'db-{{env}}.example.com', + 'config': { + 'environment': '{{env}}' + } + } + + result = resolver.resolve_interpolations(data) + + assert result['database_url'] == 'db-production.example.com' + assert result['config']['environment'] == 'production' + + def test_complex_interpolation_resolution(self): + """Test complex interpolation scenarios""" + resolver = InterpolationResolver() + data = { + 'env': 'prod', + 'region': 'us-east-1', + 'cluster': 'web', + 'full_name': '{{env}}-{{region}}-{{cluster}}', + 'nested': { + 'value': 'Environment is {{env}}' + }, + 'reference': { + 'to_nested': '{{nested}}' + } + } + + result = resolver.resolve_interpolations(data) + + assert result['full_name'] == 'prod-us-east-1-web' + assert result['nested']['value'] == 'Environment is prod' + assert result['reference']['to_nested'] == data['nested'] + + +class TestInterpolationValidator: + """Test InterpolationValidator class""" + + def test_valid_interpolations_pass(self): + """Test that resolved interpolations pass validation""" + validator = InterpolationValidator() + data = { + 'env': 'production', + 'database_url': 'db-production.example.com' + } + + # Should not raise an exception + validator.check_all_interpolations_resolved(data) + + def test_unresolved_interpolations_fail(self): + """Test that unresolved interpolations fail validation""" + validator = InterpolationValidator() + data = { + 'env': 'production', + 'database_url': 'db-{{unresolved}}.example.com' + } + + with pytest.raises(Exception) as exc_info: + validator.check_all_interpolations_resolved(data) + + assert "Interpolation could not be resolved" in str(exc_info.value) + assert "{{unresolved}}" in str(exc_info.value) + + def test_escaped_interpolations_pass(self): + """Test that escaped interpolations pass validation""" + validator = InterpolationValidator() + data = { + 'env': 'production', + 'template': 'Use {{`variable`}} for templating' + } + + # Should not raise an exception + validator.check_all_interpolations_resolved(data) + + +class TestEscapingResolver: + """Test EscapingResolver class""" + + def test_resolve_escaping(self): + """Test escaping resolution""" + resolver = EscapingResolver() + data = { + 'template': '{{`escaped_value`}}', + 'normal': 'normal_value' + } + + result = resolver.resolve_escaping(data) + + # The actual escaping logic would be implemented in DictEscapingResolver + # This test verifies the method can be called without error + assert result is not None diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 00000000..75cc4905 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,324 @@ +# Copyright 2019 Adobe. All rights reserved. +# This file is licensed to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy +# of the License at http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS +# OF ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +import os +import tempfile +import shutil + +import yaml +from unittest.mock import patch, MagicMock +from io import StringIO + +from himl.main import ConfigRunner + + +class TestConfigRunner: + """Test ConfigRunner class""" + + def setup_method(self): + """Set up test fixtures""" + self.temp_dir = tempfile.mkdtemp() + self.runner = ConfigRunner() + + def teardown_method(self): + """Clean up test fixtures""" + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + def create_test_yaml(self, path, content): + """Helper to create test YAML files""" + full_path = os.path.join(self.temp_dir, path) + os.makedirs(os.path.dirname(full_path), exist_ok=True) + with open(full_path, 'w') as f: + yaml.dump(content, f) + return full_path + + def test_get_parser(self): + """Test argument parser creation""" + parser = self.runner.get_parser() + + # Test that parser has expected arguments + args = parser.parse_args(['test_path']) + assert args.path == 'test_path' + + # Test with optional arguments + args = parser.parse_args([ + 'test_path', + '--output-file', 'output.yaml', + '--format', 'json', + '--filter', 'key1', + '--exclude', 'secret_key', + '--skip-interpolation-validation', + '--skip-interpolation-resolving', + '--enclosing-key', 'app', + '--cwd', '/custom/path' + ]) + + assert args.path == 'test_path' + assert args.output_file == 'output.yaml' + assert args.output_format == 'json' + assert args.filter == ['key1'] + assert args.exclude == ['secret_key'] + assert args.skip_interpolation_validation is True + assert args.skip_interpolation_resolving is True + assert args.enclosing_key == 'app' + assert args.cwd == '/custom/path' + + @patch('himl.main.ConfigProcessor') + def test_do_run_basic(self, mock_config_processor): + """Test basic do_run functionality""" + # Create mock options + mock_opts = MagicMock() + mock_opts.cwd = None + mock_opts.path = 'test_path' + mock_opts.filter = None + mock_opts.exclude = None + mock_opts.output_file = None + mock_opts.print_data = True + mock_opts.output_format = 'yaml' + mock_opts.enclosing_key = None + mock_opts.remove_enclosing_key = None + mock_opts.skip_interpolation_resolving = False + mock_opts.skip_interpolation_validation = False + mock_opts.skip_secrets = False + mock_opts.multi_line_string = False + mock_opts.merge_list_strategy = MagicMock() + mock_opts.merge_list_strategy.value = 'append_unique' + + # Mock ConfigProcessor + mock_processor_instance = MagicMock() + mock_config_processor.return_value = mock_processor_instance + + with patch('os.getcwd', return_value='/current/dir'): + self.runner.do_run(mock_opts) + + # Verify ConfigProcessor was called correctly + mock_config_processor.assert_called_once() + mock_processor_instance.process.assert_called_once_with( + '/current/dir', + 'test_path', + (), + (), + None, + None, + 'yaml', + True, + None, + False, + False, + False, + False, + type_strategies=[(list, ['append_unique']), (dict, ["merge"])] + ) + + @patch('himl.main.ConfigProcessor') + def test_do_run_with_filters(self, mock_config_processor): + """Test do_run with filters and exclusions""" + mock_opts = MagicMock() + mock_opts.cwd = '/custom/cwd' + mock_opts.path = 'test_path' + mock_opts.filter = ['key1', 'key2'] + mock_opts.exclude = ['secret'] + mock_opts.output_file = 'output.yaml' + mock_opts.print_data = False + mock_opts.output_format = 'json' + mock_opts.enclosing_key = 'app' + mock_opts.remove_enclosing_key = None + mock_opts.skip_interpolation_resolving = True + mock_opts.skip_interpolation_validation = True + mock_opts.skip_secrets = True + mock_opts.multi_line_string = True + mock_opts.merge_list_strategy = MagicMock() + mock_opts.merge_list_strategy.value = 'override' + + mock_processor_instance = MagicMock() + mock_config_processor.return_value = mock_processor_instance + + self.runner.do_run(mock_opts) + + mock_processor_instance.process.assert_called_once_with( + '/custom/cwd', + 'test_path', + ['key1', 'key2'], + ['secret'], + 'app', + None, + 'json', + False, + 'output.yaml', + True, + True, + True, + True, + type_strategies=[(list, ['override']), (dict, ["merge"])] + ) + + @patch('himl.main.ConfigProcessor') + def test_run_with_args(self, mock_config_processor): + """Test run method with command line arguments""" + mock_processor_instance = MagicMock() + mock_config_processor.return_value = mock_processor_instance + + args = ['test_path', '--format', 'json'] + + with patch('os.getcwd', return_value='/current/dir'): + self.runner.run(args) + + mock_config_processor.assert_called_once() + mock_processor_instance.process.assert_called_once() + + @patch('sys.stdout', new_callable=StringIO) + @patch('himl.main.ConfigProcessor') + def test_run_integration_simple(self, mock_config_processor, mock_stdout): + """Test integration with simple config""" + # Setup mock processor to return test data + test_data = {'env': 'test', 'debug': True} + mock_processor_instance = MagicMock() + mock_processor_instance.process.return_value = test_data + mock_config_processor.return_value = mock_processor_instance + + # Create test config + self.create_test_yaml('config.yaml', test_data) + + args = [os.path.join(self.temp_dir, 'config.yaml')] + + with patch('os.getcwd', return_value=self.temp_dir): + self.runner.run(args) + + # Verify the processor was called + mock_processor_instance.process.assert_called_once() + + def test_parser_list_merge_strategies(self): + """Test list merge strategy options""" + parser = self.runner.get_parser() + + # Test append strategy + args = parser.parse_args(['test_path', '--list-merge-strategy', 'append']) + assert args.merge_list_strategy.value == 'append' + + # Test override strategy + args = parser.parse_args(['test_path', '--list-merge-strategy', 'override']) + assert args.merge_list_strategy.value == 'override' + + # Test prepend strategy + args = parser.parse_args(['test_path', '--list-merge-strategy', 'prepend']) + assert args.merge_list_strategy.value == 'prepend' + + # Test append_unique strategy (default) + args = parser.parse_args(['test_path']) + assert args.merge_list_strategy.value == 'append_unique' + + def test_parser_boolean_flags(self): + """Test boolean flag parsing""" + parser = self.runner.get_parser() + + # Test default values + args = parser.parse_args(['test_path']) + assert args.skip_interpolation_validation is False + assert args.skip_interpolation_resolving is False + + # Test when flags are set + args = parser.parse_args([ + 'test_path', + '--skip-interpolation-validation', + '--skip-interpolation-resolving' + ]) + assert args.skip_interpolation_validation is True + assert args.skip_interpolation_resolving is True + + def test_parser_multiple_filters(self): + """Test multiple filter and exclude arguments""" + parser = self.runner.get_parser() + + args = parser.parse_args([ + 'test_path', + '--filter', 'key1', + '--filter', 'key2', + '--exclude', 'secret1', + '--exclude', 'secret2' + ]) + + assert args.filter == ['key1', 'key2'] + assert args.exclude == ['secret1', 'secret2'] + + @patch('himl.main.ConfigProcessor') + def test_output_file_sets_print_data_false(self, mock_config_processor): + """Test that specifying output file without --print-data sets print_data to False""" + mock_opts = MagicMock() + mock_opts.cwd = None + mock_opts.path = 'test_path' + mock_opts.filter = None + mock_opts.exclude = None + mock_opts.output_file = 'output.yaml' + mock_opts.print_data = False # Not explicitly set by user, defaults to False + mock_opts.output_format = 'yaml' + mock_opts.enclosing_key = None + mock_opts.remove_enclosing_key = None + mock_opts.skip_interpolation_resolving = False + mock_opts.skip_interpolation_validation = False + mock_opts.skip_secrets = False + mock_opts.multi_line_string = False + mock_opts.merge_list_strategy = MagicMock() + mock_opts.merge_list_strategy.value = 'append_unique' + + mock_processor_instance = MagicMock() + mock_config_processor.return_value = mock_processor_instance + + with patch('os.getcwd', return_value='/current/dir'): + self.runner.do_run(mock_opts) + + # When output_file is specified without --print-data, print_data should be False + call_args = mock_processor_instance.process.call_args[0] + # print_data is the 8th positional argument (index 7) + assert call_args[7] is False + # output_file is the 9th positional argument (index 8) + assert call_args[8] == 'output.yaml' + + def test_parser_help_message(self): + """Test that parser help can be generated without error""" + parser = self.runner.get_parser() + + # This should not raise an exception + help_text = parser.format_help() + assert 'path' in help_text + assert 'output-file' in help_text + assert 'format' in help_text + + @patch('himl.main.ConfigProcessor') + def test_empty_filters_and_excludes(self, mock_config_processor): + """Test handling of empty filters and excludes""" + mock_opts = MagicMock() + mock_opts.cwd = None + mock_opts.path = 'test_path' + mock_opts.filter = [] # Empty list + mock_opts.exclude = [] # Empty list + mock_opts.output_file = None + mock_opts.output_format = 'yaml' + mock_opts.enclosing_key = None + mock_opts.remove_enclosing_key = None + mock_opts.skip_interpolation_resolving = False + mock_opts.skip_interpolation_validation = False + mock_opts.skip_secrets = False + mock_opts.multi_line_string = False + mock_opts.merge_list_strategy = MagicMock() + mock_opts.merge_list_strategy.value = 'append_unique' + + mock_processor_instance = MagicMock() + mock_config_processor.return_value = mock_processor_instance + + with patch('os.getcwd', return_value='/current/dir'): + self.runner.do_run(mock_opts) + + # Empty lists should be converted to empty tuples + call_args = mock_processor_instance.process.call_args[0] + # filters is the 3rd positional argument (index 2) + assert call_args[2] == () + # exclude_keys is the 4th positional argument (index 3) + assert call_args[3] == () diff --git a/tests/test_remote_state.py b/tests/test_remote_state.py new file mode 100644 index 00000000..f592ece5 --- /dev/null +++ b/tests/test_remote_state.py @@ -0,0 +1,67 @@ +# Copyright 2019 Adobe. All rights reserved. +# This file is licensed to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy +# of the License at http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS +# OF ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +import json +from unittest.mock import patch, MagicMock + +from himl.remote_state import S3TerraformRemoteStateRetriever + + +class TestS3TerraformRemoteStateRetriever: + """Test S3TerraformRemoteStateRetriever class""" + + def setup_method(self): + """Set up test fixtures""" + self.retriever = S3TerraformRemoteStateRetriever() + + @patch('boto3.session.Session') + def test_get_s3_client_success(self, mock_session): + """Test successful S3 client creation and object retrieval""" + # Mock the S3 client and response + mock_client = MagicMock() + mock_session_instance = MagicMock() + mock_session_instance.client.return_value = mock_client + mock_session.return_value = mock_session_instance + + # Mock S3 response + terraform_state = { + 'version': 4, + 'terraform_version': '1.0.0', + 'outputs': { + 'vpc_id': {'value': 'vpc-12345'}, + 'subnet_ids': {'value': ['subnet-1', 'subnet-2']} + } + } + mock_client.get_object.return_value = { + 'Body': MagicMock(read=MagicMock(return_value=json.dumps(terraform_state).encode())) + } + + result = S3TerraformRemoteStateRetriever.get_s3_client( + 'my-terraform-bucket', + 'path/to/terraform.tfstate', + 'my-aws-profile' + ) + + assert result == terraform_state + mock_session.assert_called_once_with(profile_name='my-aws-profile') + mock_session_instance.client.assert_called_once_with('s3') + mock_client.get_object.assert_called_once_with( + Bucket='my-terraform-bucket', + Key='path/to/terraform.tfstate' + ) + + @patch.object(S3TerraformRemoteStateRetriever, 'get_s3_client') + def test_get_dynamic_data_empty_states(self, mock_get_s3_client): + """Test dynamic data retrieval with empty remote states list""" + result = self.retriever.get_dynamic_data([]) + + expected = {'outputs': {}} + assert result == expected + mock_get_s3_client.assert_not_called() diff --git a/tests/test_secret_resolvers.py b/tests/test_secret_resolvers.py new file mode 100644 index 00000000..fd810efd --- /dev/null +++ b/tests/test_secret_resolvers.py @@ -0,0 +1,336 @@ +# Copyright 2019 Adobe. All rights reserved. +# This file is licensed to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy +# of the License at http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS +# OF ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +import pytest +import sys +from unittest.mock import patch, MagicMock + +from himl.secret_resolvers import ( + SecretResolver, SSMSecretResolver, S3SecretResolver, + VaultSecretResolver, SopsSecretResolver, AggregatedSecretResolver +) + + +class TestSecretResolver: + """Test base SecretResolver class""" + + def test_get_param_or_exception_success(self): + """Test successful parameter retrieval""" + resolver = SecretResolver() + params = {'key': 'value', 'number': 42} + + assert resolver.get_param_or_exception('key', params) == 'value' + assert resolver.get_param_or_exception('number', params) == 42 + + def test_get_param_or_exception_missing(self): + """Test exception when parameter is missing""" + resolver = SecretResolver() + params = {'key': 'value'} + + with pytest.raises(Exception) as exc_info: + resolver.get_param_or_exception('missing_key', params) + + assert "Could not find required key" in str(exc_info.value) + assert "missing_key" in str(exc_info.value) + + def test_supports_not_implemented(self): + """Test that supports method raises NotImplementedError""" + resolver = SecretResolver() + + with pytest.raises(NotImplementedError): + resolver.supports('any_type') + + def test_resolve_not_implemented(self): + """Test that resolve method raises NotImplementedError""" + resolver = SecretResolver() + + with pytest.raises(NotImplementedError): + resolver.resolve('any_type', {}) + + +class TestSSMSecretResolver: + """Test SSMSecretResolver class""" + + def test_supports_with_boto3(self): + """Test supports method when boto3 is available""" + with patch.dict(sys.modules, {'boto3': MagicMock()}): + resolver = SSMSecretResolver() + assert resolver.supports('ssm') is True + assert resolver.supports('s3') is False + + def test_supports_without_boto3(self): + """Test supports method when boto3 is not available""" + with patch.dict(sys.modules, {}, clear=True): + resolver = SSMSecretResolver() + assert resolver.supports('ssm') is False + + @patch('himl.simplessm.SimpleSSM') + def test_resolve_success(self, mock_simple_ssm): + """Test successful SSM secret resolution""" + # Setup mocks + mock_ssm_instance = MagicMock() + mock_ssm_instance.get.return_value = 'secret_value' + mock_simple_ssm.return_value = mock_ssm_instance + + resolver = SSMSecretResolver(default_aws_profile='default') + secret_params = { + 'path': '/my/secret/path', + 'aws_profile': 'test-profile', + 'region_name': 'us-west-2' + } + + result = resolver.resolve('ssm', secret_params) + + assert result == 'secret_value' + mock_simple_ssm.assert_called_once_with('test-profile', 'us-west-2') + mock_ssm_instance.get.assert_called_once_with('/my/secret/path') + + @patch('himl.simplessm.SimpleSSM') + def test_resolve_with_default_profile(self, mock_simple_ssm): + """Test SSM resolution with default profile""" + mock_ssm_instance = MagicMock() + mock_ssm_instance.get.return_value = 'secret_value' + mock_simple_ssm.return_value = mock_ssm_instance + + resolver = SSMSecretResolver(default_aws_profile='default-profile') + secret_params = {'path': '/my/secret/path'} + + result = resolver.resolve('ssm', secret_params) + + assert result == 'secret_value' + mock_simple_ssm.assert_called_once_with('default-profile', 'us-east-1') + + def test_resolve_missing_profile(self): + """Test SSM resolution without AWS profile""" + resolver = SSMSecretResolver() + secret_params = {'path': '/my/secret/path'} + + with pytest.raises(Exception) as exc_info: + resolver.resolve('ssm', secret_params) + + assert "Could not find the aws_profile" in str(exc_info.value) + + def test_resolve_missing_path(self): + """Test SSM resolution without path""" + resolver = SSMSecretResolver(default_aws_profile='default') + secret_params = {'aws_profile': 'test-profile'} + + with pytest.raises(Exception) as exc_info: + resolver.resolve('ssm', secret_params) + + assert "Could not find required key" in str(exc_info.value) + + +class TestS3SecretResolver: + """Test S3SecretResolver class""" + + def test_supports_with_boto3(self): + """Test supports method when boto3 is available""" + with patch.dict(sys.modules, {'boto3': MagicMock()}): + resolver = S3SecretResolver() + assert resolver.supports('s3') is True + assert resolver.supports('ssm') is False + + @patch('himl.simples3.SimpleS3') + def test_resolve_success(self, mock_simple_s3): + """Test successful S3 secret resolution""" + mock_s3_instance = MagicMock() + mock_s3_instance.get.return_value = 'file_content' + mock_simple_s3.return_value = mock_s3_instance + + resolver = S3SecretResolver(default_aws_profile='default') + secret_params = { + 'bucket': 'my-bucket', + 'path': 'path/to/file.txt', + 'aws_profile': 'test-profile', + 'region_name': 'us-west-2', + 'base64encode': 'false' + } + + result = resolver.resolve('s3', secret_params) + + assert result == 'file_content' + mock_simple_s3.assert_called_once_with('test-profile', 'us-west-2') + mock_s3_instance.get.assert_called_once_with('my-bucket', 'path/to/file.txt', False) + + @patch('himl.simples3.SimpleS3') + def test_resolve_with_base64_encoding(self, mock_simple_s3): + """Test S3 resolution with base64 encoding""" + mock_s3_instance = MagicMock() + mock_s3_instance.get.return_value = 'encoded_content' + mock_simple_s3.return_value = mock_s3_instance + + resolver = S3SecretResolver(default_aws_profile='default') + secret_params = { + 'bucket': 'my-bucket', + 'path': 'path/to/file.txt', + 'aws_profile': 'test-profile', + 'base64encode': 'true' + } + + result = resolver.resolve('s3', secret_params) + + assert result == 'encoded_content' + mock_s3_instance.get.assert_called_once_with('my-bucket', 'path/to/file.txt', True) + + def test_resolve_missing_bucket(self): + """Test S3 resolution without bucket""" + resolver = S3SecretResolver(default_aws_profile='default') + secret_params = {'path': 'path/to/file.txt', 'aws_profile': 'test-profile'} + + with pytest.raises(Exception) as exc_info: + resolver.resolve('s3', secret_params) + + assert "Could not find required key" in str(exc_info.value) + + +class TestVaultSecretResolver: + """Test VaultSecretResolver class""" + + def test_supports_with_hvac(self): + """Test supports method when hvac is available""" + with patch.dict(sys.modules, {'hvac': MagicMock()}): + resolver = VaultSecretResolver() + assert resolver.supports('vault') is True + assert resolver.supports('ssm') is False + + @patch('himl.simplevault.SimpleVault') + def test_resolve_token_policy(self, mock_simple_vault): + """Test Vault token policy resolution""" + mock_vault_instance = MagicMock() + mock_vault_instance.get_token.return_value = 'vault_token' + mock_simple_vault.return_value = mock_vault_instance + + resolver = VaultSecretResolver() + secret_params = {'token_policy': 'my_policy'} + + result = resolver.resolve('vault', secret_params) + + assert result == 'vault_token' + mock_vault_instance.get_token.assert_called_once_with('my_policy') + + @patch('himl.simplevault.SimpleVault') + def test_resolve_path(self, mock_simple_vault): + """Test Vault path resolution""" + mock_vault_instance = MagicMock() + mock_vault_instance.get_path.return_value = {'key': 'value'} + mock_simple_vault.return_value = mock_vault_instance + + resolver = VaultSecretResolver() + secret_params = {'path': '/secret/path'} + + result = resolver.resolve('vault', secret_params) + + assert result == {'key': 'value'} + mock_vault_instance.get_path.assert_called_once_with('/secret/path') + + @patch('himl.simplevault.SimpleVault') + def test_resolve_key(self, mock_simple_vault): + """Test Vault key resolution""" + mock_vault_instance = MagicMock() + mock_vault_instance.get_key.return_value = 'secret_value' + mock_simple_vault.return_value = mock_vault_instance + + resolver = VaultSecretResolver() + secret_params = {'key': '/secret/path/key_name'} + + result = resolver.resolve('vault', secret_params) + + assert result == 'secret_value' + mock_vault_instance.get_key.assert_called_once_with('/secret/path', 'key_name') + + +class TestSopsSecretResolver: + """Test SopsSecretResolver class""" + + def test_supports(self): + """Test supports method""" + resolver = SopsSecretResolver() + assert resolver.supports('sops') is True + assert resolver.supports('vault') is False + + @patch('himl.simplesops.SimpleSops') + def test_resolve_success(self, mock_simple_sops): + """Test successful SOPS secret resolution""" + mock_sops_instance = MagicMock() + mock_sops_instance.get.return_value = 'decrypted_value' + mock_simple_sops.return_value = mock_sops_instance + + resolver = SopsSecretResolver() + secret_params = { + 'secret_file': '/path/to/secrets.yaml', + 'secret_key': 'my_key' + } + + result = resolver.resolve('sops', secret_params) + + assert result == 'decrypted_value' + mock_sops_instance.get.assert_called_once_with( + secret_file='/path/to/secrets.yaml', + secret_key='my_key' + ) + + def test_resolve_missing_file(self): + """Test SOPS resolution without secret file""" + resolver = SopsSecretResolver() + secret_params = {'secret_key': 'my_key'} + + with pytest.raises(Exception) as exc_info: + resolver.resolve('sops', secret_params) + + assert "Could not find required key" in str(exc_info.value) + + +class TestAggregatedSecretResolver: + """Test AggregatedSecretResolver class""" + + def test_initialization(self): + """Test AggregatedSecretResolver initialization""" + resolver = AggregatedSecretResolver(default_aws_profile='test-profile') + assert len(resolver.secret_resolvers) == 4 + assert any(isinstance(r, SSMSecretResolver) for r in resolver.secret_resolvers) + assert any(isinstance(r, S3SecretResolver) for r in resolver.secret_resolvers) + assert any(isinstance(r, VaultSecretResolver) for r in resolver.secret_resolvers) + assert any(isinstance(r, SopsSecretResolver) for r in resolver.secret_resolvers) + + def test_supports_delegated(self): + """Test that supports method delegates to individual resolvers""" + with patch.dict(sys.modules, {'boto3': MagicMock(), 'hvac': MagicMock()}): + resolver = AggregatedSecretResolver() + assert resolver.supports('ssm') is True + assert resolver.supports('s3') is True + assert resolver.supports('vault') is True + assert resolver.supports('sops') is True + assert resolver.supports('unknown') is False + + @patch('himl.simplessm.SimpleSSM') + def test_resolve_delegated(self, mock_simple_ssm): + """Test that resolve method delegates to appropriate resolver""" + mock_ssm_instance = MagicMock() + mock_ssm_instance.get.return_value = 'secret_value' + mock_simple_ssm.return_value = mock_ssm_instance + + with patch.dict(sys.modules, {'boto3': MagicMock()}): + resolver = AggregatedSecretResolver(default_aws_profile='default') + secret_params = {'path': '/my/secret', 'aws_profile': 'test'} + + result = resolver.resolve('ssm', secret_params) + + assert result == 'secret_value' + + def test_resolve_unsupported_type(self): + """Test resolve with unsupported secret type""" + resolver = AggregatedSecretResolver() + + with pytest.raises(Exception) as exc_info: + resolver.resolve('unsupported_type', {}) + + assert "Could not resolve secret type" in str(exc_info.value) + assert "unsupported_type" in str(exc_info.value) diff --git a/tests/test_simple_modules.py b/tests/test_simple_modules.py new file mode 100644 index 00000000..cc023e48 --- /dev/null +++ b/tests/test_simple_modules.py @@ -0,0 +1,370 @@ +# Copyright 2019 Adobe. All rights reserved. +# This file is licensed to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy +# of the License at http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS +# OF ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +import os +import pytest + +from unittest.mock import patch, MagicMock +from botocore.exceptions import ClientError + +from himl.simples3 import SimpleS3 +from himl.simplessm import SimpleSSM +from himl.simplesops import SimpleSops, Sops, SopsError +from himl.simplevault import SimpleVault + + +class TestSimpleS3: + """Test SimpleS3 class""" + + def setup_method(self): + """Set up test fixtures""" + self.s3 = SimpleS3('test-profile', 'us-east-1') + + @patch('boto3.session.Session') + def test_get_success(self, mock_session): + """Test successful S3 object retrieval""" + mock_client = MagicMock() + mock_session_instance = MagicMock() + mock_session_instance.client.return_value = mock_client + mock_session.return_value = mock_session_instance + + mock_client.get_object.return_value = { + 'Body': MagicMock(read=MagicMock(return_value=b'file content')) + } + + result = self.s3.get('my-bucket', 'path/to/file.txt') + + assert result == b'file content' + mock_session.assert_called_once_with(profile_name='test-profile') + mock_client.get_object.assert_called_once_with(Bucket='my-bucket', Key='path/to/file.txt') + + @patch('boto3.session.Session') + def test_get_with_base64_encoding(self, mock_session): + """Test S3 retrieval with base64 encoding""" + mock_client = MagicMock() + mock_session_instance = MagicMock() + mock_session_instance.client.return_value = mock_client + mock_session.return_value = mock_session_instance + + mock_client.get_object.return_value = { + 'Body': MagicMock(read=MagicMock(return_value=b'binary content')) + } + + result = self.s3.get('my-bucket', 'path/to/file.bin', base64Encode=True) + + # Should return base64 encoded string + import base64 + expected = base64.b64encode(b'binary content').decode('utf-8') + assert result == expected + + @patch('boto3.session.Session') + def test_get_client_error(self, mock_session): + """Test S3 client error handling""" + mock_client = MagicMock() + mock_session_instance = MagicMock() + mock_session_instance.client.return_value = mock_client + mock_session.return_value = mock_session_instance + + mock_client.get_object.side_effect = ClientError( + {'Error': {'Code': 'NoSuchKey'}}, + 'GetObject' + ) + + with pytest.raises(Exception) as exc_info: + self.s3.get('my-bucket', 'nonexistent/file.txt') + + assert 'Error while trying to read S3 value' in str(exc_info.value) + assert 'NoSuchKey' in str(exc_info.value) + + def test_parse_data_no_encoding(self): + """Test parse_data without encoding""" + result = self.s3.parse_data(b'test content', False) + assert result == b'test content' + + def test_parse_data_with_base64(self): + """Test parse_data with base64 encoding""" + import base64 + test_data = b'test content' + result = self.s3.parse_data(test_data, True) + expected = base64.b64encode(test_data).decode('utf-8') + assert result == expected + + +class TestSimpleSSM: + """Test SimpleSSM class""" + + def setup_method(self): + """Set up test fixtures""" + self.ssm = SimpleSSM('test-profile', 'us-east-1') + + @patch.dict(os.environ, {}, clear=True) + @patch('boto3.client') + def test_get_success(self, mock_boto_client): + """Test successful SSM parameter retrieval""" + mock_client = MagicMock() + mock_boto_client.return_value = mock_client + + mock_client.get_parameter.return_value = { + 'Parameter': {'Value': 'secret_value'} + } + + result = self.ssm.get('/my/secret/key') + + assert result == 'secret_value' + mock_boto_client.assert_called_once_with('ssm', region_name='us-east-1') + mock_client.get_parameter.assert_called_once_with(Name='/my/secret/key', WithDecryption=True) + + @patch.dict(os.environ, {'AWS_PROFILE': 'original-profile'}) + @patch('boto3.client') + def test_get_preserves_original_profile(self, mock_boto_client): + """Test that original AWS profile is preserved""" + # Create SSM instance after environment is patched + ssm = SimpleSSM('test-profile', 'us-east-1') + + mock_client = MagicMock() + mock_boto_client.return_value = mock_client + + mock_client.get_parameter.return_value = { + 'Parameter': {'Value': 'secret_value'} + } + + ssm.get('/my/secret/key') + + # Original profile should be restored + assert os.environ.get('AWS_PROFILE') == 'original-profile' + + @patch.dict(os.environ, {}, clear=True) + @patch('boto3.client') + def test_get_removes_profile_when_none_initially(self, mock_boto_client): + """Test that AWS_PROFILE is removed when none was set initially""" + mock_client = MagicMock() + mock_boto_client.return_value = mock_client + + mock_client.get_parameter.return_value = { + 'Parameter': {'Value': 'secret_value'} + } + + self.ssm.get('/my/secret/key') + + # AWS_PROFILE should not be in environment + assert 'AWS_PROFILE' not in os.environ + + @patch('boto3.client') + def test_get_client_error(self, mock_boto_client): + """Test SSM client error handling""" + mock_client = MagicMock() + mock_boto_client.return_value = mock_client + + mock_client.get_parameter.side_effect = ClientError( + {'Error': {'Code': 'ParameterNotFound'}}, + 'GetParameter' + ) + + with pytest.raises(Exception) as exc_info: + self.ssm.get('/nonexistent/parameter') + + assert 'Error while trying to read SSM value' in str(exc_info.value) + assert 'ParameterNotFound' in str(exc_info.value) + + +class TestSimpleSops: + """Test SimpleSops class""" + + def setup_method(self): + """Set up test fixtures""" + self.sops = SimpleSops() + + @patch.object(Sops, 'get_keys') + def test_get_success(self, mock_get_keys): + """Test successful SOPS secret retrieval""" + mock_get_keys.return_value = 'decrypted_value' + + result = self.sops.get('/path/to/secrets.yaml', 'my_key') + + assert result == 'decrypted_value' + mock_get_keys.assert_called_once_with(secret_file='/path/to/secrets.yaml', secret_key='my_key') + + @patch.object(Sops, 'get_keys') + def test_get_sops_error(self, mock_get_keys): + """Test SOPS error handling""" + mock_get_keys.side_effect = SopsError('/path/to/secrets.yaml', 1, 'Decryption failed', True) + + with pytest.raises(Exception) as exc_info: + self.sops.get('/path/to/secrets.yaml', 'my_key') + + assert 'Error while trying to read sops value' in str(exc_info.value) + + +class TestSops: + """Test Sops utility class""" + + @patch('himl.simplesops.Popen') + def test_decrypt_success(self, mock_popen): + """Test successful SOPS decryption""" + mock_process = MagicMock() + mock_process.communicate.return_value = (b'decrypted: content\n', b'') + mock_process.returncode = 0 + mock_popen.return_value = mock_process + + result = Sops.decrypt('/path/to/encrypted.yaml') + + assert result == {'decrypted': 'content'} + mock_popen.assert_called_once() + + def test_get_keys_simple(self): + """Test get_keys with simple key""" + sops = Sops() + test_data = {'key1': 'value1', 'key2': 'value2'} + + with patch.object(Sops, 'decrypt', return_value=test_data): + result = sops.get_keys('/path/to/file.yaml', 'key1') + assert result == 'value1' + + def test_get_keys_nested(self): + """Test get_keys with nested key""" + sops = Sops() + test_data = { + 'level1': { + 'level2': { + 'key': 'nested_value' + } + } + } + + with patch.object(Sops, 'decrypt', return_value=test_data): + result = sops.get_keys('/path/to/file.yaml', "['level1']['level2']['key']") + assert result == 'nested_value' + + def test_get_keys_missing_key(self): + """Test get_keys with missing key""" + sops = Sops() + test_data = {'existing_key': 'value'} + + with patch.object(Sops, 'decrypt', return_value=test_data): + with pytest.raises(SopsError): + sops.get_keys('/path/to/file.yaml', 'missing_key') + + +class TestSimpleVault: + """Test SimpleVault class""" + + def setup_method(self): + """Set up test fixtures""" + self.vault = SimpleVault() + + @patch('hvac.Client') + def test_get_vault_client_authenticated(self, mock_hvac_client): + """Test getting authenticated Vault client""" + mock_client = MagicMock() + mock_client.is_authenticated.return_value = True + mock_hvac_client.return_value = mock_client + + result = self.vault.get_vault_client() + + assert result == mock_client + mock_hvac_client.assert_called_once() + + @patch.dict(os.environ, {'VAULT_PASSWORD': 'test_pass', 'VAULT_USERNAME': 'test_user'}) + @patch('hvac.Client') + def test_get_vault_client_ldap_fallback(self, mock_hvac_client): + """Test Vault client with LDAP fallback authentication""" + mock_client = MagicMock() + mock_client.is_authenticated.side_effect = [False, True] # First call fails, second succeeds + mock_hvac_client.return_value = mock_client + + result = self.vault.get_vault_client() + + assert result == mock_client + mock_client.auth.ldap.login.assert_called_once_with( + username='test_user', + password='test_pass' + ) + + @patch.dict(os.environ, {'VAULT_PASSWORD': 'test_pass', 'VAULT_USERNAME': 'test_user'}) + @patch('hvac.Client') + def test_get_vault_client_ldap_failure(self, mock_hvac_client): + """Test Vault client LDAP authentication failure""" + mock_client = MagicMock() + mock_client.is_authenticated.return_value = False + mock_client.auth.ldap.login.side_effect = Exception('LDAP failed') + mock_hvac_client.return_value = mock_client + + with pytest.raises(Exception) as exc_info: + self.vault.get_vault_client() + + assert 'Error authenticating Vault over LDAP' in str(exc_info.value) + + @patch.dict(os.environ, {'VAULT_ROLE': 'test_role'}) + @patch.object(SimpleVault, 'get_vault_client') + def test_get_token(self, mock_get_client): + """Test token generation""" + mock_client = MagicMock() + mock_client.create_token.return_value = { + 'auth': {'client_token': 'generated_token'} + } + mock_get_client.return_value = mock_client + + result = self.vault.get_token('my_policy') + + assert result == 'generated_token' + mock_client.create_token.assert_called_once_with( + policies=['my_policy'], + role='test_role', + lease='24h' + ) + + @patch.dict(os.environ, {'VAULT_MOUNT_POINT': 'secret'}) + @patch.object(SimpleVault, 'get_vault_client') + def test_get_path(self, mock_get_client): + """Test path retrieval""" + mock_client = MagicMock() + mock_client.secrets.kv.v2.read_secret_version.return_value = { + 'data': {'data': {'key1': 'value1', 'key2': 'value2'}} + } + mock_get_client.return_value = mock_client + + result = self.vault.get_path('/my/secret/path') + + assert result == {'key1': 'value1', 'key2': 'value2'} + mock_client.secrets.kv.v2.read_secret_version.assert_called_once_with( + mount_point='secret', + path='/my/secret/path' + ) + + @patch.object(SimpleVault, 'get_path') + def test_get_key(self, mock_get_path): + """Test key retrieval""" + mock_get_path.return_value = {'key1': 'value1', 'key2': 'value2'} + + result = self.vault.get_key('/my/secret/path', 'key1') + + assert result == 'value1' + mock_get_path.assert_called_once_with('/my/secret/path') + + +class TestSopsError: + """Test SopsError exception class""" + + def test_sops_error_creation(self): + """Test SopsError creation""" + error = SopsError('/path/to/file', 1, 'Error message', True) + + assert error.filename == '/path/to/file' + assert error.exit_code == 1 + assert error.stderr == 'Error message' + assert error.decryption is True + + def test_sops_error_string_representation(self): + """Test SopsError string representation""" + error = SopsError('/path/to/file', 1, 'Error message', True) + error_str = str(error) + + assert '/path/to/file' in error_str + assert 'Error message' in error_str