Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 87 additions & 27 deletions src/boolean/eq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,32 +28,38 @@ impl<F: Field> EqGadget<F> for Boolean<F> {
// This works because a - b == 0 if and only if a = 0 and b = 0, or a = 1 and b
// = 1, which is exactly the definition of a == b.

if condition != &Constant(false) {
let cs = self.cs().or(other.cs()).or(condition.cs());
match (self, other) {
// 1 == 1; 0 == 0
(Constant(true), Constant(true)) | (Constant(false), Constant(false)) => {
return Ok(())
},
// false != true
(Constant(_), Constant(_)) => return Err(SynthesisError::Unsatisfiable),
// handled below
(_, _) => (),
};
let difference = || match (self, other) {
// 1 - a
(Constant(true), Var(a)) | (Var(a), Constant(true)) => {
lc_diff![one, a.variable()]
},
// a - 0 = a
(Constant(false), Var(a)) | (Var(a), Constant(false)) => a.variable().into(),
// b - a,
(Var(a), Var(b)) => lc_diff![b.variable(), a.variable()],
// handled above
(_, _) => unreachable!(),
};
cs.enforce_r1cs_constraint(difference, || condition.lc(), || lc!())?;
// If condition is false, this is a no-op.
if condition == &Constant(false) {
return Ok(());
}

let cs = self.cs().or(other.cs()).or(condition.cs());
match (self, other) {
// 1 == 1; 0 == 0
(Constant(true), Constant(true)) | (Constant(false), Constant(false)) => return Ok(()),
// Constants unequal: if condition is Constant(true), unsatisfiable; if condition
// is variable, we should enforce (a - b) * condition == 0 which will imply
// condition = 0 in this case.
(Constant(_), Constant(_)) if condition == &Constant(true) => {
return Err(SynthesisError::Unsatisfiable)
},
_ => (),
};

let difference = || match (self, other) {
// 1 - a
(Constant(true), Var(a)) | (Var(a), Constant(true)) => lc_diff![one, a.variable()],
// a - 0 = a
(Constant(false), Var(a)) | (Var(a), Constant(false)) => a.variable().into(),
// b - a
(Var(a), Var(b)) => lc_diff![b.variable(), a.variable()],
// both constants unequal: return ±1 accordingly
(Constant(true), Constant(false)) => one.into(),
(Constant(false), Constant(true)) => lc!() - one,
// equal constants handled earlier
(Constant(false), Constant(false)) | (Constant(true), Constant(true)) => lc!(),
};
cs.enforce_r1cs_constraint(difference, || condition.lc(), || lc!())?;
Ok(())
}

Expand All @@ -78,7 +84,7 @@ impl<F: Field> EqGadget<F> for Boolean<F> {
},
// false == false and true == true
(Constant(_), Constant(_)) => return Err(SynthesisError::Unsatisfiable),
(_, _) => (),
(..) => (),
}
let sum = || match (self, other) {
// 1 + a
Expand All @@ -90,7 +96,7 @@ impl<F: Field> EqGadget<F> for Boolean<F> {
// b + a,
(Var(a), Var(b)) => lc![b.variable(), a.variable()],
// handled above
(_, _) => unreachable!(),
(..) => unreachable!(),
};
cs.enforce_r1cs_constraint(sum, || should_enforce.lc(), || one.into())?;
}
Expand Down Expand Up @@ -244,4 +250,58 @@ mod tests {
})
.unwrap()
}

#[test]
fn conditional_enforce_equal_const_unequal_cond_var_false() {
use ark_relations::gr1cs::ConstraintSystem;
use ark_test_curves::bls12_381::Fr;

let cs = ConstraintSystem::<Fr>::new_ref();
let a = Boolean::<Fr>::TRUE;
let b = Boolean::<Fr>::FALSE;
let cond = Boolean::new_witness(cs.clone(), || Ok(false)).unwrap();

a.conditional_enforce_equal(&b, &cond).unwrap();
assert!(cs.is_satisfied().unwrap());
}

#[test]
fn conditional_enforce_equal_const_unequal_cond_var_true_unsat() {
use ark_relations::gr1cs::ConstraintSystem;
use ark_test_curves::bls12_381::Fr;

let cs = ConstraintSystem::<Fr>::new_ref();
let a = Boolean::<Fr>::TRUE;
let b = Boolean::<Fr>::FALSE;
let cond = Boolean::new_witness(cs.clone(), || Ok(true)).unwrap();

a.conditional_enforce_equal(&b, &cond).unwrap();
assert!(!cs.is_satisfied().unwrap());
}

#[test]
fn conditional_enforce_equal_const_unequal_cond_const_false_noop() {
use ark_relations::gr1cs::ConstraintSystem;
use ark_test_curves::bls12_381::Fr;

let cs = ConstraintSystem::<Fr>::new_ref();
let a = Boolean::<Fr>::TRUE;
let b = Boolean::<Fr>::FALSE;
let cond = Boolean::<Fr>::FALSE;

a.conditional_enforce_equal(&b, &cond).unwrap();
assert!(cs.is_satisfied().unwrap());
}

#[test]
fn conditional_enforce_equal_const_unequal_cond_const_true_errors() {
use ark_test_curves::bls12_381::Fr;

let a = Boolean::<Fr>::TRUE;
let b = Boolean::<Fr>::FALSE;
let cond = Boolean::<Fr>::TRUE;

let err = a.conditional_enforce_equal(&b, &cond).unwrap_err();
assert!(matches!(err, SynthesisError::Unsatisfiable));
}
}