Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] partition protobuf dependency inference by any "resolve-like" fields from plugins #21918

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import re
import typing
from collections import defaultdict
from dataclasses import dataclass
from typing import DefaultDict
Expand All @@ -13,6 +14,12 @@
AllProtobufTargets,
ProtobufDependenciesField,
ProtobufSourceField,
ProtobufSourceTarget,
)
from pants.core.target_types import (
ResolveLikeField,
ResolveLikeFieldToValueRequest,
ResolveLikeFieldToValueResult,
)
from pants.core.util_rules.stripped_source_files import StrippedFileName, StrippedFileNameRequest
from pants.engine.addresses import Address
Expand All @@ -21,29 +28,44 @@
from pants.engine.target import (
DependenciesRequest,
ExplicitlyProvidedDependencies,
Field,
FieldSet,
HydratedSources,
HydrateSourcesRequest,
InferDependenciesRequest,
InferredDependencies,
Target,
WrappedTarget,
WrappedTargetRequest,
)
from pants.engine.unions import UnionRule
from pants.engine.unions import UnionMembership, UnionRule
from pants.util.frozendict import FrozenDict
from pants.util.logging import LogLevel
from pants.util.ordered_set import FrozenOrderedSet, OrderedSet
from pants.util.strutil import softwrap


@dataclass(frozen=True)
class ProtobufMappingResolveKey:
field_type: type[Field]
resolve: str


_NO_RESOLVE_LIKE_FIELDS_DEFINED = ProtobufMappingResolveKey(
field_type=ProtobufSourceField, resolve="NONE"
)


@dataclass(frozen=True)
class ProtobufMapping:
"""A mapping of stripped .proto file names to their owning file address."""
"""A mapping of stripped .proto file names to their owning file address indirectly mapped by
resolve-like fields."""

mapping: FrozenDict[str, Address]
ambiguous_modules: FrozenDict[str, tuple[Address, ...]]
mapping: FrozenDict[ProtobufMappingResolveKey, FrozenDict[str, Address]]
ambiguous_modules: FrozenDict[ProtobufMappingResolveKey, FrozenDict[str, tuple[Address, ...]]]


@rule(desc="Creating map of Protobuf file names to Protobuf targets", level=LogLevel.DEBUG)
async def map_protobuf_files(protobuf_targets: AllProtobufTargets) -> ProtobufMapping:
async def _map_single_pseudo_resolve(protobuf_targets: AllProtobufTargets) -> ProtobufMapping:
stripped_file_per_target = await MultiGet(
Get(StrippedFileName, StrippedFileNameRequest(tgt[ProtobufSourceField].file_path))
for tgt in protobuf_targets
Expand All @@ -64,9 +86,116 @@ async def map_protobuf_files(protobuf_targets: AllProtobufTargets) -> ProtobufMa
stripped_files_to_addresses.pop(ambiguous_stripped_f)

return ProtobufMapping(
mapping=FrozenDict(sorted(stripped_files_to_addresses.items())),
mapping=FrozenDict(
{
_NO_RESOLVE_LIKE_FIELDS_DEFINED: FrozenDict(
sorted(stripped_files_to_addresses.items())
)
}
),
ambiguous_modules=FrozenDict(
{
_NO_RESOLVE_LIKE_FIELDS_DEFINED: FrozenDict(
(k, tuple(sorted(v)))
for k, v in sorted(stripped_files_with_multiple_owners.items())
)
}
),
)


@rule(desc="Creating map of Protobuf file names to Protobuf targets", level=LogLevel.DEBUG)
async def map_protobuf_files(
protobuf_targets: AllProtobufTargets, union_membership: UnionMembership
) -> ProtobufMapping:
# Determine the resolve-like fields installed on the `protobuf_source` target type.
resolve_like_field_types: set[type[Field]] = set()
for field_type in ProtobufSourceTarget.class_field_types(union_membership):
if issubclass(field_type, ResolveLikeField):
resolve_like_field_types.add(field_type)
if not resolve_like_field_types:
return await _map_single_pseudo_resolve(protobuf_targets)

# Discover which resolves are present in the protobuf_source targets.
resolve_requests: list[ResolveLikeFieldToValueRequest] = []
target_and_field_type_for_resolve_requests: list[tuple[Target, type[Field]]] = []
for tgt in protobuf_targets:
saw_at_least_one_field = False
for field_type in resolve_like_field_types:
if tgt.has_field(field_type):
resolve_request_type = typing.cast(
ResolveLikeField, tgt[field_type]
).get_resolve_like_field_to_value_request()
resolve_request = resolve_request_type(target=tgt)
resolve_requests.append(resolve_request)
target_and_field_type_for_resolve_requests.append((tgt, field_type))
saw_at_least_one_field = True

if not saw_at_least_one_field:
raise ValueError(f"Did not find a resolve field on target at address `{tgt.address}`.")

# Obtain the resolves for each target and then partition.
resolve_results = await MultiGet(
Get(ResolveLikeFieldToValueResult, ResolveLikeFieldToValueRequest, resolve_request)
for resolve_request in resolve_requests
)
targets_partitioned_by_resolve: dict[ProtobufMappingResolveKey, list[Target]] = defaultdict(
list
)
for resolve_result, (target, field_type) in zip(
resolve_results, target_and_field_type_for_resolve_requests
):
resolve_key = ProtobufMappingResolveKey(field_type=field_type, resolve=resolve_result.value)
targets_partitioned_by_resolve[resolve_key].append(target)

stripped_file_per_target = await MultiGet(
Get(StrippedFileName, StrippedFileNameRequest(tgt[ProtobufSourceField].file_path))
for tgt in protobuf_targets
)
target_to_stripped_file: dict[Target, StrippedFileName] = dict(
zip(protobuf_targets, stripped_file_per_target)
)

stripped_files_to_addresses: dict[ProtobufMappingResolveKey, dict[str, Address]] = defaultdict(
dict
)
stripped_files_with_multiple_owners: dict[
ProtobufMappingResolveKey, dict[str, set[Address]]
] = defaultdict(lambda: defaultdict(set))

for resolve_key, targets_in_resolve in targets_partitioned_by_resolve.items():
for tgt in targets_in_resolve:
stripped_file = target_to_stripped_file[tgt]
if stripped_file.value in stripped_files_to_addresses[resolve_key]:
stripped_files_with_multiple_owners[resolve_key][stripped_file.value].update(
{stripped_files_to_addresses[resolve_key][stripped_file.value], tgt.address}
)
else:
stripped_files_to_addresses[resolve_key][stripped_file.value] = tgt.address

# Remove files with ambiguous owners in each resolve.
for (
resolve_key,
stripped_files_with_multiple_owners_in_resolve,
) in stripped_files_with_multiple_owners.items():
for ambiguous_stripped_f in stripped_files_with_multiple_owners_in_resolve:
stripped_files_to_addresses[resolve_key].pop(ambiguous_stripped_f)

return ProtobufMapping(
mapping=FrozenDict(
{
resolve_key: FrozenDict(sorted(stripped_files_to_addresses_in_resolve.items()))
for resolve_key, stripped_files_to_addresses_in_resolve in stripped_files_to_addresses.items()
}
),
ambiguous_modules=FrozenDict(
(k, tuple(sorted(v))) for k, v in sorted(stripped_files_with_multiple_owners.items())
{
resolve_key: FrozenDict(
(k, tuple(sorted(v)))
for k, v in sorted(stripped_files_with_multiple_owners_in_resolve.items())
)
for resolve_key, stripped_files_with_multiple_owners_in_resolve in stripped_files_with_multiple_owners.items()
}
),
)

Expand Down Expand Up @@ -96,6 +225,35 @@ class InferProtobufDependencies(InferDependenciesRequest):
infer_from = ProtobufDependencyInferenceFieldSet


async def get_resolve_key_from_target(address: Address) -> ProtobufMappingResolveKey:
wrapped_target = await Get(
WrappedTarget, WrappedTargetRequest(address=address, description_of_origin="protobuf")
)
resolve_field_type: type[Field] | None = None
for field_type in wrapped_target.target.field_types:
if issubclass(field_type, ResolveLikeField):
if resolve_field_type is not None:
raise NotImplementedError(
f"TODO: Multiple resolve-like fields on target at address `{address}`."
)
resolve_field_type = field_type
if resolve_field_type is None:
raise ValueError(f"Failed to find resolve-like field on target at address `{address}.")

resolve_request_type = typing.cast(
ResolveLikeField, wrapped_target.target[resolve_field_type]
).get_resolve_like_field_to_value_request()
resolve_request = resolve_request_type(target=wrapped_target.target)
resolve_result = await Get(
ResolveLikeFieldToValueResult, ResolveLikeFieldToValueRequest, resolve_request
)

return ProtobufMappingResolveKey(
field_type=resolve_field_type,
resolve=resolve_result.value,
)


@rule(desc="Inferring Protobuf dependencies by analyzing imports")
async def infer_protobuf_dependencies(
request: InferProtobufDependencies, protobuf_mapping: ProtobufMapping, protoc: Protoc
Expand All @@ -104,6 +262,13 @@ async def infer_protobuf_dependencies(
return InferredDependencies([])

address = request.field_set.address

resolve_key: ProtobufMappingResolveKey
if _NO_RESOLVE_LIKE_FIELDS_DEFINED in protobuf_mapping.mapping:
resolve_key = _NO_RESOLVE_LIKE_FIELDS_DEFINED
else:
resolve_key = await get_resolve_key_from_target(address)

explicitly_provided_deps, hydrated_sources = await MultiGet(
Get(ExplicitlyProvidedDependencies, DependenciesRequest(request.field_set.dependencies)),
Get(HydratedSources, HydrateSourcesRequest(request.field_set.source)),
Expand All @@ -114,8 +279,14 @@ async def infer_protobuf_dependencies(

result: OrderedSet[Address] = OrderedSet()
for import_path in parse_proto_imports(file_content.content.decode()):
unambiguous = protobuf_mapping.mapping.get(import_path)
ambiguous = protobuf_mapping.ambiguous_modules.get(import_path)
mapping_in_resolve = protobuf_mapping.mapping.get(resolve_key)
unambiguous = mapping_in_resolve.get(import_path) if mapping_in_resolve else None

ambiguous_modules_in_resolve = protobuf_mapping.ambiguous_modules.get(resolve_key)
ambiguous = (
ambiguous_modules_in_resolve.get(import_path) if ambiguous_modules_in_resolve else None
)

if unambiguous:
result.add(unambiguous)
elif ambiguous:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from pants.backend.codegen.protobuf import protobuf_dependency_inference
from pants.backend.codegen.protobuf.protobuf_dependency_inference import (
_NO_RESOLVE_LIKE_FIELDS_DEFINED,
InferProtobufDependencies,
ProtobufDependencyInferenceFieldSet,
ProtobufMapping,
Expand Down Expand Up @@ -106,15 +107,23 @@ def test_protobuf_mapping(rule_runner: RuleRunner) -> None:
assert result == ProtobufMapping(
mapping=FrozenDict(
{
"protos/f1.proto": Address("root1/protos", relative_file_path="f1.proto"),
"protos/f2.proto": Address("root1/protos", relative_file_path="f2.proto"),
_NO_RESOLVE_LIKE_FIELDS_DEFINED: FrozenDict(
{
"protos/f1.proto": Address("root1/protos", relative_file_path="f1.proto"),
"protos/f2.proto": Address("root1/protos", relative_file_path="f2.proto"),
}
)
}
),
ambiguous_modules=FrozenDict(
{
"two_owners/f.proto": (
Address("root1/two_owners", relative_file_path="f.proto"),
Address("root2/two_owners", relative_file_path="f.proto"),
_NO_RESOLVE_LIKE_FIELDS_DEFINED: FrozenDict(
{
"two_owners/f.proto": (
Address("root1/two_owners", relative_file_path="f.proto"),
Address("root2/two_owners", relative_file_path="f.proto"),
)
}
)
}
),
Expand Down
10 changes: 9 additions & 1 deletion src/python/pants/backend/python/target_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TestsBatchCompatibilityTagField,
TestSubsystem,
)
from pants.core.target_types import ResolveLikeField, ResolveLikeFieldToValueRequest
from pants.core.util_rules.environments import EnvironmentField
from pants.engine.addresses import Address, Addresses
from pants.engine.target import (
Expand Down Expand Up @@ -137,7 +138,11 @@ def value_or_global_default(self, python_setup: PythonSetup) -> Tuple[str, ...]:
return python_setup.compatibility_or_constraints(self.value)


class PythonResolveField(StringField, AsyncFieldMixin):
class PythonResolveLikeFieldToValueRequest(ResolveLikeFieldToValueRequest):
pass


class PythonResolveField(StringField, AsyncFieldMixin, ResolveLikeField):
alias = "resolve"
required = False
help = help_text(
Expand All @@ -163,6 +168,9 @@ def normalized_value(self, python_setup: PythonSetup) -> str:
)
return resolve

def get_resolve_like_field_to_value_request(self) -> type[ResolveLikeFieldToValueRequest]:
return PythonResolveLikeFieldToValueRequest


class PrefixedPythonResolveField(PythonResolveField):
alias = "python_resolve"
Expand Down
11 changes: 11 additions & 0 deletions src/python/pants/backend/python/target_types_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
PythonFilesGeneratorSettingsRequest,
PythonProvidesField,
PythonResolveField,
PythonResolveLikeFieldToValueRequest,
ResolvedPexEntryPoint,
ResolvedPythonDistributionEntryPoints,
ResolvePexEntryPointRequest,
Expand All @@ -49,6 +50,7 @@
)
from pants.backend.python.util_rules.interpreter_constraints import interpreter_constraints_contains
from pants.backend.python.util_rules.package_dists import InvalidEntryPoint
from pants.core.target_types import ResolveLikeFieldToValueRequest, ResolveLikeFieldToValueResult
from pants.core.util_rules.unowned_dependency_behavior import (
UnownedDependencyError,
UnownedDependencyUsage,
Expand Down Expand Up @@ -635,6 +637,14 @@ async def validate_python_dependencies(
return ValidatedDependencies()


@rule
async def python_resolve_field_to_string(
request: PythonResolveLikeFieldToValueRequest, python_setup: PythonSetup
) -> ResolveLikeFieldToValueResult:
resolve = request.target[PythonResolveField].normalized_value(python_setup)
return ResolveLikeFieldToValueResult(value=resolve)


def rules():
return (
*collect_rules(),
Expand All @@ -645,4 +655,5 @@ def rules():
UnionRule(InferDependenciesRequest, InferPexBinaryEntryPointDependency),
UnionRule(InferDependenciesRequest, InferPythonDistributionDependencies),
UnionRule(ValidateDependenciesRequest, PythonValidateDependenciesRequest),
UnionRule(ResolveLikeFieldToValueRequest, PythonResolveLikeFieldToValueRequest),
)
Loading
Loading