@@ -157,16 +157,19 @@ Expr to_rounding_shift(const Call *c) {
157
157
return rounding_shift (cast (add->type , add->args [0 ]), b);
158
158
}
159
159
}
160
- // Also need to handle the annoying case of a reinterpret wrapping a widen_right_add
160
+
161
+ // Also need to handle the annoying case of a reinterpret cast wrapping a widen_right_add
161
162
// TODO: this pattern makes me want to change the semantics of this op.
162
- if (const Reinterpret *reinterp = a.as <Reinterpret >()) {
163
- if (reinterp-> type . bits () == reinterp-> value . type (). bits ()) {
164
- if (const Call *add = Call::as_intrinsic (reinterp ->value , {Call::widen_right_add})) {
163
+ if (const Cast *cast = a.as <Cast >()) {
164
+ if (cast-> is_reinterpret ()) {
165
+ if (const Call *add = Call::as_intrinsic (cast ->value , {Call::widen_right_add})) {
165
166
if (can_prove (lower_intrinsics (add->args [1 ] == round ))) {
166
- // We expect the first operand to be a reinterpet.
167
- const Reinterpret *reinterp_a = add->args [0 ].as <Reinterpret>();
168
- internal_assert (reinterp_a) << " Failed: " << add->args [0 ] << " \n " ;
169
- return rounding_shift (reinterp_a->value , b);
167
+ // We expect the first operand to be a reinterpet cast.
168
+ if (const Cast *cast_a = add->args [0 ].as <Cast>()) {
169
+ if (cast_a->is_reinterpret ()) {
170
+ return rounding_shift (cast_a->value , b);
171
+ }
172
+ }
170
173
}
171
174
}
172
175
}
@@ -245,9 +248,9 @@ class FindIntrinsics : public IRMutator {
245
248
if (b.type ().code () != narrow_a.type ().code ()) {
246
249
// Need to do a safe reinterpret.
247
250
Type t = b.type ().with_code (code);
248
- result = widen_right_add (reinterpret (t, b), narrow_a);
251
+ result = widen_right_add (cast (t, b), narrow_a);
249
252
internal_assert (result.type () != op->type );
250
- result = reinterpret (op->type , result);
253
+ result = cast (op->type , result);
251
254
} else {
252
255
result = widen_right_add (b, narrow_a);
253
256
}
@@ -258,9 +261,9 @@ class FindIntrinsics : public IRMutator {
258
261
if (a.type ().code () != narrow_b.type ().code ()) {
259
262
// Need to do a safe reinterpret.
260
263
Type t = a.type ().with_code (code);
261
- result = widen_right_add (reinterpret (t, a), narrow_b);
264
+ result = widen_right_add (cast (t, a), narrow_b);
262
265
internal_assert (result.type () != op->type );
263
- result = reinterpret (op->type , result);
266
+ result = cast (op->type , result);
264
267
} else {
265
268
result = widen_right_add (a, narrow_b);
266
269
}
@@ -328,9 +331,9 @@ class FindIntrinsics : public IRMutator {
328
331
if (a.type ().code () != narrow_b.type ().code ()) {
329
332
// Need to do a safe reinterpret.
330
333
Type t = a.type ().with_code (code);
331
- result = widen_right_sub (reinterpret (t, a), narrow_b);
334
+ result = widen_right_sub (cast (t, a), narrow_b);
332
335
internal_assert (result.type () != op->type );
333
- result = reinterpret (op->type , result);
336
+ result = cast (op->type , result);
334
337
} else {
335
338
result = widen_right_sub (a, narrow_b);
336
339
}
@@ -410,9 +413,9 @@ class FindIntrinsics : public IRMutator {
410
413
if (b.type ().code () != narrow_a.type ().code ()) {
411
414
// Need to do a safe reinterpret.
412
415
Type t = b.type ().with_code (code);
413
- result = widen_right_mul (reinterpret (t, b), narrow_a);
416
+ result = widen_right_mul (cast (t, b), narrow_a);
414
417
internal_assert (result.type () != op->type );
415
- result = reinterpret (op->type , result);
418
+ result = cast (op->type , result);
416
419
} else {
417
420
result = widen_right_mul (b, narrow_a);
418
421
}
@@ -423,9 +426,9 @@ class FindIntrinsics : public IRMutator {
423
426
if (a.type ().code () != narrow_b.type ().code ()) {
424
427
// Need to do a safe reinterpret.
425
428
Type t = a.type ().with_code (code);
426
- result = widen_right_mul (reinterpret (t, a), narrow_b);
429
+ result = widen_right_mul (cast (t, a), narrow_b);
427
430
internal_assert (result.type () != op->type );
428
- result = reinterpret (op->type , result);
431
+ result = cast (op->type , result);
429
432
} else {
430
433
result = widen_right_mul (a, narrow_b);
431
434
}
@@ -1261,8 +1264,8 @@ Expr lower_saturating_add(const Expr &a, const Expr &b) {
1261
1264
return select (sum < a, a.type ().max (), sum);
1262
1265
} else if (a.type ().is_int ()) {
1263
1266
Type u = a.type ().with_code (halide_type_uint);
1264
- Expr ua = reinterpret (u, a);
1265
- Expr ub = reinterpret (u, b);
1267
+ Expr ua = cast (u, a);
1268
+ Expr ub = cast (u, b);
1266
1269
Expr upper = make_const (u, (uint64_t (1 ) << (a.type ().bits () - 1 )) - 1 );
1267
1270
Expr lower = make_const (u, (uint64_t (1 ) << (a.type ().bits () - 1 )));
1268
1271
Expr sum = ua + ub;
@@ -1272,7 +1275,7 @@ Expr lower_saturating_add(const Expr &a, const Expr &b) {
1272
1275
// a + b >= 0 === a >= -b === a >= ~b + 1 === a > ~b
1273
1276
Expr pos_result = min (sum, upper);
1274
1277
Expr neg_result = max (sum, lower);
1275
- return simplify (reinterpret (a.type (), select (~b < a, pos_result, neg_result)));
1278
+ return simplify (cast (a.type (), select (~b < a, pos_result, neg_result)));
1276
1279
} else {
1277
1280
internal_error << " Bad type for saturating_add: " << a.type () << " \n " ;
1278
1281
return Expr ();
@@ -1288,8 +1291,8 @@ Expr lower_saturating_sub(const Expr &a, const Expr &b) {
1288
1291
} else if (a.type ().is_int ()) {
1289
1292
// Do the math in unsigned, to avoid overflow in the simplifier.
1290
1293
Type u = a.type ().with_code (halide_type_uint);
1291
- Expr ua = reinterpret (u, a);
1292
- Expr ub = reinterpret (u, b);
1294
+ Expr ua = cast (u, a);
1295
+ Expr ub = cast (u, b);
1293
1296
Expr upper = make_const (u, (uint64_t (1 ) << (a.type ().bits () - 1 )) - 1 );
1294
1297
Expr lower = make_const (u, (uint64_t (1 ) << (a.type ().bits () - 1 )));
1295
1298
Expr diff = ua - ub;
@@ -1300,7 +1303,7 @@ Expr lower_saturating_sub(const Expr &a, const Expr &b) {
1300
1303
// and saturate the negative difference to be at least -2^31 + 2^32 = 2^31
1301
1304
Expr neg_diff = max (lower, diff);
1302
1305
// Then select between them, and cast back to the signed type.
1303
- return simplify (reinterpret (a.type (), select (b <= a, pos_diff, neg_diff)));
1306
+ return simplify (cast (a.type (), select (b <= a, pos_diff, neg_diff)));
1304
1307
} else if (a.type ().is_uint ()) {
1305
1308
return simplify (select (b < a, a - b, make_zero (a.type ())));
1306
1309
} else {
0 commit comments