Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 36 additions & 19 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ const (
MessageVersionV0 MessageVersion = 1 // v0
)

// messageVersionPrefix is the high bit mask used to indicate a versioned message.
// If the first byte has this bit set, the message is versioned; the remaining
// 7 bits encode the version number (0 for V0, 1 for V1, etc.).
// See: https://github.com/anza-xyz/solana-sdk/blob/master/message/src/versions/mod.rs
const messageVersionPrefix = 0x80

type Message struct {
version MessageVersion
// List of base-58 encoded public keys used by the transaction,
Expand Down Expand Up @@ -343,11 +349,14 @@ func (mx *Message) MarshalV0() ([]byte, error) {
buf = append(buf, instruction.Data...)
}
}
versionNum := byte(mx.version) // TODO: what number is this?
if versionNum > 127 {
// The actual Solana version number is the Go enum value minus 1
// (MessageVersionV0=1 maps to Solana version 0).
// The wire prefix is messageVersionPrefix (0x80) OR'd with the version number.
solanaVersion := byte(mx.version - 1)
if solanaVersion > 0x7F {
return nil, fmt.Errorf("invalid message version: %d", mx.version)
}
buf = append([]byte{byte(versionNum + 127)}, buf...)
buf = append([]byte{messageVersionPrefix | solanaVersion}, buf...)

bin.EncodeCompactU16Length(&buf, len(mx.AddressTableLookups))
for _, lookup := range mx.AddressTableLookups {
Expand Down Expand Up @@ -383,8 +392,9 @@ func (mx *Message) UnmarshalWithDecoder(decoder *bin.Decoder) (err error) {
if err != nil {
return err
}
// TODO: is this the right way to determine if this is a legacy or v0 message?
if versionNum[0] < 127 {
// If the high bit (0x80) is set, this is a versioned message;
// otherwise it is a legacy message where this byte is numRequiredSignatures.
if versionNum[0]&messageVersionPrefix == 0 {
mx.version = MessageVersionLegacy
} else {
mx.version = MessageVersionV0
Expand Down Expand Up @@ -502,12 +512,15 @@ func (mx Message) getStaticKeys() (keys PublicKeySlice) {
}

func (mx *Message) UnmarshalV0(decoder *bin.Decoder) (err error) {
version, err := decoder.ReadByte()
prefix, err := decoder.ReadByte()
if err != nil {
return fmt.Errorf("failed to read message version: %w", err)
return fmt.Errorf("failed to read message version prefix: %w", err)
}
solanaVersion := prefix & 0x7F
if solanaVersion != 0 {
return fmt.Errorf("unsupported message version: %d", solanaVersion)
}
// TODO: check version
mx.version = MessageVersion(version - 127)
mx.version = MessageVersion(solanaVersion + 1) // map Solana version 0 → MessageVersionV0 (1)

// The middle of the message is the same as the legacy message:
err = mx.UnmarshalLegacy(decoder)
Expand All @@ -521,7 +534,7 @@ func (mx *Message) UnmarshalV0(decoder *bin.Decoder) (err error) {
}
if addressTableLookupsLen > 0 {
mx.AddressTableLookups = make([]MessageAddressTableLookup, addressTableLookupsLen)
for i := 0; i < int(addressTableLookupsLen); i++ {
for i := range addressTableLookupsLen {
// read account pubkey
_, err = decoder.Read(mx.AddressTableLookups[i].AccountKey[:])
if err != nil {
Expand Down Expand Up @@ -584,7 +597,7 @@ func (mx *Message) UnmarshalLegacy(decoder *bin.Decoder) (err error) {
return fmt.Errorf("numAccountKeys %d is too large for remaining bytes %d", numAccountKeys, decoder.Remaining())
}
mx.AccountKeys = make(PublicKeySlice, numAccountKeys)
for i := 0; i < numAccountKeys; i++ {
for i := range numAccountKeys {
_, err := decoder.Read(mx.AccountKeys[i][:])
if err != nil {
return fmt.Errorf("unable to decode mx.AccountKeys[%d]: %w", i, err)
Expand All @@ -606,7 +619,7 @@ func (mx *Message) UnmarshalLegacy(decoder *bin.Decoder) (err error) {
return fmt.Errorf("numInstructions %d is greater than remaining bytes %d", numInstructions, decoder.Remaining())
}
mx.Instructions = make([]CompiledInstruction, numInstructions)
for instructionIndex := 0; instructionIndex < numInstructions; instructionIndex++ {
for instructionIndex := range numInstructions {
programIDIndex, err := decoder.ReadUint8()
if err != nil {
return fmt.Errorf("unable to decode mx.Instructions[%d].ProgramIDIndex: %w", instructionIndex, err)
Expand All @@ -622,7 +635,7 @@ func (mx *Message) UnmarshalLegacy(decoder *bin.Decoder) (err error) {
return fmt.Errorf("ix[%v]: numAccounts %d is greater than remaining bytes %d", instructionIndex, numAccounts, decoder.Remaining())
}
mx.Instructions[instructionIndex].Accounts = make([]uint16, numAccounts)
for i := 0; i < numAccounts; i++ {
for i := range numAccounts {
accountIndex, err := decoder.ReadUint8()
if err != nil {
return fmt.Errorf("unable to decode accountIndex for ix[%d].Accounts[%d]: %w", instructionIndex, i, err)
Expand Down Expand Up @@ -843,10 +856,11 @@ func (m *Message) IsWritableStatic(account PublicKey) bool {
}
h := m.Header
if index >= int(h.NumRequiredSignatures) {
// unsignedAccountIndex < numWritableUnsignedAccounts
return index-int(h.NumRequiredSignatures) < (m.numStaticAccounts()-int(h.NumRequiredSignatures))-int(h.NumReadonlyUnsignedAccounts)
// Use int arithmetic to avoid underflow (Rust uses saturating_sub here).
numWritableUnsigned := max(m.numStaticAccounts()-int(h.NumRequiredSignatures)-int(h.NumReadonlyUnsignedAccounts), 0)
return index-int(h.NumRequiredSignatures) < numWritableUnsigned
}
return index < int(h.NumRequiredSignatures-h.NumReadonlySignedAccounts)
return index < max(int(h.NumRequiredSignatures)-int(h.NumReadonlySignedAccounts), 0)
}

func (m Message) IsWritable(account PublicKey) (bool, error) {
Expand Down Expand Up @@ -875,10 +889,13 @@ func (m Message) IsWritable(account PublicKey) (bool, error) {
if index >= m.numStaticAccounts() {
return m.isWritableInLookups(index), nil
} else if index >= int(h.NumRequiredSignatures) {
// unsignedAccountIndex < numWritableUnsignedAccounts
return index-int(h.NumRequiredSignatures) < (m.numStaticAccounts()-int(h.NumRequiredSignatures))-int(h.NumReadonlyUnsignedAccounts), nil
// Use int arithmetic to avoid underflow (Rust uses saturating_sub here).
numWritableUnsigned := max(m.numStaticAccounts()-int(h.NumRequiredSignatures)-int(h.NumReadonlyUnsignedAccounts), 0)
return index-int(h.NumRequiredSignatures) < numWritableUnsigned, nil
}
return index < int(h.NumRequiredSignatures-h.NumReadonlySignedAccounts), nil
// Use int arithmetic to avoid uint8 underflow (Rust uses saturating_sub here).
numWritableSigned := max(int(h.NumRequiredSignatures)-int(h.NumReadonlySignedAccounts), 0)
return index < numWritableSigned, nil
}

func (m Message) signerKeys() PublicKeySlice {
Expand Down
Loading
Loading