Skip to content

Commit 4330aa2

Browse files
committed
Clean up
1 parent 0a791d4 commit 4330aa2

File tree

2 files changed

+131
-59
lines changed

2 files changed

+131
-59
lines changed

lib/Sema/TypeCheckType.cpp

Lines changed: 82 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6976,36 +6976,52 @@ void TypeChecker::checkExistentialTypes(
69766976
checker.checkRequirements(genericParams->getRequirements());
69776977
}
69786978

6979-
// copied from elsewhere, to unify later
6980-
Type inferResultBuilderComponentType(NominalTypeDecl *builder) {
6981-
Type componentType;
6982-
6983-
SmallVector<ValueDecl *, 4> potentialMatches;
6979+
/// Retrieves the valid result types of the result builder
6980+
/// (the return types of `buildFinalResult` / `buildBlock` / `buildPartialResult`).
6981+
llvm::SmallVector<Type, 4> retrieveResultBuilderResultTypes(NominalTypeDecl *builder) {
69846982
ASTContext &ctx = builder->getASTContext();
6985-
bool supportsBuildBlock = TypeChecker::typeSupportsBuilderOp(
6986-
builder->getDeclaredInterfaceType(), builder, ctx.Id_buildBlock,
6987-
/*argLabels=*/{}, &potentialMatches);
6988-
if (supportsBuildBlock) {
6983+
llvm::SmallVector<Type, 4> resultTypes;
6984+
6985+
Identifier methodIds[] = {
6986+
ctx.Id_buildFinalResult,
6987+
ctx.Id_buildBlock,
6988+
ctx.Id_buildPartialBlock
6989+
};
6990+
6991+
for (auto methodId : methodIds) {
6992+
SmallVector<ValueDecl *, 4> potentialMatches;
6993+
bool supportsMethod = TypeChecker::typeSupportsBuilderOp(
6994+
builder->getDeclaredInterfaceType(), builder, methodId,
6995+
/*argLabels=*/{}, &potentialMatches);
6996+
6997+
if (!supportsMethod)
6998+
continue;
6999+
69897000
for (auto decl : potentialMatches) {
69907001
auto func = dyn_cast<FuncDecl>(decl);
69917002
if (!func || !func->isStatic())
69927003
continue;
69937004

6994-
// If we haven't seen a component type before, gather it.
6995-
if (!componentType) {
6996-
componentType = func->getResultInterfaceType();
7005+
auto resultType = func->getResultInterfaceType();
7006+
if (!resultType || resultType->hasError())
69977007
continue;
7008+
7009+
// Add the result type if we haven't seen it before
7010+
bool isDuplicate = false;
7011+
for (auto existingType : resultTypes) {
7012+
if (existingType->isEqual(resultType)) {
7013+
isDuplicate = true;
7014+
break;
7015+
}
69987016
}
69997017

7000-
// If there are inconsistent component types, bail out.
7001-
if (!componentType->isEqual(func->getResultInterfaceType())) {
7002-
componentType = Type();
7003-
break;
7018+
if (!isDuplicate) {
7019+
resultTypes.push_back(resultType);
70047020
}
70057021
}
70067022
}
7007-
7008-
return componentType;
7023+
7024+
return resultTypes;
70097025
}
70107026

70117027
Type invalidResultBuilderType(UnboundGenericType* unboundTy,
@@ -7018,10 +7034,13 @@ Type invalidResultBuilderType(UnboundGenericType* unboundTy,
70187034
return ErrorType::get(dc->getASTContext());
70197035
}
70207036

7037+
/// Opens a result builder `UnboundGenericType` into a `BoundGenericType`
7038+
/// by creating and solving a constraint between the result builder's
7039+
/// result type and the return type of the attached declaration.
70217040
Type openUnboundResultBuilderType(UnboundGenericType* unboundTy,
70227041
CustomAttr *attr, DeclContext *dc) {
7023-
auto builder = dyn_cast_or_null<NominalTypeDecl>(unboundTy->getDecl());
7024-
if (!builder) {
7042+
auto resultBuilderDecl = dyn_cast_or_null<NominalTypeDecl>(unboundTy->getDecl());
7043+
if (!resultBuilderDecl) {
70257044
return invalidResultBuilderType(unboundTy, attr, dc);
70267045
}
70277046

@@ -7043,57 +7062,61 @@ Type openUnboundResultBuilderType(UnboundGenericType* unboundTy,
70437062
if (!owningDeclResultType) {
70447063
return invalidResultBuilderType(unboundTy, attr, dc);
70457064
}
7046-
7047-
// Retrieve the result type of the result builder itself
7048-
auto componentType = inferResultBuilderComponentType(builder);
7049-
if (!componentType) {
7065+
7066+
// Retrieve the supported result types of the result builder.
7067+
auto resultTypes = retrieveResultBuilderResultTypes(resultBuilderDecl);
7068+
if (resultTypes.empty()) {
70507069
return invalidResultBuilderType(unboundTy, attr, dc);
70517070
}
7052-
7053-
using namespace constraints;
7054-
ConstraintSystem cs(dc, std::nullopt);
70557071

7056-
// Create a type variable for each of the result builder's generic params
7057-
auto genericSig = builder->getGenericSignature();
7072+
auto genericSig = resultBuilderDecl->getGenericSignature();
70587073

7059-
llvm::SmallVector<Type, 8> typeVarReplacements;
7060-
llvm::SmallVector<TypeVariableType*, 8> typeVars;
7061-
for (unsigned i = 0; i < genericSig.getGenericParams().size(); ++i) {
7062-
auto locator = cs.getConstraintLocator(builder);
7063-
auto typeVar = cs.createTypeVariable(locator, TVO_CanBindToHole);
7064-
typeVarReplacements.push_back(typeVar);
7065-
typeVars.push_back(typeVar);
7066-
}
7074+
for (auto componentType : resultTypes) {
7075+
using namespace constraints;
7076+
ConstraintSystem cs(dc, std::nullopt);
70677077

7068-
// Retrieve the generic types of the ResultBuilder component and return type
7069-
// of the attached function, replacing any reference to the ResultBuilder's
7070-
// generic arguments with the corresponding type vars.
7071-
auto subMap = SubstitutionMap::get(
7072-
genericSig,
7073-
typeVarReplacements,
7074-
LookUpConformanceInModule());
7078+
// Create a type variable for each of the result builder's generic params
7079+
llvm::SmallVector<Type, 8> typeVarReplacements;
7080+
llvm::SmallVector<TypeVariableType*, 8> typeVars;
7081+
for (unsigned i = 0; i < genericSig.getGenericParams().size(); ++i) {
7082+
auto locator = cs.getConstraintLocator(
7083+
resultBuilderDecl,
7084+
{ConstraintLocator::GenericArgument, i});
7085+
auto typeVar = cs.createTypeVariable(locator, TVO_CanBindToHole);
7086+
typeVarReplacements.push_back(typeVar);
7087+
typeVars.push_back(typeVar);
7088+
}
70757089

7076-
auto componentTypeWithTypeVars = componentType.subst(subMap);
7090+
// Replace any references to the result builder's generic params
7091+
// in the result type with the corresponding type variables.
7092+
auto subMap = SubstitutionMap::get(
7093+
genericSig,
7094+
typeVarReplacements,
7095+
LookUpConformanceInModule());
70777096

7078-
// The result builder result type should be equal to the return type of the attached declaration
7079-
cs.addConstraint(ConstraintKind::Equal,
7080-
owningDeclResultType,
7081-
componentTypeWithTypeVars,
7082-
/*preparedOverload:*/ nullptr);
7097+
auto componentTypeWithTypeVars = componentType.subst(subMap);
70837098

7084-
auto solution = cs.solveSingle();
7099+
// The result builder result type should be equal to the return type of the attached declaration.
7100+
cs.addConstraint(ConstraintKind::Equal,
7101+
owningDeclResultType,
7102+
componentTypeWithTypeVars,
7103+
/*preparedOverload:*/ nullptr);
70857104

7086-
if (!solution) {
7087-
return invalidResultBuilderType(unboundTy, attr, dc);
7088-
}
7105+
auto solution = cs.solveSingle();
70897106

7090-
// Bind the result builder type to the solved type parameters
7091-
llvm::SmallVector<Type, 8> solvedReplacements;
7092-
for (auto typeVar : typeVars) {
7093-
solvedReplacements.push_back(solution->typeBindings[typeVar]);
7107+
// If a solution exists, bind the result builder's generic params to the solved types.
7108+
if (solution) {
7109+
llvm::SmallVector<Type, 8> solvedReplacements;
7110+
for (auto typeVar : typeVars) {
7111+
solvedReplacements.push_back(solution->typeBindings[typeVar]);
7112+
}
7113+
7114+
return BoundGenericType::get(resultBuilderDecl, unboundTy->getParent(), solvedReplacements);
7115+
}
70947116
}
70957117

7096-
return BoundGenericType::get(builder, unboundTy->getParent(), solvedReplacements);
7118+
// No result type produced a valid solution
7119+
return invalidResultBuilderType(unboundTy, attr, dc);
70977120
}
70987121

70997122
Type CustomAttrTypeRequest::evaluate(Evaluator &eval, CustomAttr *attr,

test/Constraints/result_builder_diags.swift

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,5 +1157,54 @@ func testInferResultBuilderGenerics() {
11571157
("foo", 1) // expected-warning {{expression of type '(String, Int)' is unused}}
11581158
("bar", 2) // expected-warning {{expression of type '(String, Int)' is unused}}
11591159
}
1160+
1161+
@resultBuilder
1162+
enum ComplexListBuilder<Element> {
1163+
static func buildBlock(_ elements: Element...) -> [Element] {
1164+
elements
1165+
}
1166+
1167+
static func buildFinalResult(_ component: [Element]) -> Set<Element> where Element: Hashable {
1168+
Set(component)
1169+
}
1170+
1171+
static func buildFinalResult(_ component: [Element]) -> ContiguousArray<Element> {
1172+
ContiguousArray(component)
1173+
}
1174+
}
1175+
1176+
@ComplexListBuilder
1177+
var stringSetFromBuildFinalResult: Set<String> {
1178+
"foo"
1179+
"bar"
1180+
}
1181+
1182+
@ComplexListBuilder
1183+
var contiguousStringsFromBuildFinalResult: ContiguousArray<String> {
1184+
"foo"
1185+
"bar"
1186+
}
1187+
1188+
@ComplexListBuilder
1189+
var stringArrayFromBuildFinalResult: [String] { // expected-error {{cannot convert return expression of type 'Set<String>' to return type '[String]'}}
1190+
"foo"
1191+
"bar"
1192+
}
1193+
1194+
@resultBuilder
1195+
enum PartialArrayBuilder<Element> {
1196+
static func buildPartialBlock(first: Element) -> [Element] {
1197+
[first]
1198+
}
1199+
1200+
static func buildPartialBlock(accumulated: [Element], next: Element) -> [Element] {
1201+
accumulated + [next]
1202+
}
1203+
}
11601204

1205+
@PartialArrayBuilder
1206+
var stringArrayFromBuildPartial: [String] {
1207+
"foo"
1208+
"bar"
1209+
}
11611210
}

0 commit comments

Comments
 (0)