Skip to content

Commit ac7125b

Browse files
authored
feat: Add SetAccounts implementation to be able to access accounts for Create instruction after decoding (#337)
* Add SetAccounts implementation to be able to access accounts for Create instruction after decoding * associated-token-account: Create expects 6 accounts (remove rent sysvar); fix tree count; lowercase error strings * (fix): marshalling unsigned transaction * undo a change to message.go * undo more changes to message
1 parent 48bd64f commit ac7125b

5 files changed

Lines changed: 240 additions & 38 deletions

File tree

message.go

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -349,22 +349,19 @@ func (mx *Message) MarshalV0() ([]byte, error) {
349349
}
350350
buf = append([]byte{byte(versionNum + 127)}, buf...)
351351

352-
if mx.AddressTableLookups != nil && len(mx.AddressTableLookups) > 0 {
353-
// wite length of address table lookups as u8
354-
buf = append(buf, byte(len(mx.AddressTableLookups)))
355-
for _, lookup := range mx.AddressTableLookups {
356-
// write account pubkey
357-
buf = append(buf, lookup.AccountKey[:]...)
358-
// write writable indexes
359-
bin.EncodeCompactU16Length(&buf, len(lookup.WritableIndexes))
360-
buf = append(buf, lookup.WritableIndexes...)
361-
// write readonly indexes
362-
bin.EncodeCompactU16Length(&buf, len(lookup.ReadonlyIndexes))
363-
buf = append(buf, lookup.ReadonlyIndexes...)
364-
}
365-
} else {
366-
buf = append(buf, 0)
352+
// wite length of address table lookups as u8
353+
buf = append(buf, byte(len(mx.AddressTableLookups)))
354+
for _, lookup := range mx.AddressTableLookups {
355+
// write account pubkey
356+
buf = append(buf, lookup.AccountKey[:]...)
357+
// write writable indexes
358+
bin.EncodeCompactU16Length(&buf, len(lookup.WritableIndexes))
359+
buf = append(buf, lookup.WritableIndexes...)
360+
// write readonly indexes
361+
bin.EncodeCompactU16Length(&buf, len(lookup.ReadonlyIndexes))
362+
buf = append(buf, lookup.ReadonlyIndexes...)
367363
}
364+
368365
return buf, nil
369366
}
370367

programs/associated-token-account/Create.go

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,6 @@ type Create struct {
4646
//
4747
// [5] = [] TokenProgram
4848
// ··········· SPL token program ID
49-
//
50-
// [6] = [] SysVarRent
51-
// ··········· SysVarRentPubkey
5249
solana.AccountMetaSlice `bin:"-" borsh_skip:"true"`
5350
}
5451

@@ -73,8 +70,18 @@ func (inst *Create) SetMint(mint solana.PublicKey) *Create {
7370
return inst
7471
}
7572

76-
func (inst Create) Build() *Instruction {
73+
func (inst *Create) SetAccounts(accounts []*solana.AccountMeta) error {
74+
inst.AccountMetaSlice = accounts
75+
if len(accounts) < 6 {
76+
return fmt.Errorf("insufficient accounts, Create requires at-least 6 accounts not %d", len(accounts))
77+
}
78+
inst.Payer = accounts[0].PublicKey
79+
inst.Wallet = accounts[2].PublicKey
80+
inst.Mint = accounts[3].PublicKey
81+
return nil
82+
}
7783

84+
func (inst Create) Build() *Instruction {
7885
// Find the associatedTokenAddress;
7986
associatedTokenAddress, _, _ := solana.FindAssociatedTokenAddress(
8087
inst.Wallet,
@@ -112,11 +119,6 @@ func (inst Create) Build() *Instruction {
112119
IsSigner: false,
113120
IsWritable: false,
114121
},
115-
{
116-
PublicKey: solana.SysVarRentPubkey,
117-
IsSigner: false,
118-
IsWritable: false,
119-
},
120122
}
121123

122124
inst.AccountMetaSlice = keys
@@ -139,13 +141,13 @@ func (inst Create) ValidateAndBuild() (*Instruction, error) {
139141

140142
func (inst *Create) Validate() error {
141143
if inst.Payer.IsZero() {
142-
return errors.New("Payer not set")
144+
return errors.New("payer not set")
143145
}
144146
if inst.Wallet.IsZero() {
145-
return errors.New("Wallet not set")
147+
return errors.New("wallet not set")
146148
}
147149
if inst.Mint.IsZero() {
148-
return errors.New("Mint not set")
150+
return errors.New("mint not set")
149151
}
150152
_, _, err := solana.FindAssociatedTokenAddress(
151153
inst.Wallet,
@@ -164,19 +166,17 @@ func (inst *Create) EncodeToTree(parent treeout.Branches) {
164166
programBranch.Child(format.Instruction("Create")).
165167
//
166168
ParentFunc(func(instructionBranch treeout.Branches) {
167-
168169
// Parameters of the instruction:
169170
instructionBranch.Child("Params[len=0]").ParentFunc(func(paramsBranch treeout.Branches) {})
170171

171172
// Accounts of the instruction:
172-
instructionBranch.Child("Accounts[len=7").ParentFunc(func(accountsBranch treeout.Branches) {
173+
instructionBranch.Child("Accounts[len=6").ParentFunc(func(accountsBranch treeout.Branches) {
173174
accountsBranch.Child(format.Meta(" payer", inst.AccountMetaSlice.Get(0)))
174175
accountsBranch.Child(format.Meta("associatedTokenAddress", inst.AccountMetaSlice.Get(1)))
175176
accountsBranch.Child(format.Meta(" wallet", inst.AccountMetaSlice.Get(2)))
176177
accountsBranch.Child(format.Meta(" tokenMint", inst.AccountMetaSlice.Get(3)))
177178
accountsBranch.Child(format.Meta(" systemProgram", inst.AccountMetaSlice.Get(4)))
178179
accountsBranch.Child(format.Meta(" tokenProgram", inst.AccountMetaSlice.Get(5)))
179-
accountsBranch.Child(format.Meta(" sysVarRent", inst.AccountMetaSlice.Get(6)))
180180
})
181181
})
182182
})
@@ -200,3 +200,19 @@ func NewCreateInstruction(
200200
SetWallet(walletAddress).
201201
SetMint(splTokenMintAddress)
202202
}
203+
204+
func (inst *Create) GetPayerAccount() *solana.AccountMeta {
205+
return inst.AccountMetaSlice.Get(0)
206+
}
207+
208+
func (inst *Create) GetAssociatedTokenAddressAccount() *solana.AccountMeta {
209+
return inst.AccountMetaSlice.Get(1)
210+
}
211+
212+
func (inst *Create) GetWalletAccount() *solana.AccountMeta {
213+
return inst.AccountMetaSlice.Get(2)
214+
}
215+
216+
func (inst *Create) GetMintAccount() *solana.AccountMeta {
217+
return inst.AccountMetaSlice.Get(3)
218+
}
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
// Copyright 2021 github.com/gagliardetto
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package associatedtokenaccount
16+
17+
import (
18+
"encoding/hex"
19+
"testing"
20+
21+
bin "github.com/gagliardetto/binary"
22+
solana "github.com/gagliardetto/solana-go"
23+
"github.com/stretchr/testify/assert"
24+
"github.com/stretchr/testify/require"
25+
)
26+
27+
func TestEncodingInstruction(t *testing.T) {
28+
t.Run("should encode", func(t *testing.T) {
29+
t.Run("Create", func(t *testing.T) {
30+
// Build an instruction and ensure current encoding matches implementation
31+
payer := solana.NewWallet().PublicKey()
32+
wallet := solana.NewWallet().PublicKey()
33+
mint := solana.NewWallet().PublicKey()
34+
ix := NewCreateInstructionBuilder().
35+
SetPayer(payer).
36+
SetWallet(wallet).
37+
SetMint(mint).
38+
Build()
39+
data, err := ix.Data()
40+
require.NoError(t, err)
41+
encodedHex := hex.EncodeToString(data)
42+
// Current ATA Create encodes no payload bytes
43+
require.Equal(t, "", encodedHex)
44+
})
45+
})
46+
47+
tests := []struct {
48+
name string
49+
hexData string
50+
expectInstruction *Instruction
51+
}{
52+
{
53+
name: "Create",
54+
hexData: "",
55+
expectInstruction: &Instruction{
56+
BaseVariant: bin.BaseVariant{
57+
TypeID: bin.TypeIDFromUint8(0),
58+
Impl: &Create{},
59+
},
60+
},
61+
},
62+
}
63+
64+
t.Run("should encode", func(t *testing.T) {
65+
for _, test := range tests {
66+
t.Run(test.name, func(t *testing.T) {
67+
data, err := test.expectInstruction.Data()
68+
require.NoError(t, err)
69+
encodedHex := hex.EncodeToString(data)
70+
require.Equal(t, test.hexData, encodedHex)
71+
})
72+
}
73+
})
74+
75+
t.Run("should decode", func(t *testing.T) {
76+
for _, test := range tests {
77+
t.Run(test.name, func(t *testing.T) {
78+
data, err := hex.DecodeString(test.hexData)
79+
require.NoError(t, err)
80+
var instruction *Instruction
81+
err = bin.NewBinDecoder(data).Decode(&instruction)
82+
require.NoError(t, err)
83+
assert.Equal(t, test.expectInstruction, instruction)
84+
})
85+
}
86+
})
87+
}
88+
89+
func TestDecodeSetsAccountsAndGetters(t *testing.T) {
90+
payer := solana.NewWallet().PublicKey()
91+
wallet := solana.NewWallet().PublicKey()
92+
mint := solana.NewWallet().PublicKey()
93+
94+
// Build an instruction to obtain correctly ordered accounts and data
95+
ix := NewCreateInstructionBuilder().
96+
SetPayer(payer).
97+
SetWallet(wallet).
98+
SetMint(mint).
99+
Build()
100+
101+
accounts := ix.Accounts()
102+
data, err := ix.Data()
103+
require.NoError(t, err)
104+
105+
decoded, err := DecodeInstruction(accounts, data)
106+
require.NoError(t, err)
107+
108+
create, ok := decoded.Impl.(*Create)
109+
require.True(t, ok)
110+
111+
// Check decoded fields populated via SetAccounts
112+
assert.Equal(t, payer, create.Payer)
113+
assert.Equal(t, wallet, create.Wallet)
114+
assert.Equal(t, mint, create.Mint)
115+
116+
// Check getters return expected account metas
117+
require.NotNil(t, create.GetPayerAccount())
118+
require.NotNil(t, create.GetAssociatedTokenAddressAccount())
119+
require.NotNil(t, create.GetWalletAccount())
120+
require.NotNil(t, create.GetMintAccount())
121+
122+
assert.True(t, create.GetPayerAccount().IsSigner)
123+
assert.True(t, create.GetPayerAccount().IsWritable)
124+
assert.Equal(t, payer, create.GetPayerAccount().PublicKey)
125+
assert.Equal(t, wallet, create.GetWalletAccount().PublicKey)
126+
assert.Equal(t, mint, create.GetMintAccount().PublicKey)
127+
128+
// Verify associated token address is correctly derived and placed at index 1
129+
ata, _, err := solana.FindAssociatedTokenAddress(wallet, mint)
130+
require.NoError(t, err)
131+
assert.Equal(t, ata, create.GetAssociatedTokenAddressAccount().PublicKey)
132+
}

transaction.go

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -484,16 +484,24 @@ func (tx *Transaction) MarshalBinary() ([]byte, error) {
484484
return nil, fmt.Errorf("failed to encode tx.Message to binary: %w", err)
485485
}
486486

487-
var signatureCount []byte
488-
bin.EncodeCompactU16Length(&signatureCount, len(tx.Signatures))
489-
output := make([]byte, 0, len(signatureCount)+len(signatureCount)*64+len(messageContent))
490-
output = append(output, signatureCount...)
491-
for _, sig := range tx.Signatures {
492-
output = append(output, sig[:]...)
487+
signatures := tx.Signatures
488+
for i := len(signatures); i < int(tx.Message.Header.NumRequiredSignatures); i++ {
489+
// append dummy signatures to the transaction, without it serialized transaction will be invalid
490+
// reference: https://github.com/solana-labs/solana-web3.js/blob/4e9988cfc561f3ed11f4c5016a29090a61d129a8/src/transaction/versioned.ts#L36
491+
signatures = append(signatures, SignatureFromBytes(make([]byte, SignatureLength)))
493492
}
494-
output = append(output, messageContent...)
495493

496-
return output, nil
494+
var signaturesCountBytes []byte
495+
bin.EncodeCompactU16Length(&signaturesCountBytes, len(signatures))
496+
497+
binaryTx := make([]byte, 0, len(signaturesCountBytes)+len(signatures)*64+len(messageContent))
498+
binaryTx = append(binaryTx, signaturesCountBytes...)
499+
for _, sig := range signatures {
500+
binaryTx = append(binaryTx, sig[:]...)
501+
}
502+
503+
binaryTx = append(binaryTx, messageContent...)
504+
return binaryTx, nil
497505
}
498506

499507
func (tx Transaction) MarshalWithEncoder(encoder *bin.Encoder) error {

transaction_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"github.com/mr-tron/base58"
2626
"github.com/stretchr/testify/assert"
2727
"github.com/stretchr/testify/require"
28+
"go.uber.org/zap"
2829
)
2930

3031
type testTransactionInstructions struct {
@@ -315,6 +316,54 @@ func TestTransactionVerifySignatures(t *testing.T) {
315316
}
316317
}
317318

319+
func TestTransactionSerializeExisting(t *testing.T) {
320+
// random pump amm swap transaction (3HWKcTbnAMXt3TZDi8LitZCAT5ht7tYqXCroQgNkEnRuvjWCAhmx4UAFnKTWzxS2JXxhbfTiKEdXeU3VHWKzEkNY)
321+
trueEncoded := "AXJEirR5ePYXWIemCsSrRB3kTOxBQvJ8pZ4Of+vs75/Lw4NgNf0jr+eyI+2CAZZcwEQ54v/tcIh0p5qisd6ZfQWAAQAHDuIP6DQ7XgVvZzx4ZyZS5BjxS0s5JIGa893M/2foLj/N9tJulKNTEJ+CcURWZpXddGLJc0niyvBs8fADaZXWBStZSYulGfGwKCcEVhAem2vYiJnZnDQHW8EbXuli2pgFPcs4OJAtX+MJFvqNDoxBApNOXVmhOPi+s+Xj5G6dipCzZIG9Det/kYMpmklt7LaKvckpmIhAC+cygPT/H7L6uDUm47unlLqsWvR0JhT3lhSFUnSWfDsT92IHGjplDB07Bwa4Mppngp8gB0rx3o6jx3q5zZqJ7jaF//YA5f9pM0rWAwZGb+UhFzL/7K26csOb57yM5bvF9xJrLEObOkAAAACMlyWPTiSJ8bs9ECkUjg2DC1oTmdr/EIQEjnvY2+n4WQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABt324ddloZPZy+FGzut5rBy0he1fWzeROoz1hX7/AKkMFN78gl7GdpQlCBi7ZUBl9CmNMVbVcbTU+AkMGOmoY2CQL4wWkL0iw2m16tD4VGZ2ZSCn9Rr5uDY8UTDVNj2KSwXRBHtnMBytk20Zd1LAlKbsLUEcQ8RupHeXS0aTLV9S0zqFMcANe7dmXMUOgO/qYKWak7hGKVjQNpnUxyaEPQgHAAUCkHwBAAcACQNkjocBAAAAAAgGAAEAEAkKAQELEQwADg0QAgEDBBEPCgoJCBILGDPmhaQBf4OtOLwvMTcAAABscskIAAAAAAoDAQAAAQkKAwIAAAEJCQIABQwCAAAAQEtMAAAAAAAJAgAGDAIAAABAQg8AAAAAAAEcQpjY9205kUMXxtRgI8+U286AVT/u2xYmbclxYfAs+QJDZQMS0kc="
322+
323+
tx, err := TransactionFromBase64(trueEncoded)
324+
require.NoError(t, err)
325+
326+
encoded, err := tx.ToBase64()
327+
require.NoError(t, err)
328+
require.Equal(t, trueEncoded, encoded)
329+
}
330+
331+
func TestTransactionSerializePumpFunSwap(t *testing.T) {
332+
expectedEncoded := "AQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABAAcMC8/60geEHarnEMtcmE3t7lADNe2/gxa7NAhBe5Ufe5mtEeak/ClEpPqCUb74FUJuG/soxrZkZndgfGrZ9WamRvj7gB4Ax7akKlldX2HW0ZpscDlcG0xgSGrm4qPYLjpXha3ZyoEhdb0urZWzhcxyIVTkHNfJDlRkaaaZumtUuJid6LnvUOa256MT0Ym0MG/y6Uqt2PX3ijrL9vC9eTaGKzqGXmnuD1SAyrz2Y1fk3C8Y1Y1Fwep0ifs3I9l5PHKm6c4nvkyiO9V3lgShoznE+kfvld5CoS4qMMKpnUOErY8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAbd9uHXZaGT2cvhRs7reawctIXtX1s3kTqM9YV+/wCpBqfVFxksXFEhjMlMPUrxf1ja7gibof1E49vZigAAAACs8TbrAfwcTog9I8i1hEq1mjf2at1XxemsO1PgWdNcZAFW4PaTZlrPRNsVaL8XW6pRicuX9dL/O2VdK7b9bRiw0WlzNBoKArNIcmjqJI0o+XoQGnMjSEA/HZqyYrSJgLkBCwwFAQYCAwQABwgJCgsYZgY9EgHa6+pMR9bEaRkAAHg1HjQAAAAA"
333+
instruction := &testTransactionInstructions{
334+
programID: MPK("6EF8rrecthR5Dkzon8Nwu78hRvfCKubJ14M5uBEwF6P"),
335+
accounts: []*AccountMeta{
336+
{PublicKey: MPK("4wTV1YmiEkRvAtNtsSGPtUrqRYQMe5SKy2uB4Jjaxnjf"), IsSigner: false, IsWritable: false},
337+
{PublicKey: MPK("CebN5WGQ4jvEPvsVU4EoHEpgzq1VV7AbicfhtW4xC9iM"), IsSigner: false, IsWritable: true},
338+
{PublicKey: MPK("GjgKTqtzDei5E3uZyA2CN29KQgugF564K1hoc1jHpump"), IsSigner: false, IsWritable: false},
339+
{PublicKey: MPK("HkvYAZV1Mg6kt5KMaA5YBQazZECg21zaZdQEMUiLrjKc"), IsSigner: false, IsWritable: true},
340+
{PublicKey: MPK("9zpyjwrYdRWNMyqicoiuL3gUcrbvrkd5Kq9nxui1znw1"), IsSigner: false, IsWritable: true},
341+
{PublicKey: MPK("BdQqJnuqqFhNZUNYGEEsuhBidpf8qHqfjDQvcjDN3nti"), IsSigner: false, IsWritable: true},
342+
{PublicKey: MPK("o7RY6P2vQMuGSu1TrLM81weuzgDjaCRTXYRaXJwWcvc"), IsSigner: true, IsWritable: true},
343+
{PublicKey: MPK("11111111111111111111111111111111"), IsSigner: false, IsWritable: false},
344+
{PublicKey: MPK("TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA"), IsSigner: false, IsWritable: false},
345+
{PublicKey: MPK("SysvarRent111111111111111111111111111111111"), IsSigner: false, IsWritable: false},
346+
{PublicKey: MPK("Ce6TQqeHC9p8KetsN6JsjHK7UTZk7nasjjnr7XxXp9F1"), IsSigner: false, IsWritable: false},
347+
{PublicKey: MPK("6EF8rrecthR5Dkzon8Nwu78hRvfCKubJ14M5uBEwF6P"), IsSigner: false, IsWritable: false},
348+
},
349+
data: []byte{102, 6, 61, 18, 1, 218, 235, 234, 76, 71, 214, 196, 105, 25, 0, 0, 120, 53, 30, 52, 0, 0, 0, 0},
350+
}
351+
352+
tx, err := NewTransactionBuilder().
353+
AddInstruction(instruction).
354+
SetFeePayer(MPK("o7RY6P2vQMuGSu1TrLM81weuzgDjaCRTXYRaXJwWcvc")).
355+
SetRecentBlockHash(MustHashFromBase58("F6TUDvYPMwDLP1MW4BUWTNm6S94XR1UZ2nGVyubqo6oi")).
356+
Build()
357+
require.NoError(t, err)
358+
require.NotNil(t, tx)
359+
360+
encoded, err := tx.ToBase64()
361+
require.NoError(t, err)
362+
363+
zlog.Debug("encoded", zap.String("encoded", encoded))
364+
require.Equal(t, expectedEncoded, encoded)
365+
}
366+
318367
func BenchmarkTransactionFromDecoder(b *testing.B) {
319368
txString := "Ak8jvC3ch5hq3lhOHPkACoFepIUON2zEN4KRcw4lDS6GBsQfnSdzNGPETm/yi0hPKk75/i2VXFj0FLUWnGR64ADyUbqnirFjFtaSNgcGi02+Tm7siT4CPpcaTq0jxfYQK/h9FdxXXPnLry74J+RE8yji/BtJ/Cjxbx+TIHigeIYJAgEBBByE1Y6EqCJKsr7iEupU6lsBHtBdtI4SK3yWMCFA0iEKeFPgnGmtp+1SIX1Ak+sN65iBaR7v4Iim5m1OEuFQTgi9N57UnhNpCNuUePaTt7HJaFBmyeZB3deXeKWVudpY3gAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAWVECK/n3a7QR6OKWYR4DuAVjS6FXgZj82W0dJpSIPnEBAwQAAgEDDAIAAABAQg8AAAAAAA=="
320369
txBin, err := base64.StdEncoding.DecodeString(txString)

0 commit comments

Comments
 (0)