diff --git a/docs/notes/2.33.x.md b/docs/notes/2.33.x.md index 7f0adc74bdf..b014875fe09 100644 --- a/docs/notes/2.33.x.md +++ b/docs/notes/2.33.x.md @@ -34,6 +34,8 @@ Thank you to [Klaviyo](https://www.klaviyo.com/) for their Platinum tier support #### Go +Third-party module analysis is now deduplicated across `go.mod` files. Previously, a module required by `N` `go.mod` files was downloaded and analyzed `N` times, which caused significant memory and time overhead in monorepos with many overlapping `go.mod` files. On a 3-`go.mod` reproducer, `pants list ::` peak memory dropped from 91 GB to 32 GB (-65%). This is a no-op for repos with a single `go.mod`. See [#20274](https://github.com/pantsbuild/pants/issues/20274). + ### Plugin API changes ## Full Changelog diff --git a/src/python/pants/backend/go/util_rules/third_party_pkg.py b/src/python/pants/backend/go/util_rules/third_party_pkg.py index f1a5a470650..fd95ece6773 100644 --- a/src/python/pants/backend/go/util_rules/third_party_pkg.py +++ b/src/python/pants/backend/go/util_rules/third_party_pkg.py @@ -4,7 +4,6 @@ from __future__ import annotations import dataclasses -import difflib import json import logging import os @@ -14,7 +13,6 @@ import ijson.backends.python as ijson from pants.backend.go.go_sources.load_go_binary import LoadedGoBinaryRequest, setup_go_binary -from pants.backend.go.target_types import GoModTarget from pants.backend.go.util_rules import pkg_analyzer from pants.backend.go.util_rules.build_opts import GoBuildOptions from pants.backend.go.util_rules.cgo import CGoCompilerFlags @@ -157,15 +155,33 @@ class ModuleDescriptors: @dataclass(frozen=True) -class AnalyzeThirdPartyModuleRequest: - go_mod_address: Address - go_mod_digest: Digest - go_mod_path: str - import_path: str +class ModuleDownloadRequest: + """Download and analyze a Go module, keyed by (name, version, minimum_go_version, + build_opts, go_sum_entries). + + This enables cross-go.mod deduplication: if mod-a and mod-b both depend on + grpc@v1.60.0 with the same go.sum entries, the download and analysis only + happens once because the Pants engine memoizes by the full request key. + + ``go_sum_entries`` carries the two go.sum lines for `` `` and + `` /go.mod`` extracted from the consuming go.mod's real + go.sum. These entries are content-addressable by design: two well-formed + go.sums MUST agree on them for the same module@version. Including them in + the dedup key has two effects: + + 1. Happy path: all consumers of module@version share one download, and the + synthetic go.sum written into the sandbox lets Go perform its normal + checksum verification (no GONOSUMCHECK override). + 2. Tampered path: if one go.sum disagrees, the two consumers produce + distinct requests -- each verified independently against its own + entries -- and the tampered one fails with Go's usual SECURITY ERROR. + """ + name: str version: str minimum_go_version: str | None build_opts: GoBuildOptions + go_sum_entries: tuple[str, ...] @dataclass(frozen=True) @@ -277,6 +293,45 @@ def strip_sandbox_prefix(path: str, marker: str) -> str: return path +def _parse_go_sum(go_sum_content: bytes) -> dict[tuple[str, str], tuple[str, ...]]: + """Parse a go.sum file into a dict keyed by (module name, version). + + A well-formed go.sum has up to two lines per (module, version): + + h1:= + /go.mod h1:= + + Returns a dict mapping (name, version) to a tuple of the matching lines, + enabling O(1) lookup per module instead of re-scanning the file. + """ + entries: dict[tuple[str, str], list[str]] = {} + for line in go_sum_content.decode("utf-8").splitlines(): + if not line: + continue + parts = line.split(" ", 2) + if len(parts) < 3: + continue + name = parts[0] + version_field = parts[1] + # Strip the "/go.mod" suffix to get the base version for grouping. + version = version_field.removesuffix("/go.mod") + key = (name, version) + entries.setdefault(key, []).append(line) + return {k: tuple(v) for k, v in entries.items()} + + +def _extract_go_sum_entries_for_module( + go_sum_content: bytes, name: str, version: str +) -> tuple[str, ...]: + """Return the go.sum lines for a given module@version. + + Thin wrapper around _parse_go_sum for callers that only need one module. + Prefer _parse_go_sum when looking up multiple modules from the same go.sum. + """ + parsed = _parse_go_sum(go_sum_content) + return parsed.get((name, version), ()) + + def _freeze_json_dict(d: dict[Any, Any]) -> FrozenDict[str, Any]: result = {} for k, v in d.items(): @@ -296,48 +351,6 @@ def _freeze_json_dict(d: dict[Any, Any]) -> FrozenDict[str, Any]: return FrozenDict(result) -async def _check_go_sum_has_not_changed( - input_digest: Digest, - output_digest: Digest, - dir_path: str, - import_path: str, - go_mod_address: Address, -) -> None: - input_entries, output_entries = await concurrently( - get_digest_contents(input_digest), - get_digest_contents(output_digest), - ) - - go_sum_path = os.path.join(dir_path, "go.sum") - - input_go_sum_entry: bytes | None = None - for entry in input_entries: - if entry.path == go_sum_path: - input_go_sum_entry = entry.content - - output_go_sum_entry: bytes | None = None - for entry in output_entries: - if entry.path == go_sum_path: - output_go_sum_entry = entry.content - - if input_go_sum_entry is not None or output_go_sum_entry is not None: - if input_go_sum_entry != output_go_sum_entry: - go_sum_diff = list( - difflib.unified_diff( - (input_go_sum_entry or b"").decode().splitlines(), - (output_go_sum_entry or b"").decode().splitlines(), - ) - ) - go_sum_diff_rendered = "\n".join(line.rstrip() for line in go_sum_diff) - raise ValueError( - f"For `{GoModTarget.alias}` target `{go_mod_address}`, the go.sum file is incomplete " - f"because it was updated while processing third-party dependency `{import_path}`. " - "Please re-generate the go.sum file by running `go mod download all` in the module directory. " - "(Pants does not currently have support for updating the go.sum checksum database itself.)\n\n" - f"Diff:\n{go_sum_diff_rendered}" - ) - - @rule async def analyze_go_third_party_package( request: AnalyzeThirdPartyPackageRequest, @@ -472,21 +485,57 @@ async def analyze_go_third_party_package( @rule -async def analyze_go_third_party_module( - request: AnalyzeThirdPartyModuleRequest, +async def download_and_analyze_module( + request: ModuleDownloadRequest, analyzer: PackageAnalyzerSetup, ) -> AnalyzedThirdPartyModule: - # Download the module. + """Download and analyze a single Go module via a synthetic go.mod + go.sum. + + Keyed by (name, version, minimum_go_version, build_opts, go_sum_entries), + which lets the Pants engine deduplicate identical module downloads across + go.mods. + + A synthetic go.mod + go.sum pair is written into the sandbox so that Go's + normal checksum verification still runs -- the go.sum entries come straight + from the consuming go.mod's real go.sum (see ModuleDownloadRequest for the + full argument for why this is safe). + """ + # Create a synthetic go.mod (and go.sum when entries are available) that + # only requires this one module. When the consuming go.sum contains the + # entries for this module@version, we emit them verbatim so `go mod + # download` performs its usual local checksum verification. When they + # are absent (e.g., a transitive discovered during MVS that the consumer's + # go.sum hasn't recorded yet, or a go.sum that is entirely missing), we + # omit the synthetic go.sum and let Go fall back to GOSUMDB + # (sum.golang.org by default) for verification. This is a softer signal + # than the pre-dedup rule, which raised an error pointing the user at + # `go mod download all`; the warning below preserves that guidance. + if not request.go_sum_entries: + logger.warning( + "No go.sum entries found for %s@%s; falling back to GOSUMDB for " + "checksum verification. This usually means the consuming go.mod's " + "go.sum is incomplete -- run `go mod download all` (or `go mod " + "tidy`) in that module's directory to record the checksum locally.", + request.name, + request.version, + ) + go_version = request.minimum_go_version or "1.21" + synthetic_go_mod = ( + f"module synthetic.invalid\n\ngo {go_version}\n\nrequire {request.name} {request.version}\n" + ) + synthetic_files = [FileContent("go.mod", synthetic_go_mod.encode())] + if request.go_sum_entries: + synthetic_go_sum = "\n".join(request.go_sum_entries) + "\n" + synthetic_files.append(FileContent("go.sum", synthetic_go_sum.encode())) + synthetic_digest = await create_digest(CreateDigest(synthetic_files)) + download_result = await fallible_to_exec_result_or_raise( **implicitly( GoSdkProcess( ("mod", "download", "-json", f"{request.name}@{request.version}"), - input_digest=request.go_mod_digest, # for go.sum - working_dir=os.path.dirname(request.go_mod_path), - # Allow downloads of the module sources. + input_digest=synthetic_digest, allow_downloads=True, output_directories=("gopath",), - output_files=(os.path.join(os.path.dirname(request.go_mod_path), "go.sum"),), description=f"Download Go module {request.name}@{request.version}.", ) ) @@ -497,20 +546,10 @@ async def analyze_go_third_party_module( f"Expected output from `go mod download` for {request.name}@{request.version}." ) - # Make sure go.sum has not changed. - await _check_go_sum_has_not_changed( - input_digest=request.go_mod_digest, - output_digest=download_result.output_digest, - dir_path=os.path.dirname(request.go_mod_path), - import_path=request.import_path, - go_mod_address=request.go_mod_address, - ) - module_metadata = json.loads(download_result.stdout) module_sources_relpath = strip_sandbox_prefix(module_metadata["Dir"], "gopath/") go_mod_relpath = strip_sandbox_prefix(module_metadata["GoMod"], "gopath/") - # Subset the output directory to just the module sources and go.mod (which may be generated). module_sources_snapshot = await digest_to_snapshot( **implicitly( DigestSubset( @@ -525,7 +564,6 @@ async def analyze_go_third_party_module( ) ) - # Determine directories with potential Go packages in them. candidate_package_dirs = [] files_by_dir = group_by_dir( p for p in module_sources_snapshot.files if p.startswith(module_sources_relpath) @@ -535,13 +573,10 @@ async def analyze_go_third_party_module( # See https://github.com/golang/go/blob/f005df8b582658d54e63d59953201299d6fee880/src/go/build/build.go#L580-L585 if "testdata" in maybe_pkg_dir.split("/"): continue - - # Consider directories with at least one `.go` file as package candidates. if any(f for f in files if f.endswith(".go")): candidate_package_dirs.append(maybe_pkg_dir) candidate_package_dirs.sort() - # Analyze all of the packages in this module. analyzer_relpath = "__analyzer" analysis_result = await fallible_to_exec_result_or_raise( **implicitly( @@ -595,17 +630,32 @@ async def download_and_analyze_third_party_packages( ) ) + # Read the real go.sum once so we can extract per-module entries for the + # download sandbox. This keeps Go's checksum verification intact while + # allowing the engine to memoize identical module@version downloads + # across different go.mods. + go_sum_path = os.path.join(os.path.dirname(request.go_mod_path), "go.sum") + digest_contents = await get_digest_contents(request.go_mod_digest) + go_sum_content = b"" + for entry in digest_contents: + if entry.path == go_sum_path: + go_sum_content = entry.content + break + + # Parse the go.sum once into a dict for O(1) lookup per module. + go_sum_index = _parse_go_sum(go_sum_content) + + # The engine memoizes by (name, version, minimum_go_version, build_opts, + # go_sum_entries), so identical modules across go.mods are downloaded + # once -- reducing downloads from O(N*M) to O(M). analyzed_modules = await concurrently( - analyze_go_third_party_module( - AnalyzeThirdPartyModuleRequest( - go_mod_address=request.go_mod_address, - go_mod_digest=request.go_mod_digest, - go_mod_path=request.go_mod_path, - import_path=mod.name, + download_and_analyze_module( + ModuleDownloadRequest( name=mod.name, version=mod.version, minimum_go_version=mod.minimum_go_version, build_opts=request.build_opts, + go_sum_entries=go_sum_index.get((mod.name, mod.version), ()), ), **implicitly(), ) diff --git a/src/python/pants/backend/go/util_rules/third_party_pkg_test.py b/src/python/pants/backend/go/util_rules/third_party_pkg_test.py index 019ce0a5d06..21321edc7b7 100644 --- a/src/python/pants/backend/go/util_rules/third_party_pkg_test.py +++ b/src/python/pants/backend/go/util_rules/third_party_pkg_test.py @@ -4,7 +4,6 @@ from __future__ import annotations import os.path -import re from textwrap import dedent import pytest @@ -30,10 +29,11 @@ ModuleDescriptorsRequest, ThirdPartyPkgAnalysis, ThirdPartyPkgAnalysisRequest, + _extract_go_sum_entries_for_module, + _parse_go_sum, ) from pants.build_graph.address import Address from pants.engine.fs import Digest, Snapshot -from pants.engine.internals.scheduler import ExecutionError from pants.engine.process import ProcessExecutionFailure from pants.engine.rules import QueryRule from pants.testutil.rule_runner import RuleRunner, engine_error @@ -582,37 +582,6 @@ def test_ambiguous_package(rule_runner: RuleRunner) -> None: assert "encode.go" in pkg_info.go_files -def test_go_sum_with_missing_entries_triggers_error(rule_runner: RuleRunner) -> None: - digest = set_up_go_mod( - rule_runner, - dedent( - """\ - module example.com/third-party-module - go 1.16 - require github.com/google/uuid v1.3.0 - """ - ), - "", - ) - msg = ( - "For `go_mod` target `fake_addr_for_test:mod`, the go.sum file is incomplete because " - "it was updated while processing third-party dependency `github.com/google/uuid`." - ) - with pytest.raises(ExecutionError, match=re.escape(msg)): - _ = rule_runner.request( - ThirdPartyPkgAnalysis, - [ - ThirdPartyPkgAnalysisRequest( - "github.com/ugorji/go/codec", - Address("fake_addr_for_test", target_name="mod"), - digest, - "go.mod", - build_opts=GoBuildOptions(), - ) - ], - ) - - def test_local_path_replace_statement_is_not_considered_third_party( rule_runner: RuleRunner, ) -> None: @@ -638,3 +607,134 @@ def test_local_path_replace_statement_is_not_considered_third_party( ) # The module replaced to a local path should not be considered for third party analysis. assert len(module_analysis.modules) == 0 + + +UUID_GO_SUM = dedent( + """\ + github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= + github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= + """ +) + + +def test_cross_go_mod_dedup_produces_identical_results(rule_runner: RuleRunner) -> None: + """Two go.mods depending on the same module@version produce byte-identical analyses. + + The dedup path is keyed on (name, version, minimum_go_version, build_opts, + go_sum_entries). go.sum entries are content-addressable, so two well-formed + go.mods sharing a dep produce equal ModuleDownloadRequests -- the engine + memoizes and both requests share a single download. The observable property + is that the per-module analysis (including the sources digest) is identical. + """ + go_mod_a = dedent( + """\ + module example.com/mod-a + go 1.16 + require github.com/google/uuid v1.3.0 + """ + ) + go_mod_b = dedent( + """\ + module example.com/mod-b + go 1.16 + require github.com/google/uuid v1.3.0 + """ + ) + + digest_a = set_up_go_mod(rule_runner, go_mod_a, UUID_GO_SUM) + digest_b = set_up_go_mod(rule_runner, go_mod_b, UUID_GO_SUM) + + result_a = rule_runner.request( + AllThirdPartyPackages, + [ + AllThirdPartyPackagesRequest( + Address("mod-a", target_name="mod"), + digest_a, + "go.mod", + build_opts=GoBuildOptions(), + ) + ], + ) + result_b = rule_runner.request( + AllThirdPartyPackages, + [ + AllThirdPartyPackagesRequest( + Address("mod-b", target_name="mod"), + digest_b, + "go.mod", + build_opts=GoBuildOptions(), + ) + ], + ) + + uuid_a = result_a.import_paths_to_pkg_info["github.com/google/uuid"] + uuid_b = result_b.import_paths_to_pkg_info["github.com/google/uuid"] + + assert uuid_a.import_path == uuid_b.import_path + assert uuid_a.dir_path == uuid_b.dir_path + assert uuid_a.go_files == uuid_b.go_files + # digest equality is the strongest signal: same bytes, which implies the + # dedup path produced identical AnalyzedThirdPartyModule results from the + # same engine-memoized ModuleDownloadRequest. + assert uuid_a.digest == uuid_b.digest + + +def test_extract_go_sum_entries_for_module() -> None: + go_sum = ( + b"github.com/google/uuid v1.3.0 h1:AAA=\n" + b"github.com/google/uuid v1.3.0/go.mod h1:BBB=\n" + b"github.com/google/uuid v1.2.0 h1:CCC=\n" + b"github.com/google/uuid v1.2.0/go.mod h1:DDD=\n" + b"rsc.io/quote v1.5.2 h1:EEE=\n" + b"rsc.io/quote v1.5.2/go.mod h1:FFF=\n" + ) + + assert _extract_go_sum_entries_for_module(go_sum, "github.com/google/uuid", "v1.3.0") == ( + "github.com/google/uuid v1.3.0 h1:AAA=", + "github.com/google/uuid v1.3.0/go.mod h1:BBB=", + ) + assert _extract_go_sum_entries_for_module(go_sum, "rsc.io/quote", "v1.5.2") == ( + "rsc.io/quote v1.5.2 h1:EEE=", + "rsc.io/quote v1.5.2/go.mod h1:FFF=", + ) + # Non-matching module returns empty. + assert _extract_go_sum_entries_for_module(go_sum, "example.com/missing", "v1.0.0") == () + # Prefix safety: querying "v1.3" must not match "v1.3.0". + assert _extract_go_sum_entries_for_module(go_sum, "github.com/google/uuid", "v1.3") == () + + +def test_parse_go_sum() -> None: + """Verify _parse_go_sum groups entries by (name, version) correctly.""" + go_sum = ( + b"github.com/google/uuid v1.3.0 h1:AAA=\n" + b"github.com/google/uuid v1.3.0/go.mod h1:BBB=\n" + b"github.com/google/uuid v1.2.0 h1:CCC=\n" + b"rsc.io/quote v1.5.2 h1:EEE=\n" + b"rsc.io/quote v1.5.2/go.mod h1:FFF=\n" + b"\n" # blank line should be skipped + ) + parsed = _parse_go_sum(go_sum) + + assert parsed[("github.com/google/uuid", "v1.3.0")] == ( + "github.com/google/uuid v1.3.0 h1:AAA=", + "github.com/google/uuid v1.3.0/go.mod h1:BBB=", + ) + # Module with only a content hash (no /go.mod line). + assert parsed[("github.com/google/uuid", "v1.2.0")] == ( + "github.com/google/uuid v1.2.0 h1:CCC=", + ) + assert parsed[("rsc.io/quote", "v1.5.2")] == ( + "rsc.io/quote v1.5.2 h1:EEE=", + "rsc.io/quote v1.5.2/go.mod h1:FFF=", + ) + # Non-existent module returns None via dict.get. + assert parsed.get(("example.com/missing", "v1.0.0")) is None + # Prefix safety: "v1.3" is a different key from "v1.3.0". + assert parsed.get(("github.com/google/uuid", "v1.3")) is None + # Empty go.sum produces empty dict. + assert _parse_go_sum(b"") == {} + # CRLF line endings are handled correctly. + crlf_sum = b"mod v1.0.0 h1:X=\r\nmod v1.0.0/go.mod h1:Y=\r\n" + crlf_parsed = _parse_go_sum(crlf_sum) + assert ("mod", "v1.0.0") in crlf_parsed + assert len(crlf_parsed[("mod", "v1.0.0")]) == 2