From 1a3e693e92bf0a33244f1a809f70fbd907172c4f Mon Sep 17 00:00:00 2001 From: Sonic Date: Sun, 15 Mar 2026 01:21:59 +0200 Subject: [PATCH] fix: pre-balance divergence and get rid of some panics --- cmd/mithril/node/node.go | 22 ++- config.example.toml | 10 ++ pkg/replay/bankhash_verify.go | 202 ++++++++++++++++++++++ pkg/replay/block.go | 26 ++- pkg/replay/transaction.go | 4 +- pkg/replay/transaction_test.go | 301 +++++++++++++++++++++++++++++++++ 6 files changed, 556 insertions(+), 9 deletions(-) create mode 100644 pkg/replay/bankhash_verify.go create mode 100644 pkg/replay/transaction_test.go diff --git a/cmd/mithril/node/node.go b/cmd/mithril/node/node.go index 89a713bd..12fa59be 100644 --- a/cmd/mithril/node/node.go +++ b/cmd/mithril/node/node.go @@ -89,6 +89,9 @@ var ( debugAcctWrites []string cpuprofPath string + bankhashVerifyEndpoint string // URL of a reference mithril node for bankhash verification + bankhashVerifyMode string // "warn" or "panic" + paramArenaSizeMB uint64 borrowedAccountArenaSize uint64 @@ -133,6 +136,8 @@ func init() { // [debug] section flags Run.Flags().StringSliceVar(&debugTxs, "transaction-signatures", []string{}, "Pass tx signature strings to enable debug logging during that transaction's execution") Run.Flags().StringSliceVar(&debugAcctWrites, "account-writes", []string{}, "Pass account pubkeys to enable debug logging of transactions that modify the account") + Run.Flags().StringVar(&bankhashVerifyEndpoint, "bankhash-verify-endpoint", "", "URL of a reference mithril node for bankhash verification (e.g. http://reference-node:8899)") + Run.Flags().StringVar(&bankhashVerifyMode, "bankhash-verify-mode", "warn", "Bankhash verification mode: 'warn' (log mismatch) or 'panic' (halt on mismatch)") // Top-level flags Run.Flags().StringVar(&scratchDirectory, "scratch-directory", "/tmp", "Path for downloads (e.g. snapshots) and other temp state") @@ -443,6 +448,10 @@ func initConfigAndBindFlags(cmd *cobra.Command) error { if len(debugAcctWrites) == 0 { debugAcctWrites = getStringSlice("account-writes", "development.debug.account_writes") } + bankhashVerifyEndpoint = getString("bankhash-verify-endpoint", "debug.bankhash_verify_endpoint") + if v := getString("bankhash-verify-mode", "debug.bankhash_verify_mode"); v != "" { + bankhashVerifyMode = v + } // [tuning] section (with fallback to legacy [development]) paramArenaSizeMB = getUint64("param-arena-size-mb", "tuning.param_arena_size_mb") @@ -2210,6 +2219,17 @@ func runReplayWithRecovery( } }() - result = replay.ReplayBlocks(ctx, accountsDb, accountsDbPath, mithrilState, resumeState, startSlot, endSlot, rpcEndpoints, blockDir, txParallelism, isLive, useLightbringer, dbgOpts, metricsWriter, rpcServer, blockFetchOpts, onCancelWriteState) + // Create bankhash verifier if configured + var bhVerifier *replay.BankhashVerifier + if bankhashVerifyEndpoint != "" { + mode := replay.BankhashVerifyWarn + if bankhashVerifyMode == "panic" { + mode = replay.BankhashVerifyPanic + } + bhVerifier = replay.NewBankhashVerifier(bankhashVerifyEndpoint, mode) + mlog.Log.Infof("Bankhash verification enabled: endpoint=%s mode=%s", bankhashVerifyEndpoint, bankhashVerifyMode) + } + + result = replay.ReplayBlocks(ctx, accountsDb, accountsDbPath, mithrilState, resumeState, startSlot, endSlot, rpcEndpoints, blockDir, txParallelism, isLive, useLightbringer, dbgOpts, metricsWriter, rpcServer, blockFetchOpts, onCancelWriteState, bhVerifier) return result } diff --git a/config.example.toml b/config.example.toml index cb6841a2..ddd643f4 100644 --- a/config.example.toml +++ b/config.example.toml @@ -290,6 +290,16 @@ name = "mithril" # Account pubkeys to enable debug logging of transactions that modify them # account_writes = ["pubkey1", "pubkey2"] + # Bankhash verification: compare computed bankhashes against a reference + # mithril node to detect state divergences early. The reference node must + # have its RPC server enabled (see [rpc] section). + + # Endpoint of a trusted reference mithril node (empty = disabled) + # bankhash_verify_endpoint = "http://reference-node:8899" + + # What to do on mismatch: "warn" (log and continue) or "panic" (halt) + # bankhash_verify_mode = "warn" + # ============================================================================ # [reporting] - Metrics & Reporting # ============================================================================ diff --git a/pkg/replay/bankhash_verify.go b/pkg/replay/bankhash_verify.go new file mode 100644 index 00000000..713826b5 --- /dev/null +++ b/pkg/replay/bankhash_verify.go @@ -0,0 +1,202 @@ +package replay + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "sync" + "time" + + "github.com/Overclock-Validator/mithril/pkg/base58" + "github.com/Overclock-Validator/mithril/pkg/mlog" +) + +// BankhashVerifier checks computed bankhashes against a reference mithril node. +// Verification is async and does not block the replay loop. +type BankhashVerifier struct { + endpoint string + httpClient *http.Client + mode BankhashVerifyMode + + mu sync.Mutex + lastVerified uint64 + mismatches int + verified int + errors int + verifyCh chan verifyRequest + done chan struct{} +} + +type BankhashVerifyMode int + +const ( + BankhashVerifyWarn BankhashVerifyMode = iota // log and continue + BankhashVerifyPanic // halt on mismatch +) + +type verifyRequest struct { + slot uint64 + computed []byte +} + +type jsonRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID int `json:"id"` + Method string `json:"method"` + Params []interface{} `json:"params"` +} + +type jsonRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID int `json:"id"` + Result json.RawMessage `json:"result"` + Error *jsonRPCError `json:"error"` +} + +type jsonRPCError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +func NewBankhashVerifier(endpoint string, mode BankhashVerifyMode) *BankhashVerifier { + v := &BankhashVerifier{ + endpoint: endpoint, + httpClient: &http.Client{ + Timeout: 5 * time.Second, + }, + mode: mode, + verifyCh: make(chan verifyRequest, 64), + done: make(chan struct{}), + } + go v.worker() + return v +} + +// Submit queues a bankhash for async verification. +func (v *BankhashVerifier) Submit(slot uint64, computedHash []byte) { + hashCopy := make([]byte, len(computedHash)) + copy(hashCopy, computedHash) + select { + case v.verifyCh <- verifyRequest{slot: slot, computed: hashCopy}: + default: + // channel full, skip to avoid blocking replay + v.mu.Lock() + v.errors++ + v.mu.Unlock() + } +} + +// Stop shuts down the verification worker and logs a summary. +func (v *BankhashVerifier) Stop() { + close(v.verifyCh) + <-v.done + v.mu.Lock() + defer v.mu.Unlock() + mlog.Log.Infof("bankhash verify shutdown: verified=%d mismatches=%d errors=%d last_slot=%d", + v.verified, v.mismatches, v.errors, v.lastVerified) +} + +func (v *BankhashVerifier) worker() { + defer close(v.done) + for req := range v.verifyCh { + v.verify(req) + } +} + +func (v *BankhashVerifier) verify(req verifyRequest) { + expected, err := v.fetchBankHash(req.slot) + if err != nil { + v.mu.Lock() + v.errors++ + errCount := v.errors + v.mu.Unlock() + if errCount%100 == 1 { + mlog.Log.Infof("bankhash verify: fetch error for slot %d: %v (total errors: %d)", req.slot, err, errCount) + } + return + } + + v.mu.Lock() + v.verified++ + v.lastVerified = req.slot + v.mu.Unlock() + + if !bytes.Equal(req.computed, expected) { + v.mu.Lock() + v.mismatches++ + v.mu.Unlock() + + computedStr := base58.Encode(req.computed) + expectedStr := base58.Encode(expected) + + if v.mode == BankhashVerifyPanic { + panic(fmt.Sprintf("DIVERGENCE in slot %d: bankhash mismatch: computed=%s expected=%s", req.slot, computedStr, expectedStr)) + } + mlog.Log.Errorf("DIVERGENCE in slot %d: bankhash mismatch: computed=%s expected=%s", req.slot, computedStr, expectedStr) + } +} + +func (v *BankhashVerifier) fetchBankHash(slot uint64) ([]byte, error) { + reqBody := jsonRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "getBankHash", + Params: []interface{}{slot}, + } + + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + httpReq, err := http.NewRequestWithContext(ctx, "POST", v.endpoint, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := v.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("http request: %w", err) + } + defer resp.Body.Close() + + respBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + var rpcResp jsonRPCResponse + if err := json.Unmarshal(respBytes, &rpcResp); err != nil { + return nil, fmt.Errorf("unmarshal response: %w", err) + } + + if rpcResp.Error != nil { + return nil, fmt.Errorf("rpc error %d: %s", rpcResp.Error.Code, rpcResp.Error.Message) + } + + var hashStr string + if err := json.Unmarshal(rpcResp.Result, &hashStr); err != nil { + return nil, fmt.Errorf("unmarshal result: %w", err) + } + + hashBytes, err := base58.DecodeFromString(hashStr) + if err != nil { + return nil, fmt.Errorf("decode base58 hash: %w", err) + } + + return hashBytes[:], nil +} + +// Stats returns current verification statistics. +func (v *BankhashVerifier) Stats() (verified, mismatches, errors int) { + v.mu.Lock() + defer v.mu.Unlock() + return v.verified, v.mismatches, v.errors +} diff --git a/pkg/replay/block.go b/pkg/replay/block.go index 8f2e4433..e5c34363 100644 --- a/pkg/replay/block.go +++ b/pkg/replay/block.go @@ -1189,7 +1189,7 @@ func ReplayBlocks( acctsDb *accountsdb.AccountsDb, acctsDbPath string, mithrilState *state.MithrilState, // State file with manifest_* seed fields - resumeState *ResumeState, // nil if not resuming, contains parent slot info when resuming + resumeState *ResumeState, // nil if not resuming, contains parent slot info when resuming startSlot, endSlot uint64, rpcEndpoints []string, // RPC endpoints in priority order (first = primary, rest = fallbacks) blockDir string, @@ -1201,6 +1201,7 @@ func ReplayBlocks( rpcServer *rpcserver.RpcServer, blockFetchOpts *BlockFetchOpts, onCancelWriteState OnCancelWriteState, // callback to write state immediately on cancellation (can be nil) + bankhashVerifier *BankhashVerifier, // optional bankhash verifier (nil to disable) ) *ReplayResult { result := &ReplayResult{} @@ -1648,6 +1649,11 @@ func ReplayBlocks( fmt.Fprintf(bankhashLogFile, "%d %s\n", block.Slot, base58.Encode(lastSlotCtx.FinalBankhash)) } + // Verify bankhash against reference node (async, non-blocking) + if bankhashVerifier != nil { + bankhashVerifier.Submit(block.Slot, lastSlotCtx.FinalBankhash) + } + statsd.Count(statsd.SlotReplays, 1, nil) statsd.Timing(statsd.SlotReplayDurationMs, uint64(slotReplayDuration.Nanoseconds())/1e6, nil) statsd.Gauge(statsd.Epoch, float64(block.Epoch), nil) @@ -1918,6 +1924,11 @@ func ReplayBlocks( // Serialize all epoch stakes for persistence result.ComputedEpochStakes = serializeAllEpochStakes() + // Stop bankhash verifier and log summary + if bankhashVerifier != nil { + bankhashVerifier.Stop() + } + return result } @@ -2026,7 +2037,7 @@ func sequentialTxLoop(slotCtx *sealevel.SlotCtx, sigverifyWg *sync.WaitGroup, bl } txFeeInfo, txErr := ProcessTransaction(slotCtx, sigverifyWg, tx, txMeta, dbgOpts, nil) - if txErr != nil { + if txMeta != nil && txErr != nil { if txMeta.Err == nil && tx.IsVote() { mlog.Log.Errorf("[run:%s] DIVERGENCE in slot %d: vote tx %s failed locally but succeeded onchain => bankhash mismatch at parent slot %d", CurrentRunID, block.Slot, tx.Signatures[0], block.ParentSlot) @@ -2035,11 +2046,11 @@ func sequentialTxLoop(slotCtx *sealevel.SlotCtx, sigverifyWg *sync.WaitGroup, bl } // check for success-failure return value divergences - if txErr == nil && txMeta.Err != nil { + if txMeta != nil && txErr == nil && txMeta.Err != nil { mlog.Log.Errorf("[run:%s] DIVERGENCE in slot %d: tx %s succeeded locally but failed onchain: %+v", CurrentRunID, block.Slot, tx.Signatures[0], block.TxMetas[idx].Err) panic(fmt.Sprintf("tx %s return value divergence: txErr was nil, but onchain err was %+v", tx.Signatures[0], block.TxMetas[idx].Err)) - } else if txErr != nil && txMeta.Err == nil { + } else if txMeta != nil && txErr != nil && txMeta.Err == nil { mlog.Log.Errorf("[run:%s] DIVERGENCE in slot %d: tx %s failed locally (%v) but succeeded onchain", CurrentRunID, block.Slot, tx.Signatures[0], txErr) panic(fmt.Sprintf("tx %s return value divergence: txErr was %+v (%s), but onchain err was nil", tx.Signatures[0], txErr, txErr)) @@ -2058,11 +2069,14 @@ func parallelTxLoop(slotCtx *sealevel.SlotCtx, sigverifyWg *sync.WaitGroup, bloc if rblock.FromLightbringer { wg := &sync.WaitGroup{} - workerPool, _ := ants.NewPoolWithFunc(txParallelism, func(i interface{}) { + workerPool, poolErr := ants.NewPoolWithFunc(txParallelism, func(i interface{}) { defer wg.Done() idx := i.(uint64) txFeeInfos[idx], errs[idx] = ProcessTransaction(slotCtx, sigverifyWg, rblock.Transactions[idx], nil, dbgOpts, nil) }) + if poolErr != nil { + panic(fmt.Sprintf("failed to create worker pool: %s", poolErr)) + } for _, entry := range rblock.Entries { for _, txIdx := range entry.Indices { @@ -2157,7 +2171,7 @@ func ProcessBlock( for i := range block.Transactions { unresolvedBlock.Transactions[i] = &solana.Transaction{} *(unresolvedBlock.Transactions[i]) = *block.Transactions[i] - if unresolvedBlock.TxMetas != nil && !block.FromLightbringer { + if block.TxMetas != nil && !block.FromLightbringer { unresolvedBlock.TxMetas[i] = &rpc.TransactionMeta{} *(unresolvedBlock.TxMetas[i]) = *block.TxMetas[i] } diff --git a/pkg/replay/transaction.go b/pkg/replay/transaction.go index 5658d53c..2dbfacb6 100644 --- a/pkg/replay/transaction.go +++ b/pkg/replay/transaction.go @@ -419,7 +419,7 @@ func ProcessTransaction(slotCtx *sealevel.SlotCtx, sigverifyWg *sync.WaitGroup, start = time.Now() txFeeInfo, _, err := fees.CalculateAndDeductTxFees(tx, txMeta, instrs, &execCtx.TransactionContext.Accounts, computeBudgetLimits, slotCtx.Features) if err != nil { - return txFeeInfo, nil + return &fees.TxFeeInfo{}, nil } metrics.GlobalBlockReplay.CalcAndDeductFees.AddTimingSince(start) @@ -495,7 +495,7 @@ func ProcessTransaction(slotCtx *sealevel.SlotCtx, sigverifyWg *sync.WaitGroup, } // check for CU consumed divergences - if instrErr == nil && *txMeta.ComputeUnitsConsumed != execCtx.ComputeMeter.Used() { + if instrErr == nil && txMeta != nil && txMeta.ComputeUnitsConsumed != nil && *txMeta.ComputeUnitsConsumed != execCtx.ComputeMeter.Used() { discrepancy := max(execCtx.ComputeMeter.Used(), *txMeta.ComputeUnitsConsumed) - min(execCtx.ComputeMeter.Used(), *txMeta.ComputeUnitsConsumed) var sign byte if execCtx.ComputeMeter.Used() > *txMeta.ComputeUnitsConsumed { diff --git a/pkg/replay/transaction_test.go b/pkg/replay/transaction_test.go new file mode 100644 index 00000000..4c8d0c6f --- /dev/null +++ b/pkg/replay/transaction_test.go @@ -0,0 +1,301 @@ +package replay + +import ( + "fmt" + "math" + "strings" + "sync" + "testing" + + "github.com/Overclock-Validator/mithril/pkg/accounts" + a "github.com/Overclock-Validator/mithril/pkg/addresses" + "github.com/Overclock-Validator/mithril/pkg/features" + "github.com/Overclock-Validator/mithril/pkg/sealevel" + "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/rpc" +) + +func makeSlotCtx(acctList []*accounts.Account) *sealevel.SlotCtx { + f := features.NewFeaturesDefault() + memAccts := accounts.NewMemAccounts() + for _, acct := range acctList { + memAccts.SetAccountWithoutLock(acct.Key, acct) + } + return &sealevel.SlotCtx{ + Accounts: memAccts, + ParentAccts: accounts.NewMemAccounts(), + Slot: 405551840, + Epoch: 938, + Features: f, + FeeRateGovernor: &sealevel.FeeRateGovernor{ + LamportsPerSignature: 5000, + PrevLamportsPerSignature: 5000, + }, + AcctMapsMu: &sync.Mutex{}, + ModifiedAccts: make(map[solana.PublicKey]bool), + WritableAccts: make(map[solana.PublicKey]bool), + VoteTimestampMu: &sync.Mutex{}, + VoteTimestamps: make(map[solana.PublicKey]sealevel.BlockTimestamp), + } +} + +func makeVoteTx(payer, voteAcct solana.PublicKey) *solana.Transaction { + voteProgramPk := solana.PublicKeyFromBytes(a.VoteProgramAddr[:]) + tx := &solana.Transaction{ + Signatures: []solana.Signature{testSig(0x01)}, + } + tx.Message.Header.NumRequiredSignatures = 1 + tx.Message.Header.NumReadonlyUnsignedAccounts = 1 + tx.Message.AccountKeys = []solana.PublicKey{payer, voteAcct, voteProgramPk} + tx.Message.Instructions = []solana.CompiledInstruction{ + { + ProgramIDIndex: 2, + Accounts: []uint16{0, 1}, + Data: []byte{}, + }, + } + return tx +} + +func checkPreBalances(slotCtx *sealevel.SlotCtx, tx *solana.Transaction, txMeta *rpc.TransactionMeta) *string { + if txMeta == nil { + return nil + } + + f := slotCtx.Features + instrs, acctMetasPerInstr, err := instrsAndAcctMetasFromTx(tx, f) + if err != nil { + msg := fmt.Sprintf("instrsAndAcctMetasFromTx: %s", err) + return &msg + } + _ = acctMetasPerInstr + + computeBudgetLimits, err := sealevel.ComputeBudgetExecuteInstructions(instrs, f) + if err != nil { + msg := fmt.Sprintf("ComputeBudgetExecuteInstructions: %s", err) + return &msg + } + _ = computeBudgetLimits + + // Load accounts the same way ProcessTransaction does + instrsAcct := sealevel.MakeInstructionsSysvarAccount(instrs) + transactionAccts, _, err := loadAndValidateTxAccts(slotCtx, acctMetasPerInstr, tx, instrs, instrsAcct, computeBudgetLimits.LoadedAccountBytes) + if err != nil { + msg := fmt.Sprintf("loadAndValidateTxAccts: %s", err) + return &msg + } + + // Run the pre-balance check (same logic as transaction.go:395-416) + for count := uint64(0); count < uint64(len(tx.Message.AccountKeys)); count++ { + txAcct, err := transactionAccts.GetAccount(count) + if err != nil { + msg := fmt.Sprintf("unable to get tx acct %d", count) + return &msg + } + + if !isNativeProgram(txAcct.Key) && !txAcct.IsDummy { + if txAcct.Lamports != txMeta.PreBalances[count] { + msg := fmt.Sprintf("tx %s pre-balance divergence: lamport balance for %s was %d but onchain lamport balance was %d", + tx.Signatures[0], txAcct.Key, txAcct.Lamports, txMeta.PreBalances[count]) + return &msg + } + } + transactionAccts.Unlock(count) + } + return nil +} + +func TestPreBalanceDivergenceDetected(t *testing.T) { + payerPk := testPk(0x76) + voteAcctPk := testPk(0x8a) + + mithrilLamports := uint64(51_492_474_108) + onchainLamports := uint64(51_492_669_108) + + payerAcct := &accounts.Account{ + Key: payerPk, + Lamports: mithrilLamports, + Owner: a.SystemProgramAddr, + RentEpoch: math.MaxUint64, + } + voteAcct := &accounts.Account{ + Key: voteAcctPk, + Lamports: 14_572_195_217_575, + Owner: a.VoteProgramAddr, + RentEpoch: math.MaxUint64, + Data: make([]byte, 3762), + } + voteProgramAcct := &accounts.Account{ + Key: solana.PublicKeyFromBytes(a.VoteProgramAddr[:]), + Lamports: 1, + Owner: a.NativeLoaderAddr, + Executable: true, + } + + slotCtx := makeSlotCtx([]*accounts.Account{payerAcct, voteAcct, voteProgramAcct}) + tx := makeVoteTx(payerPk, voteAcctPk) + + txMeta := &rpc.TransactionMeta{ + Fee: 5000, + PreBalances: []uint64{onchainLamports, 14_572_195_217_575, 1}, + PostBalances: []uint64{onchainLamports - 5000, 14_572_195_217_575, 1}, + } + + divergence := checkPreBalances(slotCtx, tx, txMeta) + if divergence == nil { + t.Fatal("expected pre-balance divergence to be detected, but check passed") + } + + if !strings.Contains(*divergence, "pre-balance divergence") { + t.Fatalf("expected 'pre-balance divergence' in message, got: %s", *divergence) + } + if !strings.Contains(*divergence, payerPk.String()) { + t.Fatalf("expected payer pubkey in message, got: %s", *divergence) + } + if !strings.Contains(*divergence, fmt.Sprintf("%d", mithrilLamports)) { + t.Fatalf("expected mithril lamports in message, got: %s", *divergence) + } + if !strings.Contains(*divergence, fmt.Sprintf("%d", onchainLamports)) { + t.Fatalf("expected onchain lamports in message, got: %s", *divergence) + } + + t.Logf("divergence correctly detected: %s", *divergence) +} + +func TestPreBalanceMatchPasses(t *testing.T) { + payerPk := testPk(0x76) + voteAcctPk := testPk(0x8a) + + lamports := uint64(51_492_669_108) + + payerAcct := &accounts.Account{ + Key: payerPk, + Lamports: lamports, + Owner: a.SystemProgramAddr, + RentEpoch: math.MaxUint64, + } + voteAcct := &accounts.Account{ + Key: voteAcctPk, + Lamports: 14_572_195_217_575, + Owner: a.VoteProgramAddr, + RentEpoch: math.MaxUint64, + Data: make([]byte, 3762), + } + voteProgramAcct := &accounts.Account{ + Key: solana.PublicKeyFromBytes(a.VoteProgramAddr[:]), + Lamports: 1, + Owner: a.NativeLoaderAddr, + Executable: true, + } + + slotCtx := makeSlotCtx([]*accounts.Account{payerAcct, voteAcct, voteProgramAcct}) + tx := makeVoteTx(payerPk, voteAcctPk) + + txMeta := &rpc.TransactionMeta{ + Fee: 5000, + PreBalances: []uint64{lamports, 14_572_195_217_575, 1}, + PostBalances: []uint64{lamports - 5000, 14_572_195_217_575, 1}, + } + + divergence := checkPreBalances(slotCtx, tx, txMeta) + if divergence != nil { + t.Fatalf("expected no divergence, but got: %s", *divergence) + } +} + +func TestPreBalanceDivergenceAmountIs195000(t *testing.T) { + payerPk := testPk(0x76) + voteAcctPk := testPk(0x8a) + + mithrilLamports := uint64(51_492_474_108) + onchainLamports := uint64(51_492_669_108) + expectedDiff := onchainLamports - mithrilLamports + + if expectedDiff != 195_000 { + t.Fatalf("expected difference of 195,000 lamports, got %d", expectedDiff) + } + if expectedDiff%5000 != 0 { + t.Fatalf("expected difference to be a multiple of 5,000 (vote tx fee), got %d", expectedDiff) + } + t.Logf("divergence = %d lamports = %d * 5,000 (vote tx fee)", expectedDiff, expectedDiff/5000) + + payerAcct := &accounts.Account{ + Key: payerPk, + Lamports: mithrilLamports, + Owner: a.SystemProgramAddr, + RentEpoch: math.MaxUint64, + } + voteAcct := &accounts.Account{ + Key: voteAcctPk, + Lamports: 14_572_195_217_575, + Owner: a.VoteProgramAddr, + RentEpoch: math.MaxUint64, + Data: make([]byte, 3762), + } + voteProgramAcct := &accounts.Account{ + Key: solana.PublicKeyFromBytes(a.VoteProgramAddr[:]), + Lamports: 1, + Owner: a.NativeLoaderAddr, + Executable: true, + } + + slotCtx := makeSlotCtx([]*accounts.Account{payerAcct, voteAcct, voteProgramAcct}) + tx := makeVoteTx(payerPk, voteAcctPk) + + txMeta := &rpc.TransactionMeta{ + Fee: 5000, + PreBalances: []uint64{onchainLamports, 14_572_195_217_575, 1}, + PostBalances: []uint64{onchainLamports - 5000, 14_572_195_217_575, 1}, + } + + divergence := checkPreBalances(slotCtx, tx, txMeta) + if divergence == nil { + t.Fatal("expected pre-balance divergence to be detected") + } + + if mithrilLamports >= onchainLamports { + t.Fatal("test assumes mithril has fewer lamports than onchain") + } +} + +func TestNativeProgramSkippedInPreBalanceCheck(t *testing.T) { + payerPk := testPk(0x76) + voteAcctPk := testPk(0x8a) + voteProgramPk := solana.PublicKeyFromBytes(a.VoteProgramAddr[:]) + + lamports := uint64(51_492_669_108) + + payerAcct := &accounts.Account{ + Key: payerPk, + Lamports: lamports, + Owner: a.SystemProgramAddr, + RentEpoch: math.MaxUint64, + } + voteAcct := &accounts.Account{ + Key: voteAcctPk, + Lamports: 14_572_195_217_575, + Owner: a.VoteProgramAddr, + RentEpoch: math.MaxUint64, + Data: make([]byte, 3762), + } + voteProgramAcct := &accounts.Account{ + Key: voteProgramPk, + Lamports: 1, + Owner: a.NativeLoaderAddr, + Executable: true, + } + + slotCtx := makeSlotCtx([]*accounts.Account{payerAcct, voteAcct, voteProgramAcct}) + tx := makeVoteTx(payerPk, voteAcctPk) + + txMeta := &rpc.TransactionMeta{ + Fee: 5000, + PreBalances: []uint64{lamports, 14_572_195_217_575, 999_999}, + PostBalances: []uint64{lamports - 5000, 14_572_195_217_575, 999_999}, + } + + divergence := checkPreBalances(slotCtx, tx, txMeta) + if divergence != nil { + t.Fatalf("native program account should be skipped in pre-balance check, but got: %s", *divergence) + } +}