diff --git a/src/boolean/eq.rs b/src/boolean/eq.rs index 8ba664e..74d9a29 100644 --- a/src/boolean/eq.rs +++ b/src/boolean/eq.rs @@ -28,32 +28,38 @@ impl EqGadget for Boolean { // This works because a - b == 0 if and only if a = 0 and b = 0, or a = 1 and b // = 1, which is exactly the definition of a == b. - if condition != &Constant(false) { - let cs = self.cs().or(other.cs()).or(condition.cs()); - match (self, other) { - // 1 == 1; 0 == 0 - (Constant(true), Constant(true)) | (Constant(false), Constant(false)) => { - return Ok(()) - }, - // false != true - (Constant(_), Constant(_)) => return Err(SynthesisError::Unsatisfiable), - // handled below - (_, _) => (), - }; - let difference = || match (self, other) { - // 1 - a - (Constant(true), Var(a)) | (Var(a), Constant(true)) => { - lc_diff![one, a.variable()] - }, - // a - 0 = a - (Constant(false), Var(a)) | (Var(a), Constant(false)) => a.variable().into(), - // b - a, - (Var(a), Var(b)) => lc_diff![b.variable(), a.variable()], - // handled above - (_, _) => unreachable!(), - }; - cs.enforce_r1cs_constraint(difference, || condition.lc(), || lc!())?; + // If condition is false, this is a no-op. + if condition == &Constant(false) { + return Ok(()); } + + let cs = self.cs().or(other.cs()).or(condition.cs()); + match (self, other) { + // 1 == 1; 0 == 0 + (Constant(true), Constant(true)) | (Constant(false), Constant(false)) => return Ok(()), + // Constants unequal: if condition is Constant(true), unsatisfiable; if condition + // is variable, we should enforce (a - b) * condition == 0 which will imply + // condition = 0 in this case. + (Constant(_), Constant(_)) if condition == &Constant(true) => { + return Err(SynthesisError::Unsatisfiable) + }, + _ => (), + }; + + let difference = || match (self, other) { + // 1 - a + (Constant(true), Var(a)) | (Var(a), Constant(true)) => lc_diff![one, a.variable()], + // a - 0 = a + (Constant(false), Var(a)) | (Var(a), Constant(false)) => a.variable().into(), + // b - a + (Var(a), Var(b)) => lc_diff![b.variable(), a.variable()], + // both constants unequal: return ±1 accordingly + (Constant(true), Constant(false)) => one.into(), + (Constant(false), Constant(true)) => lc!() - one, + // equal constants handled earlier + (Constant(false), Constant(false)) | (Constant(true), Constant(true)) => lc!(), + }; + cs.enforce_r1cs_constraint(difference, || condition.lc(), || lc!())?; Ok(()) } @@ -78,7 +84,7 @@ impl EqGadget for Boolean { }, // false == false and true == true (Constant(_), Constant(_)) => return Err(SynthesisError::Unsatisfiable), - (_, _) => (), + (..) => (), } let sum = || match (self, other) { // 1 + a @@ -90,7 +96,7 @@ impl EqGadget for Boolean { // b + a, (Var(a), Var(b)) => lc![b.variable(), a.variable()], // handled above - (_, _) => unreachable!(), + (..) => unreachable!(), }; cs.enforce_r1cs_constraint(sum, || should_enforce.lc(), || one.into())?; } @@ -244,4 +250,58 @@ mod tests { }) .unwrap() } + + #[test] + fn conditional_enforce_equal_const_unequal_cond_var_false() { + use ark_relations::gr1cs::ConstraintSystem; + use ark_test_curves::bls12_381::Fr; + + let cs = ConstraintSystem::::new_ref(); + let a = Boolean::::TRUE; + let b = Boolean::::FALSE; + let cond = Boolean::new_witness(cs.clone(), || Ok(false)).unwrap(); + + a.conditional_enforce_equal(&b, &cond).unwrap(); + assert!(cs.is_satisfied().unwrap()); + } + + #[test] + fn conditional_enforce_equal_const_unequal_cond_var_true_unsat() { + use ark_relations::gr1cs::ConstraintSystem; + use ark_test_curves::bls12_381::Fr; + + let cs = ConstraintSystem::::new_ref(); + let a = Boolean::::TRUE; + let b = Boolean::::FALSE; + let cond = Boolean::new_witness(cs.clone(), || Ok(true)).unwrap(); + + a.conditional_enforce_equal(&b, &cond).unwrap(); + assert!(!cs.is_satisfied().unwrap()); + } + + #[test] + fn conditional_enforce_equal_const_unequal_cond_const_false_noop() { + use ark_relations::gr1cs::ConstraintSystem; + use ark_test_curves::bls12_381::Fr; + + let cs = ConstraintSystem::::new_ref(); + let a = Boolean::::TRUE; + let b = Boolean::::FALSE; + let cond = Boolean::::FALSE; + + a.conditional_enforce_equal(&b, &cond).unwrap(); + assert!(cs.is_satisfied().unwrap()); + } + + #[test] + fn conditional_enforce_equal_const_unequal_cond_const_true_errors() { + use ark_test_curves::bls12_381::Fr; + + let a = Boolean::::TRUE; + let b = Boolean::::FALSE; + let cond = Boolean::::TRUE; + + let err = a.conditional_enforce_equal(&b, &cond).unwrap_err(); + assert!(matches!(err, SynthesisError::Unsatisfiable)); + } }