Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(frontend)!: Restrict capturing mutable variable in lambdas #7488

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
34 changes: 25 additions & 9 deletions compiler/noirc_frontend/src/elaborator/statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,24 @@ impl Elaborator<'_> {
let (lvalue, lvalue_type, mutable) = self.elaborate_lvalue(assign.lvalue);

if !mutable {
let (name, location) = self.get_lvalue_name_and_location(&lvalue);
let (_, name, location) = self.get_lvalue_error_info(&lvalue);
self.push_err(TypeCheckError::VariableMustBeMutable { name, location });
} else if let Some(lambda_context) = self.lambda_stack.last() {
// We must check whether the mutable variable we are attempting to assign
// comes from a lambda capture. All captures are immutable so we want to error
// if the user attempts to mutate a captured variable inside of a lambda without mutable references.
let capture_ids =
lambda_context.captures.iter().map(|var| var.ident.id).collect::<Vec<_>>();
let (id, name, location) = self.get_lvalue_error_info(&lvalue);
let typ = self.interner.definition_type(id);
for capture_id in capture_ids {
if capture_id == id && !typ.is_mutable_ref() {
self.push_err(TypeCheckError::MutableCaptureWithoutRef {
name: name.clone(),
location,
});
Comment on lines +157 to +169
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking instead of checking on assignment or mutable-ref creation or method calls we could go to the source of where a lambda capture is introduced into a lambda's scope and change that variable to be immutable. Is that possible?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I looked over the lambda capture code some more and we don't create any new definition ids.

So instead of changing the definition (which we can't since it is not a copy and would change the original to be immutable), we should have a single function to verify whether something is mutable and have the check only in that function. I looked across the codebase and found that we actually have already accidentally duplicated this code. &mut x calls check_can_mutate while obj.mutating_method() calls verify_mutable_reference. Assignment uses elaborate_lvalue` which returns a mutable boolean, although that method operates on an LValue rather than an ExprId so it is a bit different.

We should unify these methods into one (probably keeping elaborate_lvalue) and add the mutable capture check on the Ident case for this shared method. elaborate_lvalue should also call a new helper method to check the capture mutability on its Ident case as well.

}
}
Comment on lines +160 to +171
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't try it but I thought of a way to avoid collecting the capture IDs into a Vec:

Suggested change
let capture_ids =
lambda_context.captures.iter().map(|var| var.ident.id).collect::<Vec<_>>();
let (id, name, location) = self.get_lvalue_error_info(&lvalue);
let typ = self.interner.definition_type(id);
for capture_id in capture_ids {
if capture_id == id && !typ.is_mutable_ref() {
self.push_err(TypeCheckError::MutableCaptureWithoutRef {
name: name.clone(),
location,
});
}
}
let (id, name, location) = self.get_lvalue_error_info(&lvalue);
let typ = self.interner.definition_type(id);
if !typ.is_mutable_ref() && lambda_context.captures.iter().any(|var| var.ident.id == id)
{
let name = name.clone();
self.push_err(TypeCheckError::MutableCaptureWithoutRef { name, location });
}

}

self.unify_with_coercions(&expr_type, &lvalue_type, expression, expr_location, || {
Expand Down Expand Up @@ -331,20 +347,20 @@ impl Elaborator<'_> {
(expr, self.interner.next_type_variable())
}

fn get_lvalue_name_and_location(&self, lvalue: &HirLValue) -> (String, Location) {
fn get_lvalue_error_info(&self, lvalue: &HirLValue) -> (DefinitionId, String, Location) {
match lvalue {
HirLValue::Ident(name, _) => {
let location = name.location;

if let Some(definition) = self.interner.try_definition(name.id) {
(definition.name.clone(), location)
(name.id, definition.name.clone(), location)
} else {
("(undeclared variable)".into(), location)
(DefinitionId::dummy_id(), "(undeclared variable)".into(), location)
}
}
HirLValue::MemberAccess { object, .. } => self.get_lvalue_name_and_location(object),
HirLValue::Index { array, .. } => self.get_lvalue_name_and_location(array),
HirLValue::Dereference { lvalue, .. } => self.get_lvalue_name_and_location(lvalue),
HirLValue::MemberAccess { object, .. } => self.get_lvalue_error_info(object),
HirLValue::Index { array, .. } => self.get_lvalue_error_info(array),
HirLValue::Dereference { lvalue, .. } => self.get_lvalue_error_info(lvalue),
}
}

Expand Down Expand Up @@ -446,8 +462,8 @@ impl Elaborator<'_> {
Type::Slice(elem_type) => *elem_type,
Type::Error => Type::Error,
Type::String(_) => {
let (_lvalue_name, lvalue_location) =
self.get_lvalue_name_and_location(&lvalue);
let (_id, _lvalue_name, lvalue_location) =
self.get_lvalue_error_info(&lvalue);
self.push_err(TypeCheckError::StringIndexAssign {
location: lvalue_location,
});
Expand Down
12 changes: 11 additions & 1 deletion compiler/noirc_frontend/src/hir/type_check/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ pub enum TypeCheckError {
VariableMustBeMutable { name: String, location: Location },
#[error("Cannot mutate immutable variable `{name}`")]
CannotMutateImmutableVariable { name: String, location: Location },
#[error("Variable {name} captured in lambda must be a mutable reference")]
MutableCaptureWithoutRef { name: String, location: Location },
#[error("No method named '{method_name}' found for type '{object_type}'")]
UnresolvedMethodCall { method_name: String, object_type: Type, location: Location },
#[error("Cannot invoke function field '{method_name}' on type '{object_type}' as a method")]
Expand Down Expand Up @@ -321,9 +323,12 @@ impl TypeCheckError {
| TypeCheckError::CyclicType { location, .. }
| TypeCheckError::TypeAnnotationsNeededForIndex { location }
| TypeCheckError::UnnecessaryUnsafeBlock { location }
| TypeCheckError::NestedUnsafeBlock { location } => *location,
| TypeCheckError::NestedUnsafeBlock { location }
| TypeCheckError::MutableCaptureWithoutRef { location, .. } => *location,

TypeCheckError::DuplicateNamedTypeArg { name: ident, .. }
| TypeCheckError::NoSuchNamedTypeArg { name: ident, .. } => ident.location(),

TypeCheckError::NoMatchingImplFound(no_matching_impl_found_error) => {
no_matching_impl_found_error.location
}
Expand Down Expand Up @@ -476,6 +481,11 @@ impl<'a> From<&'a TypeCheckError> for Diagnostic {
| TypeCheckError::InvalidShiftSize { location } => {
Diagnostic::simple_error(error.to_string(), String::new(), *location)
}
TypeCheckError::MutableCaptureWithoutRef { name, location } => Diagnostic::simple_error(
format!("Mutable variable {name} captured in lambda must be a mutable reference"),
"Use '&mut' instead of 'mut' to capture a mutable variable.".to_string(),
*location,
),
TypeCheckError::PublicReturnType { typ, location } => Diagnostic::simple_error(
"Functions cannot declare a public return type".to_string(),
format!("return type is {typ}"),
Expand Down
4 changes: 4 additions & 0 deletions compiler/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,10 @@ impl Type {
}
}

pub(crate) fn is_mutable_ref(&self) -> bool {
matches!(self.follow_bindings_shallow().as_ref(), Type::MutableReference(_))
}

/// True if this type can be used as a parameter to `main` or a contract function.
/// This is only false for unsized types like slices or slices that do not make sense
/// as a program input such as named generics or mutable references.
Expand Down
61 changes: 61 additions & 0 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3888,3 +3888,64 @@ fn subtract_to_int_min() {
let errors = get_program_errors(src);
assert_eq!(errors.len(), 0);
}

#[test]
fn mutate_with_reference_in_lambda() {
let src = r#"
fn main() {
let x = &mut 3;
let f = || {
*x += 2;
};
f();
assert(*x == 5);
}
"#;

assert_no_errors(src);
}

#[test]
fn mutate_with_reference_marked_mutable_in_lambda() {
let src = r#"
fn main() {
let mut x = &mut 3;
let f = || {
*x += 2;
};
f();
assert(*x == 5);
}
"#;
assert_no_errors(src);
}

#[test]
fn deny_capturing_mut_variable_without_reference_in_lambda() {
let src = r#"
fn main() {
let mut x = 3;
let f = || {
x += 2;
^ Mutable variable x captured in lambda must be a mutable reference
~ Use '&mut' instead of 'mut' to capture a mutable variable.
};
f();
assert(x == 5);
}
"#;
check_errors(src);
}

#[test]
fn allow_capturing_mut_variable_only_used_immutably() {
let src = r#"
fn main() {
let mut x = 3;
let f = || x;
let _x2 = f();
assert(x == 3);
}
"#;
assert_no_errors(src);
}
Loading