Skip to content

Commit

Permalink
Merge e4b4f14 into 677c10c
Browse files Browse the repository at this point in the history
  • Loading branch information
asterite authored Feb 25, 2025
2 parents 677c10c + e4b4f14 commit b0e1584
Show file tree
Hide file tree
Showing 34 changed files with 467 additions and 30 deletions.
126 changes: 126 additions & 0 deletions compiler/noirc_frontend/src/elaborator/input_validations.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
use noirc_errors::Location;

use crate::ast::{
BlockExpression, CallExpression, Expression, ExpressionKind, Ident, IfExpression, Param, Path,
PathKind, PathSegment, Pattern, Statement, StatementKind, UnaryOp,
};

use super::Elaborator;

impl<'context> Elaborator<'context> {
/// Adds statements to `statements` to validate the given parameters.
///
/// For example, this function:
///
/// fn main(x: u8, y: u8) {
/// assert_eq(x, y);
/// }
///
/// is transformed into this one:
///
/// fn main(x: u8, y: u8) {
/// if !std::runtime::is_unconstrained() {
/// std::validation::AssertsIsValidInput::assert_is_valid_input(x);
/// std::validation::AssertsIsValidInput::assert_is_valid_input(y);
/// }
/// assert_eq(x, y);
/// }
pub(super) fn add_entry_point_parameters_validation(
&self,
params: &[Param],
statements: &mut Vec<Statement>,
) {
if params.is_empty() {
return;
}

let location = params[0].location;

let mut consequence_statements = Vec::with_capacity(params.len());
for param in params {
self.add_entry_point_pattern_validation(&param.pattern, &mut consequence_statements);
}

let consequence = BlockExpression { statements: consequence_statements };
let consequence = ExpressionKind::Block(consequence);
let consequence = Expression::new(consequence, location);

let func = path(&["std", "runtime", "is_unconstrained"], location);
let func = var(func);
let not = not(call(func, Vec::new()));
let if_ = if_then(not, consequence);
let statement = Statement { kind: StatementKind::Expression(if_), location };
statements.insert(0, statement);
}

fn add_entry_point_pattern_validation(
&self,
pattern: &Pattern,
statements: &mut Vec<Statement>,
) {
match pattern {
Pattern::Identifier(ident) => {
if ident.0.contents == "_" {
return;
}

let location = ident.location();
let segments =
["std", "validation", "AssertsIsValidInput", "assert_is_valid_input"];
let func = path(&segments, location);
let func = var(func);
let argument = var(Path::from_ident(ident.clone()));
let call = call(func, vec![argument]);
let call = Statement { kind: StatementKind::Semi(call), location };
statements.push(call);
}
Pattern::Mutable(pattern, ..) => {
self.add_entry_point_pattern_validation(pattern, statements);
}
Pattern::Tuple(patterns, ..) => {
for pattern in patterns {
self.add_entry_point_pattern_validation(pattern, statements);
}
}
Pattern::Struct(..) => todo!("add_entry_point_pattern_validation for Struct pattern"),
Pattern::Interned(interned_pattern, ..) => {
let pattern = self.interner.get_pattern(*interned_pattern);
self.add_entry_point_pattern_validation(pattern, statements);
}
}
}
}

fn path(segments: &[&str], location: Location) -> Path {
let segments = segments.iter().map(|segment| PathSegment {
ident: Ident::new(segment.to_string(), location),
generics: None,
location,
});
Path { segments: segments.collect(), kind: PathKind::Plain, location }
}

fn var(path: Path) -> Expression {
let location = path.location;
let var = ExpressionKind::Variable(path);
Expression::new(var, location)
}

fn call(func: Expression, arguments: Vec<Expression>) -> Expression {
let location = func.location;
let call = CallExpression { func: Box::new(func), arguments, is_macro_call: false };
Expression::new(ExpressionKind::Call(Box::new(call)), location)
}

fn not(rhs: Expression) -> Expression {
let location = rhs.location;
let not = ExpressionKind::prefix(UnaryOp::Not, rhs);
Expression::new(not, location)
}

fn if_then(condition: Expression, consequence: Expression) -> Expression {
let location = condition.location;
let if_ =
ExpressionKind::If(Box::new(IfExpression { condition, consequence, alternative: None }));
Expression::new(if_, location)
}
9 changes: 7 additions & 2 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ use crate::{
mod comptime;
mod enums;
mod expressions;
mod input_validations;
mod lints;
mod options;
mod path_resolution;
Expand Down Expand Up @@ -1018,7 +1019,12 @@ impl<'context> Elaborator<'context> {
.filter_map(|generic| self.find_generic(&generic.ident().0.contents).cloned())
.collect();

let statements = std::mem::take(&mut func.def.body.statements);
let mut statements = std::mem::take(&mut func.def.body.statements);

if is_entry_point && self.crate_graph.try_stdlib_crate_id().is_some() {
self.add_entry_point_parameters_validation(func.parameters(), &mut statements);
}

let body = BlockExpression { statements };

let struct_id = if let Some(Type::DataType(struct_type, _)) = &self.self_type {
Expand Down Expand Up @@ -1060,7 +1066,6 @@ impl<'context> Elaborator<'context> {
self.scopes.end_function();
self.current_item = None;
}

fn mark_type_as_used(&mut self, typ: &Type) {
match typ {
Type::Array(_n, typ) => self.mark_type_as_used(typ),
Expand Down
9 changes: 5 additions & 4 deletions compiler/noirc_frontend/src/graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,11 @@ impl CrateGraph {
}

pub fn stdlib_crate_id(&self) -> &CrateId {
self.arena
.keys()
.find(|crate_id| crate_id.is_stdlib())
.expect("ICE: The stdlib should exist in the CrateGraph")
self.try_stdlib_crate_id().expect("ICE: The stdlib should exist in the CrateGraph")
}

pub fn try_stdlib_crate_id(&self) -> Option<&CrateId> {
self.arena.keys().find(|crate_id| crate_id.is_stdlib())
}

pub fn add_crate_root(&mut self, file_id: FileId) -> CrateId {
Expand Down
15 changes: 14 additions & 1 deletion noir_stdlib/src/collections/bounded_vec.nr
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::{cmp::Eq, convert::From, runtime::is_unconstrained, static_assert};
use crate::{
cmp::Eq, convert::From, runtime::is_unconstrained, static_assert,
validation::AssertsIsValidInput,
};

/// A `BoundedVec<T, MaxLen>` is a growable storage similar to a `Vec<T>` except that it
/// is bounded with a maximum possible length. Unlike `Vec`, `BoundedVec` is not implemented
Expand Down Expand Up @@ -526,6 +529,16 @@ impl<T, let MaxLen: u32, let Len: u32> From<[T; Len]> for BoundedVec<T, MaxLen>
}
}

impl<T, let MaxLen: u32> AssertsIsValidInput for BoundedVec<T, MaxLen>
where
T: AssertsIsValidInput,
{
fn assert_is_valid_input(self) {
assert(self.len <= MaxLen);
self.storage.assert_is_valid_input();
}
}

mod bounded_vec_tests {

mod get {
Expand Down
9 changes: 9 additions & 0 deletions noir_stdlib/src/embedded_curve_ops.nr
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::cmp::Eq;
use crate::ops::arith::{Add, Neg, Sub};
use super::validation::AssertsIsValidInput;

/// A point on the embedded elliptic curve
/// By definition, the base field of the embedded curve is the scalar field of the proof system curve, i.e the Noir Field.
Expand Down Expand Up @@ -53,6 +54,14 @@ impl Eq for EmbeddedCurvePoint {
}
}

impl AssertsIsValidInput for EmbeddedCurvePoint {
fn assert_is_valid_input(self) {
self.x.assert_is_valid_input();
self.y.assert_is_valid_input();
self.is_infinite.assert_is_valid_input();
}
}

/// Scalar for the embedded curve represented as low and high limbs
/// By definition, the scalar field of the embedded curve is base field of the proving system curve.
/// It may not fit into a Field element, so it is represented with two Field elements; its low and high limbs.
Expand Down
1 change: 1 addition & 0 deletions noir_stdlib/src/lib.nr
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub mod append;
pub mod mem;
pub mod panic;
pub mod hint;
pub mod validation;

use convert::AsPrimitive;

Expand Down
9 changes: 8 additions & 1 deletion noir_stdlib/src/uint128.nr
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::cmp::{Eq, Ord, Ordering};
use crate::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Not, Rem, Shl, Shr, Sub};
use crate::static_assert;
use super::{convert::AsPrimitive, default::Default};
use super::{convert::AsPrimitive, default::Default, validation::AssertsIsValidInput};

global pow64: Field = 18446744073709551616; //2^64;
global pow63: Field = 9223372036854775808; // 2^63;
Expand Down Expand Up @@ -329,6 +329,13 @@ impl Default for U128 {
}
}

impl AssertsIsValidInput for U128 {
fn assert_is_valid_input(self) {
self.hi.assert_is_valid_input();
self.lo.assert_is_valid_input();
}
}

mod tests {
use crate::default::Default;
use crate::ops::Not;
Expand Down
Loading

0 comments on commit b0e1584

Please sign in to comment.