Skip to content

Commit 3a00a2f

Browse files
committed
add marhsalInto across all content
1 parent 4b744a7 commit 3a00a2f

18 files changed

+575
-144
lines changed

pkg/protocol/alert/alert.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,26 @@ func (a Alert) ContentType() protocol.ContentType {
145145
return protocol.ContentTypeAlert
146146
}
147147

148+
// Size returns the minimal buffer size required for MarshalInto.
149+
func (a Alert) Size() int {
150+
return 2
151+
}
152+
148153
// Marshal returns the encoded alert.
149154
func (a *Alert) Marshal() ([]byte, error) {
150-
return []byte{byte(a.Level), byte(a.Description)}, nil
155+
out := make([]byte, 2)
156+
err := a.MarshalInto(out)
157+
return out, err
158+
}
159+
160+
// MarshalInto returns the encoded alert.
161+
func (a *Alert) MarshalInto(out []byte) error {
162+
if len(out) < 2 {
163+
return errBufferTooSmall
164+
}
165+
out[0] = byte(a.Level)
166+
out[1] = byte(a.Description)
167+
return nil
151168
}
152169

153170
// Unmarshal populates the alert from binary data.

pkg/protocol/application_data.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,20 @@ func (a ApplicationData) ContentType() ContentType {
1919

2020
// Marshal encodes the ApplicationData to binary.
2121
func (a *ApplicationData) Marshal() ([]byte, error) {
22-
return a.Data, nil
22+
out := make([]byte, len(a.Data))
23+
err := a.MarshalInto(out)
24+
return out, err
25+
}
26+
27+
// Marshal encodes the ApplicationData to binary.
28+
func (a *ApplicationData) MarshalInto(out []byte) error {
29+
copy(out, a.Data)
30+
return nil
31+
}
32+
33+
// Size returns the size required for MarshalInto.
34+
func (a ApplicationData) Size() int {
35+
return len(a.Data)
2336
}
2437

2538
// Unmarshal populates the ApplicationData from binary.

pkg/protocol/change_cipher_spec.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,25 @@ func (c ChangeCipherSpec) ContentType() ContentType {
1515
return ContentTypeChangeCipherSpec
1616
}
1717

18+
// Size returns the minimal buffer size required for MarshalInto.
19+
func (c ChangeCipherSpec) Size() int {
20+
return 1
21+
}
22+
1823
// Marshal encodes the ChangeCipherSpec to binary.
1924
func (c *ChangeCipherSpec) Marshal() ([]byte, error) {
20-
return []byte{0x01}, nil
25+
out := make([]byte, 1)
26+
err := c.MarshalInto(out)
27+
return out, err
28+
}
29+
30+
// MarshalInto encodes the ChangeCipherSpec to binary into a pre-allocated buffer.
31+
func (c *ChangeCipherSpec) MarshalInto(out []byte) error {
32+
if len(out) < 1 {
33+
return errBufferTooSmall
34+
}
35+
out[0] = 0x01
36+
return nil
2137
}
2238

2339
// Unmarshal populates the ChangeCipherSpec from binary.

pkg/protocol/content.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@ type Content interface {
2222
ContentType() ContentType
2323
Marshal() ([]byte, error)
2424
Unmarshal(data []byte) error
25+
//Size() int
2526
}

pkg/protocol/extension/extension.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,16 @@ func Marshal(e []Extension) ([]byte, error) {
134134

135135
return append(out, extensions...), nil
136136
}
137+
138+
// Size returns the length of extensions marshal.
139+
func Size(e []Extension) int {
140+
total := 2
141+
for _, e := range e {
142+
raw, err := e.Marshal()
143+
if err != nil {
144+
return 0
145+
}
146+
total += len(raw)
147+
}
148+
return total
149+
}

pkg/protocol/handshake/handshake.go

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ func (t Type) String() string { //nolint:cyclop
6262
// Message is the body of a Handshake datagram.
6363
type Message interface {
6464
Marshal() ([]byte, error)
65+
MarshalInto([]byte) error
66+
Size() int
6567
Unmarshal(data []byte) error
6668
Type() Type
6769
}
@@ -84,28 +86,49 @@ func (h Handshake) ContentType() protocol.ContentType {
8486
return protocol.ContentTypeHandshake
8587
}
8688

89+
// Size returns the minimal buffer size required for MarshalInto.
90+
func (h *Handshake) Size() int {
91+
return HeaderLength + h.Message.Size()
92+
}
93+
8794
// Marshal encodes a handshake into a binary message.
8895
func (h *Handshake) Marshal() ([]byte, error) {
8996
if h.Message == nil {
9097
return nil, errHandshakeMessageUnset
9198
} else if h.Header.FragmentOffset != 0 {
9299
return nil, errUnableToMarshalFragmented
93100
}
101+
out := make([]byte, h.Size())
102+
err := h.MarshalInto(out)
103+
return out, err
104+
}
105+
106+
// Marshal encodes a handshake into a binary message.
107+
func (h *Handshake) MarshalInto(out []byte) error {
108+
if h.Message == nil {
109+
return errHandshakeMessageUnset
110+
} else if h.Header.FragmentOffset != 0 {
111+
return errUnableToMarshalFragmented
112+
}
113+
114+
if len(out) < h.Size() {
115+
return errBufferTooSmall
116+
}
94117

95-
msg, err := h.Message.Marshal()
118+
err := h.Message.MarshalInto(out[HeaderLength:])
96119
if err != nil {
97-
return nil, err
120+
return err
98121
}
99122

100-
h.Header.Length = uint32(len(msg)) //nolint:gosec // G115
123+
h.Header.Length = uint32(h.Message.Size()) //nolint:gosec // G115
101124
h.Header.FragmentLength = h.Header.Length
102125
h.Header.Type = h.Message.Type()
103-
header, err := h.Header.Marshal()
126+
err = h.Header.MarshalInto(out[:])
104127
if err != nil {
105-
return nil, err
128+
return err
106129
}
107130

108-
return append(header, msg...), nil
131+
return nil
109132
}
110133

111134
// Unmarshal decodes a handshake from a binary message.

pkg/protocol/handshake/message_certificate.go

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
// https://tools.ietf.org/html/rfc5246#section-7.4.2
1414
type MessageCertificate struct {
1515
Certificate [][]byte
16-
cache []byte
1716
}
1817

1918
// Type returns the Handshake Type.
@@ -25,22 +24,29 @@ const (
2524
handshakeMessageCertificateLengthFieldSize = 3
2625
)
2726

28-
// Marshal encodes the Handshake.
29-
func (m *MessageCertificate) Marshal() ([]byte, error) {
30-
if m.cache != nil {
31-
return m.cache, nil
32-
}
27+
// MarshalInto returns the minimal size required for MarshalInto.
28+
func (m *MessageCertificate) Size() int {
3329
total := handshakeMessageCertificateLengthFieldSize
3430

3531
for _, cert := range m.Certificate {
3632
total += handshakeMessageCertificateLengthFieldSize + len(cert)
3733
}
3834

39-
out := make([]byte, total)
35+
return total
36+
}
37+
38+
// Marshal encodes the Handshake.
39+
func (m *MessageCertificate) Marshal() ([]byte, error) {
40+
out := make([]byte, m.Size())
41+
err := m.MarshalInto(out)
42+
return out, err
43+
}
4044

45+
// MarshalInto encodes the Handshake into a pre-allocated buffer.
46+
func (m *MessageCertificate) MarshalInto(out []byte) error {
4147
// Total Payload Size
4248
//nolint:gosec // G115
43-
util.PutBigEndianUint24(out, uint32(total-handshakeMessageCertificateLengthFieldSize))
49+
util.PutBigEndianUint24(out, uint32(m.Size()-handshakeMessageCertificateLengthFieldSize))
4450
offset := handshakeMessageCertificateLengthFieldSize
4551

4652
for _, cert := range m.Certificate {
@@ -54,9 +60,7 @@ func (m *MessageCertificate) Marshal() ([]byte, error) {
5460
offset += len(cert)
5561
}
5662

57-
m.cache = out
58-
59-
return out, nil
63+
return nil
6064
}
6165

6266
// Unmarshal populates the message from encoded data.

pkg/protocol/handshake/message_certificate_13.go

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -71,44 +71,77 @@ func (m *MessageCertificate13) Marshal() ([]byte, error) {
7171
return nil, errCertificateRequestContextTooLong
7272
}
7373

74+
out := make([]byte, m.Size())
75+
err := m.MarshalInto(out)
76+
return out, err
77+
}
78+
79+
func (m *MessageCertificate13) Size() int {
80+
return 1 + len(m.CertificateRequestContext) + cert13CertLengthFieldSize + m.certsSize()
81+
}
82+
83+
func (m *MessageCertificate13) certsSize() int {
84+
certificateListSize := 0
85+
for _, entry := range m.CertificateList {
86+
certificateListSize += cert13CertLengthFieldSize
87+
certificateListSize += len(entry.CertificateData)
88+
certificateListSize += cert13ExtLengthFieldSize
89+
certificateListSize += extension.Size(entry.Extensions)
90+
}
91+
return certificateListSize
92+
}
93+
94+
// MarshalInto, same as Marshal but uses a pre-allocated buffer.
95+
func (m *MessageCertificate13) MarshalInto(out []byte) error {
96+
// Validate certificate_request_context length
97+
if len(m.CertificateRequestContext) > cert13ContextMaxLength {
98+
return errCertificateRequestContextTooLong
99+
}
100+
101+
if len(out) < m.Size() {
102+
return errBufferTooSmall
103+
}
104+
105+
// Check size of certificate_list is still within bounds
106+
if m.certsSize() > maxUint24 {
107+
return errCertificateListTooLong
108+
}
109+
74110
// Start with certificate_request_context (1-byte length prefix)
75111
//nolint:gosec // G115: certificate_request_context length is validated to be <= 255 above.
76-
out := []byte{byte(len(m.CertificateRequestContext))}
77-
out = append(out, m.CertificateRequestContext...)
112+
offset := 0
113+
out[0] = byte(len(m.CertificateRequestContext))
114+
offset += 1
115+
n := copy(out[offset:], m.CertificateRequestContext)
116+
offset += n
117+
118+
// Add certificate_list with 3-byte length prefix
119+
util.PutBigEndianUint24(out[offset:], uint32(m.certsSize())) //nolint:gosec // G115
120+
offset += 3
78121

79122
// Build certificate_list
80-
certificateList := []byte{}
81123
for _, entry := range m.CertificateList {
82124
// Add cert_data as a 3-byte length prefix
83125
certDataLen := len(entry.CertificateData)
84126
if certDataLen == 0 || certDataLen > maxUint24 {
85-
return nil, errInvalidCertificateEntry
127+
return errInvalidCertificateEntry
86128
}
87-
certDataLenBytes := make([]byte, cert13CertLengthFieldSize)
88-
util.PutBigEndianUint24(certDataLenBytes, uint32(certDataLen)) //nolint:gosec // G115
89-
certificateList = append(certificateList, certDataLenBytes...)
90-
certificateList = append(certificateList, entry.CertificateData...)
129+
util.PutBigEndianUint24(out[offset:], uint32(certDataLen)) //nolint:gosec // G115
130+
offset += 3
131+
n = copy(out[offset:], entry.CertificateData)
132+
offset += n
91133

92134
// Marshal extensions (includes a 2-byte length prefix)
93135
extensionsData, err := extension.Marshal(entry.Extensions)
94136
if err != nil {
95-
return nil, err
137+
return err
96138
}
97-
certificateList = append(certificateList, extensionsData...)
139+
n = copy(out[offset:], extensionsData)
140+
offset += n
98141

99-
// Check size of certificate_list is still within bounds
100-
if len(certificateList) > maxUint24 {
101-
return nil, errCertificateListTooLong
102-
}
103142
}
104143

105-
// Add certificate_list with 3-byte length prefix
106-
certificateListLenBytes := make([]byte, cert13CertLengthFieldSize)
107-
util.PutBigEndianUint24(certificateListLenBytes, uint32(len(certificateList))) //nolint:gosec // G115
108-
out = append(out, certificateListLenBytes...)
109-
out = append(out, certificateList...)
110-
111-
return out, nil
144+
return nil
112145
}
113146

114147
// parseCertificate13Entry parses a single certificate entry from the cryptobyte string.

0 commit comments

Comments
 (0)