Skip to content

Commit 559a4b0

Browse files
canonicalize specilize args to always use SpecializationArg::Struct/Enum (#9021)
1 parent 3bbe98a commit 559a4b0

File tree

6 files changed

+1209
-1160
lines changed

6 files changed

+1209
-1160
lines changed

crates/cairo-lang-lowering/src/lower/test_data/loop

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1383,7 +1383,7 @@ Final lowering:
13831383
Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin
13841384
blk0 (root):
13851385
Statements:
1386-
(v2: core::RangeCheck, v3: core::gas::GasBuiltin, v4: core::panics::PanicResult::<(test::A, core::felt252, ())>) <- test::foo[58-134]{0, { 0: core::felt252 }, }(v0, v1)
1386+
(v2: core::RangeCheck, v3: core::gas::GasBuiltin, v4: core::panics::PanicResult::<(test::A, core::felt252, ())>) <- test::foo[58-134]{0, { 0 }, }(v0, v1)
13871387
End:
13881388
Match(match_enum(v4) {
13891389
PanicResult::Ok(v5) => blk1,

crates/cairo-lang-lowering/src/lower/test_data/specialized

Lines changed: 108 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ End:
3232
Parameters:
3333
blk0 (root):
3434
Statements:
35-
(v0: core::bool) <- bool::False({})
36-
(v1: core::felt252) <- 1
37-
(v2: ()) <- test::bar(v0, v1)
35+
(v1: ()) <- struct_construct()
36+
(v0: core::bool) <- bool::False(v1)
37+
(v2: core::felt252) <- 1
38+
(v3: ()) <- test::bar(v0, v2)
3839
End:
39-
Return(v2)
40+
Return(v3)
4041

4142
//! > semantic_diagnostics
4243

@@ -78,11 +79,12 @@ End:
7879
Parameters:
7980
blk0 (root):
8081
Statements:
81-
(v0: core::bool) <- bool::False({})
82-
(v1: core::box::Box::<core::felt252>) <- 2.into_box()
83-
(v2: ()) <- test::bar(v0, v1)
82+
(v1: ()) <- struct_construct()
83+
(v0: core::bool) <- bool::False(v1)
84+
(v2: core::box::Box::<core::felt252>) <- 2.into_box()
85+
(v3: ()) <- test::bar(v0, v2)
8486
End:
85-
Return(v2)
87+
Return(v3)
8688

8789
//! > semantic_diagnostics
8890

@@ -125,11 +127,12 @@ End:
125127
Parameters:
126128
blk0 (root):
127129
Statements:
128-
(v0: core::bool) <- bool::False({})
129-
(v1: core::array::Array::<core::felt252>) <- core::array::array_new::<core::felt252>()
130-
(v2: ()) <- test::bar(v0, v1)
130+
(v1: ()) <- struct_construct()
131+
(v0: core::bool) <- bool::False(v1)
132+
(v2: core::array::Array::<core::felt252>) <- core::array::array_new::<core::felt252>()
133+
(v3: ()) <- test::bar(v0, v2)
131134
End:
132-
Return(v2)
135+
Return(v3)
133136

134137
//! > semantic_diagnostics
135138

@@ -176,17 +179,18 @@ End:
176179
Parameters:
177180
blk0 (root):
178181
Statements:
179-
(v0: core::bool) <- bool::False({})
180-
(v6: core::array::Array::<core::felt252>) <- core::array::array_new::<core::felt252>()
181-
(v7: core::felt252) <- 1
182-
(v4: core::array::Array::<core::felt252>) <- core::array::array_append::<core::felt252>(v6, v7)
183-
(v5: core::felt252) <- 2
184-
(v2: core::array::Array::<core::felt252>) <- core::array::array_append::<core::felt252>(v4, v5)
185-
(v3: core::felt252) <- 3
186-
(v1: core::array::Array::<core::felt252>) <- core::array::array_append::<core::felt252>(v2, v3)
187-
(v8: ()) <- test::bar(v0, v1)
182+
(v1: ()) <- struct_construct()
183+
(v0: core::bool) <- bool::False(v1)
184+
(v7: core::array::Array::<core::felt252>) <- core::array::array_new::<core::felt252>()
185+
(v8: core::felt252) <- 1
186+
(v5: core::array::Array::<core::felt252>) <- core::array::array_append::<core::felt252>(v7, v8)
187+
(v6: core::felt252) <- 2
188+
(v3: core::array::Array::<core::felt252>) <- core::array::array_append::<core::felt252>(v5, v6)
189+
(v4: core::felt252) <- 3
190+
(v2: core::array::Array::<core::felt252>) <- core::array::array_append::<core::felt252>(v3, v4)
191+
(v9: ()) <- test::bar(v0, v2)
188192
End:
189-
Return(v8)
193+
Return(v9)
190194

191195
//! > semantic_diagnostics
192196

@@ -229,11 +233,12 @@ End:
229233
Parameters:
230234
blk0 (root):
231235
Statements:
232-
(v0: core::bool) <- bool::False({})
233-
(v1: core::array::Array::<core::felt252>) <- core::array::array_new::<core::felt252>()
234-
(v2: ()) <- test::bar(v0, v1)
236+
(v1: ()) <- struct_construct()
237+
(v0: core::bool) <- bool::False(v1)
238+
(v2: core::array::Array::<core::felt252>) <- core::array::array_new::<core::felt252>()
239+
(v3: ()) <- test::bar(v0, v2)
235240
End:
236-
Return(v2)
241+
Return(v3)
237242

238243
//! > semantic_diagnostics
239244

@@ -278,19 +283,20 @@ End:
278283
Return()
279284

280285
//! > specialized_lowering
281-
Parameters: v6: core::felt252
286+
Parameters: v7: core::felt252
282287
blk0 (root):
283288
Statements:
284-
(v0: core::bool) <- bool::False({})
285-
(v1: core::array::Array::<core::felt252>) <- core::array::array_new::<core::felt252>()
286-
(v4: core::array::Array::<core::felt252>) <- core::array::array_new::<core::felt252>()
287-
(v5: core::array::Array::<core::felt252>, v3: @core::array::Array::<core::felt252>) <- snapshot(v4)
288-
(v2: core::array::Span::<core::felt252>) <- struct_construct(v3)
289-
(v7: core::felt252) <- 0
290-
(v8: core::box::Box::<core::felt252>) <- 1.into_box()
291-
(v9: core::array::Array::<core::felt252>, v10: ()) <- test::bar(v0, v1, v2, v6, v7, v8)
289+
(v1: ()) <- struct_construct()
290+
(v0: core::bool) <- bool::False(v1)
291+
(v2: core::array::Array::<core::felt252>) <- core::array::array_new::<core::felt252>()
292+
(v5: core::array::Array::<core::felt252>) <- core::array::array_new::<core::felt252>()
293+
(v6: core::array::Array::<core::felt252>, v4: @core::array::Array::<core::felt252>) <- snapshot(v5)
294+
(v3: core::array::Span::<core::felt252>) <- struct_construct(v4)
295+
(v8: core::felt252) <- 0
296+
(v9: core::box::Box::<core::felt252>) <- 1.into_box()
297+
(v10: core::array::Array::<core::felt252>, v11: ()) <- test::bar(v0, v2, v3, v7, v8, v9)
292298
End:
293-
Return(v9, v10)
299+
Return(v10, v11)
294300

295301
//! > semantic_diagnostics
296302

@@ -339,17 +345,18 @@ End:
339345
Return()
340346

341347
//! > specialized_lowering
342-
Parameters: v3: core::felt252, v5: core::felt252, v6: core::felt252
348+
Parameters: v4: core::felt252, v6: core::felt252, v7: core::felt252
343349
blk0 (root):
344350
Statements:
345-
(v0: core::bool) <- bool::False({})
346-
(v2: core::felt252) <- 1
347-
(v4: core::felt252) <- 3
348-
(v7: core::felt252) <- 6
349-
(v1: test::S6) <- struct_construct(v2, v3, v4, v5, v6, v7)
350-
(v8: ()) <- test::bar(v0, v1)
351+
(v1: ()) <- struct_construct()
352+
(v0: core::bool) <- bool::False(v1)
353+
(v3: core::felt252) <- 1
354+
(v5: core::felt252) <- 3
355+
(v8: core::felt252) <- 6
356+
(v2: test::S6) <- struct_construct(v3, v4, v5, v6, v7, v8)
357+
(v9: ()) <- test::bar(v0, v2)
351358
End:
352-
Return(v8)
359+
Return(v9)
353360

354361
//! > semantic_diagnostics
355362

@@ -405,13 +412,14 @@ End:
405412
Return()
406413

407414
//! > specialized_lowering
408-
Parameters: v1: test::SA1
415+
Parameters: v2: test::SA1
409416
blk0 (root):
410417
Statements:
411-
(v0: core::bool) <- bool::False({})
412-
(v2: ()) <- test::bar(v0, v1)
418+
(v1: ()) <- struct_construct()
419+
(v0: core::bool) <- bool::False(v1)
420+
(v3: ()) <- test::bar(v0, v2)
413421
End:
414-
Return(v2)
422+
Return(v3)
415423

416424
//! > semantic_diagnostics
417425

@@ -465,13 +473,14 @@ End:
465473
Return()
466474

467475
//! > specialized_lowering
468-
Parameters: v1: test::SA1
476+
Parameters: v2: test::SA1
469477
blk0 (root):
470478
Statements:
471-
(v0: core::bool) <- bool::False({})
472-
(v2: ()) <- test::bar(v0, v1)
479+
(v1: ()) <- struct_construct()
480+
(v0: core::bool) <- bool::False(v1)
481+
(v3: ()) <- test::bar(v0, v2)
473482
End:
474-
Return(v2)
483+
Return(v3)
475484

476485
//! > semantic_diagnostics
477486

@@ -538,13 +547,14 @@ End:
538547
Return()
539548

540549
//! > specialized_lowering
541-
Parameters: v1: test::SB
550+
Parameters: v2: test::SB
542551
blk0 (root):
543552
Statements:
544-
(v0: core::bool) <- bool::False({})
545-
(v2: ()) <- test::bar(v0, v1)
553+
(v1: ()) <- struct_construct()
554+
(v0: core::bool) <- bool::False(v1)
555+
(v3: ()) <- test::bar(v0, v2)
546556
End:
547-
Return(v2)
557+
Return(v3)
548558

549559
//! > semantic_diagnostics
550560

@@ -589,15 +599,16 @@ End:
589599
Return()
590600

591601
//! > specialized_lowering
592-
Parameters: v2: core::felt252
602+
Parameters: v3: core::felt252
593603
blk0 (root):
594604
Statements:
595-
(v0: core::bool) <- bool::False({})
596-
(v3: core::felt252) <- 5
597-
(v1: test::S) <- struct_construct(v2, v3)
598-
(v4: ()) <- test::bar(v0, v1)
605+
(v1: ()) <- struct_construct()
606+
(v0: core::bool) <- bool::False(v1)
607+
(v4: core::felt252) <- 5
608+
(v2: test::S) <- struct_construct(v3, v4)
609+
(v5: ()) <- test::bar(v0, v2)
599610
End:
600-
Return(v4)
611+
Return(v5)
601612

602613
//! > semantic_diagnostics
603614

@@ -660,23 +671,24 @@ End:
660671
Return(v12)
661672

662673
//! > specialized_lowering
663-
Parameters: v9: core::felt252, v12: core::felt252, v11: core::felt252
674+
Parameters: v10: core::felt252, v13: core::felt252, v12: core::felt252
664675
blk0 (root):
665676
Statements:
666-
(v0: core::bool) <- bool::False({})
667-
(v4: core::array::Array::<test::Outer>) <- core::array::array_new::<test::Outer>()
668-
(v8: core::felt252) <- 1
669-
(v6: test::Inner) <- struct_construct(v8, v9)
670-
(v7: core::felt252) <- 3
671-
(v5: test::Outer) <- struct_construct(v6, v7)
672-
(v2: core::array::Array::<test::Outer>) <- core::array::array_append::<test::Outer>(v4, v5)
673-
(v13: core::felt252) <- 5
674-
(v10: test::Inner) <- struct_construct(v12, v13)
675-
(v3: test::Outer) <- struct_construct(v10, v11)
676-
(v1: core::array::Array::<test::Outer>) <- core::array::array_append::<test::Outer>(v2, v3)
677-
(v14: ()) <- test::bar(v0, v1)
678-
End:
679-
Return(v14)
677+
(v1: ()) <- struct_construct()
678+
(v0: core::bool) <- bool::False(v1)
679+
(v5: core::array::Array::<test::Outer>) <- core::array::array_new::<test::Outer>()
680+
(v9: core::felt252) <- 1
681+
(v7: test::Inner) <- struct_construct(v9, v10)
682+
(v8: core::felt252) <- 3
683+
(v6: test::Outer) <- struct_construct(v7, v8)
684+
(v3: core::array::Array::<test::Outer>) <- core::array::array_append::<test::Outer>(v5, v6)
685+
(v14: core::felt252) <- 5
686+
(v11: test::Inner) <- struct_construct(v13, v14)
687+
(v4: test::Outer) <- struct_construct(v11, v12)
688+
(v2: core::array::Array::<test::Outer>) <- core::array::array_append::<test::Outer>(v3, v4)
689+
(v15: ()) <- test::bar(v0, v2)
690+
End:
691+
Return(v15)
680692

681693
//! > semantic_diagnostics
682694

@@ -725,14 +737,15 @@ End:
725737
Return()
726738

727739
//! > specialized_lowering
728-
Parameters: v2: core::felt252
740+
Parameters: v3: core::felt252
729741
blk0 (root):
730742
Statements:
731-
(v0: core::bool) <- bool::False({})
732-
(v1: test::E) <- E::A(v2)
733-
(v3: ()) <- test::bar(v0, v1)
743+
(v1: ()) <- struct_construct()
744+
(v0: core::bool) <- bool::False(v1)
745+
(v2: test::E) <- E::A(v3)
746+
(v4: ()) <- test::bar(v0, v2)
734747
End:
735-
Return(v3)
748+
Return(v4)
736749

737750
//! > ==========================================================================
738751

@@ -783,16 +796,17 @@ End:
783796
Return()
784797

785798
//! > specialized_lowering
786-
Parameters: v4: core::felt252
799+
Parameters: v5: core::felt252
787800
blk0 (root):
788801
Statements:
789-
(v0: core::bool) <- bool::False({})
790-
(v3: core::felt252) <- 1
791-
(v2: test::S) <- struct_construct(v3, v4)
792-
(v1: test::E) <- E::A(v2)
793-
(v5: ()) <- test::bar(v0, v1)
802+
(v1: ()) <- struct_construct()
803+
(v0: core::bool) <- bool::False(v1)
804+
(v4: core::felt252) <- 1
805+
(v3: test::S) <- struct_construct(v4, v5)
806+
(v2: test::E) <- E::A(v3)
807+
(v6: ()) <- test::bar(v0, v2)
794808
End:
795-
Return(v5)
809+
Return(v6)
796810

797811
//! > ==========================================================================
798812

@@ -860,8 +874,9 @@ End:
860874
Parameters:
861875
blk0 (root):
862876
Statements:
863-
(v0: core::bool) <- bool::False({})
864-
(v1: core::felt252) <- 100000
865-
(v2: core::felt252) <- test::bar1(v0, v1)
877+
(v1: ()) <- struct_construct()
878+
(v0: core::bool) <- bool::False(v1)
879+
(v2: core::felt252) <- 100000
880+
(v3: core::felt252) <- test::bar1(v0, v2)
866881
End:
867-
Return(v2)
882+
Return(v3)

crates/cairo-lang-lowering/src/optimizations/const_folding.rs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,35 @@ use crate::{
4747
StatementStructDestructure, VarRemapping, VarUsage, Variable, VariableArena, VariableId,
4848
};
4949

50+
/// Converts a const value to a specialization arg.
51+
/// For struct and enum const values, recursively converts to SpecializationArg::Struct/Enum.
52+
fn const_to_specialization_arg<'db>(
53+
db: &'db dyn Database,
54+
value: ConstValueId<'db>,
55+
boxed: bool,
56+
) -> SpecializationArg<'db> {
57+
match value.long(db) {
58+
ConstValue::Struct(members, ty) => {
59+
// Only convert to SpecializationArg::Struct if the type is actually a concrete struct,
60+
// not a closure or fixed size array.
61+
if matches!(ty.long(db), TypeLongId::Concrete(ConcreteTypeId::Struct(_))) {
62+
let args = members
63+
.iter()
64+
.map(|member| const_to_specialization_arg(db, *member, false))
65+
.collect();
66+
SpecializationArg::Struct(args)
67+
} else {
68+
SpecializationArg::Const { value, boxed }
69+
}
70+
}
71+
ConstValue::Enum(variant, payload) => SpecializationArg::Enum {
72+
variant: *variant,
73+
payload: Box::new(const_to_specialization_arg(db, *payload, false)),
74+
},
75+
_ => SpecializationArg::Const { value, boxed },
76+
}
77+
}
78+
5079
/// Keeps track of equivalent values that variables might be replaced with.
5180
/// Note: We don't keep track of types as we assume the usage is always correct.
5281
#[derive(Debug, Clone)]
@@ -1310,7 +1339,7 @@ impl<'db, 'mt> ConstFoldingContext<'db, 'mt> {
13101339
}
13111340

13121341
match var_info {
1313-
VarInfo::Const(value) => Some(SpecializationArg::Const { value, boxed: false }),
1342+
VarInfo::Const(value) => Some(const_to_specialization_arg(self.db, value, false)),
13141343
VarInfo::Box(info) => try_extract_matches!(info.as_ref(), VarInfo::Const)
13151344
.map(|value| SpecializationArg::Const { value: *value, boxed: true }),
13161345
VarInfo::Snapshot(info) => {

0 commit comments

Comments
 (0)