diff --git a/client/backend_arguments.py b/client/backend_arguments.py index c5a3703155e..e83dca58ec3 100644 --- a/client/backend_arguments.py +++ b/client/backend_arguments.py @@ -69,7 +69,11 @@ def serialize(self) -> Dict[str, object]: } def get_checked_directory_allowlist(self) -> Set[str]: - return {element.path() for element in self.elements} + allow_list = set() + for element in self.elements: + if expected_element_path := element.path() is not None: + allow_list.add(expected_element_path) + return allow_list def cleanup(self) -> None: pass @@ -98,7 +102,11 @@ def serialize(self) -> Dict[str, object]: } def get_checked_directory_allowlist(self) -> Set[str]: - return {element.path() for element in self.elements} + allow_list = set() + for element in self.elements: + if expected_element_path := element.path() is not None: + allow_list.add(expected_element_path) + return allow_list def cleanup(self) -> None: pass diff --git a/client/configuration/search_path.py b/client/configuration/search_path.py index da7601a7558..6af6172f8fa 100644 --- a/client/configuration/search_path.py +++ b/client/configuration/search_path.py @@ -19,13 +19,20 @@ import glob import logging import os -from typing import Dict, Iterable, List, Sequence, Union +import re +from typing import Dict, Iterable, List, Sequence, Tuple, Union from .. import filesystem from . import exceptions LOG: logging.Logger = logging.getLogger(__name__) +dist_info_in_root: Dict[str, List[str]] = {} + +_site_filter = re.compile(r".*-([0-99]\.)*dist-info") # type: re.Pattern[str] + +_PYCACHE = re.compile("__pycache__(/)*.*") # type: re.Pattern[str] + def _expand_relative_root(path: str, relative_root: str) -> str: if not path.startswith("//"): @@ -35,11 +42,11 @@ def _expand_relative_root(path: str, relative_root: str) -> str: class Element(abc.ABC): @abc.abstractmethod - def path(self) -> str: + def path(self) -> Union[str, None]: raise NotImplementedError() @abc.abstractmethod - def command_line_argument(self) -> str: + def command_line_argument(self) -> Union[str, None]: raise NotImplementedError() @@ -72,15 +79,55 @@ class SitePackageElement(Element): package_name: str is_toplevel_module: bool = False - def package_path(self) -> str: - module_suffix = ".py" if self.is_toplevel_module else "" - return self.package_name + module_suffix + def package_path(self) -> Union[str, None]: + if not self.is_toplevel_module: + return self.package_name - def path(self) -> str: - return os.path.join(self.site_root, self.package_path()) + this_pkg_filter = re.compile( + r"{}-([0-99]\.)*dist-info(/)*.*".format(self.package_name) + ) # type: re.Pattern[str] - def command_line_argument(self) -> str: - return self.site_root + "$" + self.package_path() + if self.site_root not in dist_info_in_root: + dist_info_in_root[self.site_root] = [] + for directory in os.listdir(self.site_root): + if _site_filter.fullmatch(directory) is not None: + dist_info_in_root[self.site_root].append(directory) + + for directory in dist_info_in_root[self.site_root]: + if this_pkg_filter.fullmatch(directory): + dist_info_path = f"{self.site_root}/{directory}" + break + else: + return None + + not_toplevel_patterns = ( + this_pkg_filter, + _PYCACHE, + ) # type: Tuple[re.Pattern[str], re.Pattern[str]] + + # pyre-fixme[61]: Local variable `dist_info_path` is undefined, or not always defined. + with open(file=f"{dist_info_path}/RECORD", mode="r") as record: + for line in record.readlines(): + file_name = line.split(",")[0] + for pattern in not_toplevel_patterns: + if pattern.fullmatch(file_name): + break + else: + return file_name + + return None + + def path(self) -> Union[str, None]: + excepted_package_path: Union[str, None] = self.package_path() + if excepted_package_path is None: + return None + return os.path.join(self.site_root, excepted_package_path) + + def command_line_argument(self) -> Union[str, None]: + excepted_package_path: Union[str, None] = self.package_path() + if excepted_package_path is None: + return None + return self.site_root + "$" + excepted_package_path class RawElement(abc.ABC): @@ -223,10 +270,11 @@ def process_raw_elements( elements: List[Element] = [] def add_if_exists(element: Element) -> bool: - if os.path.exists(element.path()): - elements.append(element) - return True - return False + expected_path = element.path() + if expected_path is None or not os.path.exists(expected_path): + return False + elements.append(element) + return True for raw_element in raw_elements: expanded_raw_elements = raw_element.expand_glob() diff --git a/client/configuration/tests/search_path_test.py b/client/configuration/tests/search_path_test.py index 90a57b2586b..0c3c30516c8 100644 --- a/client/configuration/tests/search_path_test.py +++ b/client/configuration/tests/search_path_test.py @@ -70,24 +70,45 @@ def test_create_raw_element(self) -> None: ) def test_path(self) -> None: - self.assertEqual(SimpleElement("foo").path(), "foo") - self.assertEqual(SubdirectoryElement("foo", "bar").path(), "foo/bar") - self.assertEqual(SitePackageElement("foo", "bar").path(), "foo/bar") + with tempfile.TemporaryDirectory() as temp_root: + ensure_directories_exists(Path(temp_root), ("foo", "foo/bar")) + self.assertEqual(SimpleElement("foo").path(), "foo") + self.assertEqual(SubdirectoryElement("foo", "bar").path(), "foo/bar") + self.assertEqual( + SitePackageElement(f"{temp_root}/foo", "bar").path(), + f"{temp_root}/foo/bar", + ) def test_command_line_argument(self) -> None: - self.assertEqual(SimpleElement("foo").command_line_argument(), "foo") - self.assertEqual( - SubdirectoryElement("foo", "bar").command_line_argument(), - "foo$bar", - ) - self.assertEqual( - SitePackageElement("foo", "bar").command_line_argument(), - "foo$bar", - ) - self.assertEqual( - SitePackageElement("foo", "bar", True).command_line_argument(), - "foo$bar.py", - ) + with tempfile.TemporaryDirectory() as temp_root: + ensure_directories_exists(Path(temp_root), ("foo", "foo/bar")) + self.assertEqual(SimpleElement("foo").command_line_argument(), "foo") + self.assertEqual( + SubdirectoryElement("foo", "bar").command_line_argument(), + "foo$bar", + ) + self.assertEqual( + SitePackageElement(f"{temp_root}/foo", "bar").command_line_argument(), + f"{temp_root}/foo$bar", + ) + + with tempfile.TemporaryDirectory() as temp_root: + ensure_directories_exists( + Path(temp_root), ("foo", "foo/bar-1.0.0.dist-info") + ) + Path.touch(Path(f"{temp_root}/foo/bar.py")) + + with open( + f"{temp_root}/foo/bar-1.0.0.dist-info/RECORD", "w", encoding="UTF-8" + ) as f: + f.write("bar.py,,") + + self.assertEqual( + SitePackageElement( + f"{temp_root}/foo", "bar", True + ).command_line_argument(), + f"{temp_root}/foo$bar.py", + ) def test_expand_global_root(self) -> None: self.assertEqual( @@ -241,3 +262,26 @@ def test_process_required_raw_elements_site_package_nonexistence(self) -> None: site_roots=[], required=True, ) + + def test_toplevel_module_not_pyfile(self) -> None: + with tempfile.TemporaryDirectory() as temp_root: + ensure_directories_exists( + Path(temp_root), ("foo", "foo/bar-1.0.0.dist-info") + ) + Path.touch(Path(f"{temp_root}/foo/bar.so")) + + with open( + f"{temp_root}/foo/bar-1.0.0.dist-info/RECORD", "w", encoding="UTF-8" + ) as f: + f.write("bar.so,,") + + self.assertEqual( + SitePackageElement(f"{temp_root}/foo", "bar", True).path(), + f"{temp_root}/foo/bar.so", + ) + self.assertEqual( + process_raw_elements( + [SitePackageRawElement("bar", True)], [f"{temp_root}/foo"] + ), + [SitePackageElement(f"{temp_root}/foo", "bar", True)], + ) diff --git a/client/tests/backend_arguments_test.py b/client/tests/backend_arguments_test.py index a114fe2f891..e08ae09eb33 100644 --- a/client/tests/backend_arguments_test.py +++ b/client/tests/backend_arguments_test.py @@ -576,15 +576,20 @@ def test_get_source_path__confliciting_source_specified(self) -> None: ) def test_get_checked_directory_for_simple_source_path(self) -> None: - element0 = search_path.SimpleElement("ozzie") - element1 = search_path.SubdirectoryElement("diva", "flea") - element2 = search_path.SitePackageElement("super", "slash") - self.assertCountEqual( - SimpleSourcePath( - [element0, element1, element2, element0] - ).get_checked_directory_allowlist(), - [element0.path(), element1.path(), element2.path()], - ) + with tempfile.TemporaryDirectory() as temp_root: + Path.mkdir(Path(f"{temp_root}/super")) + Path.mkdir(Path(f"{temp_root}/super/slash")) + + element0 = search_path.SimpleElement("ozzie") + element1 = search_path.SubdirectoryElement("diva", "flea") + element2 = search_path.SitePackageElement(f"{temp_root}/super", "slash") + + self.assertCountEqual( + SimpleSourcePath( + [element0, element1, element2, element0] + ).get_checked_directory_allowlist(), + [element0.path(), element1.path(), element2.path()], + ) def test_get_checked_directory_for_buck_source_path(self) -> None: self.assertCountEqual(