-
Notifications
You must be signed in to change notification settings - Fork 231
Expand file tree
/
Copy pathdomain.go
More file actions
432 lines (362 loc) · 12.2 KB
/
domain.go
File metadata and controls
432 lines (362 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
// Copyright 2020-2025 Consensys Software Inc.
// Licensed under the Apache License, Version 2.0. See the LICENSE file for details.
// Code generated by consensys/gnark-crypto DO NOT EDIT
package fft
import (
"encoding/binary"
"errors"
"io"
"math/big"
"math/bits"
"runtime"
"sync"
"weak"
"github.com/consensys/gnark-crypto/ecc/bn254/fr"
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark-crypto/utils"
)
// Domain with a power of 2 cardinality
// compute a 2^k-th root of unity and store it in Generator
// all other values are derived from it (e.g. GeneratorInv)
type Domain struct {
Cardinality uint64
CardinalityInv fr.Element
Generator fr.Element
GeneratorInv fr.Element
FrMultiplicativeGen fr.Element // generator of Fr*
FrMultiplicativeGenInv fr.Element
// this is set with the WithoutPrecompute option;
// if true, the domain does some pre-computation and stores it.
// if false, the FFT will compute the twiddles on the fly (this is less CPU efficient, but uses less memory)
withPrecompute bool
// the following slices are not serialized and are (re)computed through domain.preComputeTwiddles()
// twiddles factor for the FFT using Generator for each stage of the recursive FFT
twiddles [][]fr.Element
// twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT
twiddlesInv [][]fr.Element
// we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover
// cosetTable <1, u, u², ..., uⁿ⁻¹> where u is the shifting element
cosetTable []fr.Element
// cosetTableInv same as cosetTable but with u⁻¹
cosetTableInv []fr.Element
}
// GeneratorFullMultiplicativeGroup returns a generator of 𝔽ᵣˣ
func GeneratorFullMultiplicativeGroup() fr.Element {
var res fr.Element
res.SetUint64(5)
return res
}
// domainCacheKey is the composite key for the cache.
// It uses a struct with comparable types as the map key.
type domainCacheKey struct {
m uint64
gen fr.Element
}
var (
domainCache = make(map[domainCacheKey]weak.Pointer[Domain])
domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain
keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map
domainMapLock sync.Mutex // Ensures exclusive access to domainCache map
)
// NewDomain returns a subgroup with a power of 2 cardinality >= m.
//
// Parameters:
// - m: minimum cardinality (will be rounded up to next power of 2)
// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.)
//
// The domain can be cached when both withCache and withPrecompute are enabled.
// Cached domains are automatically cleaned up when no longer in use.
func NewDomain(m uint64, opts ...DomainOption) *Domain {
opt := domainOptions(opts...)
// Skip caching if disabled or precomputation is off
if !opt.withCache || !opt.withPrecompute {
return createDomain(m, opt)
}
// Compute the cache key.
key := domainCacheKey{m: m}
if opt.shift != nil {
key.gen.Set(opt.shift)
} else {
key.gen = GeneratorFullMultiplicativeGroup() // Default generator
}
// Lets ensure that only one goroutine is generating a domain for this
// specific key. We acquire it already here to ensure if there is a existing
// goroutine generating a domain for this key, we wait for it to finish and
// then we can just return the cached domain.
keyMapLock.Lock()
keyLock := domainGenLocks[key]
if keyLock == nil {
keyLock = new(sync.Mutex)
domainGenLocks[key] = keyLock
}
keyLock.Lock()
defer keyLock.Unlock()
keyMapLock.Unlock()
// Check cache first. But for the cache, we need to lock the cache map (we
// currently only hold the per-key lock, not global cache lock).
domainMapLock.Lock()
// we don't defer it because we want to release it while creating the
// domain. And domain creation can panic, leading to double unlock which
// hides the original panic.
if weakDomain, exists := domainCache[key]; exists {
if domain := weakDomain.Value(); domain != nil {
domainMapLock.Unlock()
return domain
}
}
// Lets release the global cache lock while we do this so that other keys
// can be added to cache.
domainMapLock.Unlock()
// Create a new domain (expensive operation, but only blocks same key).
domain := createDomain(m, opt)
// Store in cache with cleanup
weakDomain := weak.Make(domain)
domainMapLock.Lock()
domainCache[key] = weakDomain
domainMapLock.Unlock()
// Add cleanup to remove from cache when domain is garbage collected
runtime.AddCleanup(domain, func(key domainCacheKey) {
// cleanup *may* be called concurrently, but could be sequential. We run
// it in a separate goroutine to avoid block other cleanups being run if
// this cleanup is being run on the same key which is being generated
// (thus lock being held).
go func() {
keyMapLock.Lock()
defer keyMapLock.Unlock()
if keyLock, ok := domainGenLocks[key]; ok {
keyLock.Lock()
defer keyLock.Unlock()
// We can now safely delete from both maps. But we only do if
// the cached weak pointer is the same one we created. Otherwise
// this means this cleanup is running after a new domain was
// already cached (double cleanup).
// We also want to hold both per-key and cache lock to avoid the
// maps being out of sync
domainMapLock.Lock()
defer domainMapLock.Unlock()
if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain {
delete(domainCache, key)
delete(domainGenLocks, key)
}
}
}()
}, key)
return domain
}
func createDomain(m uint64, opt domainConfig) *Domain {
domain := &Domain{}
x := ecc.NextPowerOfTwo(m)
domain.Cardinality = uint64(x)
domain.FrMultiplicativeGen = GeneratorFullMultiplicativeGroup()
if opt.shift != nil {
domain.FrMultiplicativeGen.Set(opt.shift)
}
domain.FrMultiplicativeGenInv.Inverse(&domain.FrMultiplicativeGen)
var err error
domain.Generator, err = Generator(m)
if err != nil {
panic(err)
}
domain.GeneratorInv.Inverse(&domain.Generator)
domain.CardinalityInv.SetUint64(uint64(x)).Inverse(&domain.CardinalityInv)
// twiddle factors
domain.withPrecompute = opt.withPrecompute
if domain.withPrecompute {
domain.preComputeTwiddles()
}
return domain
}
// Generator returns a generator for Z/2^(log(m))Z
// or an error if m is too big (required root of unity doesn't exist)
func Generator(m uint64) (fr.Element, error) {
return fr.Generator(m)
}
// Twiddles returns the twiddles factor for the FFT using Generator for each stage of the recursive FFT
// or an error if the domain was created with the WithoutPrecompute option
func (d *Domain) Twiddles() ([][]fr.Element, error) {
if d.twiddles == nil {
return nil, errors.New("twiddles not precomputed")
}
return d.twiddles, nil
}
// TwiddlesInv returns the twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT
// or an error if the domain was created with the WithoutPrecompute option
func (d *Domain) TwiddlesInv() ([][]fr.Element, error) {
if d.twiddlesInv == nil {
return nil, errors.New("twiddles not precomputed")
}
return d.twiddlesInv, nil
}
// CosetTable returns the cosetTable u*<1,g,..,g^(n-1)>
// or an error if the domain was created with the WithoutPrecompute option
func (d *Domain) CosetTable() ([]fr.Element, error) {
if d.cosetTable == nil {
return nil, errors.New("cosetTable not precomputed")
}
return d.cosetTable, nil
}
// CosetTableInv returns the cosetTableInv u*<1,g,..,g^(n-1)>
// or an error if the domain was created with the WithoutPrecompute option
func (d *Domain) CosetTableInv() ([]fr.Element, error) {
if d.cosetTableInv == nil {
return nil, errors.New("cosetTableInv not precomputed")
}
return d.cosetTableInv, nil
}
func (d *Domain) preComputeTwiddles() {
// nb fft stages
nbStages := uint64(bits.TrailingZeros64(d.Cardinality))
d.twiddles = make([][]fr.Element, nbStages)
d.twiddlesInv = make([][]fr.Element, nbStages)
d.cosetTable = make([]fr.Element, d.Cardinality)
d.cosetTableInv = make([]fr.Element, d.Cardinality)
var wg sync.WaitGroup
expTable := func(x fr.Element, t []fr.Element) {
BuildExpTable(x, t)
wg.Done()
}
wg.Add(4)
go func() {
buildTwiddles(d.twiddles, d.Generator, nbStages)
wg.Done()
}()
go func() {
buildTwiddles(d.twiddlesInv, d.GeneratorInv, nbStages)
wg.Done()
}()
go expTable(d.FrMultiplicativeGen, d.cosetTable)
go expTable(d.FrMultiplicativeGenInv, d.cosetTableInv)
wg.Wait()
}
func buildTwiddles(t [][]fr.Element, omega fr.Element, nbStages uint64) {
if nbStages == 0 {
return
}
if len(t) != int(nbStages) {
panic("invalid twiddle table")
}
// we just compute the first stage
t[0] = make([]fr.Element, 1+(1<<(nbStages-1)))
BuildExpTable(omega, t[0])
// for the next stages, we just iterate on the first stage with larger stride
for i := uint64(1); i < nbStages; i++ {
t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1)))
k := 0
for j := 0; j < len(t[i]); j++ {
t[i][j] = t[0][k]
k += 1 << i
}
}
}
// BuildExpTable precomputes the first n powers of w in parallel
// table[0] = w^0
// table[1] = w^1
// ...
func BuildExpTable(w fr.Element, table []fr.Element) {
table[0].SetOne()
n := len(table)
// see if it makes sense to parallelize exp tables pre-computation
interval := 0
if runtime.NumCPU() >= 4 {
interval = (n - 1) / (runtime.NumCPU() / 4)
}
// this ratio roughly correspond to the number of multiplication one can do in place of a Exp operation
// TODO @gbotrel revisit this; Exps in this context will be by a "small power of 2" so faster than this ref ratio.
const ratioExpMul = 6000 / 17
if interval < ratioExpMul {
precomputeExpTableChunk(w, 1, table[1:])
return
}
// we parallelize
var wg sync.WaitGroup
for i := 1; i < n; i += interval {
start := i
end := i + interval
if end > n {
end = n
}
wg.Add(1)
go func() {
precomputeExpTableChunk(w, uint64(start), table[start:end])
wg.Done()
}()
}
wg.Wait()
}
func precomputeExpTableChunk(w fr.Element, power uint64, table []fr.Element) {
// this condition ensures that creating a domain of size 1 with cosets don't fail
if len(table) > 0 {
table[0].Exp(w, new(big.Int).SetUint64(power))
for i := 1; i < len(table); i++ {
table[i].Mul(&table[i-1], &w)
}
}
}
// WriteTo writes a binary representation of the domain (without the precomputed twiddle factors)
// to the provided writer
func (d *Domain) WriteTo(w io.Writer) (int64, error) {
// note to stay retro compatible with previous version using ecc/encoder, we encode as:
// d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute
var written int64
var err error
err = binary.Write(w, binary.BigEndian, d.Cardinality)
if err != nil {
return written, err
}
written += 8
toEncode := []*fr.Element{&d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv}
for _, v := range toEncode {
buf := v.Bytes()
_, err = w.Write(buf[:])
if err != nil {
return written, err
}
written += fr.Bytes
}
err = binary.Write(w, binary.BigEndian, d.withPrecompute)
if err != nil {
return written, err
}
written += 1
return written, nil
}
// ReadFrom attempts to decode a domain from Reader
func (d *Domain) ReadFrom(r io.Reader) (int64, error) {
var read int64
var err error
err = binary.Read(r, binary.BigEndian, &d.Cardinality)
if err != nil {
return read, err
}
read += 8
toDecode := []*fr.Element{&d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv}
for _, v := range toDecode {
var buf [fr.Bytes]byte
_, err = r.Read(buf[:])
if err != nil {
return read, err
}
read += fr.Bytes
*v, err = fr.BigEndian.Element(&buf)
if err != nil {
return read, err
}
}
err = binary.Read(r, binary.BigEndian, &d.withPrecompute)
if err != nil {
return read, err
}
read += 1
if d.withPrecompute {
d.preComputeTwiddles()
}
return read, nil
}
// BitReverse applies the bit-reversal permutation to v.
//
// The length of v must be a power of 2.
//
// Deprecated: Use [utils.BitReverse] instead.
func BitReverse[T any](v []T) {
utils.BitReverse(v)
}