Skip to content

Commit 89aa6ad

Browse files
new expression api (#15)
* new expression api * chores: clippy * fmt --------- Co-authored-by: kunxian xia <[email protected]>
1 parent 5f6c787 commit 89aa6ad

File tree

2 files changed

+73
-29
lines changed

2 files changed

+73
-29
lines changed

crates/multilinear_extensions/src/expression.rs

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,6 +1068,7 @@ pub struct WitIn {
10681068
)]
10691069
#[repr(C)]
10701070
pub enum StructuralWitInType {
1071+
Empty,
10711072
/// The correspeonding evaluation vector is the sequence: M = M' * multi_factor * descending + offset
10721073
/// where M' = [0, 1, 2, ..., max_len - 1] and descending = if descending { -1 } else { 1 }
10731074
EqualDistanceSequence {
@@ -1078,23 +1079,34 @@ pub enum StructuralWitInType {
10781079
},
10791080
/// The corresponding evaluation vector is the sequence: [0, 0, 0, 1, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, ..., 0, 1, 2, 3, ..., 2^max_bits-1]
10801081
/// The length of the vectors is 2^(max_bits + 1)
1081-
StackedIncrementalSequence { max_bits: usize },
1082+
StackedIncrementalSequence {
1083+
max_bits: usize,
1084+
},
10821085
/// The corresponding evaluation vector is the sequence: [0, 0] + [1, 1] + [2] * 4 + [3] * 8 + ... + [max_value] * (2^max_value)
10831086
/// The length of the vectors is 2^(max_value + 1)
1084-
StackedConstantSequence { max_value: usize },
1087+
StackedConstantSequence {
1088+
max_value: usize,
1089+
},
10851090
/// The corresponding evaluation vector is the sequence: [0, ..., 0, 1, ..., 1, ..., 2^(n-k-1)-1, ..., 2^(n-k-1)-1]
10861091
/// where each element is repeated by 2^k times
10871092
/// The total length of the vector is 2^n
1088-
InnerRepeatingIncrementalSequence { k: usize, n: usize },
1093+
InnerRepeatingIncrementalSequence {
1094+
k: usize,
1095+
n: usize,
1096+
},
10891097
/// The corresponding evaluation vector is the sequence: [0, ..., 2^k-1]
10901098
/// repeated by 2^(n-k) times
10911099
/// The total length of the vector is 2^n
1092-
OuterRepeatingIncrementalSequence { k: usize, n: usize },
1100+
OuterRepeatingIncrementalSequence {
1101+
k: usize,
1102+
n: usize,
1103+
},
10931104
}
10941105

10951106
impl StructuralWitInType {
10961107
pub fn max_len(&self) -> usize {
10971108
match self {
1109+
StructuralWitInType::Empty => 0,
10981110
StructuralWitInType::EqualDistanceSequence { max_len, .. } => *max_len,
10991111
StructuralWitInType::StackedIncrementalSequence { max_bits } => 1 << (max_bits + 1),
11001112
StructuralWitInType::StackedConstantSequence { max_value } => 1 << (max_value + 1),
@@ -1224,25 +1236,14 @@ fn eval_expr_at_index<E: ExtensionField>(
12241236
pub fn wit_infer_by_monomial_expr<'a, E: ExtensionField>(
12251237
flat_expr: &[Term<Expression<E>, Expression<E>>],
12261238
witness: &[ArcMultilinearExtension<'a, E>],
1227-
instance: &[ArcMultilinearExtension<'a, E>],
1239+
pub_io_evals: &[Either<E::BaseField, E>],
12281240
challenges: &[E],
12291241
) -> ArcMultilinearExtension<'a, E> {
12301242
let eval_leng = witness[0].evaluations().len();
12311243

1232-
let witness = chain!(witness, instance).cloned().collect_vec();
1233-
1234-
// evaluate all scalar terms first
1235-
// when instance was access in scalar, we only take its first item
1236-
// this operation is sound
1237-
let instance_first_element = instance
1238-
.iter()
1239-
.map(|instance| instance.evaluations.index(0))
1240-
.collect_vec();
12411244
let scalar_evals = flat_expr
12421245
.par_iter()
1243-
.map(|Term { scalar, .. }| {
1244-
eval_by_expr_constant(&instance_first_element, challenges, scalar)
1245-
})
1246+
.map(|Term { scalar, .. }| eval_by_expr_constant(pub_io_evals, challenges, scalar))
12461247
.collect::<Vec<_>>();
12471248

12481249
let evaluations = (0..eval_leng)
@@ -1257,7 +1258,7 @@ pub fn wit_infer_by_monomial_expr<'a, E: ExtensionField>(
12571258
product
12581259
.iter()
12591260
.fold(Either::Left(E::BaseField::ONE), |acc, e| {
1260-
let v = eval_expr_at_index(e, i, &witness, challenges);
1261+
let v = eval_expr_at_index(e, i, witness, challenges);
12611262
combine_cumulative_either!(v, acc, |v, acc| v * acc)
12621263
});
12631264

@@ -1282,21 +1283,22 @@ pub fn wit_infer_by_monomial_expr<'a, E: ExtensionField>(
12821283
pub fn wit_infer_by_expr<'a, E: ExtensionField>(
12831284
expr: &Expression<E>,
12841285
n_witin: WitnessId,
1285-
n_structural_witin: WitnessId,
12861286
n_fixed: WitnessId,
1287+
n_instance: usize,
12871288
fixed: &[ArcMultilinearExtension<'a, E>],
12881289
witnesses: &[ArcMultilinearExtension<'a, E>],
12891290
structual_witnesses: &[ArcMultilinearExtension<'a, E>],
1290-
instance: &[ArcMultilinearExtension<'a, E>],
1291+
pub_io_mles: &[ArcMultilinearExtension<'a, E>],
1292+
pub_io_evals: &[Either<E::BaseField, E>],
12911293
challenges: &[E],
12921294
) -> ArcMultilinearExtension<'a, E> {
1293-
let witin = chain!(witnesses, structual_witnesses, fixed)
1295+
let witin = chain!(witnesses, fixed, pub_io_mles, structual_witnesses)
12941296
.cloned()
12951297
.collect_vec();
12961298
wit_infer_by_monomial_expr(
1297-
&monomialize_expr_to_wit_terms(expr, n_witin, n_structural_witin, n_fixed),
1299+
&monomialize_expr_to_wit_terms(expr, n_witin, n_fixed, n_instance),
12981300
&witin,
1299-
instance,
1301+
pub_io_evals,
13001302
challenges,
13011303
)
13021304
}
@@ -1700,6 +1702,7 @@ mod tests {
17001702
],
17011703
&[],
17021704
&[],
1705+
&[],
17031706
&[E::ONE],
17041707
);
17051708
res.get_ext_field_vec();

crates/multilinear_extensions/src/expression/utils.rs

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,13 @@ pub fn eval_by_expr_with_instance<E: ExtensionField>(
115115
pub fn monomialize_expr_to_wit_terms<E: ExtensionField>(
116116
expr: &Expression<E>,
117117
num_witin: WitnessId,
118-
num_structural_witin: WitnessId,
119118
num_fixed: WitnessId,
119+
num_instance: usize,
120120
) -> Vec<Term<Expression<E>, Expression<E>>> {
121121
let witid_offset = 0 as WitnessId;
122-
let structural_witin_offset = witid_offset + num_witin;
123-
let fixed_offset = structural_witin_offset + num_structural_witin;
122+
let fixed_offset = witid_offset + num_witin;
124123
let instance_offset = fixed_offset + num_fixed;
124+
let structural_witin_offset = instance_offset + num_instance as WitnessId;
125125

126126
let monomial_terms_expr = expr.get_monomial_terms();
127127
monomial_terms_expr
@@ -133,19 +133,60 @@ pub fn monomialize_expr_to_wit_terms<E: ExtensionField>(
133133
}| {
134134
product.iter_mut().for_each(|t| match t {
135135
Expression::WitIn(_) => (),
136-
Expression::StructuralWitIn(structural_wit_id, _) => {
137-
*t = Expression::WitIn(structural_witin_offset + *structural_wit_id);
138-
}
139136
Expression::Fixed(Fixed(fixed_id)) => {
140137
*t = Expression::WitIn(fixed_offset + (*fixed_id as u16));
141138
}
142139
Expression::Instance(Instance(instance_id)) => {
143140
*t = Expression::WitIn(instance_offset + (*instance_id as u16));
144141
}
142+
Expression::StructuralWitIn(structural_wit_id, _) => {
143+
*t = Expression::WitIn(structural_witin_offset + *structural_wit_id);
144+
}
145145
e => panic!("unknown monomial terms {:?}", e),
146146
});
147147
Term { scalar, product }
148148
},
149149
)
150150
.collect_vec()
151151
}
152+
153+
/// convert complex expression into monomial form to WitIn
154+
/// orders WitIn ++ StructuralWitIn ++ Fixed
155+
pub fn expr_convert_to_witins<E: ExtensionField>(
156+
expr: &mut Expression<E>,
157+
num_witin: WitnessId,
158+
num_fixed: WitnessId,
159+
num_instance: usize,
160+
) {
161+
let witid_offset = 0 as WitnessId;
162+
let fixed_offset = witid_offset + num_witin;
163+
let instance_offset = fixed_offset + num_fixed;
164+
let structural_witin_offset = instance_offset + num_instance as WitnessId;
165+
166+
match expr {
167+
Expression::Fixed(fixed_id) => {
168+
*expr = Expression::WitIn(fixed_offset + (fixed_id.0 as u16))
169+
}
170+
Expression::WitIn(..) => (),
171+
Expression::StructuralWitIn(structural_wit_id, ..) => {
172+
*expr = Expression::WitIn(structural_witin_offset + *structural_wit_id)
173+
}
174+
Expression::Instance(i) => *expr = Expression::WitIn(instance_offset + (i.0 as u16)),
175+
Expression::InstanceScalar(..) => (),
176+
Expression::Constant(..) => (),
177+
Expression::Sum(a, b) => {
178+
expr_convert_to_witins(a, num_witin, num_fixed, num_instance);
179+
expr_convert_to_witins(b, num_witin, num_fixed, num_instance);
180+
}
181+
Expression::Product(a, b) => {
182+
expr_convert_to_witins(a, num_witin, num_fixed, num_instance);
183+
expr_convert_to_witins(b, num_witin, num_fixed, num_instance);
184+
}
185+
Expression::ScaledSum(x, a, b) => {
186+
expr_convert_to_witins(x, num_witin, num_fixed, num_instance);
187+
expr_convert_to_witins(a, num_witin, num_fixed, num_instance);
188+
expr_convert_to_witins(b, num_witin, num_fixed, num_instance);
189+
}
190+
Expression::Challenge(..) => (),
191+
}
192+
}

0 commit comments

Comments
 (0)