Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 48 additions & 20 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

bin "github.com/gagliardetto/binary"
"github.com/gagliardetto/treeout"
jsoniter "github.com/json-iterator/go"

"github.com/gagliardetto/solana-go/text"
)
Expand Down Expand Up @@ -195,42 +196,69 @@ func (mx *Message) NumWritableLookups() int {
func (mx Message) MarshalJSON() ([]byte, error) {
if mx.version == MessageVersionLegacy {
out := struct {
AccountKeys []string `json:"accountKeys"`
AccountKeys PublicKeySlice `json:"accountKeys"`
Header MessageHeader `json:"header"`
RecentBlockhash string `json:"recentBlockhash"`
RecentBlockhash Hash `json:"recentBlockhash"`
Instructions []CompiledInstruction `json:"instructions"`
}{
AccountKeys: make([]string, len(mx.AccountKeys)),
AccountKeys: mx.AccountKeys,
Header: mx.Header,
RecentBlockhash: mx.RecentBlockhash.String(),
RecentBlockhash: mx.RecentBlockhash,
Instructions: mx.Instructions,
}
for i, key := range mx.AccountKeys {
out.AccountKeys[i] = key.String()
}
return json.Marshal(out)
}
// Versioned message:
lookups := mx.AddressTableLookups
if lookups == nil {
lookups = MessageAddressTableLookupSlice{}
}
out := struct {
AccountKeys []string `json:"accountKeys"`
Header MessageHeader `json:"header"`
RecentBlockhash string `json:"recentBlockhash"`
Instructions []CompiledInstruction `json:"instructions"`
AddressTableLookups []MessageAddressTableLookup `json:"addressTableLookups"`
AccountKeys PublicKeySlice `json:"accountKeys"`
Header MessageHeader `json:"header"`
RecentBlockhash Hash `json:"recentBlockhash"`
Instructions []CompiledInstruction `json:"instructions"`
AddressTableLookups MessageAddressTableLookupSlice `json:"addressTableLookups"`
}{
AccountKeys: make([]string, len(mx.AccountKeys)),
AccountKeys: mx.AccountKeys,
Header: mx.Header,
RecentBlockhash: mx.RecentBlockhash.String(),
RecentBlockhash: mx.RecentBlockhash,
Instructions: mx.Instructions,
AddressTableLookups: mx.AddressTableLookups,
AddressTableLookups: lookups,
}
for i, key := range mx.AccountKeys {
out.AccountKeys[i] = key.String()
return json.Marshal(out)
}

// UnmarshalJSON decodes the message from JSON and determines its version.
// The Solana RPC emits `addressTableLookups` only for versioned (V0+)
// messages; its presence in the JSON is what distinguishes V0 from legacy,
// since the private `version` field has no wire representation.
func (mx *Message) UnmarshalJSON(data []byte) error {
// Decode `addressTableLookups` via a RawMessage pointer so presence of the
// key can be detected in a single parse. A non-nil pointer means the key
// was present in the JSON (even if its value is `null`), which selects V0.
aux := struct {
AccountKeys PublicKeySlice `json:"accountKeys"`
Header MessageHeader `json:"header"`
RecentBlockhash Hash `json:"recentBlockhash"`
Instructions []CompiledInstruction `json:"instructions"`
AddressTableLookups *jsoniter.RawMessage `json:"addressTableLookups"`
}{}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
if out.AddressTableLookups == nil {
out.AddressTableLookups = make([]MessageAddressTableLookup, 0)
mx.AccountKeys = aux.AccountKeys
mx.Header = aux.Header
mx.RecentBlockhash = aux.RecentBlockhash
mx.Instructions = aux.Instructions

if aux.AddressTableLookups == nil {
mx.version = MessageVersionLegacy
mx.AddressTableLookups = nil
return nil
}
return json.Marshal(out)
mx.version = MessageVersionV0
return json.Unmarshal(*aux.AddressTableLookups, &mx.AddressTableLookups)
}

func (mx *Message) EncodeToTree(txTree treeout.Branches) {
Expand Down
66 changes: 66 additions & 0 deletions message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1233,6 +1233,72 @@ func TestHasDuplicates_Loaded(t *testing.T) {
})
}

// Regression test for https://github.com/solana-foundation/solana-go/issues/339:
// when decoding an RPC response with `encoding: json`, the Message version was
// always left as Legacy because the private `version` field has no JSON tag.
// The presence of `addressTableLookups` in the raw JSON now selects V0.
func TestMessageJSONVersionDetection(t *testing.T) {
legacyJSON := []byte(`{
"accountKeys": ["11111111111111111111111111111111"],
"header": {"numRequiredSignatures":1,"numReadonlySignedAccounts":0,"numReadonlyUnsignedAccounts":0},
"recentBlockhash": "11111111111111111111111111111111",
"instructions": []
}`)
v0JSON := []byte(`{
"accountKeys": ["11111111111111111111111111111111"],
"header": {"numRequiredSignatures":1,"numReadonlySignedAccounts":0,"numReadonlyUnsignedAccounts":0},
"recentBlockhash": "11111111111111111111111111111111",
"instructions": [],
"addressTableLookups": []
}`)

var legacy Message
require.NoError(t, json.Unmarshal(legacyJSON, &legacy))
assert.Equal(t, MessageVersionLegacy, legacy.GetVersion())
assert.False(t, legacy.IsVersioned())

var versioned Message
require.NoError(t, json.Unmarshal(v0JSON, &versioned))
assert.Equal(t, MessageVersionV0, versioned.GetVersion())
assert.True(t, versioned.IsVersioned())
}

// TestMessageJSONVersionRoundtrip ensures MarshalJSON/UnmarshalJSON preserve
// the message version across the round-trip.
func TestMessageJSONVersionRoundtrip(t *testing.T) {
blockhash := Hash{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}

for _, tc := range []struct {
name string
version MessageVersion
lookups MessageAddressTableLookupSlice
}{
{"legacy", MessageVersionLegacy, nil},
{"v0_no_lookups", MessageVersionV0, nil},
{"v0_with_lookups", MessageVersionV0, MessageAddressTableLookupSlice{
{AccountKey: newUniqueKey(), WritableIndexes: []uint8{0, 1}, ReadonlyIndexes: []uint8{2}},
}},
} {
t.Run(tc.name, func(t *testing.T) {
original := Message{
version: tc.version,
Header: MessageHeader{NumRequiredSignatures: 1},
AccountKeys: PublicKeySlice{newUniqueKey()},
RecentBlockhash: blockhash,
Instructions: []CompiledInstruction{},
AddressTableLookups: tc.lookups,
}
data, err := json.Marshal(original)
require.NoError(t, err)

var decoded Message
require.NoError(t, json.Unmarshal(data, &decoded))
assert.Equal(t, tc.version, decoded.GetVersion())
})
}
}

// hasDuplicates is a test helper matching Rust's has_duplicates check.
func hasDuplicates(keys PublicKeySlice) bool {
seen := make(map[PublicKey]struct{}, len(keys))
Expand Down
12 changes: 6 additions & 6 deletions transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,17 +453,17 @@ func NewTransaction(instructions []Instruction, recentBlockHash Hash, opts ...Tr
}

var idx uint16
accountKeyIndex := make(map[string]uint16, len(message.AccountKeys)+len(lookupsWritableKeys)+len(lookupsReadOnlyKeys))
accountKeyIndex := make(map[PublicKey]uint16, len(message.AccountKeys)+len(lookupsWritableKeys)+len(lookupsReadOnlyKeys))
for _, acc := range message.AccountKeys {
accountKeyIndex[acc.String()] = idx
accountKeyIndex[acc] = idx
idx++
}
for _, acc := range lookupsWritableKeys {
accountKeyIndex[acc.String()] = idx
accountKeyIndex[acc] = idx
idx++
}
for _, acc := range lookupsReadOnlyKeys {
accountKeyIndex[acc.String()] = idx
accountKeyIndex[acc] = idx
idx++
}

Expand All @@ -479,14 +479,14 @@ func NewTransaction(instructions []Instruction, recentBlockHash Hash, opts ...Tr
accounts = instruction.Accounts()
accountIndex := make([]uint16, len(accounts))
for idx, acc := range accounts {
accountIndex[idx] = accountKeyIndex[acc.PublicKey.String()]
accountIndex[idx] = accountKeyIndex[acc.PublicKey]
}
data, err := instruction.Data()
if err != nil {
return nil, fmt.Errorf("unable to encode instructions [%d]: %w", txIdx, err)
}
message.Instructions = append(message.Instructions, CompiledInstruction{
ProgramIDIndex: accountKeyIndex[instruction.ProgramID().String()],
ProgramIDIndex: accountKeyIndex[instruction.ProgramID()],
Accounts: accountIndex,
Data: data,
})
Expand Down
Loading