diff --git a/message.go b/message.go index 5b804f512..87dd2ff83 100644 --- a/message.go +++ b/message.go @@ -23,6 +23,7 @@ import ( bin "github.com/gagliardetto/binary" "github.com/gagliardetto/treeout" + jsoniter "github.com/json-iterator/go" "github.com/gagliardetto/solana-go/text" ) @@ -195,42 +196,69 @@ func (mx *Message) NumWritableLookups() int { func (mx Message) MarshalJSON() ([]byte, error) { if mx.version == MessageVersionLegacy { out := struct { - AccountKeys []string `json:"accountKeys"` + AccountKeys PublicKeySlice `json:"accountKeys"` Header MessageHeader `json:"header"` - RecentBlockhash string `json:"recentBlockhash"` + RecentBlockhash Hash `json:"recentBlockhash"` Instructions []CompiledInstruction `json:"instructions"` }{ - AccountKeys: make([]string, len(mx.AccountKeys)), + AccountKeys: mx.AccountKeys, Header: mx.Header, - RecentBlockhash: mx.RecentBlockhash.String(), + RecentBlockhash: mx.RecentBlockhash, Instructions: mx.Instructions, } - for i, key := range mx.AccountKeys { - out.AccountKeys[i] = key.String() - } return json.Marshal(out) } // Versioned message: + lookups := mx.AddressTableLookups + if lookups == nil { + lookups = MessageAddressTableLookupSlice{} + } out := struct { - AccountKeys []string `json:"accountKeys"` - Header MessageHeader `json:"header"` - RecentBlockhash string `json:"recentBlockhash"` - Instructions []CompiledInstruction `json:"instructions"` - AddressTableLookups []MessageAddressTableLookup `json:"addressTableLookups"` + AccountKeys PublicKeySlice `json:"accountKeys"` + Header MessageHeader `json:"header"` + RecentBlockhash Hash `json:"recentBlockhash"` + Instructions []CompiledInstruction `json:"instructions"` + AddressTableLookups MessageAddressTableLookupSlice `json:"addressTableLookups"` }{ - AccountKeys: make([]string, len(mx.AccountKeys)), + AccountKeys: mx.AccountKeys, Header: mx.Header, - RecentBlockhash: mx.RecentBlockhash.String(), + RecentBlockhash: mx.RecentBlockhash, Instructions: mx.Instructions, - AddressTableLookups: mx.AddressTableLookups, + AddressTableLookups: lookups, } - for i, key := range mx.AccountKeys { - out.AccountKeys[i] = key.String() + return json.Marshal(out) +} + +// UnmarshalJSON decodes the message from JSON and determines its version. +// The Solana RPC emits `addressTableLookups` only for versioned (V0+) +// messages; its presence in the JSON is what distinguishes V0 from legacy, +// since the private `version` field has no wire representation. +func (mx *Message) UnmarshalJSON(data []byte) error { + // Decode `addressTableLookups` via a RawMessage pointer so presence of the + // key can be detected in a single parse. A non-nil pointer means the key + // was present in the JSON (even if its value is `null`), which selects V0. + aux := struct { + AccountKeys PublicKeySlice `json:"accountKeys"` + Header MessageHeader `json:"header"` + RecentBlockhash Hash `json:"recentBlockhash"` + Instructions []CompiledInstruction `json:"instructions"` + AddressTableLookups *jsoniter.RawMessage `json:"addressTableLookups"` + }{} + if err := json.Unmarshal(data, &aux); err != nil { + return err } - if out.AddressTableLookups == nil { - out.AddressTableLookups = make([]MessageAddressTableLookup, 0) + mx.AccountKeys = aux.AccountKeys + mx.Header = aux.Header + mx.RecentBlockhash = aux.RecentBlockhash + mx.Instructions = aux.Instructions + + if aux.AddressTableLookups == nil { + mx.version = MessageVersionLegacy + mx.AddressTableLookups = nil + return nil } - return json.Marshal(out) + mx.version = MessageVersionV0 + return json.Unmarshal(*aux.AddressTableLookups, &mx.AddressTableLookups) } func (mx *Message) EncodeToTree(txTree treeout.Branches) { diff --git a/message_test.go b/message_test.go index b01ccddc5..58ddfec49 100644 --- a/message_test.go +++ b/message_test.go @@ -1233,6 +1233,72 @@ func TestHasDuplicates_Loaded(t *testing.T) { }) } +// Regression test for https://github.com/solana-foundation/solana-go/issues/339: +// when decoding an RPC response with `encoding: json`, the Message version was +// always left as Legacy because the private `version` field has no JSON tag. +// The presence of `addressTableLookups` in the raw JSON now selects V0. +func TestMessageJSONVersionDetection(t *testing.T) { + legacyJSON := []byte(`{ + "accountKeys": ["11111111111111111111111111111111"], + "header": {"numRequiredSignatures":1,"numReadonlySignedAccounts":0,"numReadonlyUnsignedAccounts":0}, + "recentBlockhash": "11111111111111111111111111111111", + "instructions": [] + }`) + v0JSON := []byte(`{ + "accountKeys": ["11111111111111111111111111111111"], + "header": {"numRequiredSignatures":1,"numReadonlySignedAccounts":0,"numReadonlyUnsignedAccounts":0}, + "recentBlockhash": "11111111111111111111111111111111", + "instructions": [], + "addressTableLookups": [] + }`) + + var legacy Message + require.NoError(t, json.Unmarshal(legacyJSON, &legacy)) + assert.Equal(t, MessageVersionLegacy, legacy.GetVersion()) + assert.False(t, legacy.IsVersioned()) + + var versioned Message + require.NoError(t, json.Unmarshal(v0JSON, &versioned)) + assert.Equal(t, MessageVersionV0, versioned.GetVersion()) + assert.True(t, versioned.IsVersioned()) +} + +// TestMessageJSONVersionRoundtrip ensures MarshalJSON/UnmarshalJSON preserve +// the message version across the round-trip. +func TestMessageJSONVersionRoundtrip(t *testing.T) { + blockhash := Hash{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} + + for _, tc := range []struct { + name string + version MessageVersion + lookups MessageAddressTableLookupSlice + }{ + {"legacy", MessageVersionLegacy, nil}, + {"v0_no_lookups", MessageVersionV0, nil}, + {"v0_with_lookups", MessageVersionV0, MessageAddressTableLookupSlice{ + {AccountKey: newUniqueKey(), WritableIndexes: []uint8{0, 1}, ReadonlyIndexes: []uint8{2}}, + }}, + } { + t.Run(tc.name, func(t *testing.T) { + original := Message{ + version: tc.version, + Header: MessageHeader{NumRequiredSignatures: 1}, + AccountKeys: PublicKeySlice{newUniqueKey()}, + RecentBlockhash: blockhash, + Instructions: []CompiledInstruction{}, + AddressTableLookups: tc.lookups, + } + data, err := json.Marshal(original) + require.NoError(t, err) + + var decoded Message + require.NoError(t, json.Unmarshal(data, &decoded)) + assert.Equal(t, tc.version, decoded.GetVersion()) + }) + } +} + // hasDuplicates is a test helper matching Rust's has_duplicates check. func hasDuplicates(keys PublicKeySlice) bool { seen := make(map[PublicKey]struct{}, len(keys)) diff --git a/transaction.go b/transaction.go index 8af306acd..ddbb37e53 100644 --- a/transaction.go +++ b/transaction.go @@ -453,17 +453,17 @@ func NewTransaction(instructions []Instruction, recentBlockHash Hash, opts ...Tr } var idx uint16 - accountKeyIndex := make(map[string]uint16, len(message.AccountKeys)+len(lookupsWritableKeys)+len(lookupsReadOnlyKeys)) + accountKeyIndex := make(map[PublicKey]uint16, len(message.AccountKeys)+len(lookupsWritableKeys)+len(lookupsReadOnlyKeys)) for _, acc := range message.AccountKeys { - accountKeyIndex[acc.String()] = idx + accountKeyIndex[acc] = idx idx++ } for _, acc := range lookupsWritableKeys { - accountKeyIndex[acc.String()] = idx + accountKeyIndex[acc] = idx idx++ } for _, acc := range lookupsReadOnlyKeys { - accountKeyIndex[acc.String()] = idx + accountKeyIndex[acc] = idx idx++ } @@ -479,14 +479,14 @@ func NewTransaction(instructions []Instruction, recentBlockHash Hash, opts ...Tr accounts = instruction.Accounts() accountIndex := make([]uint16, len(accounts)) for idx, acc := range accounts { - accountIndex[idx] = accountKeyIndex[acc.PublicKey.String()] + accountIndex[idx] = accountKeyIndex[acc.PublicKey] } data, err := instruction.Data() if err != nil { return nil, fmt.Errorf("unable to encode instructions [%d]: %w", txIdx, err) } message.Instructions = append(message.Instructions, CompiledInstruction{ - ProgramIDIndex: accountKeyIndex[instruction.ProgramID().String()], + ProgramIDIndex: accountKeyIndex[instruction.ProgramID()], Accounts: accountIndex, Data: data, })