Skip to content

Commit 477a5b4

Browse files
FiloSottilegopherbot
authored andcommitted
sha3: make APIs usable with zero allocations
The "buf points into storage" pattern is nice, but causes the whole state struct to escape, since escape analysis can't track the pointer once it's assigned to buf. Change-Id: I31c0e83f946d66bedb5a180e96ab5d5e936eb322 Reviewed-on: https://go-review.googlesource.com/c/crypto/+/544817 Reviewed-by: Cherry Mui <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Roland Shoemaker <[email protected]> Reviewed-by: Mauri de Souza Meneguzzo <[email protected]> Auto-Submit: Filippo Valsorda <[email protected]>
1 parent 59b5a86 commit 477a5b4

File tree

2 files changed

+77
-36
lines changed

2 files changed

+77
-36
lines changed

sha3/allocations_test.go

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright 2023 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
//go:build !noopt
6+
7+
package sha3_test
8+
9+
import (
10+
"testing"
11+
12+
"golang.org/x/crypto/sha3"
13+
)
14+
15+
var sink byte
16+
17+
func TestAllocations(t *testing.T) {
18+
t.Run("New", func(t *testing.T) {
19+
if allocs := testing.AllocsPerRun(10, func() {
20+
h := sha3.New256()
21+
b := []byte("ABC")
22+
h.Write(b)
23+
out := make([]byte, 0, 32)
24+
out = h.Sum(out)
25+
sink ^= out[0]
26+
}); allocs > 0 {
27+
t.Errorf("expected zero allocations, got %0.1f", allocs)
28+
}
29+
})
30+
t.Run("NewShake", func(t *testing.T) {
31+
if allocs := testing.AllocsPerRun(10, func() {
32+
h := sha3.NewShake128()
33+
b := []byte("ABC")
34+
h.Write(b)
35+
out := make([]byte, 0, 32)
36+
out = h.Sum(out)
37+
sink ^= out[0]
38+
h.Read(out)
39+
sink ^= out[0]
40+
}); allocs > 0 {
41+
t.Errorf("expected zero allocations, got %0.1f", allocs)
42+
}
43+
})
44+
t.Run("Sum", func(t *testing.T) {
45+
if allocs := testing.AllocsPerRun(10, func() {
46+
b := []byte("ABC")
47+
out := sha3.Sum256(b)
48+
sink ^= out[0]
49+
}); allocs > 0 {
50+
t.Errorf("expected zero allocations, got %0.1f", allocs)
51+
}
52+
})
53+
}

sha3/sha3.go

+24-36
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ const (
2323
type state struct {
2424
// Generic sponge components.
2525
a [25]uint64 // main state of the hash
26-
buf []byte // points into storage
2726
rate int // the number of bytes of state to use
2827

2928
// dsbyte contains the "domain separation" bits and the first bit of
@@ -40,6 +39,7 @@ type state struct {
4039
// Extendable-Output Functions (May 2014)"
4140
dsbyte byte
4241

42+
i, n int // storage[i:n] is the buffer, i is only used while squeezing
4343
storage [maxRate]byte
4444

4545
// Specific to SHA-3 and SHAKE.
@@ -54,24 +54,18 @@ func (d *state) BlockSize() int { return d.rate }
5454
func (d *state) Size() int { return d.outputLen }
5555

5656
// Reset clears the internal state by zeroing the sponge state and
57-
// the byte buffer, and setting Sponge.state to absorbing.
57+
// the buffer indexes, and setting Sponge.state to absorbing.
5858
func (d *state) Reset() {
5959
// Zero the permutation's state.
6060
for i := range d.a {
6161
d.a[i] = 0
6262
}
6363
d.state = spongeAbsorbing
64-
d.buf = d.storage[:0]
64+
d.i, d.n = 0, 0
6565
}
6666

6767
func (d *state) clone() *state {
6868
ret := *d
69-
if ret.state == spongeAbsorbing {
70-
ret.buf = ret.storage[:len(ret.buf)]
71-
} else {
72-
ret.buf = ret.storage[d.rate-cap(d.buf) : d.rate]
73-
}
74-
7569
return &ret
7670
}
7771

@@ -82,43 +76,40 @@ func (d *state) permute() {
8276
case spongeAbsorbing:
8377
// If we're absorbing, we need to xor the input into the state
8478
// before applying the permutation.
85-
xorIn(d, d.buf)
86-
d.buf = d.storage[:0]
79+
xorIn(d, d.storage[:d.rate])
80+
d.n = 0
8781
keccakF1600(&d.a)
8882
case spongeSqueezing:
8983
// If we're squeezing, we need to apply the permutation before
9084
// copying more output.
9185
keccakF1600(&d.a)
92-
d.buf = d.storage[:d.rate]
93-
copyOut(d, d.buf)
86+
d.i = 0
87+
copyOut(d, d.storage[:d.rate])
9488
}
9589
}
9690

9791
// pads appends the domain separation bits in dsbyte, applies
9892
// the multi-bitrate 10..1 padding rule, and permutes the state.
99-
func (d *state) padAndPermute(dsbyte byte) {
100-
if d.buf == nil {
101-
d.buf = d.storage[:0]
102-
}
93+
func (d *state) padAndPermute() {
10394
// Pad with this instance's domain-separator bits. We know that there's
10495
// at least one byte of space in d.buf because, if it were full,
10596
// permute would have been called to empty it. dsbyte also contains the
10697
// first one bit for the padding. See the comment in the state struct.
107-
d.buf = append(d.buf, dsbyte)
108-
zerosStart := len(d.buf)
109-
d.buf = d.storage[:d.rate]
110-
for i := zerosStart; i < d.rate; i++ {
111-
d.buf[i] = 0
98+
d.storage[d.n] = d.dsbyte
99+
d.n++
100+
for d.n < d.rate {
101+
d.storage[d.n] = 0
102+
d.n++
112103
}
113104
// This adds the final one bit for the padding. Because of the way that
114105
// bits are numbered from the LSB upwards, the final bit is the MSB of
115106
// the last byte.
116-
d.buf[d.rate-1] ^= 0x80
107+
d.storage[d.rate-1] ^= 0x80
117108
// Apply the permutation
118109
d.permute()
119110
d.state = spongeSqueezing
120-
d.buf = d.storage[:d.rate]
121-
copyOut(d, d.buf)
111+
d.n = d.rate
112+
copyOut(d, d.storage[:d.rate])
122113
}
123114

124115
// Write absorbs more data into the hash's state. It panics if any
@@ -127,28 +118,25 @@ func (d *state) Write(p []byte) (written int, err error) {
127118
if d.state != spongeAbsorbing {
128119
panic("sha3: Write after Read")
129120
}
130-
if d.buf == nil {
131-
d.buf = d.storage[:0]
132-
}
133121
written = len(p)
134122

135123
for len(p) > 0 {
136-
if len(d.buf) == 0 && len(p) >= d.rate {
124+
if d.n == 0 && len(p) >= d.rate {
137125
// The fast path; absorb a full "rate" bytes of input and apply the permutation.
138126
xorIn(d, p[:d.rate])
139127
p = p[d.rate:]
140128
keccakF1600(&d.a)
141129
} else {
142130
// The slow path; buffer the input until we can fill the sponge, and then xor it in.
143-
todo := d.rate - len(d.buf)
131+
todo := d.rate - d.n
144132
if todo > len(p) {
145133
todo = len(p)
146134
}
147-
d.buf = append(d.buf, p[:todo]...)
135+
d.n += copy(d.storage[d.n:], p[:todo])
148136
p = p[todo:]
149137

150138
// If the sponge is full, apply the permutation.
151-
if len(d.buf) == d.rate {
139+
if d.n == d.rate {
152140
d.permute()
153141
}
154142
}
@@ -161,19 +149,19 @@ func (d *state) Write(p []byte) (written int, err error) {
161149
func (d *state) Read(out []byte) (n int, err error) {
162150
// If we're still absorbing, pad and apply the permutation.
163151
if d.state == spongeAbsorbing {
164-
d.padAndPermute(d.dsbyte)
152+
d.padAndPermute()
165153
}
166154

167155
n = len(out)
168156

169157
// Now, do the squeezing.
170158
for len(out) > 0 {
171-
n := copy(out, d.buf)
172-
d.buf = d.buf[n:]
159+
n := copy(out, d.storage[d.i:d.n])
160+
d.i += n
173161
out = out[n:]
174162

175163
// Apply the permutation if we've squeezed the sponge dry.
176-
if len(d.buf) == 0 {
164+
if d.i == d.rate {
177165
d.permute()
178166
}
179167
}

0 commit comments

Comments
 (0)