Skip to content

Commit a991eeb

Browse files
authored
Perform bit decomposition as assignment if offset is known. (#2612)
If we detect a bit decomposition pattern like `X + Y * 256 + Z * 65536 = T`, the witgen inference returns a bit decomposition effect. But if the value of `T` is a known number, we can actually extract the concrete values for `X`, `Y` and `Z` at compile-time and continue solving with this additional information.
1 parent 16238c7 commit a991eeb

File tree

3 files changed

+140
-49
lines changed

3 files changed

+140
-49
lines changed

executor/src/witgen/jit/affine_symbolic_expression.rs

+63-36
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::{
66

77
use itertools::Itertools;
88
use num_traits::Zero;
9-
use powdr_number::{log2_exact, FieldElement};
9+
use powdr_number::{log2_exact, FieldElement, LargeInt};
1010

1111
use crate::witgen::jit::effect::Assertion;
1212

@@ -243,66 +243,93 @@ impl<T: FieldElement, V: Ord + Clone + Display> AffineSymbolicExpression<T, V> {
243243
.coefficients
244244
.iter()
245245
.map(|(var, coeff)| {
246-
let c = coeff.try_to_number()?;
246+
let coeff = coeff.try_to_number()?;
247247
let rc = self.range_constraints.get(var)?;
248-
Some((var.clone(), c, rc))
248+
let is_negative = !coeff.is_in_lower_half();
249+
let coeff_abs = if is_negative { -coeff } else { coeff };
250+
// We could work with non-powers of two, but it would require
251+
// division instead of shifts.
252+
let exponent = log2_exact(coeff_abs.to_arbitrary_integer())?;
253+
// We negate here because we are solving
254+
// c_1 * x_1 + c_2 * x_2 + ... + offset = 0,
255+
// instead of
256+
// c_1 * x_1 + c_2 * x_2 + ... = offset.
257+
Some((var.clone(), rc, !is_negative, coeff_abs, exponent))
249258
})
250259
.collect::<Option<Vec<_>>>();
251260
let Some(constrained_coefficients) = constrained_coefficients else {
252261
return Ok(ProcessResult::empty());
253262
};
254263

264+
// If the offset is a known number, we gradually remove the
265+
// components from this number.
266+
let mut offset = self.offset.try_to_number();
267+
let mut concrete_assignments = vec![];
268+
255269
// Check if they are mutually exclusive and compute assignments.
256270
let mut covered_bits: <T as FieldElement>::Integer = 0.into();
257271
let mut components = vec![];
258-
for (variable, coeff, constraint) in constrained_coefficients {
259-
let is_negative = !coeff.is_in_lower_half();
260-
let coeff_abs = if is_negative { -coeff } else { coeff };
261-
let Some(exponent) = log2_exact(coeff_abs.to_arbitrary_integer()) else {
262-
// We could work with non-powers of two, but it would require
263-
// division instead of shifts.
264-
return Ok(ProcessResult::empty());
265-
};
272+
for (variable, constraint, is_negative, coeff_abs, exponent) in constrained_coefficients
273+
.into_iter()
274+
.sorted_by_key(|(_, _, _, _, exponent)| *exponent)
275+
{
266276
let bit_mask = *constraint.multiple(coeff_abs).mask();
267277
if !(bit_mask & covered_bits).is_zero() {
268278
// Overlapping range constraints.
269279
return Ok(ProcessResult::empty());
270280
} else {
271281
covered_bits |= bit_mask;
272282
}
273-
components.push(BitDecompositionComponent {
274-
variable,
275-
// We negate here because we are solving
276-
// c_1 * x_1 + c_2 * x_2 + ... + offset = 0,
277-
// instead of
278-
// c_1 * x_1 + c_2 * x_2 + ... = offset.
279-
is_negative: !is_negative,
280-
exponent: exponent as u64,
281-
bit_mask,
282-
});
283+
284+
// If the offset is a known number, we create concrete assignments and modify the offset.
285+
// if it is not known, we return a BitDecomposition effect.
286+
if let Some(offset) = &mut offset {
287+
let mut component = if is_negative { -*offset } else { *offset }.to_integer();
288+
if component > (T::modulus() - 1.into()) >> 1 {
289+
// Convert a signed finite field element into two's complement.
290+
// a regular subtraction would underflow, so we do this.
291+
// We add the difference between negative numbers in the field
292+
// and negative numbers in two's complement.
293+
component += T::Integer::MAX - T::modulus() + 1.into();
294+
};
295+
component &= bit_mask;
296+
concrete_assignments.push(Effect::Assignment(
297+
variable.clone(),
298+
T::from(component >> exponent).into(),
299+
));
300+
if is_negative {
301+
*offset += T::from(component);
302+
} else {
303+
*offset -= T::from(component);
304+
}
305+
} else {
306+
components.push(BitDecompositionComponent {
307+
variable,
308+
is_negative,
309+
exponent: exponent as u64,
310+
bit_mask,
311+
});
312+
}
283313
}
284314

285315
if covered_bits >= T::modulus() {
286316
return Ok(ProcessResult::empty());
287317
}
288318

289-
if !components.iter().any(|c| c.is_negative) {
290-
// If all coefficients are positive and the offset is known, we can check
291-
// that all bits are covered. If not, then there is no way to extract
292-
// the components and thus we have a conflict.
293-
if let Some(offset) = self.offset.try_to_number() {
294-
if offset.to_integer() & !covered_bits != 0.into() {
295-
return Err(Error::ConflictingRangeConstraints);
296-
}
319+
if let Some(offset) = offset {
320+
if offset != 0.into() {
321+
return Err(Error::ConstraintUnsatisfiable);
297322
}
323+
assert_eq!(concrete_assignments.len(), self.coefficients.len());
324+
Ok(ProcessResult::complete(concrete_assignments))
325+
} else {
326+
Ok(ProcessResult::complete(vec![Effect::BitDecomposition(
327+
BitDecomposition {
328+
value: self.offset.clone(),
329+
components,
330+
},
331+
)]))
298332
}
299-
300-
Ok(ProcessResult::complete(vec![Effect::BitDecomposition(
301-
BitDecomposition {
302-
value: self.offset.clone(),
303-
components,
304-
},
305-
)]))
306333
}
307334

308335
fn transfer_constraints(&self) -> Option<Effect<T, V>> {

executor/src/witgen/jit/block_machine_processor.rs

+71-7
Original file line numberDiff line numberDiff line change
@@ -520,19 +520,19 @@ main_binary::operation_id_next[0] = main_binary::operation_id[1];
520520
call_var(9, 0, 0) = main_binary::operation_id_next[0];
521521
main_binary::operation_id_next[1] = main_binary::operation_id[2];
522522
call_var(9, 1, 0) = main_binary::operation_id_next[1];
523-
2**24 * main_binary::A_byte[2] + 2**0 * main_binary::A[2] := main_binary::A[3];
523+
2**0 * main_binary::A[2] + 2**24 * main_binary::A_byte[2] := main_binary::A[3];
524524
call_var(9, 2, 1) = main_binary::A_byte[2];
525-
2**16 * main_binary::A_byte[1] + 2**0 * main_binary::A[1] := main_binary::A[2];
525+
2**0 * main_binary::A[1] + 2**16 * main_binary::A_byte[1] := main_binary::A[2];
526526
call_var(9, 1, 1) = main_binary::A_byte[1];
527-
2**8 * main_binary::A_byte[0] + 2**0 * main_binary::A[0] := main_binary::A[1];
527+
2**0 * main_binary::A[0] + 2**8 * main_binary::A_byte[0] := main_binary::A[1];
528528
call_var(9, 0, 1) = main_binary::A_byte[0];
529529
main_binary::A_byte[-1] = main_binary::A[0];
530530
call_var(9, -1, 1) = main_binary::A_byte[-1];
531-
2**24 * main_binary::B_byte[2] + 2**0 * main_binary::B[2] := main_binary::B[3];
531+
2**0 * main_binary::B[2] + 2**24 * main_binary::B_byte[2] := main_binary::B[3];
532532
call_var(9, 2, 2) = main_binary::B_byte[2];
533-
2**16 * main_binary::B_byte[1] + 2**0 * main_binary::B[1] := main_binary::B[2];
533+
2**0 * main_binary::B[1] + 2**16 * main_binary::B_byte[1] := main_binary::B[2];
534534
call_var(9, 1, 2) = main_binary::B_byte[1];
535-
2**8 * main_binary::B_byte[0] + 2**0 * main_binary::B[0] := main_binary::B[1];
535+
2**0 * main_binary::B[0] + 2**8 * main_binary::B_byte[0] := main_binary::B[1];
536536
call_var(9, 0, 2) = main_binary::B_byte[0];
537537
main_binary::B_byte[-1] = main_binary::B[0];
538538
call_var(9, -1, 2) = main_binary::B_byte[-1];
@@ -607,7 +607,7 @@ params[1] = Sub::b[0];"
607607
assert_eq!(
608608
format_code(&code),
609609
"SubM::a[0] = params[0];
610-
2**8 * SubM::b[0] + 2**0 * SubM::c[0] := SubM::a[0];
610+
2**0 * SubM::c[0] + 2**8 * SubM::b[0] := SubM::a[0];
611611
params[1] = SubM::b[0];
612612
params[2] = SubM::c[0];
613613
call_var(1, 0, 0) = SubM::c[0];
@@ -816,4 +816,68 @@ S::Z[0] = params[1];
816816
params[2] = S::carry[0];"
817817
);
818818
}
819+
820+
#[test]
821+
fn bit_decomp_negative_concrete() {
822+
let input = "
823+
namespace Main(256);
824+
col witness a, b, c;
825+
[a, b, c] is [S.Y, S.Z, S.carry];
826+
namespace S(256);
827+
let BYTE: col = |i| i & 0xff;
828+
let X;
829+
let Y;
830+
let Z;
831+
Y = 19;
832+
Z = 16;
833+
let carry;
834+
carry * (1 - carry) = 0;
835+
[ X ] in [ BYTE ];
836+
[ Y ] in [ BYTE ];
837+
[ Z ] in [ BYTE ];
838+
X + Y = Z + 256 * carry;
839+
";
840+
let code = format_code(&generate_for_block_machine(input, "S", 2, 1).unwrap().code);
841+
assert_eq!(
842+
code,
843+
"\
844+
S::Y[0] = params[0];
845+
S::Z[0] = params[1];
846+
S::X[0] = 253;
847+
S::carry[0] = 1;
848+
params[2] = 1;"
849+
);
850+
}
851+
852+
#[test]
853+
fn bit_decomp_negative_concrete_2() {
854+
let input = "
855+
namespace Main(256);
856+
col witness a, b, c;
857+
[a, b, c] is [S.Y, S.Z, S.carry];
858+
namespace S(256);
859+
let BYTE: col = |i| i & 0xff;
860+
let X;
861+
let Y;
862+
let Z;
863+
Y = 1;
864+
Z = 16;
865+
let carry;
866+
carry * (1 - carry) = 0;
867+
[ X ] in [ BYTE ];
868+
[ Y ] in [ BYTE ];
869+
[ Z ] in [ BYTE ];
870+
X + Y = Z + 256 * carry;
871+
";
872+
let code = format_code(&generate_for_block_machine(input, "S", 2, 1).unwrap().code);
873+
assert_eq!(
874+
code,
875+
"\
876+
S::Y[0] = params[0];
877+
S::Z[0] = params[1];
878+
S::X[0] = 15;
879+
S::carry[0] = 0;
880+
params[2] = 0;"
881+
);
882+
}
819883
}

executor/src/witgen/jit/witgen_inference.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -835,15 +835,15 @@ namespace Xor(256 * 256);
835835
assert_eq!(
836836
code,
837837
"\
838-
2**24 * Xor::A_byte[6] + 2**0 * Xor::A[6] := Xor::A[7];
839-
2**24 * Xor::C_byte[6] + 2**0 * Xor::C[6] := Xor::C[7];
840-
2**16 * Xor::A_byte[5] + 2**0 * Xor::A[5] := Xor::A[6];
841-
2**16 * Xor::C_byte[5] + 2**0 * Xor::C[5] := Xor::C[6];
838+
2**0 * Xor::A[6] + 2**24 * Xor::A_byte[6] := Xor::A[7];
839+
2**0 * Xor::C[6] + 2**24 * Xor::C_byte[6] := Xor::C[7];
840+
2**0 * Xor::A[5] + 2**16 * Xor::A_byte[5] := Xor::A[6];
841+
2**0 * Xor::C[5] + 2**16 * Xor::C_byte[5] := Xor::C[6];
842842
call_var(0, 6, 0) = Xor::A_byte[6];
843843
call_var(0, 6, 2) = Xor::C_byte[6];
844844
machine_call(1, [Known(call_var(0, 6, 0)), Unknown(call_var(0, 6, 1)), Known(call_var(0, 6, 2))]);
845-
2**8 * Xor::A_byte[4] + 2**0 * Xor::A[4] := Xor::A[5];
846-
2**8 * Xor::C_byte[4] + 2**0 * Xor::C[4] := Xor::C[5];
845+
2**0 * Xor::A[4] + 2**8 * Xor::A_byte[4] := Xor::A[5];
846+
2**0 * Xor::C[4] + 2**8 * Xor::C_byte[4] := Xor::C[5];
847847
call_var(0, 5, 0) = Xor::A_byte[5];
848848
call_var(0, 5, 2) = Xor::C_byte[5];
849849
machine_call(1, [Known(call_var(0, 5, 0)), Unknown(call_var(0, 5, 1)), Known(call_var(0, 5, 2))]);

0 commit comments

Comments
 (0)