diff --git a/src/apm_cli/compilation/agents_compiler.py b/src/apm_cli/compilation/agents_compiler.py index a1619648..ca670abd 100644 --- a/src/apm_cli/compilation/agents_compiler.py +++ b/src/apm_cli/compilation/agents_compiler.py @@ -187,11 +187,17 @@ def compile(self, config: CompilationConfig, primitives: Optional[PrimitiveColle if primitives is None: if config.local_only: # Use basic discovery for local-only mode - primitives = discover_primitives(str(self.base_dir)) + primitives = discover_primitives( + str(self.base_dir), + exclude_patterns=config.exclude, + ) else: # Use enhanced discovery with dependencies (Task 4 integration) from ..primitives.discovery import discover_primitives_with_dependencies - primitives = discover_primitives_with_dependencies(str(self.base_dir)) + primitives = discover_primitives_with_dependencies( + str(self.base_dir), + exclude_patterns=config.exclude, + ) # Route to targets based on config.target results: List[CompilationResult] = [] diff --git a/src/apm_cli/compilation/context_optimizer.py b/src/apm_cli/compilation/context_optimizer.py index 256ecb32..0992f9cf 100644 --- a/src/apm_cli/compilation/context_optimizer.py +++ b/src/apm_cli/compilation/context_optimizer.py @@ -22,6 +22,7 @@ PlacementStrategy, PlacementSummary ) from ..utils.paths import portable_relpath +from ..utils.exclude import should_exclude, validate_exclude_patterns # CRITICAL: Shadow Click commands to prevent namespace collision # When this module is imported during 'apm compile', Click's active context @@ -132,8 +133,8 @@ def __init__(self, base_dir: str = ".", exclude_patterns: Optional[List[str]] = self._errors: List[str] = [] self._start_time: Optional[float] = None - # Configurable exclusion patterns - self._exclude_patterns = exclude_patterns or [] + # Configurable exclusion patterns (validated at init time) + self._exclude_patterns = validate_exclude_patterns(exclude_patterns) def enable_timing(self, verbose: bool = False): """Enable performance timing instrumentation.""" @@ -503,110 +504,7 @@ def _should_exclude_path(self, path: Path) -> bool: Returns: True if path should be excluded, False otherwise """ - if not self._exclude_patterns: - return False - - # Get path relative to base_dir for pattern matching - # Resolve the path first to handle cross-platform differences - # (e.g., on Windows Path('/test') != Path('C:/test') after resolve) - try: - resolved = path.resolve() - except (OSError, FileNotFoundError): - resolved = path.absolute() - try: - rel_path = resolved.relative_to(self.base_dir.resolve()) - except ValueError: - # Path is not relative to base_dir, don't exclude - return False - - # Check each exclusion pattern - for pattern in self._exclude_patterns: - if self._matches_pattern(rel_path, pattern): - return True - - return False - - def _matches_pattern(self, rel_path: Path, pattern: str) -> bool: - """Check if a relative path matches an exclusion pattern. - - Supports glob patterns including ** for recursive matching. - - Args: - rel_path: Path relative to base_dir - pattern: Exclusion pattern (glob syntax) - - Returns: - True if path matches pattern, False otherwise - """ - # Normalize both pattern and path to use forward slashes for consistent matching - # This handles Windows paths (backslashes) and Unix paths (forward slashes) - # Users can provide patterns with either separator - normalized_pattern = pattern.replace('\\', '/').replace(os.sep, '/') - - # Convert path to string with forward slashes - rel_path_str = str(rel_path).replace(os.sep, '/') - - # Handle ** patterns (match any number of directories) - if '**' in normalized_pattern: - # Convert ** glob to regex-like matching - # Split pattern into parts - parts = normalized_pattern.split('/') - path_parts = rel_path_str.split('/') - - # Try to match using recursive logic - return self._match_glob_recursive(path_parts, parts) - - # Simple fnmatch for patterns without ** - if fnmatch.fnmatch(rel_path_str, normalized_pattern): - return True - - # Also check if the path starts with the pattern (for directory matching) - # This handles cases like "apm_modules/" matching "apm_modules/foo/bar" - if normalized_pattern.endswith('/'): - if rel_path_str.startswith(normalized_pattern) or rel_path_str == normalized_pattern.rstrip('/'): - return True - else: - # Check if pattern with trailing slash would match - if rel_path_str.startswith(normalized_pattern + '/') or rel_path_str == normalized_pattern: - return True - - return False - - def _match_glob_recursive(self, path_parts: list, pattern_parts: list) -> bool: - """Recursively match path parts against pattern parts with ** support. - - Args: - path_parts: List of path components - pattern_parts: List of pattern components - - Returns: - True if path matches pattern, False otherwise - """ - if not pattern_parts: - return not path_parts - - if not path_parts: - # Check if remaining pattern parts are all ** or empty - # Empty parts can occur from patterns like "foo/" which split to ['foo', ''] - # or from consecutive slashes like "foo//bar" - return all(p == '**' or p == '' for p in pattern_parts) - - pattern_part = pattern_parts[0] - - if pattern_part == '**': - # ** matches zero or more directories - # Try matching with zero directories - if self._match_glob_recursive(path_parts, pattern_parts[1:]): - return True - # Try matching with one or more directories - if self._match_glob_recursive(path_parts[1:], pattern_parts): - return True - return False - else: - # Regular pattern part - must match current path part - if fnmatch.fnmatch(path_parts[0], pattern_part): - return self._match_glob_recursive(path_parts[1:], pattern_parts[1:]) - return False + return should_exclude(path, self.base_dir, self._exclude_patterns) def _find_optimal_placements( self, diff --git a/src/apm_cli/primitives/discovery.py b/src/apm_cli/primitives/discovery.py index bc9d97be..7d10baa8 100644 --- a/src/apm_cli/primitives/discovery.py +++ b/src/apm_cli/primitives/discovery.py @@ -1,12 +1,16 @@ """Discovery functionality for primitive files.""" +import logging import os import glob from pathlib import Path -from typing import List, Dict +from typing import List, Dict, Optional from .models import PrimitiveCollection from .parser import parse_primitive_file, parse_skill_file +from ..utils.exclude import should_exclude, validate_exclude_patterns + +logger = logging.getLogger(__name__) from ..models.apm_package import APMPackage from ..deps.lockfile import LockFile @@ -52,7 +56,10 @@ } -def discover_primitives(base_dir: str = ".") -> PrimitiveCollection: +def discover_primitives( + base_dir: str = ".", + exclude_patterns: Optional[List[str]] = None, +) -> PrimitiveCollection: """Find all APM primitive files in the project. Searches for .chatmode.md, .instructions.md, .context.md, .memory.md files @@ -60,17 +67,23 @@ def discover_primitives(base_dir: str = ".") -> PrimitiveCollection: Args: base_dir (str): Base directory to search in. Defaults to current directory. + exclude_patterns (Optional[List[str]]): Glob patterns for paths to exclude. Returns: PrimitiveCollection: Collection of discovered and parsed primitives. """ collection = PrimitiveCollection() + base_path = Path(base_dir) + safe_patterns = validate_exclude_patterns(exclude_patterns) # Find and parse files for each primitive type for primitive_type, patterns in LOCAL_PRIMITIVE_PATTERNS.items(): files = find_primitive_files(base_dir, patterns) for file_path in files: + if should_exclude(file_path, base_path, safe_patterns): + logger.debug("Excluded by pattern: %s", file_path) + continue try: primitive = parse_primitive_file(file_path, source="local") collection.add_primitive(primitive) @@ -78,12 +91,15 @@ def discover_primitives(base_dir: str = ".") -> PrimitiveCollection: print(f"Warning: Failed to parse {file_path}: {e}") # Discover SKILL.md at project root - _discover_local_skill(base_dir, collection) + _discover_local_skill(base_dir, collection, exclude_patterns=safe_patterns) return collection -def discover_primitives_with_dependencies(base_dir: str = ".") -> PrimitiveCollection: +def discover_primitives_with_dependencies( + base_dir: str = ".", + exclude_patterns: Optional[List[str]] = None, +) -> PrimitiveCollection: """Enhanced primitive discovery including dependency sources. Priority Order: @@ -93,17 +109,19 @@ def discover_primitives_with_dependencies(base_dir: str = ".") -> PrimitiveColle Args: base_dir (str): Base directory to search in. Defaults to current directory. + exclude_patterns (Optional[List[str]]): Glob patterns for paths to exclude. Returns: PrimitiveCollection: Collection of discovered and parsed primitives with source tracking. """ collection = PrimitiveCollection() + safe_patterns = validate_exclude_patterns(exclude_patterns) # Phase 1: Local primitives (highest priority) - scan_local_primitives(base_dir, collection) + scan_local_primitives(base_dir, collection, exclude_patterns=safe_patterns) # Phase 1b: Local SKILL.md - _discover_local_skill(base_dir, collection) + _discover_local_skill(base_dir, collection, exclude_patterns=safe_patterns) # Phase 2: Dependency primitives (lower priority, with conflict detection) # Plugins are normalized into standard APM packages during install @@ -113,12 +131,17 @@ def discover_primitives_with_dependencies(base_dir: str = ".") -> PrimitiveColle return collection -def scan_local_primitives(base_dir: str, collection: PrimitiveCollection) -> None: +def scan_local_primitives( + base_dir: str, + collection: PrimitiveCollection, + exclude_patterns: Optional[List[str]] = None, +) -> None: """Scan local .apm/ directory for primitives. Args: base_dir (str): Base directory to search in. collection (PrimitiveCollection): Collection to add primitives to. + exclude_patterns (Optional[List[str]]): Pre-validated exclude patterns. """ # Find and parse files for each primitive type for primitive_type, patterns in LOCAL_PRIMITIVE_PATTERNS.items(): @@ -131,8 +154,13 @@ def scan_local_primitives(base_dir: str, collection: PrimitiveCollection) -> Non for file_path in files: # Only include files that are NOT in apm_modules directory - if not _is_under_directory(file_path, apm_modules_path): - local_files.append(file_path) + if _is_under_directory(file_path, apm_modules_path): + continue + # Apply compilation.exclude patterns + if should_exclude(file_path, base_path, exclude_patterns): + logger.debug("Excluded by pattern: %s", file_path) + continue + local_files.append(file_path) for file_path in local_files: try: @@ -159,6 +187,7 @@ def _is_under_directory(file_path: Path, directory: Path) -> bool: return False + def scan_dependency_primitives(base_dir: str, collection: PrimitiveCollection) -> None: """Scan all dependencies in apm_modules/ with priority handling. @@ -302,15 +331,23 @@ def scan_directory_with_source(directory: Path, collection: PrimitiveCollection, _discover_skill_in_directory(directory, collection, source) -def _discover_local_skill(base_dir: str, collection: PrimitiveCollection) -> None: +def _discover_local_skill( + base_dir: str, + collection: PrimitiveCollection, + exclude_patterns: Optional[List[str]] = None, +) -> None: """Discover SKILL.md at the project root. Args: base_dir (str): Base directory to search in. collection (PrimitiveCollection): Collection to add skill to. + exclude_patterns (Optional[List[str]]): Pre-validated exclude patterns. """ skill_path = Path(base_dir) / "SKILL.md" if skill_path.exists() and _is_readable(skill_path): + if should_exclude(skill_path, Path(base_dir), exclude_patterns): + logger.debug("Excluded by pattern: %s", skill_path) + return try: skill = parse_skill_file(skill_path, source="local") collection.add_primitive(skill) diff --git a/src/apm_cli/utils/exclude.py b/src/apm_cli/utils/exclude.py new file mode 100644 index 00000000..e9fcb232 --- /dev/null +++ b/src/apm_cli/utils/exclude.py @@ -0,0 +1,170 @@ +"""Shared exclude-pattern matching for compilation and primitive discovery. + +Provides glob-style pattern matching with ** (recursive directory) support. +Used by both the context optimizer and primitive discovery to filter paths +against compilation.exclude patterns from apm.yml. +""" + +import fnmatch +import logging +import os +from pathlib import Path +from typing import List, Optional + +logger = logging.getLogger(__name__) + +# Maximum number of ** segments allowed in a single pattern. +# Prevents exponential recursion blowup (2^N branches per ** segment). +_MAX_DOUBLE_STAR_SEGMENTS = 5 + + +def validate_exclude_patterns(patterns: Optional[List[str]]) -> List[str]: + """Validate and normalize exclude patterns, rejecting dangerous ones. + + Args: + patterns: Raw patterns from apm.yml compilation.exclude. + + Returns: + List of validated, forward-slash-normalized patterns. + + Raises: + ValueError: If a pattern exceeds the ** segment safety limit. + """ + if not patterns: + return [] + + validated = [] + for pattern in patterns: + normalized = pattern.replace("\\", "/") + # Collapse consecutive ** segments (semantically identical to single **) + parts = normalized.split("/") + collapsed = [] + for p in parts: + if p == "**" and collapsed and collapsed[-1] == "**": + continue + collapsed.append(p) + normalized = "/".join(collapsed) + star_count = collapsed.count("**") + if star_count > _MAX_DOUBLE_STAR_SEGMENTS: + raise ValueError( + f"Exclude pattern '{pattern}' has {star_count} '**' segments " + f"(max {_MAX_DOUBLE_STAR_SEGMENTS}). Simplify the pattern." + ) + validated.append(normalized) + return validated + + +def should_exclude( + file_path: Path, + base_dir: Path, + exclude_patterns: Optional[List[str]], +) -> bool: + """Check whether a file path should be excluded. + + Args: + file_path: Absolute or relative path of the discovered file. + base_dir: Project base directory for computing relative paths. + exclude_patterns: Pre-validated, forward-slash-normalized patterns. + + Returns: + True if the file matches any exclusion pattern. + """ + if not exclude_patterns: + return False + + try: + resolved = file_path.resolve() + except (OSError, FileNotFoundError): + resolved = file_path.absolute() + try: + rel_path = resolved.relative_to(base_dir.resolve()) + except ValueError: + return False + + rel_path_str = str(rel_path).replace(os.sep, "/") + + for pattern in exclude_patterns: + if _matches_pattern(rel_path_str, pattern): + return True + + return False + + +def _matches_pattern(rel_path_str: str, pattern: str) -> bool: + """Check if a relative path string matches a single exclusion pattern. + + Supports glob wildcards including ** for recursive directory matching. + """ + if "**" in pattern: + path_parts = rel_path_str.split("/") + pattern_parts = pattern.split("/") + return _match_glob_recursive(path_parts, pattern_parts) + + if fnmatch.fnmatch(rel_path_str, pattern): + return True + + # Directory prefix matching: "docs/" or "docs" should match "docs/foo.md" + if pattern.endswith("/"): + if rel_path_str.startswith(pattern) or rel_path_str == pattern.rstrip("/"): + return True + else: + if rel_path_str.startswith(pattern + "/") or rel_path_str == pattern: + return True + + return False + + +def _match_glob_recursive(path_parts: list, pattern_parts: list) -> bool: + """Match path components against pattern components with ** support. + + Uses iterative consumption for consecutive ** segments to avoid + exponential branching, then falls back to bounded recursion for + mixed patterns. + """ + # Strip leading empty parts from trailing slashes in patterns + while pattern_parts and pattern_parts[-1] == "": + pattern_parts = pattern_parts[:-1] + + pi = 0 # pattern index + xi = 0 # path index + + # Fast iterative path for leading non-** segments + while pi < len(pattern_parts) and xi < len(path_parts): + part = pattern_parts[pi] + if part == "**": + break + if fnmatch.fnmatch(path_parts[xi], part): + pi += 1 + xi += 1 + else: + return False + + # If no ** was encountered, both must be exhausted + if pi == len(pattern_parts): + return xi == len(path_parts) + + # Delegate remaining ** matching via bounded recursion + return _match_double_star(path_parts[xi:], pattern_parts[pi:]) + + +def _match_double_star(path_parts: list, pattern_parts: list) -> bool: + """Handle ** segments with bounded recursion.""" + if not pattern_parts: + return not path_parts + + if not path_parts: + return all(p == "**" or p == "" for p in pattern_parts) + + part = pattern_parts[0] + + if part == "**": + # ** matches zero or more directories + if _match_double_star(path_parts, pattern_parts[1:]): + return True + if _match_double_star(path_parts[1:], pattern_parts): + return True + return False + else: + if fnmatch.fnmatch(path_parts[0], part): + return _match_double_star(path_parts[1:], pattern_parts[1:]) + return False diff --git a/tests/unit/compilation/test_agents_compiler_coverage.py b/tests/unit/compilation/test_agents_compiler_coverage.py index 9274fe6e..2790a8aa 100644 --- a/tests/unit/compilation/test_agents_compiler_coverage.py +++ b/tests/unit/compilation/test_agents_compiler_coverage.py @@ -205,7 +205,9 @@ def test_compile_local_only_calls_basic_discover(self): ) as mock_disc: result = compiler.compile(config) # no primitives passed → discovers - mock_disc.assert_called_once_with(str(compiler.base_dir)) + mock_disc.assert_called_once_with( + str(compiler.base_dir), exclude_patterns=config.exclude + ) # --------------------------------------------------------------------------- diff --git a/tests/unit/primitives/test_discovery_parser.py b/tests/unit/primitives/test_discovery_parser.py index 4d3fc8e8..f482f8ed 100644 --- a/tests/unit/primitives/test_discovery_parser.py +++ b/tests/unit/primitives/test_discovery_parser.py @@ -613,6 +613,163 @@ def test_parse_error_warns_and_continues(self): self.assertEqual(collection.count(), 0) +class TestExcludePatternsInDiscovery(unittest.TestCase): + """Tests for compilation.exclude filtering during primitive discovery.""" + + def setUp(self): + self.tmp = tempfile.mkdtemp() + + def tearDown(self): + import shutil + + shutil.rmtree(self.tmp, ignore_errors=True) + + def test_scan_local_primitives_excludes_matching_directory(self): + """Primitives under excluded directories are filtered out.""" + base = Path(self.tmp) + # Local instruction (should be kept) + _write( + base / ".apm" / "instructions" / "general.instructions.md", + INSTRUCTION_CONTENT, + ) + # Instruction inside docs/ (should be excluded) + _write( + base + / "docs" + / "labs" + / ".github" + / "instructions" + / "react.instructions.md", + INSTRUCTION_CONTENT, + ) + collection = PrimitiveCollection() + scan_local_primitives( + self.tmp, collection, exclude_patterns=["docs/**"] + ) + self.assertEqual(len(collection.instructions), 1) + + def test_scan_local_primitives_no_exclude_discovers_all(self): + """Without exclude patterns, all primitives are discovered.""" + base = Path(self.tmp) + _write( + base / ".apm" / "instructions" / "general.instructions.md", + INSTRUCTION_CONTENT, + ) + _write( + base + / "docs" + / "labs" + / ".github" + / "instructions" + / "react.instructions.md", + INSTRUCTION_CONTENT, + ) + collection = PrimitiveCollection() + scan_local_primitives(self.tmp, collection, exclude_patterns=None) + self.assertEqual(len(collection.instructions), 2) + + def test_scan_local_primitives_multiple_exclude_patterns(self): + """Multiple exclude patterns each filter their respective files.""" + base = Path(self.tmp) + _write( + base / ".apm" / "instructions" / "kept.instructions.md", + INSTRUCTION_CONTENT, + ) + _write( + base + / "docs" + / ".github" + / "instructions" + / "a.instructions.md", + INSTRUCTION_CONTENT, + ) + _write( + base + / "tmp" + / ".github" + / "instructions" + / "b.instructions.md", + INSTRUCTION_CONTENT, + ) + collection = PrimitiveCollection() + scan_local_primitives( + self.tmp, collection, exclude_patterns=["docs/**", "tmp/**"] + ) + self.assertEqual(len(collection.instructions), 1) + + def test_discover_primitives_respects_exclude(self): + """discover_primitives() filters with exclude_patterns.""" + base = Path(self.tmp) + _write( + base / ".apm" / "instructions" / "general.instructions.md", + INSTRUCTION_CONTENT, + ) + _write( + base + / "docs" + / ".github" + / "instructions" + / "leak.instructions.md", + INSTRUCTION_CONTENT, + ) + from apm_cli.primitives.discovery import discover_primitives + + collection = discover_primitives( + self.tmp, exclude_patterns=["docs/**"] + ) + self.assertEqual(len(collection.instructions), 1) + + def test_discover_primitives_with_dependencies_respects_exclude(self): + """discover_primitives_with_dependencies() filters with exclude_patterns.""" + base = Path(self.tmp) + _write( + base / ".apm" / "instructions" / "general.instructions.md", + INSTRUCTION_CONTENT, + ) + _write( + base + / "docs" + / ".github" + / "instructions" + / "leak.instructions.md", + INSTRUCTION_CONTENT, + ) + # Create minimal apm.yml for the function to work + (base / "apm.yml").write_text( + "name: test\nversion: 1.0.0\n", encoding="utf-8" + ) + from apm_cli.primitives.discovery import ( + discover_primitives_with_dependencies, + ) + + collection = discover_primitives_with_dependencies( + self.tmp, exclude_patterns=["docs/**"] + ) + self.assertEqual(len(collection.instructions), 1) + + def test_discover_primitives_excludes_skill_md(self): + """SKILL.md at project root is excluded when matching pattern.""" + base = Path(self.tmp) + skill_content = "# My Skill\n\nSome skill content." + (base / "SKILL.md").write_text(skill_content, encoding="utf-8") + from apm_cli.primitives.discovery import discover_primitives + + # Without exclusion -- SKILL.md found + collection = discover_primitives(self.tmp, exclude_patterns=None) + self.assertEqual(len(collection.skills), 1) + + # With exclusion matching SKILL.md + collection = discover_primitives(self.tmp, exclude_patterns=["SKILL.md"]) + self.assertEqual(len(collection.skills), 0) + + def test_validate_rejects_dos_pattern(self): + """Patterns with excessive non-consecutive ** segments are rejected.""" + from apm_cli.utils.exclude import validate_exclude_patterns + # 7 non-consecutive ** segments (consecutive ones collapse) + with self.assertRaises(ValueError): + validate_exclude_patterns(["a/**/b/**/c/**/d/**/e/**/f/**/g/**"]) + + class TestIsReadable(unittest.TestCase): """Tests for _is_readable.""" diff --git a/tests/unit/test_exclude.py b/tests/unit/test_exclude.py new file mode 100644 index 00000000..25aa9e47 --- /dev/null +++ b/tests/unit/test_exclude.py @@ -0,0 +1,200 @@ +"""Tests for the shared exclude-pattern matching utility.""" + +import tempfile +import shutil +import unittest +from pathlib import Path + +from apm_cli.utils.exclude import ( + _match_double_star, + _match_glob_recursive, + _matches_pattern, + should_exclude, + validate_exclude_patterns, +) + + +class TestValidateExcludePatterns(unittest.TestCase): + """Tests for pattern validation and DoS guard.""" + + def test_none_returns_empty(self): + self.assertEqual(validate_exclude_patterns(None), []) + + def test_empty_list_returns_empty(self): + self.assertEqual(validate_exclude_patterns([]), []) + + def test_valid_patterns_returned(self): + result = validate_exclude_patterns(["docs/**", "tmp", "*.log"]) + self.assertEqual(result, ["docs/**", "tmp", "*.log"]) + + def test_backslashes_normalized(self): + result = validate_exclude_patterns(["docs\\labs\\**"]) + self.assertEqual(result, ["docs/labs/**"]) + + def test_rejects_excessive_double_star(self): + # 6 non-consecutive ** segments (separated by literals, can't collapse) + pattern = "a/**/b/**/c/**/d/**/e/**/f/**" + with self.assertRaises(ValueError) as ctx: + validate_exclude_patterns([pattern]) + self.assertIn("6", str(ctx.exception)) + self.assertIn("max", str(ctx.exception).lower()) + + def test_allows_max_double_star(self): + pattern = "/".join(["**"] * 5) # exactly 5 -- at the limit + result = validate_exclude_patterns([pattern]) + self.assertEqual(len(result), 1) + + def test_double_star_count_ignores_non_star(self): + # Only ** segments count, not * or other parts + pattern = "a/*/b/**/c/**/d" + result = validate_exclude_patterns([pattern]) + self.assertEqual(result, [pattern]) + + def test_consecutive_double_stars_collapsed(self): + # **/**/** is semantically identical to ** + result = validate_exclude_patterns(["**/**/**/*.md"]) + self.assertEqual(result, ["**/*.md"]) + + def test_consecutive_collapse_then_count(self): + # After collapsing, 6 consecutive ** become 1 -- well under limit + pattern = "/".join(["**"] * 6) + "/*.txt" + result = validate_exclude_patterns([pattern]) + self.assertEqual(result, ["**/*.txt"]) + + +class TestMatchesPattern(unittest.TestCase): + """Tests for individual pattern matching logic.""" + + def test_simple_fnmatch(self): + self.assertTrue(_matches_pattern("foo.log", "*.log")) + + def test_simple_fnmatch_no_match(self): + self.assertFalse(_matches_pattern("foo.txt", "*.log")) + + def test_directory_prefix_with_slash(self): + self.assertTrue(_matches_pattern("docs/foo.md", "docs/")) + + def test_directory_prefix_without_slash(self): + self.assertTrue(_matches_pattern("docs/foo.md", "docs")) + + def test_exact_match(self): + self.assertTrue(_matches_pattern("docs", "docs")) + + def test_double_star_recursive(self): + self.assertTrue(_matches_pattern("a/b/c/d.txt", "a/**/d.txt")) + + def test_double_star_zero_dirs(self): + self.assertTrue(_matches_pattern("a/d.txt", "a/**/d.txt")) + + def test_leading_double_star(self): + self.assertTrue(_matches_pattern("a/b/c.md", "**/*.md")) + + def test_trailing_double_star(self): + self.assertTrue(_matches_pattern("docs/a/b/c", "docs/**")) + + +class TestMatchGlobRecursive(unittest.TestCase): + """Tests for the glob recursive matcher.""" + + def test_exact_match(self): + self.assertTrue(_match_glob_recursive(["a", "b"], ["a", "b"])) + + def test_wildcard(self): + self.assertTrue(_match_glob_recursive(["a", "foo.md"], ["a", "*.md"])) + + def test_double_star_matches_multiple(self): + self.assertTrue( + _match_glob_recursive(["a", "b", "c", "d"], ["a", "**", "d"]) + ) + + def test_double_star_matches_zero(self): + self.assertTrue(_match_glob_recursive(["a", "d"], ["a", "**", "d"])) + + def test_no_match(self): + self.assertFalse(_match_glob_recursive(["a", "b"], ["c", "d"])) + + def test_trailing_empty_part_from_slash(self): + # Pattern "foo/" splits to ["foo", ""] + self.assertTrue( + _match_glob_recursive(["foo", "bar"], ["foo", "**"]) + ) + + +class TestShouldExclude(unittest.TestCase): + """Integration tests for should_exclude with real filesystem.""" + + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.base = Path(self.tmp) + + def tearDown(self): + shutil.rmtree(self.tmp, ignore_errors=True) + + def _touch(self, rel_path: str) -> Path: + p = self.base / rel_path + p.parent.mkdir(parents=True, exist_ok=True) + p.write_text("test", encoding="utf-8") + return p + + def test_no_patterns_returns_false(self): + f = self._touch("a.txt") + self.assertFalse(should_exclude(f, self.base, None)) + self.assertFalse(should_exclude(f, self.base, [])) + + def test_excludes_matching_file(self): + f = self._touch("docs/api.md") + self.assertTrue(should_exclude(f, self.base, ["docs/**"])) + + def test_keeps_non_matching_file(self): + f = self._touch("src/main.py") + self.assertFalse(should_exclude(f, self.base, ["docs/**"])) + + def test_path_outside_base_not_excluded(self): + import os + outside = Path(tempfile.mkdtemp()) + try: + f = outside / "secret.md" + f.write_text("test", encoding="utf-8") + self.assertFalse(should_exclude(f, self.base, ["**"])) + finally: + shutil.rmtree(outside, ignore_errors=True) + + def test_directory_name_match(self): + f = self._touch("tmp/build/out.bin") + self.assertTrue(should_exclude(f, self.base, ["tmp"])) + + def test_multiple_patterns(self): + f1 = self._touch("docs/a.md") + f2 = self._touch("tmp/b.txt") + f3 = self._touch("src/c.py") + patterns = ["docs/**", "tmp/**"] + self.assertTrue(should_exclude(f1, self.base, patterns)) + self.assertTrue(should_exclude(f2, self.base, patterns)) + self.assertFalse(should_exclude(f3, self.base, patterns)) + + +class TestDoubleStarDoSGuard(unittest.TestCase): + """Ensure pathological patterns are rejected before reaching recursion.""" + + def test_twelve_stars_rejected(self): + # 12 non-consecutive ** segments (consecutive ones collapse) + parts = [] + for i in range(12): + parts.extend([f"d{i}", "**"]) + pattern = "/".join(parts) + with self.assertRaises(ValueError): + validate_exclude_patterns([pattern]) + + def test_normal_patterns_fast(self): + import time + patterns = validate_exclude_patterns(["docs/**/*.md"]) + start = time.monotonic() + for _ in range(1000): + _matches_pattern("docs/a/b/c/d/e/f/g.md", patterns[0]) + elapsed = time.monotonic() - start + # 1000 iterations should complete in well under 1 second + self.assertLess(elapsed, 1.0) + + +if __name__ == "__main__": + unittest.main()