Skip to content

Commit 60fb69c

Browse files
committed
Add MarhsalInto & Size across all content
1 parent 4b744a7 commit 60fb69c

21 files changed

+602
-162
lines changed

conn.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error {
537537
c.lock.Lock()
538538
defer c.lock.Unlock()
539539

540-
var rawPackets [][]byte
540+
rawPackets := make([][]byte, 0, len(pkts))
541541

542542
for _, pkt := range pkts {
543543
if dtlsHandshake, ok := pkt.record.Content.(*handshake.Handshake); ok {
@@ -591,8 +591,8 @@ func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte {
591591
return rawPackets
592592
}
593593

594-
combinedRawPackets := make([][]byte, 0)
595-
currentCombinedRawPacket := make([]byte, 0)
594+
combinedRawPackets := make([][]byte, 0, len(rawPackets))
595+
var currentCombinedRawPacket []byte
596596

597597
for _, rawPacket := range rawPackets {
598598
if len(currentCombinedRawPacket) == 0 && len(rawPacket) >= c.maximumTransmissionUnit {
@@ -679,8 +679,6 @@ func (c *Conn) processPacket(pkt *packet) ([]byte, error) { //nolint:cyclop
679679

680680
//nolint:cyclop
681681
func (c *Conn) processHandshakePacket(pkt *packet, dtlsHandshake *handshake.Handshake) ([][]byte, error) {
682-
var rawPackets [][]byte
683-
684682
handshakeFragments, err := c.fragmentHandshake(dtlsHandshake)
685683
if err != nil {
686684
return nil, err
@@ -690,6 +688,7 @@ func (c *Conn) processHandshakePacket(pkt *packet, dtlsHandshake *handshake.Hand
690688
c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
691689
}
692690

691+
rawPackets := make([][]byte, 0, len(handshakeFragments))
693692
for _, handshakeFragment := range handshakeFragments {
694693
seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
695694
if seq > recordlayer.MaxSequenceNumber {
@@ -764,11 +763,10 @@ func (c *Conn) fragmentHandshake(dtlsHandshake *handshake.Handshake) ([][]byte,
764763
return nil, err
765764
}
766765

767-
var fragmentedHandshakes [][]byte
768-
769766
contentFragments := splitBytes(content, c.maximumTransmissionUnit)
770767

771768
offset := 0
769+
fragmentedHandshakes := make([][]byte, 0, len(contentFragments))
772770
for _, contentFragment := range contentFragments {
773771
contentFragmentLen := len(contentFragment)
774772

handshaker.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,8 @@ func (s *handshakeFSM) Run(ctx context.Context, conn flightConn, initialState ha
187187
close(s.closed)
188188
}()
189189
for {
190-
s.cfg.log.Tracef("[handshake:%s] %s: %s", srvCliStr(s.state.isClient), s.currentFlight.String(), state.String())
190+
s.cfg.log.Tracef("[handshake:%s] %s: %s",
191+
srvCliStr(s.state.isClient), s.currentFlight.String(), state.String())
191192
if s.cfg.onFlightState != nil {
192193
s.cfg.onFlightState(s.currentFlight, state)
193194
}

pkg/protocol/alert/alert.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,28 @@ func (a Alert) ContentType() protocol.ContentType {
145145
return protocol.ContentTypeAlert
146146
}
147147

148+
// Size returns the minimal buffer size required for MarshalInto.
149+
func (a Alert) Size() int {
150+
return 2
151+
}
152+
148153
// Marshal returns the encoded alert.
149154
func (a *Alert) Marshal() ([]byte, error) {
150-
return []byte{byte(a.Level), byte(a.Description)}, nil
155+
out := make([]byte, a.Size())
156+
err := a.MarshalInto(out)
157+
158+
return out, err
159+
}
160+
161+
// MarshalInto returns the encoded alert.
162+
func (a *Alert) MarshalInto(out []byte) error {
163+
if len(out) < a.Size() {
164+
return errBufferTooSmall
165+
}
166+
out[0] = byte(a.Level)
167+
out[1] = byte(a.Description)
168+
169+
return nil
151170
}
152171

153172
// Unmarshal populates the alert from binary data.

pkg/protocol/application_data.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,22 @@ func (a ApplicationData) ContentType() ContentType {
1919

2020
// Marshal encodes the ApplicationData to binary.
2121
func (a *ApplicationData) Marshal() ([]byte, error) {
22-
return a.Data, nil
22+
out := make([]byte, len(a.Data))
23+
err := a.MarshalInto(out)
24+
25+
return out, err
26+
}
27+
28+
// MarshalInto encodes the ApplicationData to binary into a pre-allocated buffer.
29+
func (a *ApplicationData) MarshalInto(out []byte) error {
30+
copy(out, a.Data)
31+
32+
return nil
33+
}
34+
35+
// Size returns the size required for MarshalInto.
36+
func (a ApplicationData) Size() int {
37+
return len(a.Data)
2338
}
2439

2540
// Unmarshal populates the ApplicationData from binary.

pkg/protocol/change_cipher_spec.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,27 @@ func (c ChangeCipherSpec) ContentType() ContentType {
1515
return ContentTypeChangeCipherSpec
1616
}
1717

18+
// Size returns the minimal buffer size required for MarshalInto.
19+
func (c ChangeCipherSpec) Size() int {
20+
return 1
21+
}
22+
1823
// Marshal encodes the ChangeCipherSpec to binary.
1924
func (c *ChangeCipherSpec) Marshal() ([]byte, error) {
20-
return []byte{0x01}, nil
25+
out := make([]byte, 1)
26+
err := c.MarshalInto(out)
27+
28+
return out, err
29+
}
30+
31+
// MarshalInto encodes the ChangeCipherSpec to binary into a pre-allocated buffer.
32+
func (c *ChangeCipherSpec) MarshalInto(out []byte) error {
33+
if len(out) < c.Size() {
34+
return errBufferTooSmall
35+
}
36+
out[0] = 0x01
37+
38+
return nil
2139
}
2240

2341
// Unmarshal populates the ChangeCipherSpec from binary.

pkg/protocol/content.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,7 @@ const (
2121
type Content interface {
2222
ContentType() ContentType
2323
Marshal() ([]byte, error)
24+
MarshalInto([]byte) error
2425
Unmarshal(data []byte) error
26+
Size() int
2527
}

pkg/protocol/extension/extension.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,17 @@ func Marshal(e []Extension) ([]byte, error) {
134134

135135
return append(out, extensions...), nil
136136
}
137+
138+
// Size returns the length of extensions marshal.
139+
func Size(e []Extension) int {
140+
total := 2
141+
for _, e := range e {
142+
raw, err := e.Marshal()
143+
if err != nil {
144+
return 0
145+
}
146+
total += len(raw)
147+
}
148+
149+
return total
150+
}

pkg/protocol/handshake/handshake.go

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ func (t Type) String() string { //nolint:cyclop
6262
// Message is the body of a Handshake datagram.
6363
type Message interface {
6464
Marshal() ([]byte, error)
65+
MarshalInto([]byte) error
66+
Size() int
6567
Unmarshal(data []byte) error
6668
Type() Type
6769
}
@@ -84,28 +86,50 @@ func (h Handshake) ContentType() protocol.ContentType {
8486
return protocol.ContentTypeHandshake
8587
}
8688

89+
// Size returns the minimal buffer size required for MarshalInto.
90+
func (h *Handshake) Size() int {
91+
return HeaderLength + h.Message.Size()
92+
}
93+
8794
// Marshal encodes a handshake into a binary message.
8895
func (h *Handshake) Marshal() ([]byte, error) {
8996
if h.Message == nil {
9097
return nil, errHandshakeMessageUnset
9198
} else if h.Header.FragmentOffset != 0 {
9299
return nil, errUnableToMarshalFragmented
93100
}
101+
out := make([]byte, h.Size())
102+
err := h.MarshalInto(out)
94103

95-
msg, err := h.Message.Marshal()
96-
if err != nil {
97-
return nil, err
104+
return out, err
105+
}
106+
107+
// MarshalInto encodes a handshake into a binary message into a pre-allocated buffer.
108+
func (h *Handshake) MarshalInto(out []byte) error {
109+
if h.Message == nil {
110+
return errHandshakeMessageUnset
111+
} else if h.Header.FragmentOffset != 0 {
112+
return errUnableToMarshalFragmented
98113
}
99114

100-
h.Header.Length = uint32(len(msg)) //nolint:gosec // G115
115+
if len(out) < h.Size() {
116+
return errBufferTooSmall
117+
}
118+
119+
h.Header.Length = uint32(h.Message.Size()) //nolint:gosec // G115
101120
h.Header.FragmentLength = h.Header.Length
102121
h.Header.Type = h.Message.Type()
103-
header, err := h.Header.Marshal()
122+
err := h.Header.MarshalInto(out)
104123
if err != nil {
105-
return nil, err
124+
return err
125+
}
126+
127+
err = h.Message.MarshalInto(out[HeaderLength:])
128+
if err != nil {
129+
return err
106130
}
107131

108-
return append(header, msg...), nil
132+
return nil
109133
}
110134

111135
// Unmarshal decodes a handshake from a binary message.

pkg/protocol/handshake/message_certificate.go

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
// https://tools.ietf.org/html/rfc5246#section-7.4.2
1414
type MessageCertificate struct {
1515
Certificate [][]byte
16-
cache []byte
1716
}
1817

1918
// Type returns the Handshake Type.
@@ -25,22 +24,30 @@ const (
2524
handshakeMessageCertificateLengthFieldSize = 3
2625
)
2726

28-
// Marshal encodes the Handshake.
29-
func (m *MessageCertificate) Marshal() ([]byte, error) {
30-
if m.cache != nil {
31-
return m.cache, nil
32-
}
27+
// Size returns the minimal size required for MarshalInto.
28+
func (m *MessageCertificate) Size() int {
3329
total := handshakeMessageCertificateLengthFieldSize
3430

3531
for _, cert := range m.Certificate {
3632
total += handshakeMessageCertificateLengthFieldSize + len(cert)
3733
}
3834

39-
out := make([]byte, total)
35+
return total
36+
}
4037

38+
// Marshal encodes the Handshake.
39+
func (m *MessageCertificate) Marshal() ([]byte, error) {
40+
out := make([]byte, m.Size())
41+
err := m.MarshalInto(out)
42+
43+
return out, err
44+
}
45+
46+
// MarshalInto encodes the Handshake into a pre-allocated buffer.
47+
func (m *MessageCertificate) MarshalInto(out []byte) error {
4148
// Total Payload Size
4249
//nolint:gosec // G115
43-
util.PutBigEndianUint24(out, uint32(total-handshakeMessageCertificateLengthFieldSize))
50+
util.PutBigEndianUint24(out, uint32(m.Size()-handshakeMessageCertificateLengthFieldSize))
4451
offset := handshakeMessageCertificateLengthFieldSize
4552

4653
for _, cert := range m.Certificate {
@@ -54,9 +61,7 @@ func (m *MessageCertificate) Marshal() ([]byte, error) {
5461
offset += len(cert)
5562
}
5663

57-
m.cache = out
58-
59-
return out, nil
64+
return nil
6065
}
6166

6267
// Unmarshal populates the message from encoded data.

pkg/protocol/handshake/message_certificate_13.go

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -71,44 +71,77 @@ func (m *MessageCertificate13) Marshal() ([]byte, error) {
7171
return nil, errCertificateRequestContextTooLong
7272
}
7373

74+
out := make([]byte, m.Size())
75+
err := m.MarshalInto(out)
76+
77+
return out, err
78+
}
79+
80+
func (m *MessageCertificate13) Size() int {
81+
return 1 + len(m.CertificateRequestContext) + cert13CertLengthFieldSize + m.certsSize()
82+
}
83+
84+
func (m *MessageCertificate13) certsSize() int {
85+
certificateListSize := 0
86+
for _, entry := range m.CertificateList {
87+
certificateListSize += cert13CertLengthFieldSize
88+
certificateListSize += len(entry.CertificateData)
89+
certificateListSize += extension.Size(entry.Extensions)
90+
}
91+
92+
return certificateListSize
93+
}
94+
95+
// MarshalInto is same as Marshal but uses a pre-allocated buffer.
96+
func (m *MessageCertificate13) MarshalInto(out []byte) error {
97+
// Validate certificate_request_context length
98+
if len(m.CertificateRequestContext) > cert13ContextMaxLength {
99+
return errCertificateRequestContextTooLong
100+
}
101+
102+
if len(out) < m.Size() {
103+
return errBufferTooSmall
104+
}
105+
106+
// Check size of certificate_list is still within bounds
107+
if m.certsSize() > maxUint24 {
108+
return errCertificateListTooLong
109+
}
110+
74111
// Start with certificate_request_context (1-byte length prefix)
75112
//nolint:gosec // G115: certificate_request_context length is validated to be <= 255 above.
76-
out := []byte{byte(len(m.CertificateRequestContext))}
77-
out = append(out, m.CertificateRequestContext...)
113+
offset := 0
114+
out[0] = byte(len(m.CertificateRequestContext)) //nolint:gosec // G115
115+
offset += 1
116+
n := copy(out[offset:], m.CertificateRequestContext) //nolint:gosec // G115
117+
offset += n
118+
119+
// Add certificate_list with 3-byte length prefix
120+
util.PutBigEndianUint24(out[offset:], uint32(m.certsSize())) //nolint:gosec // G115
121+
offset += 3
78122

79123
// Build certificate_list
80-
certificateList := []byte{}
81124
for _, entry := range m.CertificateList {
82125
// Add cert_data as a 3-byte length prefix
83126
certDataLen := len(entry.CertificateData)
84127
if certDataLen == 0 || certDataLen > maxUint24 {
85-
return nil, errInvalidCertificateEntry
128+
return errInvalidCertificateEntry
86129
}
87-
certDataLenBytes := make([]byte, cert13CertLengthFieldSize)
88-
util.PutBigEndianUint24(certDataLenBytes, uint32(certDataLen)) //nolint:gosec // G115
89-
certificateList = append(certificateList, certDataLenBytes...)
90-
certificateList = append(certificateList, entry.CertificateData...)
130+
util.PutBigEndianUint24(out[offset:], uint32(certDataLen)) //nolint:gosec // G115
131+
offset += 3
132+
n = copy(out[offset:], entry.CertificateData)
133+
offset += n
91134

92135
// Marshal extensions (includes a 2-byte length prefix)
93136
extensionsData, err := extension.Marshal(entry.Extensions)
94137
if err != nil {
95-
return nil, err
96-
}
97-
certificateList = append(certificateList, extensionsData...)
98-
99-
// Check size of certificate_list is still within bounds
100-
if len(certificateList) > maxUint24 {
101-
return nil, errCertificateListTooLong
138+
return err
102139
}
140+
n = copy(out[offset:], extensionsData)
141+
offset += n
103142
}
104143

105-
// Add certificate_list with 3-byte length prefix
106-
certificateListLenBytes := make([]byte, cert13CertLengthFieldSize)
107-
util.PutBigEndianUint24(certificateListLenBytes, uint32(len(certificateList))) //nolint:gosec // G115
108-
out = append(out, certificateListLenBytes...)
109-
out = append(out, certificateList...)
110-
111-
return out, nil
144+
return nil
112145
}
113146

114147
// parseCertificate13Entry parses a single certificate entry from the cryptobyte string.

0 commit comments

Comments
 (0)