Skip to content

Commit 1fd2201

Browse files
Merge pull request #404 from sonicfromnewyoke/sonic/fix-json-msg-ver-detect
fix(message): json version detection
2 parents 8ed3105 + 7b75471 commit 1fd2201

3 files changed

Lines changed: 120 additions & 26 deletions

File tree

message.go

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323

2424
bin "github.com/gagliardetto/binary"
2525
"github.com/gagliardetto/treeout"
26+
jsoniter "github.com/json-iterator/go"
2627

2728
"github.com/gagliardetto/solana-go/text"
2829
)
@@ -195,42 +196,69 @@ func (mx *Message) NumWritableLookups() int {
195196
func (mx Message) MarshalJSON() ([]byte, error) {
196197
if mx.version == MessageVersionLegacy {
197198
out := struct {
198-
AccountKeys []string `json:"accountKeys"`
199+
AccountKeys PublicKeySlice `json:"accountKeys"`
199200
Header MessageHeader `json:"header"`
200-
RecentBlockhash string `json:"recentBlockhash"`
201+
RecentBlockhash Hash `json:"recentBlockhash"`
201202
Instructions []CompiledInstruction `json:"instructions"`
202203
}{
203-
AccountKeys: make([]string, len(mx.AccountKeys)),
204+
AccountKeys: mx.AccountKeys,
204205
Header: mx.Header,
205-
RecentBlockhash: mx.RecentBlockhash.String(),
206+
RecentBlockhash: mx.RecentBlockhash,
206207
Instructions: mx.Instructions,
207208
}
208-
for i, key := range mx.AccountKeys {
209-
out.AccountKeys[i] = key.String()
210-
}
211209
return json.Marshal(out)
212210
}
213211
// Versioned message:
212+
lookups := mx.AddressTableLookups
213+
if lookups == nil {
214+
lookups = MessageAddressTableLookupSlice{}
215+
}
214216
out := struct {
215-
AccountKeys []string `json:"accountKeys"`
216-
Header MessageHeader `json:"header"`
217-
RecentBlockhash string `json:"recentBlockhash"`
218-
Instructions []CompiledInstruction `json:"instructions"`
219-
AddressTableLookups []MessageAddressTableLookup `json:"addressTableLookups"`
217+
AccountKeys PublicKeySlice `json:"accountKeys"`
218+
Header MessageHeader `json:"header"`
219+
RecentBlockhash Hash `json:"recentBlockhash"`
220+
Instructions []CompiledInstruction `json:"instructions"`
221+
AddressTableLookups MessageAddressTableLookupSlice `json:"addressTableLookups"`
220222
}{
221-
AccountKeys: make([]string, len(mx.AccountKeys)),
223+
AccountKeys: mx.AccountKeys,
222224
Header: mx.Header,
223-
RecentBlockhash: mx.RecentBlockhash.String(),
225+
RecentBlockhash: mx.RecentBlockhash,
224226
Instructions: mx.Instructions,
225-
AddressTableLookups: mx.AddressTableLookups,
227+
AddressTableLookups: lookups,
226228
}
227-
for i, key := range mx.AccountKeys {
228-
out.AccountKeys[i] = key.String()
229+
return json.Marshal(out)
230+
}
231+
232+
// UnmarshalJSON decodes the message from JSON and determines its version.
233+
// The Solana RPC emits `addressTableLookups` only for versioned (V0+)
234+
// messages; its presence in the JSON is what distinguishes V0 from legacy,
235+
// since the private `version` field has no wire representation.
236+
func (mx *Message) UnmarshalJSON(data []byte) error {
237+
// Decode `addressTableLookups` via a RawMessage pointer so presence of the
238+
// key can be detected in a single parse. A non-nil pointer means the key
239+
// was present in the JSON (even if its value is `null`), which selects V0.
240+
aux := struct {
241+
AccountKeys PublicKeySlice `json:"accountKeys"`
242+
Header MessageHeader `json:"header"`
243+
RecentBlockhash Hash `json:"recentBlockhash"`
244+
Instructions []CompiledInstruction `json:"instructions"`
245+
AddressTableLookups *jsoniter.RawMessage `json:"addressTableLookups"`
246+
}{}
247+
if err := json.Unmarshal(data, &aux); err != nil {
248+
return err
229249
}
230-
if out.AddressTableLookups == nil {
231-
out.AddressTableLookups = make([]MessageAddressTableLookup, 0)
250+
mx.AccountKeys = aux.AccountKeys
251+
mx.Header = aux.Header
252+
mx.RecentBlockhash = aux.RecentBlockhash
253+
mx.Instructions = aux.Instructions
254+
255+
if aux.AddressTableLookups == nil {
256+
mx.version = MessageVersionLegacy
257+
mx.AddressTableLookups = nil
258+
return nil
232259
}
233-
return json.Marshal(out)
260+
mx.version = MessageVersionV0
261+
return json.Unmarshal(*aux.AddressTableLookups, &mx.AddressTableLookups)
234262
}
235263

236264
func (mx *Message) EncodeToTree(txTree treeout.Branches) {

message_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,6 +1233,72 @@ func TestHasDuplicates_Loaded(t *testing.T) {
12331233
})
12341234
}
12351235

1236+
// Regression test for https://github.com/solana-foundation/solana-go/issues/339:
1237+
// when decoding an RPC response with `encoding: json`, the Message version was
1238+
// always left as Legacy because the private `version` field has no JSON tag.
1239+
// The presence of `addressTableLookups` in the raw JSON now selects V0.
1240+
func TestMessageJSONVersionDetection(t *testing.T) {
1241+
legacyJSON := []byte(`{
1242+
"accountKeys": ["11111111111111111111111111111111"],
1243+
"header": {"numRequiredSignatures":1,"numReadonlySignedAccounts":0,"numReadonlyUnsignedAccounts":0},
1244+
"recentBlockhash": "11111111111111111111111111111111",
1245+
"instructions": []
1246+
}`)
1247+
v0JSON := []byte(`{
1248+
"accountKeys": ["11111111111111111111111111111111"],
1249+
"header": {"numRequiredSignatures":1,"numReadonlySignedAccounts":0,"numReadonlyUnsignedAccounts":0},
1250+
"recentBlockhash": "11111111111111111111111111111111",
1251+
"instructions": [],
1252+
"addressTableLookups": []
1253+
}`)
1254+
1255+
var legacy Message
1256+
require.NoError(t, json.Unmarshal(legacyJSON, &legacy))
1257+
assert.Equal(t, MessageVersionLegacy, legacy.GetVersion())
1258+
assert.False(t, legacy.IsVersioned())
1259+
1260+
var versioned Message
1261+
require.NoError(t, json.Unmarshal(v0JSON, &versioned))
1262+
assert.Equal(t, MessageVersionV0, versioned.GetVersion())
1263+
assert.True(t, versioned.IsVersioned())
1264+
}
1265+
1266+
// TestMessageJSONVersionRoundtrip ensures MarshalJSON/UnmarshalJSON preserve
1267+
// the message version across the round-trip.
1268+
func TestMessageJSONVersionRoundtrip(t *testing.T) {
1269+
blockhash := Hash{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
1270+
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}
1271+
1272+
for _, tc := range []struct {
1273+
name string
1274+
version MessageVersion
1275+
lookups MessageAddressTableLookupSlice
1276+
}{
1277+
{"legacy", MessageVersionLegacy, nil},
1278+
{"v0_no_lookups", MessageVersionV0, nil},
1279+
{"v0_with_lookups", MessageVersionV0, MessageAddressTableLookupSlice{
1280+
{AccountKey: newUniqueKey(), WritableIndexes: []uint8{0, 1}, ReadonlyIndexes: []uint8{2}},
1281+
}},
1282+
} {
1283+
t.Run(tc.name, func(t *testing.T) {
1284+
original := Message{
1285+
version: tc.version,
1286+
Header: MessageHeader{NumRequiredSignatures: 1},
1287+
AccountKeys: PublicKeySlice{newUniqueKey()},
1288+
RecentBlockhash: blockhash,
1289+
Instructions: []CompiledInstruction{},
1290+
AddressTableLookups: tc.lookups,
1291+
}
1292+
data, err := json.Marshal(original)
1293+
require.NoError(t, err)
1294+
1295+
var decoded Message
1296+
require.NoError(t, json.Unmarshal(data, &decoded))
1297+
assert.Equal(t, tc.version, decoded.GetVersion())
1298+
})
1299+
}
1300+
}
1301+
12361302
// hasDuplicates is a test helper matching Rust's has_duplicates check.
12371303
func hasDuplicates(keys PublicKeySlice) bool {
12381304
seen := make(map[PublicKey]struct{}, len(keys))

transaction.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -453,17 +453,17 @@ func NewTransaction(instructions []Instruction, recentBlockHash Hash, opts ...Tr
453453
}
454454

455455
var idx uint16
456-
accountKeyIndex := make(map[string]uint16, len(message.AccountKeys)+len(lookupsWritableKeys)+len(lookupsReadOnlyKeys))
456+
accountKeyIndex := make(map[PublicKey]uint16, len(message.AccountKeys)+len(lookupsWritableKeys)+len(lookupsReadOnlyKeys))
457457
for _, acc := range message.AccountKeys {
458-
accountKeyIndex[acc.String()] = idx
458+
accountKeyIndex[acc] = idx
459459
idx++
460460
}
461461
for _, acc := range lookupsWritableKeys {
462-
accountKeyIndex[acc.String()] = idx
462+
accountKeyIndex[acc] = idx
463463
idx++
464464
}
465465
for _, acc := range lookupsReadOnlyKeys {
466-
accountKeyIndex[acc.String()] = idx
466+
accountKeyIndex[acc] = idx
467467
idx++
468468
}
469469

@@ -479,14 +479,14 @@ func NewTransaction(instructions []Instruction, recentBlockHash Hash, opts ...Tr
479479
accounts = instruction.Accounts()
480480
accountIndex := make([]uint16, len(accounts))
481481
for idx, acc := range accounts {
482-
accountIndex[idx] = accountKeyIndex[acc.PublicKey.String()]
482+
accountIndex[idx] = accountKeyIndex[acc.PublicKey]
483483
}
484484
data, err := instruction.Data()
485485
if err != nil {
486486
return nil, fmt.Errorf("unable to encode instructions [%d]: %w", txIdx, err)
487487
}
488488
message.Instructions = append(message.Instructions, CompiledInstruction{
489-
ProgramIDIndex: accountKeyIndex[instruction.ProgramID().String()],
489+
ProgramIDIndex: accountKeyIndex[instruction.ProgramID()],
490490
Accounts: accountIndex,
491491
Data: data,
492492
})

0 commit comments

Comments
 (0)