Skip to content

Commit 1b2487b

Browse files
authored
Merge pull request #29019 from ProvableHQ/mohammadfawaz/28961
A few fixes to options in type checking and option lowering
2 parents cd6e63a + 54578fe commit 1b2487b

14 files changed

Lines changed: 979 additions & 76 deletions

File tree

compiler/passes/src/destructuring/ast.rs

Lines changed: 116 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,64 @@ impl AstReconstructor for DestructuringVisitor<'_> {
2424
type AdditionalInput = ();
2525
type AdditionalOutput = Vec<Statement>;
2626

27+
/// Reconstructs a binary expression, expanding equality and inequality over
28+
/// tuples into elementwise comparisons. When both sides are tuples and the
29+
/// operator is `==` or `!=`, it generates per-element comparisons and folds
30+
/// them with AND/OR; otherwise the expression is rebuilt normally.
31+
///
32+
/// Example: `(a, b) == (c, d)` → `(a == c) && (b == d)`
33+
/// Example: `(a, b, c) != (x, y, z)` → `(a != x) || (b != y) || (c != z)`
34+
fn reconstruct_binary(
35+
&mut self,
36+
input: BinaryExpression,
37+
_additional: &Self::AdditionalInput,
38+
) -> (Expression, Self::AdditionalOutput) {
39+
let (left, mut statements) = self.reconstruct_expression_tuple(input.left);
40+
let (right, statements2) = self.reconstruct_expression_tuple(input.right);
41+
statements.extend(statements2);
42+
43+
use BinaryOperation::*;
44+
45+
// Tuple equality / inequality expansion
46+
if let (Expression::Tuple(tuple_left), Expression::Tuple(tuple_right)) = (&left, &right)
47+
&& matches!(input.op, Eq | Neq)
48+
{
49+
assert_eq!(tuple_left.elements.len(), tuple_right.elements.len());
50+
51+
// Directly build elementwise (l OP r)
52+
let pieces: Vec<Expression> = tuple_left
53+
.elements
54+
.iter()
55+
.zip(&tuple_right.elements)
56+
.map(|(l, r)| {
57+
let expr: Expression = BinaryExpression {
58+
op: input.op,
59+
left: l.clone(),
60+
right: r.clone(),
61+
span: Default::default(),
62+
id: self.state.node_builder.next_id(),
63+
}
64+
.into();
65+
66+
self.state.type_table.insert(expr.id(), Type::Boolean);
67+
expr
68+
})
69+
.collect();
70+
71+
// Fold appropriately
72+
let op = match input.op {
73+
Eq => BinaryOperation::And,
74+
Neq => BinaryOperation::Or,
75+
_ => unreachable!(),
76+
};
77+
78+
return (self.fold_with_op(op, pieces.into_iter()), statements);
79+
}
80+
81+
// Fallback
82+
(BinaryExpression { op: input.op, left, right, ..input }.into(), Default::default())
83+
}
84+
2785
/// Replaces a tuple access expression with the appropriate expression.
2886
fn reconstruct_tuple_access(
2987
&mut self,
@@ -66,9 +124,12 @@ impl AstReconstructor for DestructuringVisitor<'_> {
66124
mut input: TernaryExpression,
67125
_additional: &(),
68126
) -> (Expression, Self::AdditionalOutput) {
69-
let (if_true, mut statements) = self.reconstruct_expression_tuple(std::mem::take(&mut input.if_true));
70-
let (if_false, statements2) = self.reconstruct_expression_tuple(std::mem::take(&mut input.if_false));
127+
let (condition, mut statements) =
128+
self.reconstruct_expression(std::mem::take(&mut input.condition), &Default::default());
129+
let (if_true, statements2) = self.reconstruct_expression_tuple(std::mem::take(&mut input.if_true));
71130
statements.extend(statements2);
131+
let (if_false, statements3) = self.reconstruct_expression_tuple(std::mem::take(&mut input.if_false));
132+
statements.extend(statements3);
72133

73134
match (if_true, if_false) {
74135
(Expression::Tuple(tuple_true), Expression::Tuple(tuple_false)) => {
@@ -77,20 +138,17 @@ impl AstReconstructor for DestructuringVisitor<'_> {
77138
panic!("Should have tuple type");
78139
};
79140

80-
// We'll be reusing `input.condition`, so assign it to a variable.
81-
let cond = if let Expression::Path(..) = input.condition {
82-
input.condition
141+
// We'll be reusing `condition`, so assign it to a variable.
142+
let cond = if let Expression::Path(..) = condition {
143+
condition
83144
} else {
84145
let place = Identifier::new(
85146
self.state.assigner.unique_symbol("cond", "$$"),
86147
self.state.node_builder.next_id(),
87148
);
88149

89-
let definition = self.state.assigner.simple_definition(
90-
place,
91-
input.condition,
92-
self.state.node_builder.next_id(),
93-
);
150+
let definition =
151+
self.state.assigner.simple_definition(place, condition, self.state.node_builder.next_id());
94152

95153
statements.push(definition);
96154

@@ -144,12 +202,59 @@ impl AstReconstructor for DestructuringVisitor<'_> {
144202
}
145203
(if_true, if_false) => {
146204
// This isn't a tuple. Just rebuild it and otherwise leave it alone.
147-
(TernaryExpression { if_true, if_false, ..input }.into(), statements)
205+
(TernaryExpression { condition, if_true, if_false, ..input }.into(), statements)
148206
}
149207
}
150208
}
151209

152210
/* Statements */
211+
/// `assert_eq` and `assert_neq` comparing tuples should be expanded to as many asserts as
212+
/// the length of each tuple.
213+
fn reconstruct_assert(&mut self, input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
214+
match input.variant {
215+
AssertVariant::Assert(expr) => {
216+
// Simple assert, just reconstruct the expression.
217+
let (expr, _) = self.reconstruct_expression(expr, &Default::default());
218+
(AssertStatement { variant: AssertVariant::Assert(expr), ..input }.into(), Default::default())
219+
}
220+
AssertVariant::AssertEq(ref left, ref right) | AssertVariant::AssertNeq(ref left, ref right) => {
221+
let (left, mut statements) = self.reconstruct_expression_tuple(left.clone());
222+
let (right, statements2) = self.reconstruct_expression_tuple(right.clone());
223+
statements.extend(statements2);
224+
225+
match (&left, &right) {
226+
(Expression::Tuple(tuple_left), Expression::Tuple(tuple_right)) => {
227+
// Ensure the tuple lengths match
228+
assert_eq!(tuple_left.elements.len(), tuple_right.elements.len());
229+
230+
for (l, r) in tuple_left.elements.iter().zip(&tuple_right.elements) {
231+
let assert_variant = match input.variant {
232+
AssertVariant::AssertEq(_, _) => AssertVariant::AssertEq(l.clone(), r.clone()),
233+
AssertVariant::AssertNeq(_, _) => AssertVariant::AssertNeq(l.clone(), r.clone()),
234+
_ => unreachable!(),
235+
};
236+
237+
let stmt = AssertStatement { variant: assert_variant, ..input.clone() }.into();
238+
statements.push(stmt);
239+
}
240+
241+
// We don't need the original statement, just the ones we've created.
242+
(Statement::dummy(), statements)
243+
}
244+
_ => {
245+
// Not tuples, just keep the original assert
246+
let variant = match input.variant {
247+
AssertVariant::AssertEq(_, _) => AssertVariant::AssertEq(left, right),
248+
AssertVariant::AssertNeq(_, _) => AssertVariant::AssertNeq(left, right),
249+
_ => unreachable!(),
250+
};
251+
(AssertStatement { variant, ..input }.into(), Default::default())
252+
}
253+
}
254+
}
255+
}
256+
}
257+
153258
/// Modify assignments to tuples to become assignments to the corresponding variables.
154259
///
155260
/// There are two cases we handle:

compiler/passes/src/destructuring/visitor.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ use crate::CompilerState;
1818

1919
use leo_ast::{
2020
AstReconstructor,
21+
BinaryExpression,
22+
BinaryOperation,
2123
DefinitionPlace,
2224
DefinitionStatement,
2325
Expression,
@@ -104,6 +106,34 @@ impl DestructuringVisitor<'_> {
104106
}
105107
}
106108

109+
/// Folds an iterator of expressions into a left-associated chain using `op`.
110+
///
111+
/// Given expressions `[e1, e2, e3]`, this produces `((e1 op e2) op e3)`.
112+
/// Each intermediate node is assigned a fresh ID and recorded as `Boolean`
113+
/// in the type table.
114+
///
115+
/// Panics if the iterator is empty.
116+
pub fn fold_with_op<I>(&mut self, op: BinaryOperation, pieces: I) -> Expression
117+
where
118+
I: Iterator<Item = Expression>,
119+
{
120+
pieces
121+
.reduce(|left, right| {
122+
let expr: Expression = BinaryExpression {
123+
op,
124+
left,
125+
right,
126+
span: Default::default(),
127+
id: self.state.node_builder.next_id(),
128+
}
129+
.into();
130+
131+
self.state.type_table.insert(expr.id(), Type::Boolean);
132+
expr
133+
})
134+
.expect("fold_with_op called with empty iterator")
135+
}
136+
107137
// Given the `expression` of tuple type, make a definition assigning variable to its members.
108138
//
109139
// That is, `let (mem1, mem2, mem3...) = expression;`

compiler/passes/src/option_lowering/ast.rs

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,19 @@ impl leo_ast::AstReconstructor for OptionLoweringVisitor<'_> {
161161
span: Span::default(),
162162
id: self.state.node_builder.next_id(),
163163
};
164+
let Some(Type::Optional(OptionalType { inner })) = self.state.type_table.get(&optional_expr.id())
165+
else {
166+
panic!("guaranteed by type checking");
167+
};
168+
self.state.type_table.insert(val_access.id(), *inner);
164169

165170
let is_some_access = MemberAccess {
166171
inner: reconstructed_optional_expr.clone(),
167172
name: Identifier::new(Symbol::intern("is_some"), self.state.node_builder.next_id()),
168173
span: Span::default(),
169174
id: self.state.node_builder.next_id(),
170175
};
176+
self.state.type_table.insert(is_some_access.id(), Type::Boolean);
171177

172178
// Create assertion: ensure `is_some` is `true`.
173179
let assert_stmt = AssertStatement {
@@ -206,13 +212,15 @@ impl leo_ast::AstReconstructor for OptionLoweringVisitor<'_> {
206212
span: Span::default(),
207213
id: self.state.node_builder.next_id(),
208214
};
215+
self.state.type_table.insert(val_access.id(), *expected_inner_type.clone());
209216

210217
let is_some_access = MemberAccess {
211218
inner: reconstructed_optional_expr,
212219
name: Identifier::new(Symbol::intern("is_some"), self.state.node_builder.next_id()),
213220
span: Span::default(),
214221
id: self.state.node_builder.next_id(),
215222
};
223+
self.state.type_table.insert(is_some_access.id(), Type::Boolean);
216224

217225
// s.is_some ? s.val : fallback
218226
let ternary_expr = TernaryExpression {
@@ -222,6 +230,7 @@ impl leo_ast::AstReconstructor for OptionLoweringVisitor<'_> {
222230
span: Span::default(),
223231
id: self.state.node_builder.next_id(),
224232
};
233+
self.state.type_table.insert(ternary_expr.id(), *expected_inner_type);
225234

226235
stmts1.extend(stmts2);
227236
(ternary_expr.into(), stmts1)
@@ -504,25 +513,28 @@ impl leo_ast::AstReconstructor for OptionLoweringVisitor<'_> {
504513
mut input: TupleExpression,
505514
additional: &Option<Type>,
506515
) -> (Expression, Self::AdditionalOutput) {
507-
let mut all_stmts = Vec::new();
508-
let mut new_elements = Vec::with_capacity(input.elements.len());
509-
510-
// Extract tuple element types if additional type info is Some(Type::Tuple).
516+
// Determine the expected tuple element types
511517
let expected_types = additional
512-
.as_ref()
513-
.and_then(|ty| {
514-
let mut ty = ty.clone();
515-
516-
// Unwrap Optional if any.
518+
.clone()
519+
.or_else(|| self.state.type_table.get(&input.id))
520+
.and_then(|mut ty| {
521+
// Unwrap Optional if any
517522
if let Type::Optional(inner) = ty {
518523
ty = *inner.inner;
519524
}
520-
// Expect Tuple type.
521-
if let Type::Tuple(tuple_ty) = ty { Some(tuple_ty.elements.clone()) } else { None }
525+
526+
// Expect Tuple type
527+
match ty {
528+
Type::Tuple(tuple_ty) => Some(tuple_ty.elements.clone()),
529+
_ => None,
530+
}
522531
})
523532
.expect("guaranteed by type checking");
524533

525-
// Zip elements with expected types and reconstruct with expected type.
534+
let mut all_stmts = Vec::new();
535+
let mut new_elements = Vec::with_capacity(input.elements.len());
536+
537+
// Zip elements with expected types and reconstruct with expected type
526538
for (element, expected_ty) in input.elements.into_iter().zip(expected_types) {
527539
let (expr, mut stmts) = self.reconstruct_expression(element, &Some(expected_ty));
528540
all_stmts.append(&mut stmts);

compiler/passes/src/test_passes.rs

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ This structure ensures that **adding a new compiler pass test is minimal**:
8181
*/
8282

8383
use crate::*;
84-
use leo_ast::{Ast, NetworkName, NodeBuilder};
84+
use leo_ast::NetworkName;
8585
use leo_errors::{BufferEmitter, Handler};
8686
use leo_parser::parse_ast;
8787
use leo_span::{create_session_if_not_set_then, source_map::FileName, with_session_globals};
@@ -110,21 +110,11 @@ macro_rules! compiler_passes {
110110
};
111111
}
112112

113-
/// Parse a Leo source program into an AST, returning errors via the handler.
114-
fn parse_program(source: &str, handler: &Handler) -> Result<Ast, ()> {
115-
let node_builder = NodeBuilder::default();
116-
let filename = FileName::Custom("test".into());
117-
118-
// Add the source to the session's source map
119-
let source_file = with_session_globals(|s| s.source_map.new_source(source, filename));
120-
121-
handler.extend_if_error(parse_ast(handler.clone(), &node_builder, &source_file, &[], NetworkName::TestnetV0))
122-
}
123-
124113
/// Macro to generate a single runner function for a compiler pass.
125114
///
126115
/// Each runner:
127116
/// - Sets up a BufferEmitter and Handler for error/warning reporting.
117+
/// - Parse the test into an AST.
128118
/// - Runs the first three fixed passes: PathResolution, SymbolTableCreation, TypeChecking.
129119
/// - Runs the specified compiler pass.
130120
/// - Returns the resulting AST or formatted errors/warnings.
@@ -135,9 +125,16 @@ macro_rules! make_runner {
135125
let handler = Handler::new(buf.clone());
136126

137127
create_session_if_not_set_then(|_| {
138-
// Parse program into AST
139-
let mut state = match parse_program(source, &handler) {
140-
Ok(ast) => CompilerState { ast, handler: handler.clone(), ..Default::default() },
128+
let mut state = CompilerState { handler: handler.clone(), ..Default::default() };
129+
130+
state.ast = match handler.extend_if_error(parse_ast(
131+
handler.clone(),
132+
&state.node_builder,
133+
&with_session_globals(|s| s.source_map.new_source(source, FileName::Custom("test".into()))),
134+
&[],
135+
NetworkName::TestnetV0,
136+
)) {
137+
Ok(ast) => ast,
141138
Err(()) => return format!("{}{}", buf.extract_errs(), buf.extract_warnings()),
142139
};
143140

0 commit comments

Comments
 (0)