Skip to content

fix(frontend)!: Restrict capturing mutable variable in lambdas #7488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions compiler/noirc_frontend/src/elaborator/statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,24 @@ impl Elaborator<'_> {
let (lvalue, lvalue_type, mutable) = self.elaborate_lvalue(assign.lvalue);

if !mutable {
let (name, location) = self.get_lvalue_name_and_location(&lvalue);
let (_, name, location) = self.get_lvalue_error_info(&lvalue);
self.push_err(TypeCheckError::VariableMustBeMutable { name, location });
} else if let Some(lambda_context) = self.lambda_stack.last() {
// We must check whether the mutable variable we are attempting to assign
// comes from a lambda capture. All captures are immutable so we want to error
// if the user attempts to mutate a captured variable inside of a lambda without mutable references.
let capture_ids =
lambda_context.captures.iter().map(|var| var.ident.id).collect::<Vec<_>>();
let (id, name, location) = self.get_lvalue_error_info(&lvalue);
let typ = self.interner.definition_type(id);
for capture_id in capture_ids {
if capture_id == id && !typ.is_mutable_ref() {
self.push_err(TypeCheckError::MutableCaptureWithoutRef {
name: name.clone(),
location,
});
}
}
}

self.unify_with_coercions(&expr_type, &lvalue_type, expression, expr_location, || {
Expand Down Expand Up @@ -331,20 +347,20 @@ impl Elaborator<'_> {
(expr, self.interner.next_type_variable())
}

fn get_lvalue_name_and_location(&self, lvalue: &HirLValue) -> (String, Location) {
fn get_lvalue_error_info(&self, lvalue: &HirLValue) -> (DefinitionId, String, Location) {
match lvalue {
HirLValue::Ident(name, _) => {
let location = name.location;

if let Some(definition) = self.interner.try_definition(name.id) {
(definition.name.clone(), location)
(name.id, definition.name.clone(), location)
} else {
("(undeclared variable)".into(), location)
(DefinitionId::dummy_id(), "(undeclared variable)".into(), location)
}
}
HirLValue::MemberAccess { object, .. } => self.get_lvalue_name_and_location(object),
HirLValue::Index { array, .. } => self.get_lvalue_name_and_location(array),
HirLValue::Dereference { lvalue, .. } => self.get_lvalue_name_and_location(lvalue),
HirLValue::MemberAccess { object, .. } => self.get_lvalue_error_info(object),
HirLValue::Index { array, .. } => self.get_lvalue_error_info(array),
HirLValue::Dereference { lvalue, .. } => self.get_lvalue_error_info(lvalue),
}
}

Expand Down Expand Up @@ -446,8 +462,8 @@ impl Elaborator<'_> {
Type::Slice(elem_type) => *elem_type,
Type::Error => Type::Error,
Type::String(_) => {
let (_lvalue_name, lvalue_location) =
self.get_lvalue_name_and_location(&lvalue);
let (_id, _lvalue_name, lvalue_location) =
self.get_lvalue_error_info(&lvalue);
self.push_err(TypeCheckError::StringIndexAssign {
location: lvalue_location,
});
Expand Down
12 changes: 11 additions & 1 deletion compiler/noirc_frontend/src/hir/type_check/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ pub enum TypeCheckError {
VariableMustBeMutable { name: String, location: Location },
#[error("Cannot mutate immutable variable `{name}`")]
CannotMutateImmutableVariable { name: String, location: Location },
#[error("Variable {name} captured in lambda must be a mutable reference")]
MutableCaptureWithoutRef { name: String, location: Location },
#[error("No method named '{method_name}' found for type '{object_type}'")]
UnresolvedMethodCall { method_name: String, object_type: Type, location: Location },
#[error("Cannot invoke function field '{method_name}' on type '{object_type}' as a method")]
Expand Down Expand Up @@ -321,9 +323,12 @@ impl TypeCheckError {
| TypeCheckError::CyclicType { location, .. }
| TypeCheckError::TypeAnnotationsNeededForIndex { location }
| TypeCheckError::UnnecessaryUnsafeBlock { location }
| TypeCheckError::NestedUnsafeBlock { location } => *location,
| TypeCheckError::NestedUnsafeBlock { location }
| TypeCheckError::MutableCaptureWithoutRef { location, .. } => *location,

TypeCheckError::DuplicateNamedTypeArg { name: ident, .. }
| TypeCheckError::NoSuchNamedTypeArg { name: ident, .. } => ident.location(),

TypeCheckError::NoMatchingImplFound(no_matching_impl_found_error) => {
no_matching_impl_found_error.location
}
Expand Down Expand Up @@ -476,6 +481,11 @@ impl<'a> From<&'a TypeCheckError> for Diagnostic {
| TypeCheckError::InvalidShiftSize { location } => {
Diagnostic::simple_error(error.to_string(), String::new(), *location)
}
TypeCheckError::MutableCaptureWithoutRef { name, location } => Diagnostic::simple_error(
format!("Mutable variable {name} captured in lambda must be a mutable reference"),
"Use '&mut' instead of 'mut' to capture a mutable variable.".to_string(),
*location,
),
TypeCheckError::PublicReturnType { typ, location } => Diagnostic::simple_error(
"Functions cannot declare a public return type".to_string(),
format!("return type is {typ}"),
Expand Down
4 changes: 4 additions & 0 deletions compiler/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,10 @@ impl Type {
}
}

pub(crate) fn is_mutable_ref(&self) -> bool {
matches!(self.follow_bindings_shallow().as_ref(), Type::MutableReference(_))
}

/// True if this type can be used as a parameter to `main` or a contract function.
/// This is only false for unsized types like slices or slices that do not make sense
/// as a program input such as named generics or mutable references.
Expand Down
61 changes: 61 additions & 0 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3888,3 +3888,64 @@ fn subtract_to_int_min() {
let errors = get_program_errors(src);
assert_eq!(errors.len(), 0);
}

#[test]
fn mutate_with_reference_in_lambda() {
let src = r#"
fn main() {
let x = &mut 3;
let f = || {
*x += 2;
};
f();
assert(*x == 5);
}
"#;

assert_no_errors(src);
}

#[test]
fn mutate_with_reference_marked_mutable_in_lambda() {
let src = r#"
fn main() {
let mut x = &mut 3;
let f = || {
*x += 2;
};
f();
assert(*x == 5);
}
"#;
assert_no_errors(src);
}

#[test]
fn deny_capturing_mut_variable_without_reference_in_lambda() {
let src = r#"
fn main() {
let mut x = 3;
let f = || {
x += 2;
^ Mutable variable x captured in lambda must be a mutable reference
~ Use '&mut' instead of 'mut' to capture a mutable variable.
};
f();
assert(x == 5);
}
"#;
check_errors(src);
}

#[test]
fn allow_capturing_mut_variable_only_used_immutably() {
let src = r#"
fn main() {
let mut x = 3;
let f = || x;
let _x2 = f();
assert(x == 3);
}
"#;
assert_no_errors(src);
}
Loading