From 9ef2d244324a77289615327315a96e00db6dcaa3 Mon Sep 17 00:00:00 2001 From: Sonic Date: Thu, 9 Apr 2026 17:08:39 +0300 Subject: [PATCH] fix: message bugs --- message.go | 55 +- message_test.go | 709 ++++++++++++++++++ .../address-lookup-table/instruction_test.go | 2 +- rpc/client_test.go | 4 +- sanitize.go | 244 ++++++ sanitize_test.go | 673 +++++++++++++++++ transaction.go | 21 +- transaction_test.go | 557 +++++++++++++- transaction_v0_test.go | 4 +- 9 files changed, 2227 insertions(+), 42 deletions(-) create mode 100644 message_test.go create mode 100644 sanitize.go create mode 100644 sanitize_test.go diff --git a/message.go b/message.go index bdc3b994f..d44b61905 100644 --- a/message.go +++ b/message.go @@ -86,6 +86,12 @@ const ( MessageVersionV0 MessageVersion = 1 // v0 ) +// messageVersionPrefix is the high bit mask used to indicate a versioned message. +// If the first byte has this bit set, the message is versioned; the remaining +// 7 bits encode the version number (0 for V0, 1 for V1, etc.). +// See: https://github.com/anza-xyz/solana-sdk/blob/master/message/src/versions/mod.rs +const messageVersionPrefix = 0x80 + type Message struct { version MessageVersion // List of base-58 encoded public keys used by the transaction, @@ -343,11 +349,14 @@ func (mx *Message) MarshalV0() ([]byte, error) { buf = append(buf, instruction.Data...) } } - versionNum := byte(mx.version) // TODO: what number is this? - if versionNum > 127 { + // The actual Solana version number is the Go enum value minus 1 + // (MessageVersionV0=1 maps to Solana version 0). + // The wire prefix is messageVersionPrefix (0x80) OR'd with the version number. + solanaVersion := byte(mx.version - 1) + if solanaVersion > 0x7F { return nil, fmt.Errorf("invalid message version: %d", mx.version) } - buf = append([]byte{byte(versionNum + 127)}, buf...) + buf = append([]byte{messageVersionPrefix | solanaVersion}, buf...) bin.EncodeCompactU16Length(&buf, len(mx.AddressTableLookups)) for _, lookup := range mx.AddressTableLookups { @@ -383,8 +392,9 @@ func (mx *Message) UnmarshalWithDecoder(decoder *bin.Decoder) (err error) { if err != nil { return err } - // TODO: is this the right way to determine if this is a legacy or v0 message? - if versionNum[0] < 127 { + // If the high bit (0x80) is set, this is a versioned message; + // otherwise it is a legacy message where this byte is numRequiredSignatures. + if versionNum[0]&messageVersionPrefix == 0 { mx.version = MessageVersionLegacy } else { mx.version = MessageVersionV0 @@ -502,12 +512,15 @@ func (mx Message) getStaticKeys() (keys PublicKeySlice) { } func (mx *Message) UnmarshalV0(decoder *bin.Decoder) (err error) { - version, err := decoder.ReadByte() + prefix, err := decoder.ReadByte() if err != nil { - return fmt.Errorf("failed to read message version: %w", err) + return fmt.Errorf("failed to read message version prefix: %w", err) + } + solanaVersion := prefix & 0x7F + if solanaVersion != 0 { + return fmt.Errorf("unsupported message version: %d", solanaVersion) } - // TODO: check version - mx.version = MessageVersion(version - 127) + mx.version = MessageVersion(solanaVersion + 1) // map Solana version 0 → MessageVersionV0 (1) // The middle of the message is the same as the legacy message: err = mx.UnmarshalLegacy(decoder) @@ -521,7 +534,7 @@ func (mx *Message) UnmarshalV0(decoder *bin.Decoder) (err error) { } if addressTableLookupsLen > 0 { mx.AddressTableLookups = make([]MessageAddressTableLookup, addressTableLookupsLen) - for i := 0; i < int(addressTableLookupsLen); i++ { + for i := range addressTableLookupsLen { // read account pubkey _, err = decoder.Read(mx.AddressTableLookups[i].AccountKey[:]) if err != nil { @@ -584,7 +597,7 @@ func (mx *Message) UnmarshalLegacy(decoder *bin.Decoder) (err error) { return fmt.Errorf("numAccountKeys %d is too large for remaining bytes %d", numAccountKeys, decoder.Remaining()) } mx.AccountKeys = make(PublicKeySlice, numAccountKeys) - for i := 0; i < numAccountKeys; i++ { + for i := range numAccountKeys { _, err := decoder.Read(mx.AccountKeys[i][:]) if err != nil { return fmt.Errorf("unable to decode mx.AccountKeys[%d]: %w", i, err) @@ -606,7 +619,7 @@ func (mx *Message) UnmarshalLegacy(decoder *bin.Decoder) (err error) { return fmt.Errorf("numInstructions %d is greater than remaining bytes %d", numInstructions, decoder.Remaining()) } mx.Instructions = make([]CompiledInstruction, numInstructions) - for instructionIndex := 0; instructionIndex < numInstructions; instructionIndex++ { + for instructionIndex := range numInstructions { programIDIndex, err := decoder.ReadUint8() if err != nil { return fmt.Errorf("unable to decode mx.Instructions[%d].ProgramIDIndex: %w", instructionIndex, err) @@ -622,7 +635,7 @@ func (mx *Message) UnmarshalLegacy(decoder *bin.Decoder) (err error) { return fmt.Errorf("ix[%v]: numAccounts %d is greater than remaining bytes %d", instructionIndex, numAccounts, decoder.Remaining()) } mx.Instructions[instructionIndex].Accounts = make([]uint16, numAccounts) - for i := 0; i < numAccounts; i++ { + for i := range numAccounts { accountIndex, err := decoder.ReadUint8() if err != nil { return fmt.Errorf("unable to decode accountIndex for ix[%d].Accounts[%d]: %w", instructionIndex, i, err) @@ -843,10 +856,11 @@ func (m *Message) IsWritableStatic(account PublicKey) bool { } h := m.Header if index >= int(h.NumRequiredSignatures) { - // unsignedAccountIndex < numWritableUnsignedAccounts - return index-int(h.NumRequiredSignatures) < (m.numStaticAccounts()-int(h.NumRequiredSignatures))-int(h.NumReadonlyUnsignedAccounts) + // Use int arithmetic to avoid underflow (Rust uses saturating_sub here). + numWritableUnsigned := max(m.numStaticAccounts()-int(h.NumRequiredSignatures)-int(h.NumReadonlyUnsignedAccounts), 0) + return index-int(h.NumRequiredSignatures) < numWritableUnsigned } - return index < int(h.NumRequiredSignatures-h.NumReadonlySignedAccounts) + return index < max(int(h.NumRequiredSignatures)-int(h.NumReadonlySignedAccounts), 0) } func (m Message) IsWritable(account PublicKey) (bool, error) { @@ -875,10 +889,13 @@ func (m Message) IsWritable(account PublicKey) (bool, error) { if index >= m.numStaticAccounts() { return m.isWritableInLookups(index), nil } else if index >= int(h.NumRequiredSignatures) { - // unsignedAccountIndex < numWritableUnsignedAccounts - return index-int(h.NumRequiredSignatures) < (m.numStaticAccounts()-int(h.NumRequiredSignatures))-int(h.NumReadonlyUnsignedAccounts), nil + // Use int arithmetic to avoid underflow (Rust uses saturating_sub here). + numWritableUnsigned := max(m.numStaticAccounts()-int(h.NumRequiredSignatures)-int(h.NumReadonlyUnsignedAccounts), 0) + return index-int(h.NumRequiredSignatures) < numWritableUnsigned, nil } - return index < int(h.NumRequiredSignatures-h.NumReadonlySignedAccounts), nil + // Use int arithmetic to avoid uint8 underflow (Rust uses saturating_sub here). + numWritableSigned := max(int(h.NumRequiredSignatures)-int(h.NumReadonlySignedAccounts), 0) + return index < numWritableSigned, nil } func (m Message) signerKeys() PublicKeySlice { diff --git a/message_test.go b/message_test.go new file mode 100644 index 000000000..19cd77e3c --- /dev/null +++ b/message_test.go @@ -0,0 +1,709 @@ +package solana + +import ( + "testing" + + bin "github.com/gagliardetto/binary" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Ported from solana-sdk/message/src/versions/mod.rs: test_legacy_message_serialization +func TestLegacyMessageSerializationRoundtrip(t *testing.T) { + key0 := newUniqueKey() + key1 := newUniqueKey() + key2 := newUniqueKey() + blockhash := Hash{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} + + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 2, + NumReadonlySignedAccounts: 1, + NumReadonlyUnsignedAccounts: 1, + }, + AccountKeys: PublicKeySlice{key0, key1, key2}, + RecentBlockhash: blockhash, + Instructions: []CompiledInstruction{ + { + ProgramIDIndex: 2, + Accounts: []uint16{0, 1}, + Data: []byte{0xAA, 0xBB}, + }, + }, + } + msg.version = MessageVersionLegacy + + data, err := msg.MarshalBinary() + require.NoError(t, err) + + // First byte should be numRequiredSignatures for legacy messages. + assert.Equal(t, byte(2), data[0], "first byte should be numRequiredSignatures") + + var decoded Message + err = decoded.UnmarshalWithDecoder(bin.NewBinDecoder(data)) + require.NoError(t, err) + + assert.Equal(t, MessageVersionLegacy, decoded.GetVersion()) + assert.Equal(t, msg.Header, decoded.Header) + assert.Equal(t, msg.AccountKeys, decoded.AccountKeys) + assert.Equal(t, msg.RecentBlockhash, decoded.RecentBlockhash) + require.Equal(t, len(msg.Instructions), len(decoded.Instructions)) + assert.Equal(t, msg.Instructions[0].ProgramIDIndex, decoded.Instructions[0].ProgramIDIndex) + assert.Equal(t, msg.Instructions[0].Accounts, decoded.Instructions[0].Accounts) + assert.Equal(t, Base58(msg.Instructions[0].Data), decoded.Instructions[0].Data) +} + +// Ported from solana-sdk/message/src/versions/mod.rs: test_versioned_message_serialization +func TestV0MessageSerializationRoundtrip(t *testing.T) { + key0 := newUniqueKey() + tableKey0 := newUniqueKey() + tableKey1 := newUniqueKey() + blockhash := Hash{9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22} + + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlySignedAccounts: 0, + NumReadonlyUnsignedAccounts: 0, + }, + AccountKeys: PublicKeySlice{key0}, + RecentBlockhash: blockhash, + Instructions: []CompiledInstruction{ + { + ProgramIDIndex: 1, + Accounts: []uint16{0, 2, 3, 4}, + Data: []byte{}, + }, + }, + AddressTableLookups: []MessageAddressTableLookup{ + { + AccountKey: tableKey0, + WritableIndexes: []uint8{1}, + ReadonlyIndexes: []uint8{0}, + }, + { + AccountKey: tableKey1, + WritableIndexes: []uint8{0}, + ReadonlyIndexes: []uint8{1}, + }, + }, + } + msg.version = MessageVersionV0 + + data, err := msg.MarshalBinary() + require.NoError(t, err) + + // First byte must have high bit set for versioned messages (0x80). + assert.Equal(t, byte(0x80), data[0], "first byte should be version prefix 0x80 for V0") + + var decoded Message + err = decoded.UnmarshalWithDecoder(bin.NewBinDecoder(data)) + require.NoError(t, err) + + assert.Equal(t, MessageVersionV0, decoded.GetVersion()) + assert.Equal(t, msg.Header, decoded.Header) + assert.Equal(t, msg.AccountKeys, decoded.AccountKeys) + assert.Equal(t, msg.RecentBlockhash, decoded.RecentBlockhash) + require.Equal(t, len(msg.Instructions), len(decoded.Instructions)) + assert.Equal(t, msg.Instructions[0].ProgramIDIndex, decoded.Instructions[0].ProgramIDIndex) + assert.Equal(t, msg.Instructions[0].Accounts, decoded.Instructions[0].Accounts) + + require.Equal(t, 2, len(decoded.AddressTableLookups)) + assert.Equal(t, tableKey0, decoded.AddressTableLookups[0].AccountKey) + assert.Equal(t, Uint8SliceAsNum{1}, decoded.AddressTableLookups[0].WritableIndexes) + assert.Equal(t, Uint8SliceAsNum{0}, decoded.AddressTableLookups[0].ReadonlyIndexes) + assert.Equal(t, tableKey1, decoded.AddressTableLookups[1].AccountKey) + assert.Equal(t, Uint8SliceAsNum{0}, decoded.AddressTableLookups[1].WritableIndexes) + assert.Equal(t, Uint8SliceAsNum{1}, decoded.AddressTableLookups[1].ReadonlyIndexes) +} + +// Tests the version prefix detection logic ported from solana-sdk/message/src/versions/mod.rs. +// In Rust: MESSAGE_VERSION_PREFIX = 0x80; if first_byte & 0x80 != 0 → versioned. +// This specifically tests the bug fix where byte value 127 (0x7F) was incorrectly +// classified as versioned — it should be legacy (numRequiredSignatures = 127). +func TestVersionDetection_PrefixByte(t *testing.T) { + tests := []struct { + name string + firstByte byte + expectedVersion MessageVersion + }{ + {"numRequiredSignatures=1 is legacy", 1, MessageVersionLegacy}, + {"numRequiredSignatures=64 is legacy", 64, MessageVersionLegacy}, + {"numRequiredSignatures=126 is legacy", 126, MessageVersionLegacy}, + {"numRequiredSignatures=127 is legacy", 127, MessageVersionLegacy}, // was buggy before fix + {"0x80 is V0", 0x80, MessageVersionV0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Build a minimal valid message with the given first byte. + // For legacy: first byte is numRequiredSignatures. + // For V0: first byte is 0x80 | version. + var buf []byte + if tt.expectedVersion == MessageVersionLegacy { + // legacy: header(3) + compact_len(1) + key(32) + blockhash(32) + compact_len(1)=0 instructions + buf = make([]byte, 0, 69) + buf = append(buf, tt.firstByte, 0, 0) // header + buf = append(buf, 1) // 1 account key + buf = append(buf, make([]byte, 32)...) // account key + buf = append(buf, make([]byte, 32)...) // blockhash + buf = append(buf, 0) // 0 instructions + } else { + // v0: prefix(1) + header(3) + compact_len(1) + key(32) + blockhash(32) + compact_len(1) + compact_len(1) lookups + buf = make([]byte, 0, 71) + buf = append(buf, tt.firstByte) // version prefix + buf = append(buf, 1, 0, 0) // header + buf = append(buf, 1) // 1 account key + buf = append(buf, make([]byte, 32)...) // account key + buf = append(buf, make([]byte, 32)...) // blockhash + buf = append(buf, 0) // 0 instructions + buf = append(buf, 0) // 0 address table lookups + } + + var msg Message + err := msg.UnmarshalWithDecoder(bin.NewBinDecoder(buf)) + require.NoError(t, err) + assert.Equal(t, tt.expectedVersion, msg.GetVersion()) + }) + } +} + +// Tests that unsupported version numbers (> 0) in versioned messages are rejected. +// Ported from solana-sdk/message/src/versions/v0/mod.rs version validation. +func TestVersionDetection_UnsupportedVersion(t *testing.T) { + // 0x81 = messageVersionPrefix | 1 → version 1 (unsupported) + buf := []byte{0x81, 1, 0, 0} + buf = append(buf, 1) // 1 account key + buf = append(buf, make([]byte, 32)...) // account key + buf = append(buf, make([]byte, 32)...) // blockhash + buf = append(buf, 0) // 0 instructions + buf = append(buf, 0) // 0 lookups + + var msg Message + err := msg.UnmarshalWithDecoder(bin.NewBinDecoder(buf)) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported message version") +} + +// Ported from solana-sdk/message/src/legacy.rs: test_is_writable_index_saturating_behavior. +// Tests edge cases where header values exceed the number of account keys. +func TestIsWritable_SaturatingBehavior(t *testing.T) { + // Case 1: num_readonly_signed (2) > num_required_signatures (1) + // Index 0 is signed but readonly count exceeds signature count → not writable. + key0 := newUniqueKey() + msg1 := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlySignedAccounts: 2, + NumReadonlyUnsignedAccounts: 0, + }, + AccountKeys: PublicKeySlice{key0}, + } + w, err := msg1.IsWritable(key0) + require.NoError(t, err) + assert.False(t, w, "case 1: readonly signed exceeds required signatures") + + // Case 2: num_readonly_unsigned (2) > num unsigned accounts (1) + // Only 1 account, 0 signers, all are unsigned but readonly count exceeds → not writable. + key1 := newUniqueKey() + msg2 := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 0, + NumReadonlySignedAccounts: 0, + NumReadonlyUnsignedAccounts: 2, + }, + AccountKeys: PublicKeySlice{key1}, + } + w, err = msg2.IsWritable(key1) + require.NoError(t, err) + assert.False(t, w, "case 2: readonly unsigned exceeds unsigned accounts") + + // Case 3: 1 signer, 0 readonly signed, 2 readonly unsigned but only 1 account. + // Index 0 is a writable signer. + msg3 := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlySignedAccounts: 0, + NumReadonlyUnsignedAccounts: 2, + }, + AccountKeys: PublicKeySlice{key0}, + } + w, err = msg3.IsWritable(key0) + require.NoError(t, err) + assert.True(t, w, "case 3: signer with no readonly signed is writable") + + // Case 4: 2 accounts, 1 signer, 0 readonly signed, 3 readonly unsigned. + // Index 0: writable signer; index 1: unsigned but readonly exceeds → not writable. + key2 := newUniqueKey() + msg4 := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlySignedAccounts: 0, + NumReadonlyUnsignedAccounts: 3, + }, + AccountKeys: PublicKeySlice{key0, key2}, + } + w, err = msg4.IsWritable(key0) + require.NoError(t, err) + assert.True(t, w, "case 4: key0 writable signer") + w, err = msg4.IsWritable(key2) + require.NoError(t, err) + assert.False(t, w, "case 4: key1 readonly unsigned") + + // Case 5: 2 accounts, 1 signer with 2 readonly signed, 3 readonly unsigned. + // Both accounts should be readonly. + msg5 := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlySignedAccounts: 2, + NumReadonlyUnsignedAccounts: 3, + }, + AccountKeys: PublicKeySlice{key0, key2}, + } + w, err = msg5.IsWritable(key0) + require.NoError(t, err) + assert.False(t, w, "case 5: key0 readonly (readonly_signed exceeds required)") + w, err = msg5.IsWritable(key2) + require.NoError(t, err) + assert.False(t, w, "case 5: key1 readonly unsigned") +} + +// Ported from solana-sdk/message/src/legacy.rs: test_is_maybe_writable. +// Tests the standard writability layout: +// +// Header: 3 signers (2 readonly), 1 readonly unsigned → 6 accounts total. +// idx 0: writable signer +// idx 1: readonly signer +// idx 2: readonly signer +// idx 3: writable unsigned +// idx 4: writable unsigned +// idx 5: readonly unsigned +func TestIsWritable_StandardLayout(t *testing.T) { + keys := [6]PublicKey{} + for i := range keys { + keys[i] = newUniqueKey() + } + + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 3, + NumReadonlySignedAccounts: 2, + NumReadonlyUnsignedAccounts: 1, + }, + AccountKeys: PublicKeySlice{keys[0], keys[1], keys[2], keys[3], keys[4], keys[5]}, + } + + expected := []bool{true, false, false, true, true, false} + for i, key := range keys { + w, err := msg.IsWritable(key) + require.NoError(t, err) + assert.Equal(t, expected[i], w, "index %d", i) + } +} + +// Ported from solana-sdk/message/src/legacy.rs: test_message_signed_keys_len. +func TestIsSigner(t *testing.T) { + key0 := newUniqueKey() + key1 := newUniqueKey() + programID := newUniqueKey() + + // No signers. + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 0, + }, + AccountKeys: PublicKeySlice{key0, programID}, + } + assert.False(t, msg.IsSigner(key0)) + + // One signer. + msg = Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + }, + AccountKeys: PublicKeySlice{key0, programID}, + } + assert.True(t, msg.IsSigner(key0)) + assert.False(t, msg.IsSigner(programID)) + + // Two signers. + msg = Message{ + Header: MessageHeader{ + NumRequiredSignatures: 2, + }, + AccountKeys: PublicKeySlice{key0, key1, programID}, + } + assert.True(t, msg.IsSigner(key0)) + assert.True(t, msg.IsSigner(key1)) + assert.False(t, msg.IsSigner(programID)) + + // Unknown key is not a signer. + assert.False(t, msg.IsSigner(newUniqueKey())) +} + +// Ported from solana-sdk/message/src/legacy.rs: test_program_ids. +func TestProgramIDs(t *testing.T) { + key0 := newUniqueKey() + key1 := newUniqueKey() + programID := newUniqueKey() + + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlySignedAccounts: 0, + NumReadonlyUnsignedAccounts: 1, // programID at index 2 is readonly unsigned + }, + AccountKeys: PublicKeySlice{key0, key1, programID}, + Instructions: []CompiledInstruction{ + { + ProgramIDIndex: 2, + Accounts: []uint16{0, 1}, + Data: []byte{}, + }, + }, + } + + resolved, err := msg.Program(2) + require.NoError(t, err) + assert.Equal(t, programID, resolved) + + _, err = msg.Program(3) // out of range + require.Error(t, err) +} + +// Tests that IsWritableStatic only considers static accounts, ignoring lookups. +func TestIsWritableStatic_IgnoresLookups(t *testing.T) { + keys := [4]PublicKey{} + for i := range keys { + keys[i] = newUniqueKey() + } + + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 2, + NumReadonlySignedAccounts: 1, + NumReadonlyUnsignedAccounts: 1, + }, + AccountKeys: PublicKeySlice{keys[0], keys[1], keys[2], keys[3]}, + } + msg.version = MessageVersionV0 + + assert.True(t, msg.IsWritableStatic(keys[0]), "writable signer") + assert.False(t, msg.IsWritableStatic(keys[1]), "readonly signer") + assert.True(t, msg.IsWritableStatic(keys[2]), "writable unsigned") + assert.False(t, msg.IsWritableStatic(keys[3]), "readonly unsigned") + assert.False(t, msg.IsWritableStatic(newUniqueKey()), "unknown key") +} + +// Tests JSON serialization roundtrip for both legacy and V0 messages. +func TestMessageJSONRoundtrip(t *testing.T) { + t.Run("legacy", func(t *testing.T) { + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlySignedAccounts: 0, + NumReadonlyUnsignedAccounts: 1, + }, + AccountKeys: PublicKeySlice{newUniqueKey(), newUniqueKey()}, + RecentBlockhash: Hash{1, 2, 3}, + Instructions: []CompiledInstruction{ + {ProgramIDIndex: 1, Accounts: []uint16{0}, Data: []byte{0xFF}}, + }, + } + msg.version = MessageVersionLegacy + + data, err := msg.MarshalJSON() + require.NoError(t, err) + require.NotEmpty(t, data) + + // Should not contain addressTableLookups for legacy. + assert.NotContains(t, string(data), "addressTableLookups") + }) + + t.Run("v0", func(t *testing.T) { + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlySignedAccounts: 0, + NumReadonlyUnsignedAccounts: 0, + }, + AccountKeys: PublicKeySlice{newUniqueKey()}, + RecentBlockhash: Hash{4, 5, 6}, + Instructions: []CompiledInstruction{ + {ProgramIDIndex: 0, Accounts: []uint16{0}, Data: []byte{0x01}}, + }, + AddressTableLookups: []MessageAddressTableLookup{ + { + AccountKey: newUniqueKey(), + WritableIndexes: []uint8{0}, + ReadonlyIndexes: []uint8{1}, + }, + }, + } + msg.version = MessageVersionV0 + + data, err := msg.MarshalJSON() + require.NoError(t, err) + assert.Contains(t, string(data), "addressTableLookups") + }) +} + +// Tests that the V0 prefix byte is exactly 0x80 for version 0, +// matching Rust's MESSAGE_VERSION_PREFIX | 0 = 0x80. +func TestMarshalV0_PrefixByte(t *testing.T) { + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlySignedAccounts: 0, + NumReadonlyUnsignedAccounts: 0, + }, + AccountKeys: PublicKeySlice{newUniqueKey()}, + RecentBlockhash: Hash{}, + Instructions: []CompiledInstruction{}, + } + msg.version = MessageVersionV0 + + data, err := msg.MarshalBinary() + require.NoError(t, err) + assert.Equal(t, byte(0x80), data[0]) + + // Second byte should be numRequiredSignatures. + assert.Equal(t, byte(1), data[1]) +} + +// Tests that legacy message does NOT have the 0x80 prefix. +func TestMarshalLegacy_NoPrefixByte(t *testing.T) { + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 3, + NumReadonlySignedAccounts: 1, + NumReadonlyUnsignedAccounts: 1, + }, + AccountKeys: PublicKeySlice{newUniqueKey(), newUniqueKey(), newUniqueKey()}, + RecentBlockhash: Hash{}, + Instructions: []CompiledInstruction{}, + } + msg.version = MessageVersionLegacy + + data, err := msg.MarshalBinary() + require.NoError(t, err) + + // First byte is numRequiredSignatures directly, not a version prefix. + assert.Equal(t, byte(3), data[0]) + assert.Equal(t, byte(0), data[0]&0x80, "high bit should not be set for legacy") +} + +// Tests Account() method for both static and resolved lookup accounts. +func TestAccount_StaticAndLookup(t *testing.T) { + keys := [4]PublicKey{} + for i := range keys { + keys[i] = newUniqueKey() + } + tableKey := newUniqueKey() + lookupKey := newUniqueKey() + + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlySignedAccounts: 0, + NumReadonlyUnsignedAccounts: 0, + }, + AccountKeys: PublicKeySlice{keys[0], keys[1]}, + AddressTableLookups: []MessageAddressTableLookup{ + { + AccountKey: tableKey, + WritableIndexes: []uint8{0}, + ReadonlyIndexes: []uint8{}, + }, + }, + } + msg.version = MessageVersionV0 + err := msg.SetAddressTables(map[PublicKey]PublicKeySlice{ + tableKey: {lookupKey}, + }) + require.NoError(t, err) + + // Static account. + acct, err := msg.Account(0) + require.NoError(t, err) + assert.Equal(t, keys[0], acct) + + acct, err = msg.Account(1) + require.NoError(t, err) + assert.Equal(t, keys[1], acct) + + // Lookup account (index 2 = first lookup). + acct, err = msg.Account(2) + require.NoError(t, err) + assert.Equal(t, lookupKey, acct) + + // Out of range. + _, err = msg.Account(3) + require.Error(t, err) +} + +// Tests HasAccount and GetAccountIndex. +func TestHasAccountAndGetIndex(t *testing.T) { + keys := [3]PublicKey{} + for i := range keys { + keys[i] = newUniqueKey() + } + + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + }, + AccountKeys: PublicKeySlice{keys[0], keys[1], keys[2]}, + } + + for i, key := range keys { + has, err := msg.HasAccount(key) + require.NoError(t, err) + assert.True(t, has) + + idx, err := msg.GetAccountIndex(key) + require.NoError(t, err) + assert.Equal(t, uint16(i), idx) + } + + unknown := newUniqueKey() + has, err := msg.HasAccount(unknown) + require.NoError(t, err) + assert.False(t, has) + + _, err = msg.GetAccountIndex(unknown) + require.Error(t, err) +} + +// Tests Signers() returns only the first numRequiredSignatures accounts. +func TestSigners(t *testing.T) { + keys := [5]PublicKey{} + for i := range keys { + keys[i] = newUniqueKey() + } + + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 3, + NumReadonlySignedAccounts: 1, + NumReadonlyUnsignedAccounts: 0, + }, + AccountKeys: PublicKeySlice{keys[0], keys[1], keys[2], keys[3], keys[4]}, + } + + signers := msg.Signers() + require.Equal(t, 3, len(signers)) + assert.Equal(t, keys[0], signers[0]) + assert.Equal(t, keys[1], signers[1]) + assert.Equal(t, keys[2], signers[2]) +} + +// Tests Writable() returns only writable accounts across static and lookup. +func TestWritable(t *testing.T) { + keys := [4]PublicKey{} + for i := range keys { + keys[i] = newUniqueKey() + } + + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 2, + NumReadonlySignedAccounts: 1, + NumReadonlyUnsignedAccounts: 1, + }, + AccountKeys: PublicKeySlice{keys[0], keys[1], keys[2], keys[3]}, + } + + writable, err := msg.Writable() + require.NoError(t, err) + // idx 0: writable signer, idx 1: readonly signer, idx 2: writable unsigned, idx 3: readonly unsigned + require.Equal(t, 2, len(writable)) + assert.Equal(t, keys[0], writable[0]) + assert.Equal(t, keys[2], writable[1]) +} + +// Tests that SetVersion validates input. +func TestSetVersion_Validation(t *testing.T) { + msg := &Message{} + + _, err := msg.SetVersion(MessageVersionLegacy) + require.NoError(t, err) + assert.Equal(t, MessageVersionLegacy, msg.GetVersion()) + + _, err = msg.SetVersion(MessageVersionV0) + require.NoError(t, err) + assert.Equal(t, MessageVersionV0, msg.GetVersion()) + + _, err = msg.SetVersion(MessageVersion(99)) + require.Error(t, err) +} + +// Tests IsVersioned(). +func TestIsVersioned(t *testing.T) { + msg := Message{} + msg.version = MessageVersionLegacy + assert.False(t, msg.IsVersioned()) + + msg.version = MessageVersionV0 + assert.True(t, msg.IsVersioned()) +} + +// Tests that SetAddressTables can only be called once. +func TestSetAddressTables_OnlyOnce(t *testing.T) { + msg := &Message{} + msg.version = MessageVersionV0 + + err := msg.SetAddressTables(map[PublicKey]PublicKeySlice{}) + require.NoError(t, err) + + err = msg.SetAddressTables(map[PublicKey]PublicKeySlice{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "already set") +} + +// Tests that checkPreconditions fails when address tables are needed but not set. +func TestCheckPreconditions_MissingTables(t *testing.T) { + msg := Message{ + AddressTableLookups: []MessageAddressTableLookup{ + { + AccountKey: newUniqueKey(), + WritableIndexes: []uint8{0}, + ReadonlyIndexes: []uint8{}, + }, + }, + } + msg.version = MessageVersionV0 + + _, err := msg.AccountMetaList() + require.Error(t, err) + assert.Contains(t, err.Error(), "without address tables") +} + +// Tests base64 roundtrip. +func TestMarshalUnmarshalBase64(t *testing.T) { + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlySignedAccounts: 0, + NumReadonlyUnsignedAccounts: 1, + }, + AccountKeys: PublicKeySlice{newUniqueKey(), newUniqueKey()}, + RecentBlockhash: Hash{42}, + Instructions: []CompiledInstruction{ + {ProgramIDIndex: 1, Accounts: []uint16{0}, Data: []byte{0xDE, 0xAD}}, + }, + } + msg.version = MessageVersionLegacy + + b64 := msg.ToBase64() + require.NotEmpty(t, b64) + + var decoded Message + err := decoded.UnmarshalBase64(b64) + require.NoError(t, err) + assert.Equal(t, msg.Header, decoded.Header) + assert.Equal(t, msg.AccountKeys, decoded.AccountKeys) + assert.Equal(t, msg.RecentBlockhash, decoded.RecentBlockhash) +} diff --git a/programs/address-lookup-table/instruction_test.go b/programs/address-lookup-table/instruction_test.go index da04c3287..a549d293c 100644 --- a/programs/address-lookup-table/instruction_test.go +++ b/programs/address-lookup-table/instruction_test.go @@ -87,7 +87,7 @@ func TestEncodingInstruction(t *testing.T) { } func TestEncodeExtendLookupTable(t *testing.T) { - addr1 := solana.MustPublicKeyFromBase58("11111111111111111111111111111111") + addr1 := solana.SystemProgramID addr2 := solana.MustPublicKeyFromBase58("TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA") lookupTable := solana.NewWallet().PublicKey() diff --git a/rpc/client_test.go b/rpc/client_test.go index a4d069d1f..cb01a110c 100644 --- a/rpc/client_test.go +++ b/rpc/client_test.go @@ -73,7 +73,7 @@ func TestClient_GetAccountInfo(t *testing.T) { }, Value: &Account{ Lamports: 999999, - Owner: solana.MustPublicKeyFromBase58("11111111111111111111111111111111"), + Owner: solana.SystemProgramID, Data: &DataBytesOrJSON{ rawDataEncoding: solana.EncodingBase64, asDecodedBinary: solana.Data{ @@ -1444,7 +1444,7 @@ func TestClient_GetMultipleAccounts(t *testing.T) { Value: []*Account{ { Lamports: 19039980000, - Owner: solana.MustPublicKeyFromBase58("11111111111111111111111111111111"), + Owner: solana.SystemProgramID, Data: &DataBytesOrJSON{ asDecodedBinary: solana.Data{ Content: []byte{}, diff --git a/sanitize.go b/sanitize.go new file mode 100644 index 000000000..789b3524e --- /dev/null +++ b/sanitize.go @@ -0,0 +1,244 @@ +package solana + +import ( + "errors" + "fmt" +) + +// sanitizeError represents a message or transaction validation error. +type sanitizeError struct { + msg string +} + +func (e *sanitizeError) Error() string { + return e.msg +} + +func newSanitizeError(format string, args ...any) error { + return &sanitizeError{msg: fmt.Sprintf(format, args...)} +} + +// IsSanitizeError reports whether err is a sanitization validation error. +func IsSanitizeError(err error) bool { + var se *sanitizeError + return errors.As(err, &se) +} + +// maxAccountKeys is the maximum number of accounts a message can reference. +// Account indices are encoded as u8, so the limit is 256. +const maxAccountKeys = 256 + +// Sanitize validates the structural integrity of a Message. +// Ported from solana-sdk/message: legacy.rs sanitize() and v0/mod.rs sanitize(). +func (m *Message) Sanitize() error { + if m.IsVersioned() { + return m.sanitizeV0() + } + return m.sanitizeLegacy() +} + +func (m *Message) sanitizeLegacy() error { + numKeys := len(m.AccountKeys) + + // Signing area and read-only non-signing area should not overlap. + if int(m.Header.NumRequiredSignatures)+int(m.Header.NumReadonlyUnsignedAccounts) > numKeys { + return newSanitizeError("header references more accounts than available: required_signatures(%d) + readonly_unsigned(%d) > account_keys(%d)", + m.Header.NumRequiredSignatures, m.Header.NumReadonlyUnsignedAccounts, numKeys) + } + + // There should be at least 1 RW fee-payer account. + if m.Header.NumReadonlySignedAccounts >= m.Header.NumRequiredSignatures { + return newSanitizeError("no writable signer: readonly_signed(%d) >= required_signatures(%d)", + m.Header.NumReadonlySignedAccounts, m.Header.NumRequiredSignatures) + } + + for i, ci := range m.Instructions { + if int(ci.ProgramIDIndex) >= numKeys { + return newSanitizeError("instruction %d: program_id_index %d out of bounds (account_keys len %d)", i, ci.ProgramIDIndex, numKeys) + } + // A program cannot be the payer. + if ci.ProgramIDIndex == 0 { + return newSanitizeError("instruction %d: program_id_index cannot be 0 (fee payer)", i) + } + for _, ai := range ci.Accounts { + if int(ai) >= numKeys { + return newSanitizeError("instruction %d: account index %d out of bounds (account_keys len %d)", i, ai, numKeys) + } + } + } + + return nil +} + +func (m *Message) sanitizeV0() error { + numStaticKeys := len(m.AccountKeys) + + // Signing area and read-only non-signing area should not overlap. + if int(m.Header.NumRequiredSignatures)+int(m.Header.NumReadonlyUnsignedAccounts) > numStaticKeys { + return newSanitizeError("header references more accounts than available: required_signatures(%d) + readonly_unsigned(%d) > static_keys(%d)", + m.Header.NumRequiredSignatures, m.Header.NumReadonlyUnsignedAccounts, numStaticKeys) + } + + // There should be at least 1 RW fee-payer account. + if m.Header.NumReadonlySignedAccounts >= m.Header.NumRequiredSignatures { + return newSanitizeError("no writable signer: readonly_signed(%d) >= required_signatures(%d)", + m.Header.NumReadonlySignedAccounts, m.Header.NumRequiredSignatures) + } + + // Count dynamic keys from address table lookups. + numDynamicKeys := 0 + for _, lookup := range m.AddressTableLookups { + numLookupIndexes := len(lookup.WritableIndexes) + len(lookup.ReadonlyIndexes) + // Each lookup table must be used to load at least one account. + if numLookupIndexes == 0 { + return newSanitizeError("address table lookup for %s loads no accounts", lookup.AccountKey) + } + numDynamicKeys += numLookupIndexes + } + + if numStaticKeys == 0 { + return newSanitizeError("message has no account keys") + } + + // The combined number of static and dynamic account keys must be <= 256 + // since account indices are encoded as u8. + totalKeys := numStaticKeys + numDynamicKeys + if totalKeys > maxAccountKeys { + return newSanitizeError("total account keys %d exceeds maximum %d", totalKeys, maxAccountKeys) + } + + maxAccountIdx := totalKeys - 1 + // Program IDs must be in static keys only (not from lookup tables). + maxProgramIdx := numStaticKeys - 1 + + for i, ci := range m.Instructions { + if int(ci.ProgramIDIndex) > maxProgramIdx { + return newSanitizeError("instruction %d: program_id_index %d exceeds static keys (max %d)", i, ci.ProgramIDIndex, maxProgramIdx) + } + // A program cannot be the payer. + if ci.ProgramIDIndex == 0 { + return newSanitizeError("instruction %d: program_id_index cannot be 0 (fee payer)", i) + } + for _, ai := range ci.Accounts { + if int(ai) > maxAccountIdx { + return newSanitizeError("instruction %d: account index %d out of bounds (max %d)", i, ai, maxAccountIdx) + } + } + } + + return nil +} + +// HasDuplicates checks if the message has duplicate account keys. +// Uses O(n^2) comparison but requires no heap allocation, which is faster +// for the typically small number of accounts in a message. +// Ported from solana-sdk/message/legacy.rs has_duplicates(). +func (m *Message) HasDuplicates() bool { + keys := m.AccountKeys + for i := 1; i < len(keys); i++ { + for j := i; j < len(keys); j++ { + if keys[i-1].Equals(keys[j]) { + return true + } + } + } + return false +} + +// Sanitize validates the structural integrity of a Transaction. +// It checks that the signature count matches the message header and +// that the message itself is valid. +// Ported from solana-sdk/transaction: lib.rs and versioned/mod.rs sanitize(). +func (tx *Transaction) Sanitize() error { + numSigs := len(tx.Signatures) + numRequired := int(tx.Message.Header.NumRequiredSignatures) + numStaticKeys := len(tx.Message.AccountKeys) + + // Signature count must exactly match num_required_signatures. + if numRequired > numSigs { + return newSanitizeError("not enough signatures: required %d, got %d", numRequired, numSigs) + } + if numRequired < numSigs { + return newSanitizeError("too many signatures: required %d, got %d", numRequired, numSigs) + } + + // Signatures must not exceed static account keys count + // (signatures are verified before lookup keys are loaded). + if numSigs > numStaticKeys { + return newSanitizeError("more signatures (%d) than static account keys (%d)", numSigs, numStaticKeys) + } + + return tx.Message.Sanitize() +} + +// VerifyWithResults verifies each signature independently and returns +// a per-signature boolean result. +// Ported from solana-sdk/transaction/lib.rs verify_with_results(). +func (tx *Transaction) VerifyWithResults() ([]bool, error) { + msg, err := tx.Message.MarshalBinary() + if err != nil { + return nil, err + } + + results := make([]bool, len(tx.Signatures)) + for i, sig := range tx.Signatures { + if i < len(tx.Message.AccountKeys) { + results[i] = sig.Verify(tx.Message.AccountKeys[i], msg) + } + } + return results, nil +} + +// isAdvanceNonceInstructionData checks if the instruction data starts with +// the AdvanceNonceAccount discriminant (u32 LE value 4). +func isAdvanceNonceInstructionData(data []byte) bool { + return len(data) >= 4 && data[0] == 4 && data[1] == 0 && data[2] == 0 && data[3] == 0 +} + +// nonceAdvanceInstruction returns the first instruction if it is a +// System Program AdvanceNonceAccount instruction, or nil otherwise. +func (tx *Transaction) nonceAdvanceInstruction() *CompiledInstruction { + if len(tx.Message.Instructions) == 0 { + return nil + } + ix := &tx.Message.Instructions[0] + + // Check that the program is the System Program. + if int(ix.ProgramIDIndex) >= len(tx.Message.AccountKeys) { + return nil + } + if !tx.Message.AccountKeys[ix.ProgramIDIndex].Equals(SystemProgramID) { + return nil + } + if !isAdvanceNonceInstructionData(ix.Data) { + return nil + } + return ix +} + +// UsesDurableNonce checks whether this transaction uses a durable nonce +// by inspecting the first instruction. Returns true if the first instruction +// is a System Program AdvanceNonceAccount instruction. +// Ported from solana-sdk/transaction: uses_durable_nonce(). +func (tx *Transaction) UsesDurableNonce() bool { + return tx.nonceAdvanceInstruction() != nil +} + +// GetNonceAccount returns the public key of the nonce account if this +// transaction uses a durable nonce. The nonce account is the first account +// of the first instruction (the AdvanceNonceAccount instruction). +// Returns the zero PublicKey and false if this is not a nonce transaction. +func (tx *Transaction) GetNonceAccount() (PublicKey, bool) { + ix := tx.nonceAdvanceInstruction() + if ix == nil { + return PublicKey{}, false + } + if len(ix.Accounts) == 0 { + return PublicKey{}, false + } + nonceAccountIdx := ix.Accounts[0] + if int(nonceAccountIdx) >= len(tx.Message.AccountKeys) { + return PublicKey{}, false + } + return tx.Message.AccountKeys[nonceAccountIdx], true +} diff --git a/sanitize_test.go b/sanitize_test.go new file mode 100644 index 000000000..69ae35825 --- /dev/null +++ b/sanitize_test.go @@ -0,0 +1,673 @@ +package solana + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- Message Sanitize Tests --- +// Ported from solana-sdk/message/src/legacy.rs and versions/v0/mod.rs + +func TestMessageSanitize_LegacyValid(t *testing.T) { + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlySignedAccounts: 0, + NumReadonlyUnsignedAccounts: 1, + }, + AccountKeys: PublicKeySlice{newUniqueKey(), newUniqueKey()}, + Instructions: []CompiledInstruction{ + {ProgramIDIndex: 1, Accounts: []uint16{0}, Data: []byte{}}, + }, + } + msg.version = MessageVersionLegacy + require.NoError(t, msg.Sanitize()) +} + +// Ported from legacy.rs: test_sanitize_txs (signing area + readonly overlap). +func TestMessageSanitize_Legacy_HeaderOverflow(t *testing.T) { + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 2, + NumReadonlySignedAccounts: 0, + NumReadonlyUnsignedAccounts: 2, + }, + // Only 3 keys, but header needs 2+2=4. + AccountKeys: PublicKeySlice{newUniqueKey(), newUniqueKey(), newUniqueKey()}, + } + msg.version = MessageVersionLegacy + require.Error(t, msg.Sanitize()) +} + +// Ported from v0/mod.rs: test_sanitize_without_writable_signer. +func TestMessageSanitize_NoWritableSigner(t *testing.T) { + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlySignedAccounts: 1, // all signers are readonly + NumReadonlyUnsignedAccounts: 0, + }, + AccountKeys: PublicKeySlice{newUniqueKey()}, + } + msg.version = MessageVersionLegacy + err := msg.Sanitize() + require.Error(t, err) + assert.Contains(t, err.Error(), "no writable signer") +} + +// Ported from v0/mod.rs: test_sanitize_without_signer. +func TestMessageSanitize_NoSigner(t *testing.T) { + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 0, + }, + AccountKeys: PublicKeySlice{newUniqueKey()}, + } + msg.version = MessageVersionLegacy + err := msg.Sanitize() + require.Error(t, err) +} + +// Ported from legacy.rs: program_id_index out of bounds. +func TestMessageSanitize_Legacy_InvalidProgramID(t *testing.T) { + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlySignedAccounts: 0, + NumReadonlyUnsignedAccounts: 0, + }, + AccountKeys: PublicKeySlice{newUniqueKey()}, + Instructions: []CompiledInstruction{ + {ProgramIDIndex: 5, Accounts: []uint16{0}, Data: []byte{}}, // out of bounds + }, + } + msg.version = MessageVersionLegacy + err := msg.Sanitize() + require.Error(t, err) + assert.Contains(t, err.Error(), "program_id_index") +} + +// Ported from v0/mod.rs: test_sanitize_with_instruction — program at index 0 (payer) is invalid. +func TestMessageSanitize_ProgramIsPayer(t *testing.T) { + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlySignedAccounts: 0, + NumReadonlyUnsignedAccounts: 0, + }, + AccountKeys: PublicKeySlice{newUniqueKey(), newUniqueKey()}, + Instructions: []CompiledInstruction{ + {ProgramIDIndex: 0, Accounts: []uint16{1}, Data: []byte{}}, + }, + } + msg.version = MessageVersionLegacy + err := msg.Sanitize() + require.Error(t, err) + assert.Contains(t, err.Error(), "fee payer") +} + +// Ported from legacy.rs: account index out of bounds. +func TestMessageSanitize_Legacy_InvalidAccountIndex(t *testing.T) { + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlySignedAccounts: 0, + NumReadonlyUnsignedAccounts: 0, + }, + AccountKeys: PublicKeySlice{newUniqueKey(), newUniqueKey()}, + Instructions: []CompiledInstruction{ + {ProgramIDIndex: 1, Accounts: []uint16{5}, Data: []byte{}}, // out of bounds + }, + } + msg.version = MessageVersionLegacy + err := msg.Sanitize() + require.Error(t, err) + assert.Contains(t, err.Error(), "account index") +} + +// --- V0 Sanitize Tests --- + +// Ported from v0/mod.rs: test_sanitize. +func TestMessageSanitize_V0_Valid(t *testing.T) { + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlySignedAccounts: 0, + NumReadonlyUnsignedAccounts: 0, + }, + AccountKeys: PublicKeySlice{newUniqueKey()}, + } + msg.version = MessageVersionV0 + require.NoError(t, msg.Sanitize()) +} + +// Ported from v0/mod.rs: test_sanitize_with_table_lookup. +func TestMessageSanitize_V0_WithTableLookup(t *testing.T) { + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + }, + AccountKeys: PublicKeySlice{newUniqueKey()}, + AddressTableLookups: []MessageAddressTableLookup{ + { + AccountKey: newUniqueKey(), + WritableIndexes: []uint8{1, 2, 3}, + ReadonlyIndexes: []uint8{0}, + }, + }, + } + msg.version = MessageVersionV0 + require.NoError(t, msg.Sanitize()) +} + +// Ported from v0/mod.rs: test_sanitize_with_empty_table_lookup. +func TestMessageSanitize_V0_EmptyTableLookup(t *testing.T) { + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + }, + AccountKeys: PublicKeySlice{newUniqueKey()}, + AddressTableLookups: []MessageAddressTableLookup{ + { + AccountKey: newUniqueKey(), + WritableIndexes: []uint8{}, + ReadonlyIndexes: []uint8{}, + }, + }, + } + msg.version = MessageVersionV0 + err := msg.Sanitize() + require.Error(t, err) + assert.Contains(t, err.Error(), "loads no accounts") +} + +// Ported from v0/mod.rs: test_sanitize_with_max_account_keys (256 = ok). +func TestMessageSanitize_V0_MaxAccountKeys(t *testing.T) { + keys := make(PublicKeySlice, 256) + for i := range keys { + keys[i] = newUniqueKey() + } + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + }, + AccountKeys: keys, + } + msg.version = MessageVersionV0 + require.NoError(t, msg.Sanitize()) +} + +// Ported from v0/mod.rs: test_sanitize_with_too_many_account_keys (257 = error). +func TestMessageSanitize_V0_TooManyAccountKeys(t *testing.T) { + keys := make(PublicKeySlice, 257) + for i := range keys { + keys[i] = newUniqueKey() + } + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + }, + AccountKeys: keys, + } + msg.version = MessageVersionV0 + err := msg.Sanitize() + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeds maximum") +} + +// Ported from v0/mod.rs: test_sanitize_with_table_lookup_and_ix_with_dynamic_program_id. +// Program IDs loaded from lookup tables should be rejected. +func TestMessageSanitize_V0_DynamicProgramID(t *testing.T) { + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + }, + AccountKeys: PublicKeySlice{newUniqueKey()}, + AddressTableLookups: []MessageAddressTableLookup{ + { + AccountKey: newUniqueKey(), + WritableIndexes: []uint8{1, 2, 3}, + ReadonlyIndexes: []uint8{0}, + }, + }, + Instructions: []CompiledInstruction{ + {ProgramIDIndex: 4, Accounts: []uint16{0, 1, 2, 3}, Data: []byte{}}, // index 4 is in lookup table + }, + } + msg.version = MessageVersionV0 + err := msg.Sanitize() + require.Error(t, err) + assert.Contains(t, err.Error(), "static keys") +} + +// Ported from v0/mod.rs: test_sanitize_with_table_lookup_and_ix_with_static_program_id. +func TestMessageSanitize_V0_StaticProgramID(t *testing.T) { + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + }, + AccountKeys: PublicKeySlice{newUniqueKey(), newUniqueKey()}, + AddressTableLookups: []MessageAddressTableLookup{ + { + AccountKey: newUniqueKey(), + WritableIndexes: []uint8{1, 2, 3}, + ReadonlyIndexes: []uint8{0}, + }, + }, + Instructions: []CompiledInstruction{ + {ProgramIDIndex: 1, Accounts: []uint16{2, 3, 4, 5}, Data: []byte{}}, + }, + } + msg.version = MessageVersionV0 + require.NoError(t, msg.Sanitize()) +} + +// Ported from v0/mod.rs: test_sanitize_with_invalid_ix_account. +func TestMessageSanitize_V0_InvalidIxAccount(t *testing.T) { + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + }, + AccountKeys: PublicKeySlice{newUniqueKey(), newUniqueKey()}, + AddressTableLookups: []MessageAddressTableLookup{ + { + AccountKey: newUniqueKey(), + WritableIndexes: []uint8{}, + ReadonlyIndexes: []uint8{0}, + }, + }, + Instructions: []CompiledInstruction{ + {ProgramIDIndex: 1, Accounts: []uint16{3}, Data: []byte{}}, // index 3 out of bounds (2 static + 1 lookup = 3 total, max index = 2) + }, + } + msg.version = MessageVersionV0 + err := msg.Sanitize() + require.Error(t, err) + assert.Contains(t, err.Error(), "account index") +} + +// Ported from v0/mod.rs: test_sanitize_with_too_many_table_loaded_keys. +func TestMessageSanitize_V0_TooManyDynamicKeys(t *testing.T) { + writable := make([]uint8, 128) + readonly := make([]uint8, 128) + for i := range writable { + writable[i] = uint8(i * 2) + readonly[i] = uint8(i*2 + 1) + } + msg := Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + }, + AccountKeys: PublicKeySlice{newUniqueKey()}, + AddressTableLookups: []MessageAddressTableLookup{ + { + AccountKey: newUniqueKey(), + WritableIndexes: writable, + ReadonlyIndexes: readonly, + }, + }, + } + msg.version = MessageVersionV0 + err := msg.Sanitize() + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeds maximum") +} + +// --- HasDuplicates Tests --- + +func TestMessageHasDuplicates(t *testing.T) { + key := newUniqueKey() + + t.Run("no duplicates", func(t *testing.T) { + msg := Message{AccountKeys: PublicKeySlice{newUniqueKey(), newUniqueKey(), newUniqueKey()}} + assert.False(t, msg.HasDuplicates()) + }) + + t.Run("with duplicates", func(t *testing.T) { + msg := Message{AccountKeys: PublicKeySlice{key, newUniqueKey(), key}} + assert.True(t, msg.HasDuplicates()) + }) + + t.Run("adjacent duplicates", func(t *testing.T) { + msg := Message{AccountKeys: PublicKeySlice{key, key}} + assert.True(t, msg.HasDuplicates()) + }) + + t.Run("single key", func(t *testing.T) { + msg := Message{AccountKeys: PublicKeySlice{key}} + assert.False(t, msg.HasDuplicates()) + }) + + t.Run("empty", func(t *testing.T) { + msg := Message{AccountKeys: PublicKeySlice{}} + assert.False(t, msg.HasDuplicates()) + }) +} + +// --- Transaction Sanitize Tests --- +// Ported from solana-sdk/transaction/src/lib.rs and versioned/mod.rs + +// Ported from versioned/mod.rs: test_sanitize_signatures_inner — exact match. +func TestTransactionSanitize_Valid(t *testing.T) { + signer := NewWallet().PrivateKey + + tx, err := NewTransaction([]Instruction{ + &testTransactionInstructions{ + accounts: []*AccountMeta{{PublicKey: signer.PublicKey(), IsSigner: true, IsWritable: true}}, + data: []byte{0x01}, + programID: SystemProgramID, + }, + }, Hash{1, 2, 3}) + require.NoError(t, err) + + _, err = tx.Sign(func(key PublicKey) *PrivateKey { + if key.Equals(signer.PublicKey()) { + return &signer + } + return nil + }) + require.NoError(t, err) + + require.NoError(t, tx.Sanitize()) +} + +// Ported from lib.rs: test_sanitize_txs — not enough signatures. +func TestTransactionSanitize_NotEnoughSignatures(t *testing.T) { + tx := &Transaction{ + Message: Message{ + Header: MessageHeader{ + NumRequiredSignatures: 2, + NumReadonlySignedAccounts: 0, + NumReadonlyUnsignedAccounts: 1, + }, + AccountKeys: PublicKeySlice{newUniqueKey(), newUniqueKey(), newUniqueKey()}, + }, + Signatures: []Signature{{}}, // only 1, need 2 + } + err := tx.Sanitize() + require.Error(t, err) + assert.Contains(t, err.Error(), "not enough signatures") +} + +// Ported from versioned/mod.rs: test_sanitize_signatures_inner — too many signatures. +func TestTransactionSanitize_TooManySignatures(t *testing.T) { + tx := &Transaction{ + Message: Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlySignedAccounts: 0, + NumReadonlyUnsignedAccounts: 0, + }, + AccountKeys: PublicKeySlice{newUniqueKey(), newUniqueKey()}, + }, + Signatures: []Signature{{}, {}}, // 2, but only 1 required + } + err := tx.Sanitize() + require.Error(t, err) + assert.Contains(t, err.Error(), "too many signatures") +} + +// Ported from versioned/mod.rs: signatures exceed static account keys. +func TestTransactionSanitize_SignaturesExceedKeys(t *testing.T) { + tx := &Transaction{ + Message: Message{ + Header: MessageHeader{ + NumRequiredSignatures: 3, + NumReadonlySignedAccounts: 0, + NumReadonlyUnsignedAccounts: 0, + }, + AccountKeys: PublicKeySlice{newUniqueKey(), newUniqueKey()}, // only 2 keys + }, + Signatures: []Signature{{}, {}, {}}, // 3 signatures + } + err := tx.Sanitize() + require.Error(t, err) + assert.Contains(t, err.Error(), "static account keys") +} + +// --- VerifyWithResults Tests --- + +func TestVerifyWithResults_AllValid(t *testing.T) { + signers := []PrivateKey{ + NewWallet().PrivateKey, + NewWallet().PrivateKey, + } + + tx, err := NewTransaction([]Instruction{ + &testTransactionInstructions{ + accounts: []*AccountMeta{ + {PublicKey: signers[0].PublicKey(), IsSigner: true, IsWritable: true}, + {PublicKey: signers[1].PublicKey(), IsSigner: true, IsWritable: false}, + }, + data: []byte{0x01}, + programID: SystemProgramID, + }, + }, Hash{42}) + require.NoError(t, err) + + _, err = tx.Sign(func(key PublicKey) *PrivateKey { + for _, s := range signers { + if key.Equals(s.PublicKey()) { + return &s + } + } + return nil + }) + require.NoError(t, err) + + results, err := tx.VerifyWithResults() + require.NoError(t, err) + require.Equal(t, 2, len(results)) + assert.True(t, results[0], "first signature should be valid") + assert.True(t, results[1], "second signature should be valid") +} + +func TestVerifyWithResults_OneBad(t *testing.T) { + signers := []PrivateKey{ + NewWallet().PrivateKey, + NewWallet().PrivateKey, + } + + tx, err := NewTransaction([]Instruction{ + &testTransactionInstructions{ + accounts: []*AccountMeta{ + {PublicKey: signers[0].PublicKey(), IsSigner: true, IsWritable: true}, + {PublicKey: signers[1].PublicKey(), IsSigner: true, IsWritable: false}, + }, + data: []byte{0x01}, + programID: SystemProgramID, + }, + }, Hash{42}) + require.NoError(t, err) + + _, err = tx.Sign(func(key PublicKey) *PrivateKey { + for _, s := range signers { + if key.Equals(s.PublicKey()) { + return &s + } + } + return nil + }) + require.NoError(t, err) + + // Corrupt the second signature. + tx.Signatures[1][0] ^= 0xFF + + results, err := tx.VerifyWithResults() + require.NoError(t, err) + assert.True(t, results[0], "first signature should still be valid") + assert.False(t, results[1], "second signature should be invalid") +} + +// --- UsesDurableNonce Tests --- +// Ported from solana-sdk/transaction: tx_uses_nonce_* tests. + +func makeNonceAdvanceIxData() []byte { + // System instruction discriminant for AdvanceNonceAccount = 4 (LE u32). + return []byte{4, 0, 0, 0} +} + +// Ported from lib.rs: tx_uses_nonce_ok. +func TestUsesDurableNonce_Valid(t *testing.T) { + nonceAccount := newUniqueKey() + nonceAuthority := newUniqueKey() + recentBlockhashes := newUniqueKey() + + tx := &Transaction{ + Message: Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlySignedAccounts: 0, + NumReadonlyUnsignedAccounts: 2, + }, + AccountKeys: PublicKeySlice{ + nonceAuthority, // 0: signer + nonceAccount, // 1: writable + recentBlockhashes, // 2: readonly + SystemProgramID, // 3: system program + }, + Instructions: []CompiledInstruction{ + { + ProgramIDIndex: 3, // system program + Accounts: []uint16{1, 2, 0}, // nonce account, recent blockhashes, authority + Data: makeNonceAdvanceIxData(), + }, + }, + }, + Signatures: []Signature{{}}, + } + + assert.True(t, tx.UsesDurableNonce()) + + account, ok := tx.GetNonceAccount() + require.True(t, ok) + assert.Equal(t, nonceAccount, account) +} + +// Ported from lib.rs: tx_uses_nonce_empty_ix_fail. +func TestUsesDurableNonce_EmptyInstructions(t *testing.T) { + tx := &Transaction{ + Message: Message{ + Header: MessageHeader{NumRequiredSignatures: 1}, + AccountKeys: PublicKeySlice{ + newUniqueKey(), + SystemProgramID, + }, + Instructions: []CompiledInstruction{}, + }, + Signatures: []Signature{{}}, + } + assert.False(t, tx.UsesDurableNonce()) +} + +// Ported from lib.rs: tx_uses_nonce_bad_prog_id_idx_fail. +func TestUsesDurableNonce_BadProgramIDIndex(t *testing.T) { + tx := &Transaction{ + Message: Message{ + Header: MessageHeader{NumRequiredSignatures: 1}, + AccountKeys: PublicKeySlice{newUniqueKey()}, + Instructions: []CompiledInstruction{ + { + ProgramIDIndex: 255, // out of bounds + Accounts: []uint16{0}, + Data: makeNonceAdvanceIxData(), + }, + }, + }, + Signatures: []Signature{{}}, + } + assert.False(t, tx.UsesDurableNonce()) +} + +// Ported from lib.rs: tx_uses_nonce_first_prog_id_not_nonce_fail. +func TestUsesDurableNonce_WrongProgram(t *testing.T) { + tx := &Transaction{ + Message: Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlyUnsignedAccounts: 1, + }, + AccountKeys: PublicKeySlice{ + newUniqueKey(), + newUniqueKey(), // not system program + }, + Instructions: []CompiledInstruction{ + { + ProgramIDIndex: 1, + Accounts: []uint16{0}, + Data: makeNonceAdvanceIxData(), + }, + }, + }, + Signatures: []Signature{{}}, + } + assert.False(t, tx.UsesDurableNonce()) +} + +// Ported from lib.rs: tx_uses_nonce_wrong_first_nonce_ix_fail. +func TestUsesDurableNonce_WrongInstruction(t *testing.T) { + tx := &Transaction{ + Message: Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlyUnsignedAccounts: 1, + }, + AccountKeys: PublicKeySlice{ + newUniqueKey(), + SystemProgramID, + }, + Instructions: []CompiledInstruction{ + { + ProgramIDIndex: 1, + Accounts: []uint16{0}, + Data: []byte{2, 0, 0, 0}, // Transfer, not AdvanceNonce + }, + }, + }, + Signatures: []Signature{{}}, + } + assert.False(t, tx.UsesDurableNonce()) +} + +// Tests GetNonceAccount returns false for non-nonce transactions. +func TestGetNonceAccount_NotNonce(t *testing.T) { + tx := &Transaction{ + Message: Message{ + Header: MessageHeader{NumRequiredSignatures: 1}, + AccountKeys: PublicKeySlice{newUniqueKey()}, + Instructions: []CompiledInstruction{ + {ProgramIDIndex: 0, Accounts: []uint16{0}, Data: []byte{0x01}}, + }, + }, + Signatures: []Signature{{}}, + } + _, ok := tx.GetNonceAccount() + assert.False(t, ok) +} + +// Tests that nonce instruction with short data is not detected. +func TestUsesDurableNonce_ShortData(t *testing.T) { + tx := &Transaction{ + Message: Message{ + Header: MessageHeader{ + NumRequiredSignatures: 1, + NumReadonlyUnsignedAccounts: 1, + }, + AccountKeys: PublicKeySlice{ + newUniqueKey(), + SystemProgramID, + }, + Instructions: []CompiledInstruction{ + { + ProgramIDIndex: 1, + Accounts: []uint16{0}, + Data: []byte{4, 0}, // too short — need 4 bytes for u32 + }, + }, + }, + Signatures: []Signature{{}}, + } + assert.False(t, tx.UsesDurableNonce()) +} diff --git a/transaction.go b/transaction.go index 34d180957..8af306acd 100644 --- a/transaction.go +++ b/transaction.go @@ -253,7 +253,7 @@ func NewTransaction(instructions []Instruction, recentBlockHash Hash, opts ...Tr } } if !found { - return nil, fmt.Errorf("cannot determine fee payer. You can ether pass the fee payer via the 'TransactionWithInstructions' option parameter or it falls back to the first instruction's first signer") + return nil, fmt.Errorf("cannot determine fee payer. You can either pass the fee payer via the 'TransactionWithInstructions' option parameter or it falls back to the first instruction's first signer") } } @@ -539,15 +539,12 @@ func (tx *Transaction) UnmarshalWithDecoder(decoder *bin.Decoder) (err error) { if err != nil { return fmt.Errorf("unable to read numSignatures: %w", err) } - if numSignatures < 0 { - return fmt.Errorf("numSignatures is negative") - } if numSignatures > decoder.Remaining()/64 { return fmt.Errorf("numSignatures %d is too large for remaining bytes %d", numSignatures, decoder.Remaining()) } tx.Signatures = make([]Signature, numSignatures) - for i := 0; i < numSignatures; i++ { + for i := range numSignatures { _, err := decoder.Read(tx.Signatures[i][:]) if err != nil { return fmt.Errorf("unable to read tx.Signatures[%d]: %w", i, err) @@ -666,7 +663,6 @@ func (tx *Transaction) EncodeToTree(parent treeout.Branches) { message.Child(spew.Sdump(decodedInstruction)) } } else { - // TODO: log error? message.Child(fmt.Sprintf(text.RedBG("cannot decode instruction for %s program: %s"), progKey, err)). Child(text.IndigoBG("Program") + ": " + text.Bold("") + " " + text.ColorizeBG(progKey.String())). // @@ -796,10 +792,10 @@ func countWriteableAccounts(tx *Transaction) (count int) { } return count } - numStatisKeys := len(tx.Message.AccountKeys) - statisKeys := tx.Message.AccountKeys + numStaticKeys := len(tx.Message.AccountKeys) + staticKeys := tx.Message.AccountKeys h := tx.Message.Header - for _, key := range statisKeys { + for _, key := range staticKeys { accIndex, ok := getStaticAccountIndex(tx, key) if !ok { continue @@ -807,10 +803,11 @@ func countWriteableAccounts(tx *Transaction) (count int) { index := int(accIndex) is := false if index >= int(h.NumRequiredSignatures) { - // unsignedAccountIndex < numWritableUnsignedAccounts - is = index-int(h.NumRequiredSignatures) < (numStatisKeys-int(h.NumRequiredSignatures))-int(h.NumReadonlyUnsignedAccounts) + // Use int arithmetic to avoid underflow (Rust uses saturating_sub here). + numWritableUnsigned := max(numStaticKeys-int(h.NumRequiredSignatures)-int(h.NumReadonlyUnsignedAccounts), 0) + is = index-int(h.NumRequiredSignatures) < numWritableUnsigned } else { - is = index < int(h.NumRequiredSignatures-h.NumReadonlySignedAccounts) + is = index < max(int(h.NumRequiredSignatures)-int(h.NumReadonlySignedAccounts), 0) } if is { count++ diff --git a/transaction_test.go b/transaction_test.go index 4fb05e937..e5801036e 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -28,6 +28,15 @@ import ( "go.uber.org/zap" ) +// newTestInstruction is a shorthand for creating a testTransactionInstructions. +func newTestInstruction(programID PublicKey, accounts []*AccountMeta, data []byte) *testTransactionInstructions { + return &testTransactionInstructions{ + accounts: accounts, + data: data, + programID: programID, + } +} + type testTransactionInstructions struct { accounts []*AccountMeta data []byte @@ -56,7 +65,7 @@ func TestNewTransaction(t *testing.T) { {PublicKey: MustPublicKeyFromBase58("9hFtYBYmBJCVguRYs9pBTWKYAFoKfjYR7zBPpEkVsmD"), IsSigner: true, IsWritable: true}, }, data: []byte{0xaa, 0xbb}, - programID: MustPublicKeyFromBase58("11111111111111111111111111111111"), + programID: SystemProgramID, }, &testTransactionInstructions{ accounts: []*AccountMeta{ @@ -89,7 +98,7 @@ func TestNewTransaction(t *testing.T) { MustPublicKeyFromBase58("9hFtYBYmBJCVguRYs9pBTWKYAFoKfjYR7zBPpEkVsmD"), MustPublicKeyFromBase58("6FzXPEhCJoBx7Zw3SN9qhekHemd6E2b8kVguitmVAngW"), MustPublicKeyFromBase58("SysvarS1otHashes111111111111111111111111111"), - MustPublicKeyFromBase58("11111111111111111111111111111111"), + SystemProgramID, MustPublicKeyFromBase58("SysvarC1ock11111111111111111111111111111111"), MustPublicKeyFromBase58("Vote111111111111111111111111111111111111111"), }) @@ -122,7 +131,7 @@ func TestPartialSignTransaction(t *testing.T) { {PublicKey: signers[2].PublicKey(), IsSigner: true, IsWritable: false}, }, data: []byte{0xaa, 0xbb}, - programID: MustPublicKeyFromBase58("11111111111111111111111111111111"), + programID: SystemProgramID, }, } @@ -177,7 +186,7 @@ func TestSignTransaction(t *testing.T) { {PublicKey: signers[1].PublicKey(), IsSigner: true, IsWritable: true}, }, data: []byte{0xaa, 0xbb}, - programID: MustPublicKeyFromBase58("11111111111111111111111111111111"), + programID: SystemProgramID, }, } @@ -245,7 +254,7 @@ func TestTransactionDecode(t *testing.T) { PublicKeySlice{ MustPublicKeyFromBase58("52NGrUqh6tSGhr59ajGxsH3VnAaoRdSdTbAaV9G3UW35"), MustPublicKeyFromBase58("SRMuApVNdxXokk5GT7XD5cUUgXMBCoAz2LHeuAoKWRt"), - MustPublicKeyFromBase58("11111111111111111111111111111111"), + SystemProgramID, }, tx.Message.AccountKeys, ) @@ -340,7 +349,7 @@ func TestTransactionSerializePumpFunSwap(t *testing.T) { {PublicKey: MPK("9zpyjwrYdRWNMyqicoiuL3gUcrbvrkd5Kq9nxui1znw1"), IsSigner: false, IsWritable: true}, {PublicKey: MPK("BdQqJnuqqFhNZUNYGEEsuhBidpf8qHqfjDQvcjDN3nti"), IsSigner: false, IsWritable: true}, {PublicKey: MPK("o7RY6P2vQMuGSu1TrLM81weuzgDjaCRTXYRaXJwWcvc"), IsSigner: true, IsWritable: true}, - {PublicKey: MPK("11111111111111111111111111111111"), IsSigner: false, IsWritable: false}, + {PublicKey: SystemProgramID, IsSigner: false, IsWritable: false}, {PublicKey: MPK("TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA"), IsSigner: false, IsWritable: false}, {PublicKey: MPK("SysvarRent111111111111111111111111111111111"), IsSigner: false, IsWritable: false}, {PublicKey: MPK("Ce6TQqeHC9p8KetsN6JsjHK7UTZk7nasjjnr7XxXp9F1"), IsSigner: false, IsWritable: false}, @@ -399,3 +408,539 @@ func BenchmarkTransactionVerifySignatures(b *testing.B) { tx.VerifySignatures() } } + +// Ported from solana-sdk/transaction/src/lib.rs: test_transaction_serialize. +// Tests that a transaction survives binary serialization roundtrip. +func TestTransactionSerializationRoundtrip(t *testing.T) { + signers := []PrivateKey{ + NewWallet().PrivateKey, + NewWallet().PrivateKey, + } + instructions := []Instruction{ + newTestInstruction( + SystemProgramID, + []*AccountMeta{ + {PublicKey: signers[0].PublicKey(), IsSigner: true, IsWritable: true}, + {PublicKey: signers[1].PublicKey(), IsSigner: true, IsWritable: false}, + }, + []byte{0x01, 0x02, 0x03}, + ), + } + + blockhash, err := HashFromBase58("A9QnpgfhCkmiBSjgBuWk76Wo3HxzxvDopUq9x6UUMmjn") + require.NoError(t, err) + + tx, err := NewTransaction(instructions, blockhash) + require.NoError(t, err) + + _, err = tx.Sign(func(key PublicKey) *PrivateKey { + for _, signer := range signers { + if key.Equals(signer.PublicKey()) { + return &signer + } + } + return nil + }) + require.NoError(t, err) + + // Marshal to binary. + data, err := tx.MarshalBinary() + require.NoError(t, err) + + // Unmarshal back. + var decoded Transaction + err = decoded.UnmarshalWithDecoder(bin.NewBinDecoder(data)) + require.NoError(t, err) + + // Compare. + assert.Equal(t, tx.Signatures, decoded.Signatures) + assert.Equal(t, tx.Message.Header, decoded.Message.Header) + assert.Equal(t, tx.Message.AccountKeys, decoded.Message.AccountKeys) + assert.Equal(t, tx.Message.RecentBlockhash, decoded.Message.RecentBlockhash) + require.Equal(t, len(tx.Message.Instructions), len(decoded.Message.Instructions)) + for i := range tx.Message.Instructions { + assert.Equal(t, tx.Message.Instructions[i].ProgramIDIndex, decoded.Message.Instructions[i].ProgramIDIndex) + assert.Equal(t, tx.Message.Instructions[i].Accounts, decoded.Message.Instructions[i].Accounts) + } + + // Verify signatures still valid after roundtrip. + require.NoError(t, decoded.VerifySignatures()) +} + +// Ported from solana-sdk/transaction/src/lib.rs: test_sanitize_txs. +// Tests that signature count must match num_required_signatures. +func TestVerifySignatures_SignatureCountMismatch(t *testing.T) { + signer := NewWallet().PrivateKey + + tx, err := NewTransaction([]Instruction{ + newTestInstruction( + SystemProgramID, + []*AccountMeta{{PublicKey: signer.PublicKey(), IsSigner: true, IsWritable: true}}, + []byte{0x01}, + ), + }, Hash{1, 2, 3}) + require.NoError(t, err) + + _, err = tx.Sign(func(key PublicKey) *PrivateKey { + if key.Equals(signer.PublicKey()) { + return &signer + } + return nil + }) + require.NoError(t, err) + require.NoError(t, tx.VerifySignatures()) + + // Too many signatures. + tx.Signatures = append(tx.Signatures, tx.Signatures[0]) + err = tx.VerifySignatures() + require.Error(t, err) + assert.Contains(t, err.Error(), "signers") + + // Too few signatures. + tx.Signatures = nil + err = tx.VerifySignatures() + require.Error(t, err) +} + +// Ported from solana-sdk/transaction/src/lib.rs: test_transaction_instruction_with_duplicate_keys. +// Tests that duplicate account keys in instructions are deduplicated. +func TestNewTransaction_DuplicateAccountKeys(t *testing.T) { + key := newUniqueKey() + + tx, err := NewTransaction([]Instruction{ + newTestInstruction( + SystemProgramID, + []*AccountMeta{ + {PublicKey: key, IsSigner: true, IsWritable: true}, + {PublicKey: key, IsSigner: true, IsWritable: true}, // duplicate + }, + []byte{0x01}, + ), + }, Hash{}) + require.NoError(t, err) + + // The duplicate should be deduplicated in AccountKeys. + // Should have: key (payer/signer) + system program = 2 keys. + assert.Equal(t, 2, len(tx.Message.AccountKeys)) + assert.Equal(t, uint8(1), tx.Message.Header.NumRequiredSignatures) +} + +// Ported from solana-sdk/transaction/src/lib.rs: test_transaction_correct_key. +// Tests that signing with the correct key produces verifiable signatures. +func TestTransaction_SignAndVerify(t *testing.T) { + signer := NewWallet().PrivateKey + + tx, err := NewTransaction([]Instruction{ + newTestInstruction( + SystemProgramID, + []*AccountMeta{{PublicKey: signer.PublicKey(), IsSigner: true, IsWritable: true}}, + []byte{0xDE, 0xAD}, + ), + }, Hash{42}) + require.NoError(t, err) + + _, err = tx.Sign(func(key PublicKey) *PrivateKey { + if key.Equals(signer.PublicKey()) { + return &signer + } + return nil + }) + require.NoError(t, err) + + // Valid signature. + require.NoError(t, tx.VerifySignatures()) + + // Tamper with the signature → should fail. + tx.Signatures[0][0] ^= 0xFF + err = tx.VerifySignatures() + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid signature") +} + +// Tests that NewTransaction requires at least one instruction. +func TestNewTransaction_EmptyInstructions(t *testing.T) { + _, err := NewTransaction(nil, Hash{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "requires at-least one instruction") + + _, err = NewTransaction([]Instruction{}, Hash{}) + require.Error(t, err) +} + +// Tests TransactionBuilder. +func TestTransactionBuilder(t *testing.T) { + signer := NewWallet().PrivateKey + programID := SystemProgramID + blockhash := Hash{1, 2, 3} + + tx, err := NewTransactionBuilder(). + AddInstruction(newTestInstruction( + programID, + []*AccountMeta{{PublicKey: signer.PublicKey(), IsSigner: true, IsWritable: true}}, + []byte{0x01}, + )). + SetRecentBlockHash(blockhash). + SetFeePayer(signer.PublicKey()). + Build() + require.NoError(t, err) + + assert.Equal(t, blockhash, tx.Message.RecentBlockhash) + assert.Equal(t, uint8(1), tx.Message.Header.NumRequiredSignatures) + assert.Equal(t, signer.PublicKey(), tx.Message.AccountKeys[0]) +} + +// Tests NumWriteableAccounts, NumReadonlyAccounts, NumSigners for legacy transactions. +// Ported from solana-sdk/transaction/src/lib.rs header validation tests. +func TestTransaction_AccountCounts_Legacy(t *testing.T) { + keys := [5]PublicKey{} + for i := range keys { + keys[i] = newUniqueKey() + } + + // 2 signers (1 writable, 1 readonly), 3 unsigned (1 writable, 2 readonly) + // → header: 2 required, 1 readonly_signed, 2 readonly_unsigned + tx, err := NewTransaction([]Instruction{ + newTestInstruction( + keys[4], + []*AccountMeta{ + {PublicKey: keys[0], IsSigner: true, IsWritable: true}, + {PublicKey: keys[1], IsSigner: true, IsWritable: false}, + {PublicKey: keys[2], IsSigner: false, IsWritable: true}, + {PublicKey: keys[3], IsSigner: false, IsWritable: false}, + }, + []byte{0x01}, + ), + }, Hash{}, TransactionPayer(keys[0])) + require.NoError(t, err) + + assert.Equal(t, uint8(2), tx.Message.Header.NumRequiredSignatures) + assert.Equal(t, uint8(1), tx.Message.Header.NumReadonlySignedAccounts) + // keys[3] (readonly unsigned) + keys[4] (program, readonly unsigned) = 2 + assert.Equal(t, uint8(2), tx.Message.Header.NumReadonlyUnsignedAccounts) + + assert.Equal(t, 2, tx.NumSigners()) + assert.Equal(t, 3, tx.NumReadonlyAccounts()) // 1 readonly signed + 2 readonly unsigned + assert.Equal(t, 2, tx.NumWriteableAccounts()) // keys[0] writable signer + keys[2] writable unsigned +} + +// Tests GetProgramIDs. +func TestTransaction_GetProgramIDs(t *testing.T) { + prog1 := newUniqueKey() + prog2 := newUniqueKey() + signer := newUniqueKey() + + tx, err := NewTransaction([]Instruction{ + newTestInstruction(prog1, []*AccountMeta{{PublicKey: signer, IsSigner: true, IsWritable: true}}, []byte{0x01}), + newTestInstruction(prog2, []*AccountMeta{{PublicKey: signer, IsSigner: false, IsWritable: false}}, []byte{0x02}), + }, Hash{}, TransactionPayer(signer)) + require.NoError(t, err) + + programIDs, err := tx.GetProgramIDs() + require.NoError(t, err) + require.Equal(t, 2, len(programIDs)) + + // Verify both programs are present. + found := map[PublicKey]bool{} + for _, pid := range programIDs { + found[pid] = true + } + assert.True(t, found[prog1], "prog1 should be in program IDs") + assert.True(t, found[prog2], "prog2 should be in program IDs") +} + +// Tests IsVote. +func TestTransaction_IsVote(t *testing.T) { + signer := newUniqueKey() + t.Run("non-vote transaction", func(t *testing.T) { + tx, err := NewTransaction([]Instruction{ + newTestInstruction(SystemProgramID, []*AccountMeta{{PublicKey: signer, IsSigner: true, IsWritable: true}}, []byte{0x01}), + }, Hash{}, TransactionPayer(signer)) + require.NoError(t, err) + assert.False(t, tx.IsVote()) + }) + + t.Run("vote transaction", func(t *testing.T) { + tx, err := NewTransaction([]Instruction{ + newTestInstruction(VoteProgramID, []*AccountMeta{{PublicKey: signer, IsSigner: true, IsWritable: true}}, []byte{0x01}), + }, Hash{}, TransactionPayer(signer)) + require.NoError(t, err) + assert.True(t, tx.IsVote()) + }) +} + +// Tests HasAccount, IsSigner, IsWritable on Transaction. +func TestTransaction_AccountQueries(t *testing.T) { + signer := newUniqueKey() + writable := newUniqueKey() + readonly := newUniqueKey() + programID := SystemProgramID + + tx, err := NewTransaction([]Instruction{ + newTestInstruction(programID, []*AccountMeta{ + {PublicKey: signer, IsSigner: true, IsWritable: true}, + {PublicKey: writable, IsSigner: false, IsWritable: true}, + {PublicKey: readonly, IsSigner: false, IsWritable: false}, + }, []byte{0x01}), + }, Hash{}, TransactionPayer(signer)) + require.NoError(t, err) + + // HasAccount. + has, err := tx.HasAccount(signer) + require.NoError(t, err) + assert.True(t, has) + + has, err = tx.HasAccount(newUniqueKey()) + require.NoError(t, err) + assert.False(t, has) + + // IsSigner. + assert.True(t, tx.IsSigner(signer)) + assert.False(t, tx.IsSigner(writable)) + assert.False(t, tx.IsSigner(readonly)) + + // IsWritable. + w, err := tx.IsWritable(signer) + require.NoError(t, err) + assert.True(t, w, "signer should be writable") + + w, err = tx.IsWritable(writable) + require.NoError(t, err) + assert.True(t, w, "writable account should be writable") + + w, err = tx.IsWritable(readonly) + require.NoError(t, err) + assert.False(t, w, "readonly account should not be writable") +} + +// Tests that MarshalBinary pads missing signatures with zeroes. +// Ported from solana-web3.js reference in the Go code comment. +func TestMarshalBinary_PadsMissingSignatures(t *testing.T) { + signer := newUniqueKey() + + tx := &Transaction{ + Message: Message{ + Header: MessageHeader{ + NumRequiredSignatures: 2, + NumReadonlySignedAccounts: 0, + NumReadonlyUnsignedAccounts: 1, + }, + AccountKeys: PublicKeySlice{ + signer, + newUniqueKey(), + SystemProgramID, + }, + RecentBlockhash: Hash{1}, + Instructions: []CompiledInstruction{ + {ProgramIDIndex: 2, Accounts: []uint16{0, 1}, Data: []byte{0x01}}, + }, + }, + // Only 0 signatures provided, but 2 required. + Signatures: nil, + } + + data, err := tx.MarshalBinary() + require.NoError(t, err) + + // Decode and check that 2 dummy signatures were added. + var decoded Transaction + err = decoded.UnmarshalWithDecoder(bin.NewBinDecoder(data)) + require.NoError(t, err) + assert.Equal(t, 2, len(decoded.Signatures)) + // Both should be zero-filled. + assert.Equal(t, Signature{}, decoded.Signatures[0]) + assert.Equal(t, Signature{}, decoded.Signatures[1]) +} + +// Tests base64 roundtrip on Transaction. +func TestTransaction_Base64Roundtrip(t *testing.T) { + signer := NewWallet().PrivateKey + + tx, err := NewTransaction([]Instruction{ + newTestInstruction( + SystemProgramID, + []*AccountMeta{{PublicKey: signer.PublicKey(), IsSigner: true, IsWritable: true}}, + []byte{0xCA, 0xFE}, + ), + }, Hash{99}, TransactionPayer(signer.PublicKey())) + require.NoError(t, err) + + _, err = tx.Sign(func(key PublicKey) *PrivateKey { + if key.Equals(signer.PublicKey()) { + return &signer + } + return nil + }) + require.NoError(t, err) + + b64, err := tx.ToBase64() + require.NoError(t, err) + require.NotEmpty(t, b64) + + decoded, err := TransactionFromBase64(b64) + require.NoError(t, err) + require.NoError(t, decoded.VerifySignatures()) + + assert.Equal(t, tx.Signatures, decoded.Signatures) + assert.Equal(t, tx.Message.Header, decoded.Message.Header) + assert.Equal(t, tx.Message.AccountKeys, decoded.Message.AccountKeys) +} + +// Tests TransactionFromBytes and TransactionFromBase58. +func TestTransaction_FromVariousFormats(t *testing.T) { + // Use a known valid transaction. + b64 := "AfjEs3XhTc3hrxEvlnMPkm/cocvAUbFNbCl00qKnrFue6J53AhEqIFmcJJlJW3EDP5RmcMz+cNTTcZHW/WJYwAcBAAEDO8hh4VddzfcO5jbCt95jryl6y8ff65UcgukHNLWH+UQGgxCGGpgyfQVQV02EQYqm4QwzUt2qf9f1gVLM7rI4hwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA6ANIF55zOZWROWRkeh+lExxZBnKFqbvIxZDLE7EijjoBAgIAAQwCAAAAOTAAAAAAAAA=" + data, err := base64.StdEncoding.DecodeString(b64) + require.NoError(t, err) + + // FromBytes. + tx1, err := TransactionFromBytes(data) + require.NoError(t, err) + require.NotNil(t, tx1) + require.NoError(t, tx1.VerifySignatures()) + + // FromBase64. + tx2, err := TransactionFromBase64(b64) + require.NoError(t, err) + require.NotNil(t, tx2) + assert.Equal(t, tx1.Signatures, tx2.Signatures) + assert.Equal(t, tx1.Message.Header, tx2.Message.Header) +} + +// Tests Transaction.String() doesn't panic (EncodeToTree coverage). +func TestTransaction_String_NoPanic(t *testing.T) { + signer := NewWallet().PrivateKey + + tx, err := NewTransaction([]Instruction{ + newTestInstruction( + SystemProgramID, + []*AccountMeta{{PublicKey: signer.PublicKey(), IsSigner: true, IsWritable: true}}, + []byte{0x01}, + ), + }, Hash{}, TransactionPayer(signer.PublicKey())) + require.NoError(t, err) + + _, err = tx.Sign(func(key PublicKey) *PrivateKey { + if key.Equals(signer.PublicKey()) { + return &signer + } + return nil + }) + require.NoError(t, err) + + require.NotPanics(t, func() { + s := tx.String() + assert.NotEmpty(t, s) + }) +} + +// Tests that fee payer is always first in account keys. +// Ported from solana-sdk/transaction/src/lib.rs: test_message_payer_first. +func TestNewTransaction_FeePayerFirst(t *testing.T) { + payer := newUniqueKey() + other := newUniqueKey() + programID := SystemProgramID + + tx, err := NewTransaction([]Instruction{ + newTestInstruction(programID, []*AccountMeta{{PublicKey: other, IsSigner: true, IsWritable: true}}, []byte{0x01}), + }, Hash{}, TransactionPayer(payer)) + require.NoError(t, err) + + // Fee payer must be first. + assert.Equal(t, payer, tx.Message.AccountKeys[0]) + // Fee payer is always a writable signer. + assert.True(t, tx.IsSigner(payer)) + w, err := tx.IsWritable(payer) + require.NoError(t, err) + assert.True(t, w) +} + +// Tests NumWriteableAccounts for V0 transactions with address table lookups. +func TestTransaction_NumWriteableAccounts_V0(t *testing.T) { + payer := newUniqueKey() + programID := newUniqueKey() + acctA := newUniqueKey() + acctB := newUniqueKey() + tableKey := newUniqueKey() + + tables := map[PublicKey]PublicKeySlice{ + tableKey: {acctA, acctB}, + } + + tx, err := NewTransaction( + []Instruction{ + newTestInstruction(programID, []*AccountMeta{ + {PublicKey: acctA, IsSigner: false, IsWritable: true}, + {PublicKey: acctB, IsSigner: false, IsWritable: false}, + }, []byte{0x01}), + }, + Hash{}, + TransactionPayer(payer), + TransactionAddressTables(tables), + ) + require.NoError(t, err) + require.True(t, tx.Message.IsVersioned()) + + // payer (writable signer) + acctA (writable lookup) = 2 writable + assert.Equal(t, 2, tx.NumWriteableAccounts()) +} + +// Tests PartialSign with no signer provided doesn't corrupt state. +func TestPartialSign_NoSignerProvided(t *testing.T) { + signer := NewWallet().PrivateKey + + tx, err := NewTransaction([]Instruction{ + newTestInstruction( + SystemProgramID, + []*AccountMeta{{PublicKey: signer.PublicKey(), IsSigner: true, IsWritable: true}}, + []byte{0x01}, + ), + }, Hash{}, TransactionPayer(signer.PublicKey())) + require.NoError(t, err) + + // PartialSign with no matching key — should succeed but leave signature as zero. + sigs, err := tx.PartialSign(func(key PublicKey) *PrivateKey { + return nil + }) + require.NoError(t, err) + assert.Equal(t, 1, len(sigs)) + assert.Equal(t, Signature{}, sigs[0]) +} + +// Tests that multiple instructions with the same program ID are handled correctly. +func TestNewTransaction_MultipleInstructionsSameProgram(t *testing.T) { + signer := newUniqueKey() + target := newUniqueKey() + programID := SystemProgramID + + tx, err := NewTransaction([]Instruction{ + newTestInstruction(programID, []*AccountMeta{ + {PublicKey: signer, IsSigner: true, IsWritable: true}, + {PublicKey: target, IsSigner: false, IsWritable: true}, + }, []byte{0x01}), + newTestInstruction(programID, []*AccountMeta{ + {PublicKey: signer, IsSigner: true, IsWritable: true}, + {PublicKey: target, IsSigner: false, IsWritable: true}, + }, []byte{0x02}), + }, Hash{}, TransactionPayer(signer)) + require.NoError(t, err) + + // Program should appear only once in AccountKeys despite being used in 2 instructions. + programCount := 0 + for _, key := range tx.Message.AccountKeys { + if key.Equals(programID) { + programCount++ + } + } + assert.Equal(t, 1, programCount, "program should appear only once") + + // Both instructions should reference the same program index. + assert.Equal(t, tx.Message.Instructions[0].ProgramIDIndex, tx.Message.Instructions[1].ProgramIDIndex) + assert.Equal(t, 2, len(tx.Message.Instructions)) +} + +// Tests nil transaction edge cases for count functions. +func TestTransaction_NilCounts(t *testing.T) { + assert.Equal(t, -1, countSigners(nil)) + assert.Equal(t, -1, countReadonlyAccounts(nil)) + assert.Equal(t, -1, countWriteableAccounts(nil)) +} diff --git a/transaction_v0_test.go b/transaction_v0_test.go index 8e6de5c8d..5e8fe4509 100644 --- a/transaction_v0_test.go +++ b/transaction_v0_test.go @@ -54,7 +54,7 @@ func TestTransactionV0(t *testing.T) { MPK("2m4eNwBVqu6SgFk23HgE3W5MW89yT5z1vspz2WsiFBHF"), MPK("G6NDx85GM481GPjT5kUBAvjLxzDMsgRMQ1EAxzGswEJn"), MPK("81o7hHYN5a8fc5wdjjfznK9ziJ9wcuKXwbZnuYpanxMQ"), - MPK("11111111111111111111111111111111"), + SystemProgramID, MPK("MemoSq4gqABAXKb96qnH8TysNcWxMyWCqXgDLGmfcHr"), MPK("FKN5imdi7yadX4axe4hxaqBET4n6DBDRF5LKo5aBF53j"), MPK("3or4uF7ZyuQW5GGmcmdXDJasNiSZUURF2az1UrRPYQTg"), @@ -87,7 +87,7 @@ func TestTransactionV0(t *testing.T) { MPK("2m4eNwBVqu6SgFk23HgE3W5MW89yT5z1vspz2WsiFBHF"), MPK("G6NDx85GM481GPjT5kUBAvjLxzDMsgRMQ1EAxzGswEJn"), MPK("81o7hHYN5a8fc5wdjjfznK9ziJ9wcuKXwbZnuYpanxMQ"), - MPK("11111111111111111111111111111111"), + SystemProgramID, MPK("MemoSq4gqABAXKb96qnH8TysNcWxMyWCqXgDLGmfcHr"), MPK("FKN5imdi7yadX4axe4hxaqBET4n6DBDRF5LKo5aBF53j"), MPK("3or4uF7ZyuQW5GGmcmdXDJasNiSZUURF2az1UrRPYQTg"),