Skip to content

Commit 978e745

Browse files
authored
Add Fp to extension field implementation (#2427)
In our PIL std lib, we have an `Ext` that abstracts over `Fp2` and `Fp4`. But in cases where we don't need any extension field (e.g. BN254), we wrapped base field elements as `Fp2` anyway. This is a bit weird and can lead to issues when the optimizer removes the unused dimension (e.g., it might still be referenced by a prover function or witgen annotation). I had this problem in #2426. Now, there is an `Ext<T>::Fp` variant that simply wraps `T`.
1 parent 196a0ca commit 978e745

File tree

3 files changed

+20
-10
lines changed

3 files changed

+20
-10
lines changed

std/math/extension_field.asm

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,66 +17,75 @@ let required_extension_size: -> int = || match known_field() {
1717
None => panic("The permutation/lookup argument is not implemented for the current field!")
1818
};
1919

20-
/// Wrapper around Fp2 and Fp4 to abstract which extension field is used.
20+
/// Wrapper around T, Fp2<T> and Fp4<T> to abstract which extension field is used (if any).
2121
/// Once PIL supports traits, we can remove this type and the functions below.
2222
enum Ext<T> {
23+
Fp(T),
2324
Fp2(std::math::fp2::Fp2<T>),
2425
Fp4(std::math::fp4::Fp4<T>)
2526
}
2627

2728
let<T: Add> add_ext: Ext<T>, Ext<T> -> Ext<T> = |a, b| match (a, b) {
29+
(Ext::Fp(aa), Ext::Fp(bb)) => Ext::Fp(aa + bb),
2830
(Ext::Fp2(aa), Ext::Fp2(bb)) => Ext::Fp2(std::math::fp2::add_ext(aa, bb)),
2931
(Ext::Fp4(aa), Ext::Fp4(bb)) => Ext::Fp4(std::math::fp4::add_ext(aa, bb)),
3032
_ => panic("Operands have different types")
3133
};
3234

3335
let<T: Sub> sub_ext: Ext<T>, Ext<T> -> Ext<T> = |a, b| match (a, b) {
36+
(Ext::Fp(aa), Ext::Fp(bb)) => Ext::Fp(aa - bb),
3437
(Ext::Fp2(aa), Ext::Fp2(bb)) => Ext::Fp2(std::math::fp2::sub_ext(aa, bb)),
3538
(Ext::Fp4(aa), Ext::Fp4(bb)) => Ext::Fp4(std::math::fp4::sub_ext(aa, bb)),
3639
_ => panic("Operands have different types")
3740
};
3841

3942
let<T: Add + FromLiteral + Mul> mul_ext: Ext<T>, Ext<T> -> Ext<T> = |a, b| match (a, b) {
43+
(Ext::Fp(aa), Ext::Fp(bb)) => Ext::Fp(aa * bb),
4044
(Ext::Fp2(aa), Ext::Fp2(bb)) => Ext::Fp2(std::math::fp2::mul_ext(aa, bb)),
4145
(Ext::Fp4(aa), Ext::Fp4(bb)) => Ext::Fp4(std::math::fp4::mul_ext(aa, bb)),
4246
_ => panic("Operands have different types")
4347
};
4448

4549
let eval_ext: Ext<expr> -> Ext<fe> = query |a| match a {
50+
Ext::Fp(aa) => Ext::Fp(std::prover::eval(aa)),
4651
Ext::Fp2(aa) => Ext::Fp2(std::math::fp2::eval_ext(aa)),
4752
Ext::Fp4(aa) => Ext::Fp4(std::math::fp4::eval_ext(aa)),
4853
};
4954

5055
let inv_ext: Ext<fe> -> Ext<fe> = query |a| match a {
56+
Ext::Fp(aa) => Ext::Fp(std::math::ff::inv_field(aa)),
5157
Ext::Fp2(aa) => Ext::Fp2(std::math::fp2::inv_ext(aa)),
5258
Ext::Fp4(aa) => Ext::Fp4(std::math::fp4::inv_ext(aa)),
5359
};
5460

5561
let<T> unpack_ext_array: Ext<T> -> T[] = |a| match a {
62+
Ext::Fp(aa) => [aa],
5663
Ext::Fp2(aa) => std::math::fp2::unpack_ext_array(aa),
5764
Ext::Fp4(aa) => std::math::fp4::unpack_ext_array(aa),
5865
};
5966

6067
let next_ext: Ext<expr> -> Ext<expr> = |a| match a {
68+
Ext::Fp(aa) => Ext::Fp(aa'),
6169
Ext::Fp2(aa) => Ext::Fp2(std::math::fp2::next_ext(aa)),
6270
Ext::Fp4(aa) => Ext::Fp4(std::math::fp4::next_ext(aa)),
6371
};
6472

6573
let<T: FromLiteral> from_base: T -> Ext<T> = |x| match required_extension_size() {
66-
1 => Ext::Fp2(std::math::fp2::from_base(x)),
74+
1 => Ext::Fp(x),
6775
2 => Ext::Fp2(std::math::fp2::from_base(x)),
6876
4 => Ext::Fp4(std::math::fp4::from_base(x)),
6977
_ => panic("Expected 1, 2, or 4")
7078
};
7179

72-
let<T: FromLiteral> from_array: T[] -> Ext<T> = |arr| match len(arr) {
73-
1 => Ext::Fp2(std::math::fp2::from_array(arr)),
80+
let<T> from_array: T[] -> Ext<T> = |arr| match len(arr) {
81+
1 => Ext::Fp(arr[0]),
7482
2 => Ext::Fp2(std::math::fp2::from_array(arr)),
7583
4 => Ext::Fp4(std::math::fp4::Fp4::Fp4(arr[0], arr[1], arr[2], arr[3])),
7684
_ => panic("Expected 1, 2, or 4")
7785
};
7886

7987
let constrain_eq_ext: Ext<expr>, Ext<expr> -> Constr[] = |a, b| match (a, b) {
88+
(Ext::Fp(aa), Ext::Fp(bb)) => [aa = bb],
8089
(Ext::Fp2(aa), Ext::Fp2(bb)) => std::math::fp2::constrain_eq_ext(aa, bb),
8190
(Ext::Fp4(aa), Ext::Fp4(bb)) => std::math::fp4::constrain_eq_ext(aa, bb),
8291
_ => panic("Operands have different types")

std/math/fp2.asm

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,13 +141,9 @@ let<T> unpack_ext_array: Fp2<T> -> T[] = |a| match a {
141141
};
142142

143143
/// Constructs an extension field element `a0 + a1 * X` from either `[a0, a1]` or `[a0]` (setting `a1`to zero in that case)
144-
let<T: FromLiteral> from_array: T[] -> Fp2<T> = |arr| match len(arr) {
145-
1 => {
146-
let _ = assert(!needs_extension(), || "The field is too small and needs to move to the extension field. Pass two elements instead!");
147-
from_base(arr[0])
148-
},
144+
let<T> from_array: T[] -> Fp2<T> = |arr| match len(arr) {
149145
2 => Fp2::Fp2(arr[0], arr[1]),
150-
_ => panic("Expected array of length 1 or 2")
146+
_ => panic("Expected array of length 2")
151147
};
152148

153149
mod test {

std/protocols/fingerprint.asm

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ let fingerprint_inter: expr[], Ext<expr> -> Ext<expr> = |expr_array, alpha| if l
3737

3838
// Recursively compute the fingerprint as expr_array[0] + alpha * fingerprint(expr_array[1:], alpha)
3939
let intermediate_fingerprint = match fingerprint_inter(array::sub_array(expr_array, 1, len(expr_array) - 1), alpha) {
40+
Ext::Fp(a) => {
41+
let intermediate_fingerprint: inter = a;
42+
Ext::Fp(intermediate_fingerprint)
43+
},
4044
Ext::Fp2(std::math::fp2::Fp2::Fp2(a0, a1)) => {
4145
let intermediate_fingerprint_0: inter = a0;
4246
let intermediate_fingerprint_1: inter = a1;
@@ -70,6 +74,7 @@ mod test {
7074
let assert_fingerprint_equal: fe[], expr, fe -> () = query |tuple, challenge, expected| {
7175
let result = fingerprint(tuple, from_base(challenge));
7276
match result {
77+
Ext::Fp(actual) => assert(expected == actual, || "expected != actual"),
7378
Ext::Fp2(std::math::fp2::Fp2::Fp2(actual, zero)) => {
7479
assert(zero == 0, || "Returned an extension field element");
7580
assert(expected == actual, || "expected != actual");

0 commit comments

Comments
 (0)