diff --git a/config.go b/config.go index f26cd1b6..71980404 100644 --- a/config.go +++ b/config.go @@ -236,6 +236,10 @@ type Config struct { //nolint:dupl // ListenConfig used to create the underlying listener socket. listenConfig net.ListenConfig + + // version13 + // WIP experimental feature, see https://github.com/pion/dtls/issues/188 + version13 bool } func (c *Config) includeCertificateSuites() bool { diff --git a/conn.go b/conn.go index e08151f4..4b2c903b 100644 --- a/conn.go +++ b/conn.go @@ -92,11 +92,12 @@ type Conn struct { cancelHandshaker func() cancelHandshakeReader func() - fsm *handshakeFSM + fsm handshakeFSM replayProtectionWindow uint - handshakeConfig *handshakeConfig + handshakeConfig *handshakeConfig + handshakeConfig13 *handshakeConfig13 } // createConn creates a new DTLS connection. @@ -245,6 +246,14 @@ func createConn( }, } + if config.version13 { + handshakeConfig13 := &handshakeConfig13{ + handshakeConfig: handshakeConfig, + } + conn.handshakeConfig13 = handshakeConfig13 + conn.handshakeConfig = nil + } + conn.setRemoteEpoch(0) conn.setLocalEpoch(0) @@ -273,7 +282,7 @@ func (c *Conn) Handshake() error { // // Most uses of this package need not call HandshakeContext explicitly: the // first [Conn.Read] or [Conn.Write] will call it automatically. -func (c *Conn) HandshakeContext(ctx context.Context) error { +func (c *Conn) HandshakeContext(ctx context.Context) error { //nolint:cyclop c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() @@ -287,6 +296,18 @@ func (c *Conn) HandshakeContext(ctx context.Context) error { c.handshakeDone = handshakeDone c.closeLock.Unlock() + if c.isVersion13Enabled() { + initialFlight := flight13_0 + initialFSMState := handshakePreparing + + if err := c.handshake(ctx, flightVal(initialFlight), initialFSMState); err != nil { + return err + } + c.log.Trace("Handshake DTLS 1.3 Completed") + + return nil + } + // rfc5246#section-7.4.3 // In addition, the hash and signature algorithms MUST be compatible // with the key in the server's end-entity certificate. @@ -319,7 +340,7 @@ func (c *Conn) HandshakeContext(ctx context.Context) error { initialFSMState = handshakePreparing } // Do handshake - if err := c.handshake(ctx, c.handshakeConfig, initialFlight, initialFSMState); err != nil { + if err := c.handshake(ctx, initialFlight, initialFSMState); err != nil { return err } @@ -471,6 +492,7 @@ func (c *Conn) Write(payload []byte) (int, error) { return 0, err } + // todo: check for version return len(payload), c.writePackets(c.writeDeadline, []*packet{ { record: &recordlayer.RecordLayer{ @@ -797,7 +819,7 @@ var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals }, } -func (c *Conn) readAndBuffer(ctx context.Context) error { //nolint:cyclop +func (c *Conn) readAndBuffer(ctx context.Context) error { //nolint:cyclop,gocognit bufptr, ok := poolReadBuffer.Get().(*[]byte) if !ok { return errFailedToAccessPoolReadBuffer @@ -817,6 +839,7 @@ func (c *Conn) readAndBuffer(ctx context.Context) error { //nolint:cyclop var hasHandshake, isRetransmit bool for _, p := range pkts { + // todo: check version hs, rtx, alert, err := c.handleIncomingPacket(ctx, p, rAddr, true) if alert != nil { if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { @@ -864,6 +887,7 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error { c.lock.Unlock() for _, p := range pkts { + // todo: check version _, _, alert, err := c.handleIncomingPacket(ctx, p.data, p.rAddr, false) // don't re-enqueue if alert != nil { if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { @@ -897,6 +921,17 @@ func (c *Conn) enqueueEncryptedPackets(packet addrPkt) bool { return false } +// nolint:unusedparams +func (c *Conn) handleIncomingPacket13( + ctx context.Context, + buf []byte, + rAddr net.Addr, + enqueue bool, +) (bool, bool, *alert.Alert, error) { + // Placeholder function + return false, false, nil, nil +} + //nolint:gocognit,gocyclo,cyclop,maintidx func (c *Conn) handleIncomingPacket( ctx context.Context, @@ -1119,16 +1154,28 @@ func (c *Conn) recvHandshake() <-chan recvHandshakeState { func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error { if level == alert.Fatal && len(c.state.SessionID) > 0 { - // According to the RFC, we need to delete the stored session. - // https://datatracker.ietf.org/doc/html/rfc5246#section-7.2 - if ss := c.fsm.cfg.sessionStore; ss != nil { - c.log.Tracef("clean invalid session: %s", c.state.SessionID) - if err := ss.Del(c.sessionKey()); err != nil { - return err + if c.isVersion13Enabled() { + // With compability mode for 1.3, CH uses a non-empty session_id + // https://datatracker.ietf.org/doc/html/rfc8446#appendix-D.4 + if ss := c.fsm.(*handshakeFSM13).cfg.sessionStore; ss != nil { + c.log.Tracef("clean invalid session: %s", c.state.SessionID) + if err := ss.Del(c.sessionKey()); err != nil { + return err + } + } + } else { + // According to the RFC, we need to delete the stored session. + // https://datatracker.ietf.org/doc/html/rfc5246#section-7.2 + if ss := c.fsm.(*handshakeFSM12).cfg.sessionStore; ss != nil { + c.log.Tracef("clean invalid session: %s", c.state.SessionID) + if err := ss.Del(c.sessionKey()); err != nil { + return err + } } } } + // This should be updated with DTLS 1.3 record encoding. return c.writePackets(ctx, []*packet{ { record: &recordlayer.RecordLayer{ @@ -1158,20 +1205,41 @@ func (c *Conn) isHandshakeCompletedSuccessfully() bool { //nolint:cyclop,gocognit,contextcheck func (c *Conn) handshake( ctx context.Context, - cfg *handshakeConfig, initialFlight flightVal, initialState handshakeState, ) error { - c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight) - done := make(chan struct{}) - ctxRead, cancelRead := context.WithCancel(context.Background()) - cfg.onFlightState = func(_ flightVal, s handshakeState) { - if s == handshakeFinished && c.setHandshakeCompletedSuccessfully() { - close(done) + if c.isVersion13Enabled() { + c.fsm = &handshakeFSM13{ + currentFlight: flightVal13(initialFlight), + state: &c.state, + cache: c.handshakeCache, + cfg: c.handshakeConfig13, + retransmitInterval: c.handshakeConfig13.initialRetransmitInterval, + closed: make(chan struct{}), + } + c.handshakeConfig13.onFlightState13 = func(_ flightVal13, s handshakeState) { + if c.fsm.(*handshakeFSM13).currentFlight.isLastSendFlight() { + close(done) + } + } + } else { + c.fsm = &handshakeFSM12{ + currentFlight: initialFlight, + state: &c.state, + cache: c.handshakeCache, + cfg: c.handshakeConfig, + retransmitInterval: c.handshakeConfig.initialRetransmitInterval, + closed: make(chan struct{}), + } + c.handshakeConfig.onFlightState = func(_ flightVal, s handshakeState) { + if s == handshakeFinished && c.setHandshakeCompletedSuccessfully() { + close(done) + } } } + ctxRead, cancelRead := context.WithCancel(context.Background()) ctxHs, cancel := context.WithCancel(context.Background()) c.closeLock.Lock() @@ -1368,7 +1436,11 @@ func (c *Conn) sessionKey() []byte { // As ServerName can be like 0.example.com, it's better to add // delimiter character which is not allowed to be in // neither address or domain name. - return []byte(c.rAddr.String() + "_" + c.fsm.cfg.serverName) + if c.isVersion13Enabled() { + return []byte(c.rAddr.String() + "_" + c.fsm.(*handshakeFSM13).cfg.serverName) + } + + return []byte(c.rAddr.String() + "_" + c.fsm.(*handshakeFSM12).cfg.serverName) } return c.state.SessionID @@ -1395,3 +1467,7 @@ func (c *Conn) SetWriteDeadline(t time.Time) error { // Write deadline is also fully managed by this layer. return nil } + +func (c *Conn) isVersion13Enabled() bool { + return c.handshakeConfig13 != nil +} diff --git a/conn_test.go b/conn_test.go index e7e33b05..43688bb2 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3127,10 +3127,10 @@ func TestEllipticCurveConfiguration(t *testing.T) { assert.True(t, ok, "Failed to default Elliptic curves") if len(test.ConfigCurves) != 0 { - assert.Equal(t, len(test.HandshakeCurves), len(server.fsm.cfg.ellipticCurves), "Failed to configure Elliptic curves") + assert.Equal(t, len(test.HandshakeCurves), len(server.fsm.(*handshakeFSM12).cfg.ellipticCurves), "Failed to configure Elliptic curves") for i, c := range test.ConfigCurves { - assert.Equal(t, c, server.fsm.cfg.ellipticCurves[i], "Failed to maintain Elliptic curve order") + assert.Equal(t, c, server.fsm.(*handshakeFSM12).cfg.ellipticCurves[i], "Failed to maintain Elliptic curve order") } } @@ -3500,3 +3500,50 @@ func TestCloseWithoutHandshake(t *testing.T) { assert.NoError(t, err) assert.NoError(t, server.Close()) } + +// WIP! Tests if DTLS 1.3 handshake flow is enabled and the correct error is returned. +func TestDTLS13Config(t *testing.T) { + ca, cb := dpipe.Pipe() + + // Setup client + clientCert, err := selfsign.GenerateSelfSigned() + assert.NoError(t, err) + + clientcfg, err := buildClientConfig( + WithCertificates(clientCert), + WithInsecureSkipVerify(true), + withVersion13(true), + ) + + assert.NoError(t, err) + + client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), clientcfg) + assert.NoError(t, err) + defer func() { + _ = client.Close() + }() + + _, ok := client.ConnectionState() + assert.False(t, ok) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + errorChannel := make(chan error) + go func() { + errC := client.HandshakeContext(ctx) + errorChannel <- errC + }() + + // Setup server, ignore error + server, _ := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true) + assert.NoError(t, err) + + defer func() { + _ = server.Close() + }() + + err = <-errorChannel + if err.Error() == errFlightUnimplemented13.Error() { + return + } +} diff --git a/errors.go b/errors.go index 0db0de67..f9e6d836 100644 --- a/errors.go +++ b/errors.go @@ -129,6 +129,10 @@ var ( //nolint:err113 errInvalidFlight = &InternalError{Err: errors.New("invalid flight number")} //nolint:err113 + errFlightUnimplemented13 = &InternalError{Err: errors.New("unimplemeted DTLS 1.3 flight")} + //nolint:err113 + errStateUnimplemented13 = &InternalError{Err: errors.New("unimplemeted DTLS 1.3 handshake state")} + //nolint:err113 errKeySignatureGenerateUnimplemented = &InternalError{ Err: errors.New("unable to generate key signature, unimplemented"), } diff --git a/flight_13.go b/flight_13.go new file mode 100644 index 00000000..d4d694d1 --- /dev/null +++ b/flight_13.go @@ -0,0 +1,190 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +type flightVal13 uint8 + +/* +// [RFC9147 Section-5.7] + +Full DTLS Handshake (with Cookie Exchange): + +Client Server + + +----------+ + ClientHello | Flight 1 | + --------> +----------+ + + +----------+ + <-------- HelloRetryRequest | Flight 2 | + + cookie +----------+ + + + +----------+ +ClientHello | Flight 3 | + + cookie --------> +----------+ + + + + ServerHello + {EncryptedExtensions} +----------+ + {CertificateRequest*} | Flight 4 | + {Certificate*} +----------+ + {CertificateVerify*} + {Finished} + <-------- [Application Data*] + + + + {Certificate*} +----------+ + {CertificateVerify*} | Flight 5 | + {Finished} --------> +----------+ + [Application Data] + +----------+ + <-------- [ACK] | Flight 6 | + [Application Data*] +----------+ + + [Application Data] <-------> [Application Data] + + + + +Resumption and PSK Handshake (without Cookie Exchange): + +Client Server + + ClientHello +-----------+ + + pre_shared_key | Flight 3a | + + psk_key_exchange_modes +-----------+ + + key_share* --------> + + + ServerHello + + pre_shared_key +-----------+ + + key_share* | Flight 4a | + {EncryptedExtensions} +-----------+ + <-------- {Finished} + [Application Data*] + +-----------+ + {Finished} --------> | Flight 5a | + [Application Data*] +-----------+ + + +-----------+ + <-------- [ACK] | Flight 6a | + [Application Data*] +-----------+ + + [Application Data] <-------> [Application Data] + + +Zero-RTT Handshake: + +Client Server + + ClientHello + + early_data + + psk_key_exchange_modes +-----------+ + + key_share* | Flight 3b | + + pre_shared_key +-----------+ + (Application Data*) --------> + + ServerHello + + pre_shared_key + + key_share* +-----------+ + {EncryptedExtensions} | Flight 4b | + {Finished} +-----------+ + <-------- [Application Data*] + + + +-----------+ + {Finished} --------> | Flight 5b | + [Application Data*] +-----------+ + + +-----------+ + <-------- [ACK] | Flight 6b | + [Application Data*] +-----------+ + + [Application Data] <-------> [Application Data] + + +NewSessionTicket Message: + +Client Server + + +-----------+ + <-------- [NewSessionTicket] | Flight 4c | + +-----------+ + + +-----------+ +[ACK] --------> | Flight 5c | + +-----------+ +*/ + +const ( + flight13_0 flightVal13 = iota + 1 + flight13_1 + flight13_2 + flight13_3 + flight13_3a + flight13_3b + flight13_4 + flight13_4a + flight13_4b + flight13_4c + flight13_5 + flight13_5a + flight13_5b + flight13_5c + flight13_6 + flight13_6a + flight13_6b +) + +func (f flightVal13) String() string { //nolint:cyclop + switch f { + case flight13_0: + return "Flight13 0" + case flight13_1: + return "Flight13 1" + case flight13_2: + return "Flight13 2" + case flight13_3: + return "Flight13 3" + case flight13_3a: + return "Flight13 3a" + case flight13_3b: + return "Flight13 3b" + case flight13_4: + return "Flight13 4" + case flight13_4a: + return "Flight13 4a" + case flight13_4b: + return "Flight13 4b" + case flight13_4c: + return "Flight13 4c" + case flight13_5: + return "Flight13 5" + case flight13_5a: + return "Flight13 5a" + case flight13_5b: + return "Flight13 5b" + case flight13_5c: + return "Flight13 5c" + case flight13_6: + return "Flight13 6" + case flight13_6a: + return "Flight13 6a" + case flight13_6b: + return "Flight13 6b" + default: + return "Invalid Flight" + } +} + +func (f flightVal13) isLastSendFlight() bool { // nolint: unused + return f == flight13_6 || f == flight13_6a || f == flight13_6b || f == flight13_5c +} + +func (f flightVal13) isLastRecvFlight() bool { // nolint: unused + return f == flight13_5 || f == flight13_5a || f == flight13_5b || f == flight13_4c +} diff --git a/flighthandler_13.go b/flighthandler_13.go new file mode 100644 index 00000000..615f7985 --- /dev/null +++ b/flighthandler_13.go @@ -0,0 +1,32 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + + "github.com/pion/dtls/v3/pkg/protocol/alert" +) + +// Parse received handshakes and return next flightVal. +type flightParser13 func( //nolint:unused + context.Context, + flightConn13, + *State, + *handshakeCache, + *handshakeConfig13, +) (flightVal13, *alert.Alert, error) + +//nolint:unused +type flightGenerator13 func(flightConn13, *State, *handshakeCache, *handshakeConfig13) ([]*packet, *alert.Alert, error) + +//nolint:unused +func (f flightVal13) getFlightParser13() (flightParser13, error) { + return nil, errFlightUnimplemented13 +} + +//nolint:unused +func (f flightVal13) getFlightGenerator13() (gen flightGenerator13, retransmit bool, err error) { + return nil, false, errFlightUnimplemented13 +} diff --git a/flighthandlers_client_13.go b/flighthandlers_client_13.go new file mode 100644 index 00000000..cc7b8db5 --- /dev/null +++ b/flighthandlers_client_13.go @@ -0,0 +1,26 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +// we'll add the flight handlers for the DTLS 1.3 client here. +// +// +----------+ +// | Flight 1 | +// | Flight 3 | +// | Flight 5 | +// +----------+ +// +// +-----------+ +// | Flight 3a | +// | Flight 5a | +// +-----------+ +// +// +-----------+ +// | Flight 3b | +// | Flight 5b | +// +-----------+ +// +// +-----------+ +// | Flight 5c | +// +-----------+ diff --git a/flighthandlers_server_13.go b/flighthandlers_server_13.go new file mode 100644 index 00000000..505c5b56 --- /dev/null +++ b/flighthandlers_server_13.go @@ -0,0 +1,26 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +// we'll add the flight handlers for the DTLS 1.3 server here. +// +// +----------+ +// | Flight 2 | +// | Flight 4 | +// | Flight 6 | +// +----------+ +// +// +-----------+ +// | Flight 4a | +// | Flight 6a | +// +-----------+ +// +// +-----------+ +// | Flight 4b | +// | Flight 6b | +// +-----------+ +// +// +-----------+ +// | Flight 4c | +// +-----------+ diff --git a/handshaker.go b/handshaker.go index 279ab14c..4c3b61f6 100644 --- a/handshaker.go +++ b/handshaker.go @@ -81,7 +81,7 @@ func (s handshakeState) String() string { } } -type handshakeFSM struct { +type handshakeFSM12 struct { currentFlight flightVal flights []*packet retransmit bool @@ -167,11 +167,11 @@ func srvCliStr(isClient bool) string { return "server" } -func newHandshakeFSM( +func newHandshakeFSM12( s *State, cache *handshakeCache, cfg *handshakeConfig, initialFlight flightVal, -) *handshakeFSM { - return &handshakeFSM{ +) *handshakeFSM12 { + return &handshakeFSM12{ currentFlight: initialFlight, state: s, cache: cache, @@ -181,7 +181,16 @@ func newHandshakeFSM( } } -func (s *handshakeFSM) Run(ctx context.Context, conn flightConn, initialState handshakeState) error { +type handshakeFSM interface { + Done() <-chan struct{} + Run(ctx context.Context, conn flightConn, initialState handshakeState) error + finish(ctx context.Context, c flightConn) (handshakeState, error) + prepare(ctx context.Context, conn flightConn) (handshakeState, error) + send(ctx context.Context, c flightConn) (handshakeState, error) + wait(ctx context.Context, conn flightConn) (handshakeState, error) +} + +func (s *handshakeFSM12) Run(ctx context.Context, conn flightConn, initialState handshakeState) error { state := initialState defer func() { close(s.closed) @@ -210,11 +219,11 @@ func (s *handshakeFSM) Run(ctx context.Context, conn flightConn, initialState ha } } -func (s *handshakeFSM) Done() <-chan struct{} { +func (s *handshakeFSM12) Done() <-chan struct{} { return s.closed } -func (s *handshakeFSM) prepare(ctx context.Context, conn flightConn) (handshakeState, error) { +func (s *handshakeFSM12) prepare(ctx context.Context, conn flightConn) (handshakeState, error) { s.flights = nil // Prepare flights var ( @@ -262,7 +271,7 @@ func (s *handshakeFSM) prepare(ctx context.Context, conn flightConn) (handshakeS return handshakeSending, nil } -func (s *handshakeFSM) send(ctx context.Context, c flightConn) (handshakeState, error) { +func (s *handshakeFSM12) send(ctx context.Context, c flightConn) (handshakeState, error) { // Send flights if err := c.writePackets(ctx, s.flights); err != nil { return handshakeErrored, err @@ -275,7 +284,7 @@ func (s *handshakeFSM) send(ctx context.Context, c flightConn) (handshakeState, return handshakeWaiting, nil } -func (s *handshakeFSM) wait(ctx context.Context, conn flightConn) (handshakeState, error) { //nolint:gocognit,cyclop +func (s *handshakeFSM12) wait(ctx context.Context, conn flightConn) (handshakeState, error) { //nolint:gocognit,cyclop parse, errFlight := s.currentFlight.getFlightParser() if errFlight != nil { if alertErr := conn.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil { @@ -347,7 +356,7 @@ func (s *handshakeFSM) wait(ctx context.Context, conn flightConn) (handshakeStat } } -func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState, error) { +func (s *handshakeFSM12) finish(ctx context.Context, c flightConn) (handshakeState, error) { select { case state := <-c.recvHandshake(): close(state.done) diff --git a/handshaker_13.go b/handshaker_13.go new file mode 100644 index 00000000..7a6984a8 --- /dev/null +++ b/handshaker_13.go @@ -0,0 +1,134 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + "time" + + "github.com/pion/dtls/v3/pkg/protocol/alert" +) + +// [RFC9147 Section-5.8.1] +// +-----------+ +// | PREPARING | +// +----------> | | +// | | | +// | +-----------+ +// | | +// | | Buffer next flight +// | | +// | \|/ +// | +-----------+ +// | | | +// | | SENDING |<------------------+ +// | | | | +// | +-----------+ | +// Receive | | | +// next | | Send flight or partial | +// flight | | flight | +// | | | +// | | Set retransmit timer | +// | \|/ | +// | +-----------+ | +// | | | | +// +------------| WAITING |-------------------+ +// | +----->| | Timer expires | +// | | +-----------+ | +// | | | | | | +// | | | | | | +// | +----------+ | +--------------------+ +// | Receive record | Read retransmit or ACK +// Receive | (Maybe Send ACK) | +// last | | +// flight | | Receive ACK +// | | for last flight +// \|/ | +// | +// +-----------+ | +// | | <---------+ +// | FINISHED | +// | | +// +-----------+ +// | /|\ +// | | +// | | +// +---+ +// +// Server read retransmit +// Retransmit ACK + +type handshakeFSM13 struct { + currentFlight flightVal13 + // 1.3 uses new record layer! We should replace with new packet struct. + // flights []*packet + retransmit bool //nolint:unused + retransmitInterval time.Duration + state *State + cache *handshakeCache + cfg *handshakeConfig13 + closed chan struct{} +} + +type handshakeConfig13 struct { + *handshakeConfig + onFlightState13 func(flightVal13, handshakeState) +} + +type flightConn13 interface { //nolint:unused + notify(ctx context.Context, level alert.Level, desc alert.Description) error + writePackets(context.Context, []*packet) error + recvHandshake() <-chan recvHandshakeState + handleQueuedPackets(context.Context) error + sessionKey() []byte +} + +func (s *handshakeFSM13) Run(ctx context.Context, conn flightConn, initialState handshakeState) error { + state := initialState + defer func() { + close(s.closed) + }() + for { + s.cfg.log.Tracef("[handshake13:%s] %s: %s", srvCliStr(s.state.isClient), s.currentFlight.String(), state.String()) + if s.cfg.onFlightState13 != nil { + s.cfg.onFlightState13(s.currentFlight, state) + } + var err error + switch state { + case handshakePreparing: + state, err = s.prepare(ctx, conn) + case handshakeSending: + state, err = s.send(ctx, conn) + case handshakeWaiting: + state, err = s.wait(ctx, conn) + case handshakeFinished: + state, err = s.finish(ctx, conn) + default: + return errInvalidFSMTransition + } + if err != nil { + return err + } + } +} + +func (s *handshakeFSM13) Done() <-chan struct{} { + return s.closed +} + +func (s *handshakeFSM13) prepare(ctx context.Context, conn flightConn) (handshakeState, error) { + return handshakeErrored, errStateUnimplemented13 +} + +func (s *handshakeFSM13) send(ctx context.Context, c flightConn) (handshakeState, error) { + return handshakeErrored, errStateUnimplemented13 +} + +func (s *handshakeFSM13) wait(ctx context.Context, conn flightConn) (handshakeState, error) { + return handshakeErrored, errStateUnimplemented13 +} + +func (s *handshakeFSM13) finish(ctx context.Context, c flightConn) (handshakeState, error) { + return handshakeErrored, errStateUnimplemented13 +} diff --git a/handshaker_test.go b/handshaker_test.go index 82e8cbc8..61189428 100644 --- a/handshaker_test.go +++ b/handshaker_test.go @@ -289,7 +289,7 @@ func TestHandshaker(t *testing.T) { //nolint:gocyclo,cyclop,maintidx initialRetransmitInterval: nonZeroRetransmitInterval, } - fsm := newHandshakeFSM(&ca.state, ca.handshakeCache, cfg, flight1) + fsm := newHandshakeFSM12(&ca.state, ca.handshakeCache, cfg, flight1) err := fsm.Run(ctx, ca, handshakePreparing) switch { case errors.Is(err, context.Canceled): @@ -322,7 +322,7 @@ func TestHandshaker(t *testing.T) { //nolint:gocyclo,cyclop,maintidx initialRetransmitInterval: nonZeroRetransmitInterval, } - fsm := newHandshakeFSM(&cb.state, cb.handshakeCache, cfg, flight0) + fsm := newHandshakeFSM12(&cb.state, cb.handshakeCache, cfg, flight0) err := fsm.Run(ctx, cb, handshakePreparing) switch { case errors.Is(err, context.Canceled): diff --git a/options.go b/options.go index a1fcb8dd..ae0f746e 100644 --- a/options.go +++ b/options.go @@ -80,6 +80,7 @@ type dtlsConfig struct { //nolint:dupl certificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message onConnectionAttempt func(net.Addr) error listenConfig net.ListenConfig + version13 bool } // applyDefaults applies default values to the config. @@ -124,6 +125,7 @@ func (c *dtlsConfig) toConfig() *Config { CertificateRequestMessageHook: c.certificateRequestMessageHook, OnConnectionAttempt: c.onConnectionAttempt, listenConfig: c.listenConfig, + version13: c.version13, } if len(c.certificates) > 0 { @@ -561,6 +563,17 @@ func WithClientHelloMessageHook(fn func(handshake.MessageClientHello) handshake. }) } +// WithVersion13 enables version 1.3. +// WIP experimental feature, see https://github.com/pion/dtls/issues/188 +// https://datatracker.ietf.org/doc/html/rfc9147 +func withVersion13(b bool) Option { + return sharedOption(func(c *dtlsConfig) error { + c.version13 = b + + return nil + }) +} + // serverOnlyOption wraps an apply function for server-only options. type serverOnlyOption func(*dtlsConfig) error