Skip to content

Commit 49e3775

Browse files
committed
Add AST copier review coverage tests
1 parent ad21264 commit 49e3775

2 files changed

Lines changed: 117 additions & 0 deletions

File tree

tests/language-feature/generics/variadic-pack-count-constraint.slang

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,15 @@ float packCountDifferentiableLiteralBwd<each TIndex>(float x)
115115
return x * float(countof(TIndex));
116116
}
117117

118+
struct NestedPackCount<let each N : int>
119+
{
120+
[BackwardDifferentiable]
121+
static float scale<each TIndex>(float x) where countof(TIndex) == countof(N)
122+
{
123+
return x * float(countof(TIndex));
124+
}
125+
}
126+
118127
void main()
119128
{
120129
printf("%d\n", load<3>(1, 2, 3));
@@ -161,4 +170,9 @@ void main()
161170
bwd_diff(packCountDifferentiableLiteralBwd<int, int>)(literalDpx, 1.0);
162171
printf("%f\n", literalDpx.d);
163172
// CHECK: 2.000000
173+
174+
var nestedDpx = diffPair(5.0, 0.0);
175+
bwd_diff(NestedPackCount<1, 2, 3>::scale<int, float, bool>)(nestedDpx, 1.0);
176+
printf("%f\n", nestedDpx.d);
177+
// CHECK: 3.000000
164178
}

tests/language-feature/interfaces/default-itensor-load-backward-differentiable.slang

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,96 @@ struct CountedImpl : ICounted<3>
5858
}
5959
}
6060

61+
interface IHasIntValue
62+
{
63+
int getValue();
64+
}
65+
66+
struct CoerceFrom
67+
{
68+
int value;
69+
70+
__init(int value)
71+
{
72+
this.value = value;
73+
}
74+
}
75+
76+
struct CoerceTo : IHasIntValue
77+
{
78+
int value;
79+
80+
__implicit_conversion(100)
81+
__init(CoerceFrom from)
82+
{
83+
value = from.value + 5;
84+
}
85+
86+
int getValue()
87+
{
88+
return value;
89+
}
90+
}
91+
92+
interface IOptionalValue
93+
{
94+
int getOptionalValue();
95+
}
96+
97+
struct OptionalValue : IOptionalValue
98+
{
99+
int value;
100+
101+
__init(int value)
102+
{
103+
this.value = value;
104+
}
105+
106+
int getOptionalValue()
107+
{
108+
return value;
109+
}
110+
}
111+
112+
struct PlainValue
113+
{
114+
int value;
115+
116+
__init(int value)
117+
{
118+
this.value = value;
119+
}
120+
}
121+
122+
interface IConstraintDefaults
123+
{
124+
associatedtype Stored : IHasIntValue;
125+
126+
int coerceDefault<ToType, FromType>(ToType input)
127+
where ToType(FromType) implicit
128+
where ToType : IHasIntValue
129+
{
130+
return input.getValue();
131+
}
132+
133+
int optionalDefault<T>(T value) where optional T : IOptionalValue
134+
{
135+
if (T is IOptionalValue)
136+
return value.getOptionalValue();
137+
return -7;
138+
}
139+
140+
int assocDefault<U>(U value) where U == Stored
141+
{
142+
return value.getValue();
143+
}
144+
}
145+
146+
struct ConstraintDefaultsImpl : IConstraintDefaults
147+
{
148+
typealias Stored = CoerceTo;
149+
}
150+
61151
[BackwardDifferentiable]
62152
float useTensor<T : ITensor<1>>(no_diff T tensor, float x)
63153
{
@@ -86,4 +176,17 @@ void main()
86176
CountedImpl counted;
87177
printf("%d\n", counted.reportInner<int, float, bool>());
88178
// CHECK: 106
179+
180+
ConstraintDefaultsImpl defaults;
181+
printf("%d\n", defaults.coerceDefault<CoerceTo, CoerceFrom>(CoerceFrom(11)));
182+
// CHECK: 16
183+
184+
printf("%d\n", defaults.optionalDefault<OptionalValue>(OptionalValue(8)));
185+
// CHECK: 8
186+
187+
printf("%d\n", defaults.optionalDefault<PlainValue>(PlainValue(9)));
188+
// CHECK: -7
189+
190+
printf("%d\n", defaults.assocDefault<CoerceTo>(CoerceTo(CoerceFrom(17))));
191+
// CHECK: 22
89192
}

0 commit comments

Comments
 (0)