Skip to content

add support for logical replication protocol v3 #68

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,27 +46,46 @@ func (t MessageType) String() string {
return "StreamCommit"
case MessageTypeStreamAbort:
return "StreamAbort"
case MessageTypeBeginPrepare:
return "BeginPrepare"
case MessageTypePrepare:
return "Prepare"
case MessageTypeCommitPrepared:
return "CommitPrepared"
case MessageTypeRollbackPrepared:
return "RollbackPrepared"
case MessageTypeStreamPrepare:
return "StreamPrepare"
default:
return "Unknown"
}
}

// List of types of logical replication messages.
const (
MessageTypeBegin MessageType = 'B'
MessageTypeMessage MessageType = 'M'
MessageTypeCommit MessageType = 'C'
MessageTypeOrigin MessageType = 'O'
MessageTypeRelation MessageType = 'R'
MessageTypeType MessageType = 'Y'
MessageTypeInsert MessageType = 'I'
MessageTypeUpdate MessageType = 'U'
MessageTypeDelete MessageType = 'D'
MessageTypeTruncate MessageType = 'T'
MessageTypeBegin MessageType = 'B'
MessageTypeMessage MessageType = 'M'
MessageTypeCommit MessageType = 'C'
MessageTypeOrigin MessageType = 'O'
MessageTypeRelation MessageType = 'R'
MessageTypeType MessageType = 'Y'
MessageTypeInsert MessageType = 'I'
MessageTypeUpdate MessageType = 'U'
MessageTypeDelete MessageType = 'D'
MessageTypeTruncate MessageType = 'T'

// introduced in protocol version 2
MessageTypeStreamStart MessageType = 'S'
MessageTypeStreamStop MessageType = 'E'
MessageTypeStreamCommit MessageType = 'c'
MessageTypeStreamAbort MessageType = 'A'

// introduced in protocol version 3
MessageTypeBeginPrepare MessageType = 'b'
MessageTypePrepare MessageType = 'P'
MessageTypeCommitPrepared MessageType = 'K'
MessageTypeRollbackPrepared MessageType = 'r'
MessageTypeStreamPrepare MessageType = 'p'
)

// Message is a message received from server.
Expand Down Expand Up @@ -182,7 +201,6 @@ func (m *BeginMessage) Decode(src []byte) error {
m.Xid = binary.BigEndian.Uint32(src[low:])

m.SetType(MessageTypeBegin)

return nil
}

Expand Down
1 change: 0 additions & 1 deletion messageV2.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

// MessageDecoderV2 decodes message from V2 protocol into struct.
type MessageDecoderV2 interface {
MessageDecoder
Copy link

@serprex serprex Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove this? similarly, why not have DecoderV3 inherit DecoderV2? It seems this is because V3 introduces new message types, & Decode was passing here off of baseMesssage?

would it be better to pass version as a parameter to Decode?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talking more, idea is that ParseV2 should return result of Parse like ParseV3 returns ParseV2 rather than using getCommonDecoder to have a V2 or V1 decoder where we call V2 (ignoring V1 method) if V2 implemented, otherwise call V1

DecodeV2(src []byte, inStream bool) error
}

Expand Down
5 changes: 2 additions & 3 deletions messageV2_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package pglogrepl

import (
"fmt"
"github.com/stretchr/testify/suite"
"testing"

"github.com/stretchr/testify/suite"
)

func TestLogicalDecodingMessageV2Suite(t *testing.T) {
Expand Down Expand Up @@ -123,7 +123,6 @@ func (s *streamCommitSuite) Test() {

msg[0] = 'c'
bigEndian.PutUint32(msg[1:], xid)
fmt.Printf("%+v\n", msg)
msg[5] = flags
bigEndian.PutUint64(msg[6:], uint64(commitLSN))
bigEndian.PutUint64(msg[14:], uint64(transactionEndLSN))
Expand Down
215 changes: 215 additions & 0 deletions messageV3.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
package pglogrepl

import (
"time"
)

type MessageDecoderV3 interface {
DecodeV3(src []byte, inStream bool) error
}

type BeginPrepareMessageV3 struct {
baseMessage
PrepareLSN LSN
TransactionEndLSN LSN
// The time at which the transaction was prepared.
PrepareTime time.Time
// The transaction ID of the prepared transaction.
Xid uint32
// The user defined GID of the prepared transaction.
Gid string
}

func (m *BeginPrepareMessageV3) DecodeV3(src []byte, _ bool) (err error) {
if len(src) < 29 {
return m.lengthError("BeginPrepareMessage", 29, len(src))
}

var low, used int
m.PrepareLSN, used = m.decodeLSN(src)
low += used
m.TransactionEndLSN, used = m.decodeLSN(src[low:])
low += used
m.PrepareTime, used = m.decodeTime(src[low:])
low += used
m.Xid, used = m.decodeUint32(src[low:])
low += used
m.Gid, _ = m.decodeString(src[low:])
m.SetType(MessageTypeBeginPrepare)

return nil
}

type PrepareMessageV3 struct {
baseMessage
// Flags currently unused (must be 0).
Flags uint8
PrepareLSN LSN
TransactionEndLSN LSN
// The time at which the transaction was prepared.
PrepareTime time.Time
// The transaction ID of the prepared transaction.
Xid uint32
// The user defined GID of the prepared transaction.
Gid string
}

func (m *PrepareMessageV3) DecodeV3(src []byte, _ bool) (err error) {
if len(src) < 30 {
return m.lengthError("PrepareMessage", 30, len(src))
}

var low, used int
m.Flags = src[low]
low += 1
m.PrepareLSN, used = m.decodeLSN(src[low:])
low += used
m.TransactionEndLSN, used = m.decodeLSN(src[low:])
low += used
m.PrepareTime, used = m.decodeTime(src[low:])
low += used
m.Xid, used = m.decodeUint32(src[low:])
low += used
m.Gid, _ = m.decodeString(src[low:])
m.SetType(MessageTypePrepare)

return nil
}

type CommitPreparedMessageV3 struct {
baseMessage
// Flags currently unused (must be 0).
Flags uint8
CommitLSN LSN
TransactionEndLSN LSN
CommitTime time.Time
Xid uint32
// The user defined GID of the prepared transaction.
Gid string
}

func (m *CommitPreparedMessageV3) DecodeV3(src []byte, _ bool) (err error) {
if len(src) < 30 {
return m.lengthError("CommitPreparedMessage", 30, len(src))
}

var low, used int
m.Flags = src[low]
low += 1
m.CommitLSN, used = m.decodeLSN(src[low:])
low += used
m.TransactionEndLSN, used = m.decodeLSN(src[low:])
low += used
m.CommitTime, used = m.decodeTime(src[low:])
low += used
m.Xid, used = m.decodeUint32(src[low:])
low += used
m.Gid, _ = m.decodeString(src[low:])
m.SetType(MessageTypeCommitPrepared)

return nil
}

type RollbackPreparedMessageV3 struct {
baseMessage
// Flags currently unused (must be 0).
Flags uint8
TransactionEndLSN LSN
// The end LSN of the rollback of the prepared transaction.
TransactionRollbackLSN LSN
PrepareTime time.Time
RollbackTime time.Time
Xid uint32
// The user defined GID of the prepared transaction.
Gid string
}

func (m *RollbackPreparedMessageV3) DecodeV3(src []byte, _ bool) (err error) {
if len(src) < 38 {
return m.lengthError("RollbackPreparedMessage", 38, len(src))
}

var low, used int
m.Flags = src[low]
low += 1
m.TransactionEndLSN, used = m.decodeLSN(src[low:])
low += used
m.TransactionRollbackLSN, used = m.decodeLSN(src[low:])
low += used
m.PrepareTime, used = m.decodeTime(src[low:])
low += used
m.RollbackTime, used = m.decodeTime(src[low:])
low += used
m.Xid, used = m.decodeUint32(src[low:])
low += used
m.Gid, _ = m.decodeString(src[low:])
m.SetType(MessageTypeRollbackPrepared)

return nil
}

type StreamPrepareMessageV3 struct {
baseMessage
// Flags currently unused (must be 0).
Flags uint8
PrepareLSN LSN
TransactionEndLSN LSN
PrepareTime time.Time
Xid uint32
// The user defined GID of the prepared transaction.
Gid string
}

func (m *StreamPrepareMessageV3) DecodeV3(src []byte, _ bool) (err error) {
if len(src) < 30 {
return m.lengthError("StreamPrepareMessage", 30, len(src))
}

var low, used int
m.Flags = src[low]
low += 1
m.PrepareLSN, used = m.decodeLSN(src[low:])
low += used
m.TransactionEndLSN, used = m.decodeLSN(src[low:])
low += used
m.PrepareTime, used = m.decodeTime(src[low:])
low += used
m.Xid, used = m.decodeUint32(src[low:])
low += used
m.Gid, _ = m.decodeString(src[low:])
m.SetType(MessageTypeStreamPrepare)

return nil
}

// ParseV3 parse a logical replication message from protocol version #3
// it accepts a slice of bytes read from PG and inStream parameter
// inStream must be true when StreamStartMessageV2 has been read
// it must be false after StreamStopMessageV2 has been read
func ParseV3(data []byte, inStream bool) (m Message, err error) {
var decoder MessageDecoderV3
msgType := MessageType(data[0])

switch msgType {
case MessageTypeBeginPrepare:
decoder = new(BeginPrepareMessageV3)
case MessageTypePrepare:
decoder = new(PrepareMessageV3)
case MessageTypeCommitPrepared:
decoder = new(CommitPreparedMessageV3)
case MessageTypeRollbackPrepared:
decoder = new(RollbackPreparedMessageV3)
case MessageTypeStreamPrepare:
decoder = new(StreamPrepareMessageV3)
default:
// all messages from V2 are unchanged in V3
// so we can just call ParseV2
return ParseV2(data, inStream)
}

if err = decoder.DecodeV3(data[1:], inStream); err != nil {
return nil, err
}

return decoder.(Message), nil
}
Loading
Loading