Skip to content

Commit 88a5b52

Browse files
Merge pull request #381 from sonicfromnewyoke/sonic/fix-message-bugs
fix: msg/txn inconsistencies with original SDK
2 parents 801a013 + 9ef2d24 commit 88a5b52

9 files changed

Lines changed: 2227 additions & 42 deletions

File tree

message.go

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ const (
8686
MessageVersionV0 MessageVersion = 1 // v0
8787
)
8888

89+
// messageVersionPrefix is the high bit mask used to indicate a versioned message.
90+
// If the first byte has this bit set, the message is versioned; the remaining
91+
// 7 bits encode the version number (0 for V0, 1 for V1, etc.).
92+
// See: https://github.com/anza-xyz/solana-sdk/blob/master/message/src/versions/mod.rs
93+
const messageVersionPrefix = 0x80
94+
8995
type Message struct {
9096
version MessageVersion
9197
// List of base-58 encoded public keys used by the transaction,
@@ -343,11 +349,14 @@ func (mx *Message) MarshalV0() ([]byte, error) {
343349
buf = append(buf, instruction.Data...)
344350
}
345351
}
346-
versionNum := byte(mx.version) // TODO: what number is this?
347-
if versionNum > 127 {
352+
// The actual Solana version number is the Go enum value minus 1
353+
// (MessageVersionV0=1 maps to Solana version 0).
354+
// The wire prefix is messageVersionPrefix (0x80) OR'd with the version number.
355+
solanaVersion := byte(mx.version - 1)
356+
if solanaVersion > 0x7F {
348357
return nil, fmt.Errorf("invalid message version: %d", mx.version)
349358
}
350-
buf = append([]byte{byte(versionNum + 127)}, buf...)
359+
buf = append([]byte{messageVersionPrefix | solanaVersion}, buf...)
351360

352361
bin.EncodeCompactU16Length(&buf, len(mx.AddressTableLookups))
353362
for _, lookup := range mx.AddressTableLookups {
@@ -383,8 +392,9 @@ func (mx *Message) UnmarshalWithDecoder(decoder *bin.Decoder) (err error) {
383392
if err != nil {
384393
return err
385394
}
386-
// TODO: is this the right way to determine if this is a legacy or v0 message?
387-
if versionNum[0] < 127 {
395+
// If the high bit (0x80) is set, this is a versioned message;
396+
// otherwise it is a legacy message where this byte is numRequiredSignatures.
397+
if versionNum[0]&messageVersionPrefix == 0 {
388398
mx.version = MessageVersionLegacy
389399
} else {
390400
mx.version = MessageVersionV0
@@ -502,12 +512,15 @@ func (mx Message) getStaticKeys() (keys PublicKeySlice) {
502512
}
503513

504514
func (mx *Message) UnmarshalV0(decoder *bin.Decoder) (err error) {
505-
version, err := decoder.ReadByte()
515+
prefix, err := decoder.ReadByte()
506516
if err != nil {
507-
return fmt.Errorf("failed to read message version: %w", err)
517+
return fmt.Errorf("failed to read message version prefix: %w", err)
518+
}
519+
solanaVersion := prefix & 0x7F
520+
if solanaVersion != 0 {
521+
return fmt.Errorf("unsupported message version: %d", solanaVersion)
508522
}
509-
// TODO: check version
510-
mx.version = MessageVersion(version - 127)
523+
mx.version = MessageVersion(solanaVersion + 1) // map Solana version 0 → MessageVersionV0 (1)
511524

512525
// The middle of the message is the same as the legacy message:
513526
err = mx.UnmarshalLegacy(decoder)
@@ -521,7 +534,7 @@ func (mx *Message) UnmarshalV0(decoder *bin.Decoder) (err error) {
521534
}
522535
if addressTableLookupsLen > 0 {
523536
mx.AddressTableLookups = make([]MessageAddressTableLookup, addressTableLookupsLen)
524-
for i := 0; i < int(addressTableLookupsLen); i++ {
537+
for i := range addressTableLookupsLen {
525538
// read account pubkey
526539
_, err = decoder.Read(mx.AddressTableLookups[i].AccountKey[:])
527540
if err != nil {
@@ -584,7 +597,7 @@ func (mx *Message) UnmarshalLegacy(decoder *bin.Decoder) (err error) {
584597
return fmt.Errorf("numAccountKeys %d is too large for remaining bytes %d", numAccountKeys, decoder.Remaining())
585598
}
586599
mx.AccountKeys = make(PublicKeySlice, numAccountKeys)
587-
for i := 0; i < numAccountKeys; i++ {
600+
for i := range numAccountKeys {
588601
_, err := decoder.Read(mx.AccountKeys[i][:])
589602
if err != nil {
590603
return fmt.Errorf("unable to decode mx.AccountKeys[%d]: %w", i, err)
@@ -606,7 +619,7 @@ func (mx *Message) UnmarshalLegacy(decoder *bin.Decoder) (err error) {
606619
return fmt.Errorf("numInstructions %d is greater than remaining bytes %d", numInstructions, decoder.Remaining())
607620
}
608621
mx.Instructions = make([]CompiledInstruction, numInstructions)
609-
for instructionIndex := 0; instructionIndex < numInstructions; instructionIndex++ {
622+
for instructionIndex := range numInstructions {
610623
programIDIndex, err := decoder.ReadUint8()
611624
if err != nil {
612625
return fmt.Errorf("unable to decode mx.Instructions[%d].ProgramIDIndex: %w", instructionIndex, err)
@@ -622,7 +635,7 @@ func (mx *Message) UnmarshalLegacy(decoder *bin.Decoder) (err error) {
622635
return fmt.Errorf("ix[%v]: numAccounts %d is greater than remaining bytes %d", instructionIndex, numAccounts, decoder.Remaining())
623636
}
624637
mx.Instructions[instructionIndex].Accounts = make([]uint16, numAccounts)
625-
for i := 0; i < numAccounts; i++ {
638+
for i := range numAccounts {
626639
accountIndex, err := decoder.ReadUint8()
627640
if err != nil {
628641
return fmt.Errorf("unable to decode accountIndex for ix[%d].Accounts[%d]: %w", instructionIndex, i, err)
@@ -843,10 +856,11 @@ func (m *Message) IsWritableStatic(account PublicKey) bool {
843856
}
844857
h := m.Header
845858
if index >= int(h.NumRequiredSignatures) {
846-
// unsignedAccountIndex < numWritableUnsignedAccounts
847-
return index-int(h.NumRequiredSignatures) < (m.numStaticAccounts()-int(h.NumRequiredSignatures))-int(h.NumReadonlyUnsignedAccounts)
859+
// Use int arithmetic to avoid underflow (Rust uses saturating_sub here).
860+
numWritableUnsigned := max(m.numStaticAccounts()-int(h.NumRequiredSignatures)-int(h.NumReadonlyUnsignedAccounts), 0)
861+
return index-int(h.NumRequiredSignatures) < numWritableUnsigned
848862
}
849-
return index < int(h.NumRequiredSignatures-h.NumReadonlySignedAccounts)
863+
return index < max(int(h.NumRequiredSignatures)-int(h.NumReadonlySignedAccounts), 0)
850864
}
851865

852866
func (m Message) IsWritable(account PublicKey) (bool, error) {
@@ -875,10 +889,13 @@ func (m Message) IsWritable(account PublicKey) (bool, error) {
875889
if index >= m.numStaticAccounts() {
876890
return m.isWritableInLookups(index), nil
877891
} else if index >= int(h.NumRequiredSignatures) {
878-
// unsignedAccountIndex < numWritableUnsignedAccounts
879-
return index-int(h.NumRequiredSignatures) < (m.numStaticAccounts()-int(h.NumRequiredSignatures))-int(h.NumReadonlyUnsignedAccounts), nil
892+
// Use int arithmetic to avoid underflow (Rust uses saturating_sub here).
893+
numWritableUnsigned := max(m.numStaticAccounts()-int(h.NumRequiredSignatures)-int(h.NumReadonlyUnsignedAccounts), 0)
894+
return index-int(h.NumRequiredSignatures) < numWritableUnsigned, nil
880895
}
881-
return index < int(h.NumRequiredSignatures-h.NumReadonlySignedAccounts), nil
896+
// Use int arithmetic to avoid uint8 underflow (Rust uses saturating_sub here).
897+
numWritableSigned := max(int(h.NumRequiredSignatures)-int(h.NumReadonlySignedAccounts), 0)
898+
return index < numWritableSigned, nil
882899
}
883900

884901
func (m Message) signerKeys() PublicKeySlice {

0 commit comments

Comments
 (0)