Skip to content

Commit b0e1584

Browse files
authored
Merge e4b4f14 into 677c10c
2 parents 677c10c + e4b4f14 commit b0e1584

File tree

34 files changed

+467
-30
lines changed

34 files changed

+467
-30
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
use noirc_errors::Location;
2+
3+
use crate::ast::{
4+
BlockExpression, CallExpression, Expression, ExpressionKind, Ident, IfExpression, Param, Path,
5+
PathKind, PathSegment, Pattern, Statement, StatementKind, UnaryOp,
6+
};
7+
8+
use super::Elaborator;
9+
10+
impl<'context> Elaborator<'context> {
11+
/// Adds statements to `statements` to validate the given parameters.
12+
///
13+
/// For example, this function:
14+
///
15+
/// fn main(x: u8, y: u8) {
16+
/// assert_eq(x, y);
17+
/// }
18+
///
19+
/// is transformed into this one:
20+
///
21+
/// fn main(x: u8, y: u8) {
22+
/// if !std::runtime::is_unconstrained() {
23+
/// std::validation::AssertsIsValidInput::assert_is_valid_input(x);
24+
/// std::validation::AssertsIsValidInput::assert_is_valid_input(y);
25+
/// }
26+
/// assert_eq(x, y);
27+
/// }
28+
pub(super) fn add_entry_point_parameters_validation(
29+
&self,
30+
params: &[Param],
31+
statements: &mut Vec<Statement>,
32+
) {
33+
if params.is_empty() {
34+
return;
35+
}
36+
37+
let location = params[0].location;
38+
39+
let mut consequence_statements = Vec::with_capacity(params.len());
40+
for param in params {
41+
self.add_entry_point_pattern_validation(&param.pattern, &mut consequence_statements);
42+
}
43+
44+
let consequence = BlockExpression { statements: consequence_statements };
45+
let consequence = ExpressionKind::Block(consequence);
46+
let consequence = Expression::new(consequence, location);
47+
48+
let func = path(&["std", "runtime", "is_unconstrained"], location);
49+
let func = var(func);
50+
let not = not(call(func, Vec::new()));
51+
let if_ = if_then(not, consequence);
52+
let statement = Statement { kind: StatementKind::Expression(if_), location };
53+
statements.insert(0, statement);
54+
}
55+
56+
fn add_entry_point_pattern_validation(
57+
&self,
58+
pattern: &Pattern,
59+
statements: &mut Vec<Statement>,
60+
) {
61+
match pattern {
62+
Pattern::Identifier(ident) => {
63+
if ident.0.contents == "_" {
64+
return;
65+
}
66+
67+
let location = ident.location();
68+
let segments =
69+
["std", "validation", "AssertsIsValidInput", "assert_is_valid_input"];
70+
let func = path(&segments, location);
71+
let func = var(func);
72+
let argument = var(Path::from_ident(ident.clone()));
73+
let call = call(func, vec![argument]);
74+
let call = Statement { kind: StatementKind::Semi(call), location };
75+
statements.push(call);
76+
}
77+
Pattern::Mutable(pattern, ..) => {
78+
self.add_entry_point_pattern_validation(pattern, statements);
79+
}
80+
Pattern::Tuple(patterns, ..) => {
81+
for pattern in patterns {
82+
self.add_entry_point_pattern_validation(pattern, statements);
83+
}
84+
}
85+
Pattern::Struct(..) => todo!("add_entry_point_pattern_validation for Struct pattern"),
86+
Pattern::Interned(interned_pattern, ..) => {
87+
let pattern = self.interner.get_pattern(*interned_pattern);
88+
self.add_entry_point_pattern_validation(pattern, statements);
89+
}
90+
}
91+
}
92+
}
93+
94+
fn path(segments: &[&str], location: Location) -> Path {
95+
let segments = segments.iter().map(|segment| PathSegment {
96+
ident: Ident::new(segment.to_string(), location),
97+
generics: None,
98+
location,
99+
});
100+
Path { segments: segments.collect(), kind: PathKind::Plain, location }
101+
}
102+
103+
fn var(path: Path) -> Expression {
104+
let location = path.location;
105+
let var = ExpressionKind::Variable(path);
106+
Expression::new(var, location)
107+
}
108+
109+
fn call(func: Expression, arguments: Vec<Expression>) -> Expression {
110+
let location = func.location;
111+
let call = CallExpression { func: Box::new(func), arguments, is_macro_call: false };
112+
Expression::new(ExpressionKind::Call(Box::new(call)), location)
113+
}
114+
115+
fn not(rhs: Expression) -> Expression {
116+
let location = rhs.location;
117+
let not = ExpressionKind::prefix(UnaryOp::Not, rhs);
118+
Expression::new(not, location)
119+
}
120+
121+
fn if_then(condition: Expression, consequence: Expression) -> Expression {
122+
let location = condition.location;
123+
let if_ =
124+
ExpressionKind::If(Box::new(IfExpression { condition, consequence, alternative: None }));
125+
Expression::new(if_, location)
126+
}

compiler/noirc_frontend/src/elaborator/mod.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ use crate::{
5252
mod comptime;
5353
mod enums;
5454
mod expressions;
55+
mod input_validations;
5556
mod lints;
5657
mod options;
5758
mod path_resolution;
@@ -1018,7 +1019,12 @@ impl<'context> Elaborator<'context> {
10181019
.filter_map(|generic| self.find_generic(&generic.ident().0.contents).cloned())
10191020
.collect();
10201021

1021-
let statements = std::mem::take(&mut func.def.body.statements);
1022+
let mut statements = std::mem::take(&mut func.def.body.statements);
1023+
1024+
if is_entry_point && self.crate_graph.try_stdlib_crate_id().is_some() {
1025+
self.add_entry_point_parameters_validation(func.parameters(), &mut statements);
1026+
}
1027+
10221028
let body = BlockExpression { statements };
10231029

10241030
let struct_id = if let Some(Type::DataType(struct_type, _)) = &self.self_type {
@@ -1060,7 +1066,6 @@ impl<'context> Elaborator<'context> {
10601066
self.scopes.end_function();
10611067
self.current_item = None;
10621068
}
1063-
10641069
fn mark_type_as_used(&mut self, typ: &Type) {
10651070
match typ {
10661071
Type::Array(_n, typ) => self.mark_type_as_used(typ),

compiler/noirc_frontend/src/graph/mod.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,11 @@ impl CrateGraph {
182182
}
183183

184184
pub fn stdlib_crate_id(&self) -> &CrateId {
185-
self.arena
186-
.keys()
187-
.find(|crate_id| crate_id.is_stdlib())
188-
.expect("ICE: The stdlib should exist in the CrateGraph")
185+
self.try_stdlib_crate_id().expect("ICE: The stdlib should exist in the CrateGraph")
186+
}
187+
188+
pub fn try_stdlib_crate_id(&self) -> Option<&CrateId> {
189+
self.arena.keys().find(|crate_id| crate_id.is_stdlib())
189190
}
190191

191192
pub fn add_crate_root(&mut self, file_id: FileId) -> CrateId {

noir_stdlib/src/collections/bounded_vec.nr

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use crate::{cmp::Eq, convert::From, runtime::is_unconstrained, static_assert};
1+
use crate::{
2+
cmp::Eq, convert::From, runtime::is_unconstrained, static_assert,
3+
validation::AssertsIsValidInput,
4+
};
25

36
/// A `BoundedVec<T, MaxLen>` is a growable storage similar to a `Vec<T>` except that it
47
/// is bounded with a maximum possible length. Unlike `Vec`, `BoundedVec` is not implemented
@@ -526,6 +529,16 @@ impl<T, let MaxLen: u32, let Len: u32> From<[T; Len]> for BoundedVec<T, MaxLen>
526529
}
527530
}
528531

532+
impl<T, let MaxLen: u32> AssertsIsValidInput for BoundedVec<T, MaxLen>
533+
where
534+
T: AssertsIsValidInput,
535+
{
536+
fn assert_is_valid_input(self) {
537+
assert(self.len <= MaxLen);
538+
self.storage.assert_is_valid_input();
539+
}
540+
}
541+
529542
mod bounded_vec_tests {
530543

531544
mod get {

noir_stdlib/src/embedded_curve_ops.nr

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::cmp::Eq;
22
use crate::ops::arith::{Add, Neg, Sub};
3+
use super::validation::AssertsIsValidInput;
34

45
/// A point on the embedded elliptic curve
56
/// By definition, the base field of the embedded curve is the scalar field of the proof system curve, i.e the Noir Field.
@@ -53,6 +54,14 @@ impl Eq for EmbeddedCurvePoint {
5354
}
5455
}
5556

57+
impl AssertsIsValidInput for EmbeddedCurvePoint {
58+
fn assert_is_valid_input(self) {
59+
self.x.assert_is_valid_input();
60+
self.y.assert_is_valid_input();
61+
self.is_infinite.assert_is_valid_input();
62+
}
63+
}
64+
5665
/// Scalar for the embedded curve represented as low and high limbs
5766
/// By definition, the scalar field of the embedded curve is base field of the proving system curve.
5867
/// It may not fit into a Field element, so it is represented with two Field elements; its low and high limbs.

noir_stdlib/src/lib.nr

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pub mod append;
2626
pub mod mem;
2727
pub mod panic;
2828
pub mod hint;
29+
pub mod validation;
2930

3031
use convert::AsPrimitive;
3132

noir_stdlib/src/uint128.nr

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::cmp::{Eq, Ord, Ordering};
22
use crate::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Not, Rem, Shl, Shr, Sub};
33
use crate::static_assert;
4-
use super::{convert::AsPrimitive, default::Default};
4+
use super::{convert::AsPrimitive, default::Default, validation::AssertsIsValidInput};
55

66
global pow64: Field = 18446744073709551616; //2^64;
77
global pow63: Field = 9223372036854775808; // 2^63;
@@ -329,6 +329,13 @@ impl Default for U128 {
329329
}
330330
}
331331

332+
impl AssertsIsValidInput for U128 {
333+
fn assert_is_valid_input(self) {
334+
self.hi.assert_is_valid_input();
335+
self.lo.assert_is_valid_input();
336+
}
337+
}
338+
332339
mod tests {
333340
use crate::default::Default;
334341
use crate::ops::Not;

0 commit comments

Comments
 (0)