diff --git a/specfile/specfile.py b/specfile/specfile.py index 92b8d8fb..d725afb2 100644 --- a/specfile/specfile.py +++ b/specfile/specfile.py @@ -1,13 +1,27 @@ # Copyright Contributors to the Packit project. # SPDX-License-Identifier: MIT +import copy import datetime import logging import re import types from dataclasses import dataclass +from io import FileIO, StringIO from pathlib import Path -from typing import Generator, List, Optional, Tuple, Type, Union, cast +from typing import ( + IO, + Any, + Dict, + Generator, + List, + Optional, + TextIO, + Tuple, + Type, + Union, + cast, +) import rpm @@ -31,6 +45,7 @@ from specfile.sources import Patches, Sources from specfile.spec_parser import SpecParser from specfile.tags import Tag, Tags +from specfile.types import EncodingArgs from specfile.value_parser import ( SUBSTITUTION_GROUP_PREFIX, ConditionalMacroExpansion, @@ -50,19 +65,27 @@ class Specfile: autosave: Whether to automatically save any changes made. """ + ENCODING_ARGS: EncodingArgs = {"encoding": "utf8", "errors": "surrogateescape"} + def __init__( self, - path: Union[Path, str], + path: Optional[Union[Path, str]] = None, + content: Optional[str] = None, + file: Optional[IO] = None, sourcedir: Optional[Union[Path, str]] = None, autosave: bool = False, macros: Optional[List[Tuple[str, Optional[str]]]] = None, force_parse: bool = False, ) -> None: """ - Initializes a specfile object. + Initializes a specfile object. You can specify either a path to the spec file, + its content as a string or a file object representing it. `sourcedir` is optional + if `path` or a named `file` is provided and will be set to the parent directory. Args: path: Path to the spec file. + content: String containing the spec file content. + file: File object representing the spec file. sourcedir: Path to sources and patches. autosave: Whether to automatically save any changes made. macros: List of extra macro definitions. @@ -71,12 +94,29 @@ def __init__( Such sources include sources referenced from shell expansions in tag values and sources included using the _%include_ directive. """ + # count mutually exclusive arguments + if sum([file is not None, path is not None, content is not None]) > 1: + raise ValueError( + "Only one of `file`, `path` or `content` should be provided" + ) + if file is not None: + self._file = file + elif path is not None: + self._file = Path(path).open("r+", **self.ENCODING_ARGS) + elif content is not None: + self._file = StringIO(content) + else: + raise ValueError("Either `file`, `path` or `content` must be provided") + if sourcedir is None: + try: + sourcedir = Path(self._file.name).parent + except AttributeError: + raise ValueError( + "`sourcedir` is required when providing `content` or file object without a name" + ) self.autosave = autosave - self._path = Path(path) - self._lines, self._trailing_newline = self._read_lines(self._path) - self._parser = SpecParser( - Path(sourcedir or self.path.parent), macros, force_parse - ) + self._lines, self._trailing_newline = self._read_lines(self._file) + self._parser = SpecParser(Path(sourcedir), macros, force_parse) self._parser.parse(str(self)) self._dump_debug_info("After initial parsing") @@ -85,7 +125,7 @@ def __eq__(self, other: object) -> bool: return NotImplemented return ( self.autosave == other.autosave - and self._path == other._path + and self.path == other.path and self._lines == other._lines and self._parser == other._parser ) @@ -111,6 +151,39 @@ def __exit__( ) -> None: self.save() + def __deepcopy__(self, memodict: Dict[int, Any]): + """ + Deepcopies the object, handling file-like attributes. + """ + specfile = self.__class__.__new__(self.__class__) + memodict[id(self)] = specfile + + for k, v in self.__dict__.items(): + if k == "_file": + continue + setattr(specfile, k, copy.deepcopy(v, memodict)) + + try: + path = Path(cast(FileIO, self._file).name) + except AttributeError: + # IO doesn't implement getvalue() so tell mypy this is StringIO + # (could also be BytesIO) + sio = cast(StringIO, self._file) + # not a named file, try `getvalue()` + specfile._file = type(sio)(sio.getvalue()) + else: + try: + # encoding and errors are only available on TextIO objects + file = cast(TextIO, self._file) + specfile._file = path.open( + mode=file.mode, encoding=file.encoding, errors=file.errors + ) + except AttributeError: + # files open in binary mode have no `encoding`/`errors` + specfile._file = path.open(self._file.mode) + + return specfile + def _dump_debug_info(self, message) -> None: logger.debug( f"DBG: {message}:\n" @@ -119,19 +192,30 @@ def _dump_debug_info(self, message) -> None: f" {self._parser.spec!r} @ 0x{id(self._parser.spec):012x}" ) - @staticmethod - def _read_lines(path: Path) -> Tuple[List[str], bool]: - content = path.read_text(encoding="utf8", errors="surrogateescape") - return content.splitlines(), content[-1] == "\n" + @classmethod + def _read_lines(cls, file: IO) -> Tuple[List[str], bool]: + file.seek(0) + raw_content = file.read() + if isinstance(raw_content, str): + content = raw_content + else: + content = raw_content.decode(**cls.ENCODING_ARGS) + return content.splitlines(), content.endswith("\n") @property - def path(self) -> Path: + def path(self) -> Optional[Path]: """Path to the spec file.""" - return self._path + try: + return Path(cast(FileIO, self._file).name) + except AttributeError: + return None @path.setter def path(self, value: Union[Path, str]) -> None: - self._path = Path(value) + path = Path(value) + if path == self.path: + return + self._file = path.open("r+", **self.ENCODING_ARGS) @property def sourcedir(self) -> Path: @@ -179,11 +263,26 @@ def rpm_spec(self) -> rpm.spec: def reload(self) -> None: """Reloads the spec file content.""" - self._lines, self._trailing_newline = self._read_lines(self.path) + try: + path = Path(cast(FileIO, self._file).name) + except AttributeError: + pass + else: + # reopen the path in case the original file has been deleted/replaced + self._file.close() + self._file = path.open("r+", **self.ENCODING_ARGS) + self._lines, self._trailing_newline = self._read_lines(self._file) def save(self) -> None: """Saves the spec file content.""" - self.path.write_text(str(self), encoding="utf8", errors="surrogateescape") + self._file.seek(0) + self._file.truncate(0) + content = str(self) + try: + self._file.write(content) + except TypeError: + self._file.write(content.encode(**self.ENCODING_ARGS)) + self._file.flush() def expand( self, diff --git a/specfile/types.py b/specfile/types.py index c2658917..06d9fab5 100644 --- a/specfile/types.py +++ b/specfile/types.py @@ -12,3 +12,12 @@ class SupportsIndex(Protocol, metaclass=abc.ABCMeta): # type: ignore [no-redef] @abc.abstractmethod def __index__(self) -> int: ... + + +try: + from typing import TypedDict +except ImportError: + from typing_extensions import TypedDict + + +EncodingArgs = TypedDict("EncodingArgs", {"encoding": str, "errors": str}) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index ad5ec20a..dace7ba2 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,10 +1,11 @@ # Copyright Contributors to the Packit project. # SPDX-License-Identifier: MIT - +import io import shutil import pytest +from specfile import Specfile from tests.constants import ( SPEC_AUTOPATCH, SPEC_AUTOSETUP, @@ -136,3 +137,46 @@ def spec_conditionalized_version(tmp_path): specfile_path = tmp_path / SPECFILE shutil.copyfile(SPEC_CONDITIONALIZED_VERSION / SPECFILE, specfile_path) return specfile_path + + +@pytest.fixture( + params=[ + "file_path", + "text_file", + "binary_file", + "text_io_stream", + "binary_io_stream", + "content_string", + ] +) +def specfile_factory(request): + """ + pytest fixture to create a `Specfile` instance with different input modes. + + Returns: + Function that creates a `Specfile` instance. + """ + mode = request.param + + def _create_specfile(input_path, **kwargs): + kwargs.setdefault("sourcedir", input_path.parent) + + if mode == "file_path": + return Specfile(path=input_path, **kwargs) + elif mode == "text_file": + f = open(input_path, "r+", **Specfile.ENCODING_ARGS) + return Specfile(file=f, **kwargs) + elif mode == "binary_file": + f = open(input_path, "rb+") + return Specfile(file=f, **kwargs) + elif mode == "text_io_stream": + content = input_path.read_text(**Specfile.ENCODING_ARGS) + return Specfile(file=io.StringIO(content), **kwargs) + elif mode == "binary_io_stream": + content = input_path.read_bytes() + return Specfile(file=io.BytesIO(content), **kwargs) + elif mode == "content_string": + content = input_path.read_text(**Specfile.ENCODING_ARGS) + return Specfile(content=content, **kwargs) + + return _create_specfile diff --git a/tests/integration/test_specfile.py b/tests/integration/test_specfile.py index a5c4c370..965cfa29 100644 --- a/tests/integration/test_specfile.py +++ b/tests/integration/test_specfile.py @@ -26,8 +26,8 @@ def test_parse(spec_multiple_sources): assert spec.rpm_spec.prep == prep -def test_prep_traditional(spec_traditional): - spec = Specfile(spec_traditional) +def test_prep_traditional(specfile_factory, spec_traditional): + spec = specfile_factory(spec_traditional) with spec.prep() as prep: assert AutosetupMacro not in prep.macros assert AutopatchMacro not in prep.macros @@ -50,8 +50,8 @@ def test_prep_traditional(spec_traditional): assert sections.prep[1] == "%patch 0 -p2 -b .test2 -E" -def test_prep_autosetup(spec_autosetup): - spec = Specfile(spec_autosetup) +def test_prep_autosetup(specfile_factory, spec_autosetup): + spec = specfile_factory(spec_autosetup) with spec.prep() as prep: assert len(prep.macros) == 1 assert AutosetupMacro in prep.macros @@ -60,8 +60,8 @@ def test_prep_autosetup(spec_autosetup): assert prep.autosetup.options.p == 1 -def test_prep_autopatch(spec_autopatch): - spec = Specfile(spec_autopatch) +def test_prep_autopatch(specfile_factory, spec_autopatch): + spec = specfile_factory(spec_autopatch) with spec.prep() as prep: assert len(prep.macros) == 4 assert prep.macros[1].options.M == 2 @@ -75,8 +75,8 @@ def test_prep_autopatch(spec_autopatch): assert sections.prep[3] == "%autopatch -p1 0 1 2 3 4 5 6" -def test_sources(spec_minimal): - spec = Specfile(spec_minimal) +def test_sources(specfile_factory, spec_minimal): + spec = specfile_factory(spec_minimal) source = "test.tar.gz" with spec.sources() as sources: assert not sources @@ -93,8 +93,8 @@ def test_sources(spec_minimal): assert not sources -def test_patches(spec_patchlist): - spec = Specfile(spec_patchlist) +def test_patches(specfile_factory, spec_patchlist): + spec = specfile_factory(spec_patchlist) patch = "test.patch" with spec.patches() as patches: patches.insert(0, patch) @@ -214,19 +214,13 @@ def test_patches(spec_patchlist): ], ) def test_add_changelog_entry( - spec_minimal, - entry, - author, - email, - timestamp, - evr, - result, + specfile_factory, spec_minimal, entry, author, email, timestamp, evr, result ): if author is None: flexmock(specfile.specfile).should_receive("guess_packager").and_return( "John Doe " ).once() - spec = Specfile(spec_minimal) + spec = specfile_factory(spec_minimal) spec.add_changelog_entry(entry, author, email, timestamp, evr) with spec.sections() as sections: assert sections.changelog[: len(result)] == result @@ -240,8 +234,8 @@ def test_add_changelog_entry( ("1.4.6", "0.1rc5"), ], ) -def test_set_version_and_release(spec_minimal, version, release): - spec = Specfile(spec_minimal) +def test_set_version_and_release(specfile_factory, spec_minimal, version, release): + spec = specfile_factory(spec_minimal) spec.set_version_and_release(version, release) assert spec.version == version assert spec.release == release @@ -266,8 +260,8 @@ def test_set_version_and_release(spec_minimal, version, release): ("patch3.patch", 3, "patch3"), ], ) -def test_add_patch(spec_autosetup, location, number, comment): - spec = Specfile(spec_autosetup) +def test_add_patch(specfile_factory, spec_autosetup, location, number, comment): + spec = specfile_factory(spec_autosetup) if number == 0 or location == "patch2.patch": with pytest.raises(SpecfileException): spec.add_patch(location, number, comment) @@ -284,8 +278,8 @@ def test_add_patch(spec_autosetup, location, number, comment): assert sections.package[-4] == f"# {comment}" -def test_remove_patches(spec_commented_patches): - spec = Specfile(spec_commented_patches) +def test_remove_patches(specfile_factory, spec_commented_patches): + spec = specfile_factory(spec_commented_patches) with spec.patches() as patches: del patches[1:3] patches.remove_numbered(5) @@ -320,8 +314,8 @@ def test_remove_patches(spec_commented_patches): ("%{obsrel}.%{autorelease}", True), ], ) -def test_autorelease(spec_rpmautospec, raw_release, has_autorelease): - spec = Specfile(spec_rpmautospec) +def test_autorelease(specfile_factory, spec_rpmautospec, raw_release, has_autorelease): + spec = specfile_factory(spec_rpmautospec) spec.raw_release = raw_release assert spec.has_autorelease == has_autorelease @@ -330,9 +324,9 @@ def test_autorelease(spec_rpmautospec, raw_release, has_autorelease): rpm.__version__ < "4.16", reason="%autochangelog requires rpm 4.16 or higher" ) def test_autochangelog( - spec_rpmautospec, spec_conditionalized_changelog, spec_autosetup + specfile_factory, spec_rpmautospec, spec_conditionalized_changelog, spec_autosetup ): - spec = Specfile(spec_rpmautospec) + spec = specfile_factory(spec_rpmautospec) assert spec.has_autochangelog with spec.changelog() as changelog: assert len(changelog) == 0 @@ -341,7 +335,7 @@ def test_autochangelog( spec.add_changelog_entry("test") with spec.sections() as sections: assert sections.changelog == changelog - spec = Specfile(spec_conditionalized_changelog) + spec = specfile_factory(spec_conditionalized_changelog) assert spec.has_autochangelog with spec.sections() as sections: changelog = sections.changelog.copy() @@ -352,7 +346,7 @@ def test_autochangelog( assert changelogs[0] == changelog with spec.changelog(changelogs[1]) as changelog: assert changelog[-1].content == ["test"] - spec = Specfile(spec_autosetup) + spec = specfile_factory(spec_autosetup) with spec.changelog() as changelog: changelog[0].content += "%" assert not spec.has_autochangelog @@ -362,8 +356,8 @@ def test_autochangelog( rpm.__version__ < "4.16", reason="condition expression evaluation requires rpm 4.16 or higher", ) -def test_update_tag(spec_macros): - spec = Specfile(spec_macros) +def test_update_tag(specfile_factory, spec_macros): + spec = specfile_factory(spec_macros) spec.update_tag("Version", "1.2.3~beta4") with spec.macro_definitions() as md: assert md.majorver.body == "1" @@ -439,7 +433,7 @@ def test_update_tag(spec_macros): assert md.minorver.body == "2" with spec.sources() as sources: assert sources[1].location == "tests-86.tar.xz" - spec = Specfile(spec_macros, macros=[("use_snapshot", "1")]) + spec = specfile_factory(spec_macros, macros=[("use_snapshot", "1")]) spec.update_tag("Version", "3.2.1") with spec.macro_definitions() as md: assert md.majorver.body == "0" @@ -450,7 +444,7 @@ def test_update_tag(spec_macros): assert md.get("package_version", 13).body == "%{mainver}%{?prever:~%{prever}}" assert md.get("package_version", 15).body == "3.2.1" assert spec.version == "%{package_version}" - spec = Specfile(spec_macros) + spec = specfile_factory(spec_macros) spec.update_tag("Version", "1.2.3.4~rc5") with spec.macro_definitions() as md: assert md.majorver.body == "1.2" @@ -459,7 +453,7 @@ def test_update_tag(spec_macros): assert md.mainver.body == "%{majorver}.%{minorver}.%{patchver}" assert md.prever.body == "rc5" assert spec.version == "%{package_version}" - spec = Specfile(spec_macros) + spec = specfile_factory(spec_macros) with spec.macro_definitions() as md: md.prever.commented_out = True assert spec.expanded_version == "0.1.2" @@ -474,9 +468,9 @@ def test_update_tag(spec_macros): assert spec.version == "%{package_version}" -def test_multiple_instances(spec_minimal, spec_autosetup): +def test_multiple_instances(specfile_factory, spec_minimal, spec_autosetup): spec1 = Specfile(spec_minimal) - spec2 = Specfile(spec_autosetup) + spec2 = specfile_factory(spec_autosetup) spec1.version = "14.2" assert spec2.expanded_version == "0.1" with spec2.sources() as sources: @@ -485,8 +479,8 @@ def test_multiple_instances(spec_minimal, spec_autosetup): assert sources[1].expanded_location == "tests-0.1.tar.xz" -def test_includes(spec_includes): - spec = Specfile(spec_includes) +def test_includes(specfile_factory, spec_includes): + spec = specfile_factory(spec_includes) assert not spec.tainted with spec.patches() as patches: assert not patches @@ -503,8 +497,8 @@ def test_includes(spec_includes): for inc in ["patches.inc", "provides.inc", "description1.inc", "description2.inc"]: (spec.sourcedir / inc).unlink() with pytest.raises(RPMException): - spec = Specfile(spec_includes) - spec = Specfile(spec_includes, force_parse=True) + spec = specfile_factory(spec_includes) + spec = specfile_factory(spec_includes, force_parse=True) assert spec.tainted with spec.patches() as patches: assert not patches @@ -518,18 +512,18 @@ def test_includes(spec_includes): for inc in ["macros1.inc", "macros2.inc"]: (spec.sourcedir / inc).unlink() with pytest.raises(RPMException): - spec = Specfile(spec_includes, force_parse=True) + spec = specfile_factory(spec_includes, force_parse=True) assert not (spec.sourcedir / inc).is_file() -def test_shell_expansions(spec_shell_expansions): - spec = Specfile(spec_shell_expansions) +def test_shell_expansions(specfile_factory, spec_shell_expansions): + spec = specfile_factory(spec_shell_expansions) assert spec.expanded_version == "1035.4200" assert "C.UTF-8" in spec.expand("%numeric_locale") -def test_context_management(spec_autosetup, spec_traditional): - spec = Specfile(spec_autosetup) +def test_context_management(specfile_factory, spec_autosetup, spec_traditional): + spec = specfile_factory(spec_autosetup) with spec.tags() as tags: tags.license.value = "BSD" assert spec.license == "BSD" @@ -541,7 +535,7 @@ def test_context_management(spec_autosetup, spec_traditional): assert spec.license == "BSD-3-Clause" with spec.patches() as patches: assert patches[0].location == "patch_0.patch" - spec1 = Specfile(spec_autosetup) + spec1 = specfile_factory(spec_autosetup) spec2 = Specfile(spec_traditional) with spec1.sections() as sections1, spec2.sections() as sections2: assert sections1 is not sections2 @@ -550,8 +544,8 @@ def test_context_management(spec_autosetup, spec_traditional): assert tags1 == tags2 -def test_copy(spec_autosetup): - spec = Specfile(spec_autosetup) +def test_copy(specfile_factory, spec_autosetup): + spec = specfile_factory(spec_autosetup) shallow_copy = copy.copy(spec) assert shallow_copy == spec assert shallow_copy is not spec @@ -564,7 +558,7 @@ def test_copy(spec_autosetup): assert deep_copy._parser is not spec._parser -def test_parse_if_necessary(spec_macros): +def test_parse_if_necessary(specfile_factory, spec_macros): flexmock(SpecParser).should_call("_do_parse").once() spec1 = Specfile(spec_macros) spec2 = copy.deepcopy(spec1) @@ -582,12 +576,12 @@ def test_parse_if_necessary(spec_macros): assert spec1.expanded_version == "28.1.2~rc2" flexmock(SpecParser).should_receive("id").and_return(12345) flexmock(SpecParser).should_call("_do_parse").once() - spec = Specfile(spec_macros) + spec = specfile_factory(spec_macros) flexmock(SpecParser).should_call("_do_parse").never() assert spec.expanded_name == "test" spec = None flexmock(SpecParser).should_call("_do_parse").once() - spec = Specfile(spec_macros) + spec = specfile_factory(spec_macros) flexmock(SpecParser).should_call("_do_parse").never() assert spec.expanded_name == "test" @@ -597,9 +591,9 @@ def test_parse_if_necessary(spec_macros): reason="condition expression evaluation requires rpm 4.16 or higher", ) def test_update_version( - spec_prerelease, spec_prerelease2, spec_conditionalized_version + specfile_factory, spec_prerelease, spec_prerelease2, spec_conditionalized_version ): - spec = Specfile(spec_prerelease) + spec = specfile_factory(spec_prerelease) prerelease_suffix_pattern = r"(-)rc\d+" prerelease_suffix_macro = "prerel" spec.update_version("0.1.2", prerelease_suffix_pattern, prerelease_suffix_macro) @@ -620,7 +614,7 @@ def test_update_version( assert md.prerel.body == "rc1" assert not md.prerel.commented_out assert spec.version == "%{pkgver}" - spec = Specfile(spec_prerelease) + spec = specfile_factory(spec_prerelease) with spec.macro_definitions() as md: md.prerel.commented_out = True spec.update_version("0.1.3-rc1", prerelease_suffix_pattern) @@ -632,7 +626,7 @@ def test_update_version( assert md.prerel.body == "rc2" assert md.prerel.commented_out assert spec.version == "%{pkgver}" - spec = Specfile(spec_prerelease2) + spec = specfile_factory(spec_prerelease2) prerelease_suffix_pattern = r"(-)rc\d+" prerelease_suffix_macro = "prerel" spec.update_version("0.1.2", prerelease_suffix_pattern, prerelease_suffix_macro) @@ -653,7 +647,7 @@ def test_update_version( assert md.prerel.body == "rc1" assert not md.prerel.commented_out assert spec.version == "%{pkgver}" - spec = Specfile(spec_prerelease2) + spec = specfile_factory(spec_prerelease2) with spec.macro_definitions() as md: md.prerel.commented_out = True spec.update_version("0.1.3-rc1", prerelease_suffix_pattern) @@ -665,7 +659,7 @@ def test_update_version( assert md.prerel.body == "rc1" assert not md.prerel.commented_out assert spec.version == "%{pkgver}" - spec = Specfile(spec_conditionalized_version) + spec = specfile_factory(spec_conditionalized_version) version = "0.1.3" assert spec.version == "%{upstream_version}" spec.update_version(version, prerelease_suffix_pattern) @@ -673,7 +667,7 @@ def test_update_version( assert md.upstream_version.body == version assert spec.version == "%{upstream_version}" assert spec.expanded_version == version - spec = Specfile(spec_conditionalized_version) + spec = specfile_factory(spec_conditionalized_version) with spec.macro_definitions() as md: md.commit.commented_out = False assert spec.version == "%{upstream_version}^git%{shortcommit}" @@ -684,8 +678,35 @@ def test_update_version( assert spec.expanded_version == version -def test_trailing_newline(spec_autosetup, spec_no_trailing_newline): - spec = Specfile(spec_autosetup) +def test_trailing_newline(specfile_factory, spec_autosetup, spec_no_trailing_newline): + spec = specfile_factory(spec_autosetup) assert str(spec)[-1] == "\n" - spec = Specfile(spec_no_trailing_newline) + spec = specfile_factory(spec_no_trailing_newline) assert str(spec)[-1] != "\n" + + +@pytest.mark.parametrize("remove_spec", [False, True]) +def test_reload(specfile_factory, spec_minimal, spec_traditional, remove_spec): + spec = specfile_factory(spec_minimal) + before_reload = copy.deepcopy(spec) + + if spec.path is not None: + if remove_spec: + spec_minimal.unlink() + spec_minimal.write_bytes(spec_traditional.read_bytes()) + spec.reload() + after_reload = spec + else: + spec._file.seek(0) + spec._file.truncate() + content = spec_traditional.read_text(**Specfile.ENCODING_ARGS) + try: + spec._file.write(content) + except TypeError: + spec._file.write(content.encode(**Specfile.ENCODING_ARGS)) + spec._file.flush() + spec._file.seek(0) + spec.reload() + after_reload = spec + + assert str(before_reload) != str(after_reload)