Skip to content

Commit ee5adb8

Browse files
canonical type equality constraint (#8445)
Fixes #8439 When checked, generic type equality constraints types are now in a canonical order, allowing for a commutative type equality operator. --------- Co-authored-by: Mukund Keshava <mkeshava@nvidia.com>
1 parent a6deb5e commit ee5adb8

5 files changed

Lines changed: 201 additions & 20 deletions

File tree

source/core/slang-list.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@ class List
583583
}
584584

585585
template<typename T2>
586-
Index binarySearch(const T2& obj)
586+
Index binarySearch(const T2& obj) const
587587
{
588588
return binarySearch(
589589
obj,

source/slang/slang-check-decl.cpp

Lines changed: 92 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,14 @@ struct SemanticsDeclHeaderVisitor : public SemanticsDeclVisitorBase,
365365

366366
void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl);
367367

368+
void checkGenericTypeEqualityConstraintSubType(GenericTypeConstraintDecl* decl);
369+
368370
void visitTypeCoercionConstraintDecl(TypeCoercionConstraintDecl* decl);
369371

370-
void validateGenericConstraintSubType(GenericTypeConstraintDecl* decl, TypeExp type);
372+
bool validateGenericConstraintSubType(
373+
GenericTypeConstraintDecl* decl,
374+
TypeExp type,
375+
DiagnosticSink* sink = nullptr);
371376

372377
void checkForwardReferencesInGenericConstraint(GenericTypeConstraintDecl* decl);
373378

@@ -3250,18 +3255,33 @@ bool isProperConstraineeType(Type* type)
32503255
return true;
32513256
}
32523257

3253-
void SemanticsDeclHeaderVisitor::validateGenericConstraintSubType(
3258+
bool SemanticsDeclHeaderVisitor::validateGenericConstraintSubType(
32543259
GenericTypeConstraintDecl* decl,
3255-
TypeExp type)
3260+
TypeExp type,
3261+
DiagnosticSink* sink)
32563262
{
3263+
auto diagnose = [&]()
3264+
{
3265+
if (sink)
3266+
{
3267+
if (decl->isEqualityConstraint)
3268+
{
3269+
sink->diagnose(type.exp, Diagnostics::invalidEqualityConstraintSubType, type);
3270+
}
3271+
else
3272+
{
3273+
sink->diagnose(type.exp, Diagnostics::invalidConstraintSubType, type);
3274+
}
3275+
}
3276+
};
32573277
// Validate that the sub type of a constraint is in valid form.
32583278
//
32593279
if (auto subDeclRef = isDeclRefTypeOf<Decl>(type.type))
32603280
{
32613281
if (subDeclRef.getDecl()->parentDecl == decl->parentDecl)
32623282
{
32633283
// OK, sub type is one of the generic parameter type.
3264-
return;
3284+
return true;
32653285
}
32663286
if (as<GenericDecl>(decl->parentDecl))
32673287
{
@@ -3272,8 +3292,8 @@ void SemanticsDeclHeaderVisitor::validateGenericConstraintSubType(
32723292
auto dependentGeneric = getShared()->getDependentGenericParent(subDeclRef);
32733293
if (dependentGeneric.getDecl() != decl->parentDecl)
32743294
{
3275-
getSink()->diagnose(type.exp, Diagnostics::invalidConstraintSubType, type);
3276-
return;
3295+
diagnose();
3296+
return false;
32773297
}
32783298
}
32793299
else if (as<AssocTypeDecl>(decl->parentDecl))
@@ -3291,8 +3311,8 @@ void SemanticsDeclHeaderVisitor::validateGenericConstraintSubType(
32913311
auto lookupDeclRef = as<LookupDeclRef>(subDeclRef.declRefBase);
32923312
if (!lookupDeclRef)
32933313
{
3294-
getSink()->diagnose(type.exp, Diagnostics::invalidConstraintSubType, type);
3295-
return;
3314+
diagnose();
3315+
return false;
32963316
}
32973317

32983318
// We allow `associatedtype T where This.T : ...`.
@@ -3301,24 +3321,25 @@ void SemanticsDeclHeaderVisitor::validateGenericConstraintSubType(
33013321
//
33023322
if (lookupDeclRef->getDecl()->parentDecl == decl->parentDecl ||
33033323
lookupDeclRef->getDecl() == decl->parentDecl)
3304-
return;
3324+
return true;
33053325
auto baseType = as<Type>(lookupDeclRef->getLookupSource());
33063326
if (!baseType)
33073327
{
3308-
getSink()->diagnose(type.exp, Diagnostics::invalidConstraintSubType, type);
3309-
return;
3328+
diagnose();
3329+
return false;
33103330
}
33113331
type.type = baseType;
3312-
validateGenericConstraintSubType(decl, type);
3332+
return validateGenericConstraintSubType(decl, type, sink);
33133333
}
33143334
}
33153335
if (!isProperConstraineeType(type.type))
33163336
{
33173337
// It is meaningless for certain types to be used in type constraints.
33183338
// For example, `IFoo<T>` should not appear as the left-hand-side of a generic constraint.
3319-
getSink()->diagnose(type.exp, Diagnostics::invalidConstraintSubType, type);
3320-
return;
3339+
diagnose();
3340+
return false;
33213341
}
3342+
return true;
33223343
}
33233344

33243345
// General utility function to collect all referenced declarations from a value
@@ -3443,9 +3464,9 @@ void SemanticsDeclHeaderVisitor::visitGenericTypeConstraintDecl(GenericTypeConst
34433464
// Check for forward references in generic constraints after type translation
34443465
checkForwardReferencesInGenericConstraint(decl);
34453466

3446-
validateGenericConstraintSubType(decl, decl->sub);
34473467
if (decl->isEqualityConstraint)
34483468
{
3469+
checkGenericTypeEqualityConstraintSubType(decl);
34493470
if (!isProperConstraineeType(decl->sup) && !as<ErrorType>(decl->sup.type))
34503471
{
34513472
getSink()->diagnose(
@@ -3456,6 +3477,7 @@ void SemanticsDeclHeaderVisitor::visitGenericTypeConstraintDecl(GenericTypeConst
34563477
}
34573478
else
34583479
{
3480+
validateGenericConstraintSubType(decl, decl->sub, getSink());
34593481
if (!isValidGenericConstraintType(decl->sup) && !as<ErrorType>(decl->sup.type))
34603482
{
34613483
getSink()->diagnose(
@@ -3467,6 +3489,61 @@ void SemanticsDeclHeaderVisitor::visitGenericTypeConstraintDecl(GenericTypeConst
34673489
}
34683490
}
34693491

3492+
ContainerDecl* findDeclsLowestCommonAncestor(Decl*& a, Decl*& b);
3493+
int compareDecls(Decl* lhs, Decl* rhs);
3494+
3495+
void SemanticsDeclHeaderVisitor::checkGenericTypeEqualityConstraintSubType(
3496+
GenericTypeConstraintDecl* decl)
3497+
{
3498+
auto checkAndCompare = [&]() -> int
3499+
{
3500+
bool subOk = validateGenericConstraintSubType(decl, decl->sub);
3501+
bool supOk = validateGenericConstraintSubType(decl, decl->sup);
3502+
3503+
if (subOk != supOk) // Only one is qualified
3504+
{
3505+
return int(supOk) - int(subOk);
3506+
}
3507+
else if (!(subOk || supOk))
3508+
{
3509+
getSink()->diagnose(decl, Diagnostics::noValidEqualityConstraintSubType);
3510+
// Re-run the validation to emit the diagnostic this time
3511+
validateGenericConstraintSubType(decl, decl->sub, getSink());
3512+
validateGenericConstraintSubType(decl, decl->sup, getSink());
3513+
return -1;
3514+
}
3515+
// Both sub and sup are qualified
3516+
// For example:
3517+
// __generic <A : IA, B : IB>
3518+
// where A::T == B::T
3519+
// Sort them by declaration order in the generic (A > B)
3520+
3521+
Decl* subAncestor = as<DeclRefType>(decl->sub.type)->getDeclRef().getDecl();
3522+
Decl* supAncestor = as<DeclRefType>(decl->sup.type)->getDeclRef().getDecl();
3523+
auto ancestor = findDeclsLowestCommonAncestor(subAncestor, supAncestor);
3524+
if (!ancestor)
3525+
{
3526+
return compareDecls(subAncestor, supAncestor);
3527+
}
3528+
3529+
auto subIndex = ancestor->getMembers().binarySearch(subAncestor);
3530+
auto supIndex = ancestor->getMembers().binarySearch(supAncestor);
3531+
3532+
return int(supIndex - subIndex);
3533+
};
3534+
3535+
int cmp = checkAndCompare();
3536+
if (cmp > 0)
3537+
{
3538+
Swap(decl->sub, decl->sup);
3539+
}
3540+
else if (cmp == 0 && decl->sub != decl->sup)
3541+
{
3542+
// The comparison was not fully handled for this case.
3543+
getSink()->diagnose(decl, Diagnostics::failedEqualityConstraintCanonicalOrder);
3544+
}
3545+
}
3546+
34703547
void SemanticsDeclHeaderVisitor::visitGenericTypeParamDecl(GenericTypeParamDecl* decl)
34713548
{
34723549
// TODO: could probably push checking the default value

source/slang/slang-diagnostic-defs.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1691,6 +1691,22 @@ DIAGNOSTIC(
16911691
Error,
16921692
invalidEqualityConstraintSupType,
16931693
"type '$0' is not a proper type to use in a generic equality constraint.")
1694+
DIAGNOSTIC(
1695+
30405,
1696+
Error,
1697+
noValidEqualityConstraintSubType,
1698+
"generic equality constraint requires at least one operand to be dependant on the generic "
1699+
"declaration")
1700+
DIAGNOSTIC(
1701+
30402,
1702+
Note,
1703+
invalidEqualityConstraintSubType,
1704+
"type '$0' cannot be constrained by a type equality")
1705+
DIAGNOSTIC(
1706+
30407,
1707+
Warning,
1708+
failedEqualityConstraintCanonicalOrder,
1709+
"failed to resolve canonical order of generic equality constraint.")
16941710

16951711
// 305xx: initializer lists
16961712
DIAGNOSTIC(30500, Error, tooManyInitializers, "too many initializers (expected $0, got $1)")

source/slang/slang-parser.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1683,13 +1683,14 @@ static void maybeParseGenericConstraints(Parser* parser, ContainerDecl* genericP
16831683
bool optional = AdvanceIf(parser, "optional", &whereToken);
16841684

16851685
auto subType = parser->ParseTypeExp();
1686-
if (AdvanceIf(parser, TokenType::Colon))
1686+
Token constraintToken;
1687+
if (AdvanceIf(parser, TokenType::Colon, &constraintToken))
16871688
{
16881689
for (;;)
16891690
{
16901691
auto constraint = parser->astBuilder->create<GenericTypeConstraintDecl>();
16911692
constraint->whereTokenLoc = whereToken.loc;
1692-
parser->FillPosition(constraint);
1693+
constraint->loc = constraintToken.loc;
16931694
constraint->sub = subType;
16941695
constraint->sup = parser->ParseTypeExp();
16951696
if (optional)
@@ -1703,12 +1704,12 @@ static void maybeParseGenericConstraints(Parser* parser, ContainerDecl* genericP
17031704
break;
17041705
}
17051706
}
1706-
else if (AdvanceIf(parser, TokenType::OpEql))
1707+
else if (AdvanceIf(parser, TokenType::OpEql, &constraintToken))
17071708
{
17081709
auto constraint = parser->astBuilder->create<GenericTypeConstraintDecl>();
17091710
constraint->whereTokenLoc = whereToken.loc;
17101711
constraint->isEqualityConstraint = true;
1711-
parser->FillPosition(constraint);
1712+
constraint->loc = constraintToken.loc;
17121713
constraint->sub = subType;
17131714
constraint->sup = parser->ParseTypeExp();
17141715
if (optional)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
//TEST:SIMPLE(filecheck=CHECK):
2+
//TEST:INTERPRET(filecheck=ICHECK):
3+
4+
interface MyInterface
5+
{
6+
associatedtype CompatibilityClass;
7+
8+
__generic <Other : MyInterface>
9+
This f(Other other) where Other::CompatibilityClass == This::CompatibilityClass;
10+
11+
__generic <Other : MyInterface>
12+
This g(Other other) where This::CompatibilityClass == Other::CompatibilityClass;
13+
14+
CompatibilityClass toCompat();
15+
};
16+
17+
struct MyCompatibilityClass {};
18+
19+
__generic <T>
20+
struct MyStruct : MyInterface
21+
{
22+
typealias CompatibilityClass = MyCompatibilityClass;
23+
24+
__generic <Other : MyInterface>
25+
This f(Other other) where CompatibilityClass == Other::CompatibilityClass
26+
{
27+
return this;
28+
}
29+
30+
__generic <Other : MyInterface>
31+
This g(Other other) where MyCompatibilityClass == Other::CompatibilityClass
32+
{
33+
return this;
34+
}
35+
36+
CompatibilityClass toCompat()
37+
{
38+
return MyCompatibilityClass();
39+
}
40+
};
41+
42+
struct TestInt : MyInterface
43+
{
44+
typealias CompatibilityClass = int;
45+
int value;
46+
47+
__init(int v)
48+
{
49+
value = v;
50+
}
51+
52+
__generic <Other : MyInterface>
53+
This f(Other other) where CompatibilityClass == Other::CompatibilityClass
54+
{
55+
return this;
56+
}
57+
58+
__generic <Other : MyInterface>
59+
This g(Other other) where int == Other::CompatibilityClass
60+
{
61+
return this;
62+
}
63+
64+
int toCompat()
65+
{
66+
return value;
67+
}
68+
}
69+
70+
__generic <T : MyInterface>
71+
void test(T t)
72+
where int == T::CompatibilityClass
73+
{
74+
printf("Success x %d!", t.toCompat());
75+
}
76+
77+
void main()
78+
{
79+
TestInt t = TestInt(12);
80+
test(t);
81+
}
82+
83+
// CHECK-NOT: 30402
84+
// CHECK-NOT: 30404
85+
// CHECK-NOT: 30405
86+
87+
// ICHECK: Success x 12!

0 commit comments

Comments
 (0)