Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 7 additions & 18 deletions pkg/user/pruning_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@ package user
import (
"fmt"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestPruningInTxTracker(t *testing.T) {
txClient := &TxClient{
txTracker: make(map[string]txInfo),
TxTracker: NewTxTracker(),
}
numTransactions := 10

Expand All @@ -20,31 +19,21 @@ func TestPruningInTxTracker(t *testing.T) {
for i := range numTransactions {
// 5 transactions will be pruned
if i%2 == 0 {
txClient.txTracker["tx"+fmt.Sprint(i)] = txInfo{
signer: "signer" + fmt.Sprint(i),
sequence: uint64(i),
timestamp: time.Now().
Add(-10 * time.Minute),
}
txClient.TxTracker.trackTransaction("signer"+fmt.Sprint(i), uint64(i), "tx"+fmt.Sprint(i), []byte(fmt.Sprintf("tx%d", i)))
txsToBePruned++
} else {
txClient.txTracker["tx"+fmt.Sprint(i)] = txInfo{
signer: "signer" + fmt.Sprint(i),
sequence: uint64(i),
timestamp: time.Now().
Add(-5 * time.Minute),
}
txClient.TxTracker.trackTransaction("signer"+fmt.Sprint(i), uint64(i), "tx"+fmt.Sprint(i), []byte(fmt.Sprintf("tx%d", i)))
txsNotReadyToBePruned++
}
}

txTrackerBeforePruning := len(txClient.txTracker)
txTrackerBeforePruning := len(txClient.TxTracker.TxQueue)

// All transactions were indexed
require.Equal(t, numTransactions, len(txClient.txTracker))
txClient.pruneTxTracker()
require.Equal(t, numTransactions, len(txClient.TxTracker.TxQueue))
txClient.TxTracker.pruneTxTracker()
// Prunes the transactions that are 10 minutes old
// 5 transactions will be pruned
require.Equal(t, txsNotReadyToBePruned, txTrackerBeforePruning-txsToBePruned)
require.Equal(t, len(txClient.txTracker), txsNotReadyToBePruned)
require.Equal(t, len(txClient.TxTracker.TxQueue), txsNotReadyToBePruned)
}
121 changes: 39 additions & 82 deletions pkg/user/tx_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,6 @@ var (

type Option func(client *TxClient)

// txInfo is a struct that holds the sequence and the signer of a transaction
// in the local tx pool.
type txInfo struct {
sequence uint64
signer string
timestamp time.Time
txBytes []byte
}

// TxResponse is a response from the chain after
// a transaction has been submitted.
type TxResponse struct {
Expand Down Expand Up @@ -164,9 +155,9 @@ func WithDefaultAccount(name string) Option {
c.defaultAddress = addr

// Update worker 0's account if tx queue already exists
if c.txQueue != nil && len(c.txQueue.workers) > 0 {
c.txQueue.workers[0].accountName = name
c.txQueue.workers[0].address = addr.String()
if c.parallelTxQueue != nil && len(c.parallelTxQueue.workers) > 0 {
c.parallelTxQueue.workers[0].accountName = name
c.parallelTxQueue.workers[0].address = addr.String()
}
}
}
Expand Down Expand Up @@ -199,16 +190,16 @@ func WithTxWorkers(numWorkers int) Option {
}

return func(c *TxClient) {
c.txQueue = newTxQueue(c, numWorkers)
c.parallelTxQueue = newTxQueue(c, numWorkers)
}
}

// WithParallelQueueSize sets the buffer size for the parallel submission job queue.
// Default is 100 if not specified.
func WithParallelQueueSize(size int) Option {
return func(c *TxClient) {
if c.txQueue != nil {
c.txQueue.jobQueue = make(chan *SubmissionJob, size)
if c.parallelTxQueue != nil {
c.parallelTxQueue.jobQueue = make(chan *SubmissionJob, size)
}
}
}
Expand All @@ -230,10 +221,10 @@ type TxClient struct {
defaultAddress sdktypes.AccAddress
// txTracker maps the tx hash to the Sequence and signer of the transaction
// that was submitted to the chain
txTracker map[string]txInfo
TxTracker *TxTracker
gasEstimationClient gasestimation.GasEstimatorClient
// txQueue manages parallel transaction submission when enabled
txQueue *txQueue
// parallelTxQueue manages parallel transaction submission when enabled
parallelTxQueue *txQueue
}

// NewTxClient returns a new TxClient
Expand Down Expand Up @@ -265,7 +256,7 @@ func NewTxClient(
pollTime: DefaultPollTime,
defaultAccount: records[0].Name,
defaultAddress: addr,
txTracker: make(map[string]txInfo),
TxTracker: NewTxTracker(),
cdc: cdc,
gasEstimationClient: gasestimation.NewGasEstimatorClient(conn),
}
Expand All @@ -276,8 +267,8 @@ func NewTxClient(

// Always create a tx queue with at least 1 worker (the default account)
// unless already configured by WithTxWorkers option
if txClient.txQueue == nil {
txClient.txQueue = newTxQueue(txClient, 1)
if txClient.parallelTxQueue == nil {
txClient.parallelTxQueue = newTxQueue(txClient, 1)
}

return txClient, nil
Expand Down Expand Up @@ -335,7 +326,7 @@ func SetupTxClient(
return nil, err
}

if err := txClient.txQueue.start(ctx); err != nil {
if err := txClient.parallelTxQueue.start(ctx); err != nil {
return nil, fmt.Errorf("failed to start tx queue: %w", err)
}

Expand Down Expand Up @@ -371,12 +362,12 @@ func (client *TxClient) SubmitPayForBlobToQueue(ctx context.Context, blobs []*sh
// to the provided channel when the transaction is confirmed. The caller is responsible for creating and
// closing the result channel.
func (client *TxClient) QueueBlob(ctx context.Context, resultC chan SubmissionResult, blobs []*share.Blob, opts ...TxOption) {
if client.txQueue == nil {
if client.parallelTxQueue == nil {
resultC <- SubmissionResult{Error: errTxQueueNotConfigured}
return
}

if !client.txQueue.isStarted() {
if !client.parallelTxQueue.isStarted() {
resultC <- SubmissionResult{Error: errTxQueueNotStarted}
return
}
Expand All @@ -388,7 +379,7 @@ func (client *TxClient) QueueBlob(ctx context.Context, resultC chan SubmissionRe
ResultsC: resultC,
}

client.txQueue.submitJob(job)
client.parallelTxQueue.submitJob(job)
}

// SubmitPayForBlobWithAccount forms a transaction from the provided blobs, signs it with the provided account, and submits it to the chain.
Expand Down Expand Up @@ -462,7 +453,7 @@ func (client *TxClient) BroadcastTx(ctx context.Context, msgs []sdktypes.Msg, op
// prune transactions that are older than 10 minutes
// pruning has to be done in broadcast, since users
// might not always call ConfirmTx().
client.pruneTxTracker()
client.TxTracker.pruneTxTracker()

account, err := client.getAccountNameFromMsgs(msgs)
if err != nil {
Expand Down Expand Up @@ -583,7 +574,8 @@ func (client *TxClient) submitToSingleConnection(ctx context.Context, txBytes []
}
// Save the sequence, signer and txBytes of the in the local txTracker
// before the sequence is incremented
client.trackTransaction(signer, resp.TxHash, txBytes)
sequence := client.signer.Account(signer).Sequence()
client.TxTracker.trackTransaction(signer, sequence, resp.TxHash, txBytes)

// Increment sequence after successful submission
if err := client.signer.IncrementSequence(signer); err != nil {
Expand Down Expand Up @@ -720,7 +712,8 @@ func (client *TxClient) submitToMultipleConnections(ctx context.Context, txBytes

// Return first successful response, if any
if resp, ok := <-respCh; ok && resp != nil {
client.trackTransaction(signer, resp.TxHash, txBytes)
sequence := client.signer.Account(signer).Sequence()
client.TxTracker.trackTransaction(signer, sequence, resp.TxHash, txBytes)

if err := client.signer.IncrementSequence(signer); err != nil {
return nil, fmt.Errorf("increment sequencing: %w", err)
Expand All @@ -736,15 +729,6 @@ func (client *TxClient) submitToMultipleConnections(ctx context.Context, txBytes
return nil, errors.Join(errs...)
}

// pruneTxTracker removes transactions from the local tx tracker that are older than 10 minutes
func (client *TxClient) pruneTxTracker() {
for hash, txInfo := range client.txTracker {
if time.Since(txInfo.timestamp) >= txTrackerPruningInterval {
delete(client.txTracker, hash)
}
}
}

// ConfirmTx periodically pings the provided node for the commitment of a transaction by its
// hash. It will continually loop until the context is cancelled, the tx is found or an error
// is encountered.
Expand Down Expand Up @@ -780,15 +764,15 @@ func (client *TxClient) ConfirmTx(ctx context.Context, txHash string) (*TxRespon
))
if resp.ExecutionCode != abci.CodeTypeOK {
span.RecordError(fmt.Errorf("txclient/ConfirmTx: execution error: %s", resp.Error))
client.deleteFromTxTracker(txHash)
client.TxTracker.deleteFromTxTracker(txHash)
return nil, client.buildExecutionError(txHash, resp)
}

span.AddEvent("txclient/ConfirmTx: transaction confirmed successfully")
client.deleteFromTxTracker(txHash)
client.TxTracker.deleteFromTxTracker(txHash)
return client.buildTxResponse(txHash, resp), nil
case core.TxStatusEvicted:
_, _, txBytes, exists := client.GetTxFromTxTracker(txHash)
_, _, txBytes, exists := client.TxTracker.GetTxFromTxTracker(txHash)
if !exists {
return nil, fmt.Errorf("tx: %s not found in txTracker; likely failed during broadcast", txHash)
}
Expand Down Expand Up @@ -819,7 +803,7 @@ func (client *TxClient) ConfirmTx(ctx context.Context, txHash string) (*TxRespon
span.AddEvent("txclient/ConfirmTx: transaction resubmitted successfully after eviction")
case core.TxStatusRejected:
span.RecordError(fmt.Errorf("txclient/ConfirmTx: transaction rejected: %s", resp.Error))
sequence, signer, _, exists := client.GetTxFromTxTracker(txHash)
sequence, signer, _, exists := client.TxTracker.GetTxFromTxTracker(txHash)
if !exists {
return nil, fmt.Errorf("tx: %s not found in tx client txTracker; likely failed during broadcast", txHash)
}
Expand All @@ -828,11 +812,11 @@ func (client *TxClient) ConfirmTx(ctx context.Context, txHash string) (*TxRespon
if err := client.signer.SetSequence(signer, sequence); err != nil {
return nil, fmt.Errorf("setting sequence: %w", err)
}
client.deleteFromTxTracker(txHash)
client.TxTracker.deleteFromTxTracker(txHash)
return nil, fmt.Errorf("tx with hash %s was rejected by the node with execution code: %d and log: %s", txHash, resp.ExecutionCode, resp.Error)
default:
span.RecordError(fmt.Errorf("txclient/ConfirmTx: unknown tx status for tx: %s", txHash))
client.deleteFromTxTracker(txHash)
client.TxTracker.deleteFromTxTracker(txHash)
if ctx.Err() != nil {
return nil, ctx.Err()
}
Expand Down Expand Up @@ -861,13 +845,6 @@ func extractSequenceError(fullError string) string {
return s
}

// deleteFromTxTracker safely deletes a transaction from the local tx tracker.
func (client *TxClient) deleteFromTxTracker(txHash string) {
client.mtx.Lock()
defer client.mtx.Unlock()
delete(client.txTracker, txHash)
}

// EstimateGasPriceAndUsage returns the estimated gas price based on the provided priority,
// and also the gas limit/used for the provided transaction.
// The gas limit is calculated by simulating the transaction and then calculating the amount of gas that was consumed during execution.
Expand Down Expand Up @@ -1045,26 +1022,6 @@ func (client *TxClient) getAccountNameFromMsgs(msgs []sdktypes.Msg) (string, err
return record.Name, nil
}

// trackTransaction tracks a transaction without acquiring the mutex.
// This should only be called when the caller already holds the mutex.
func (client *TxClient) trackTransaction(signer, txHash string, txBytes []byte) {
sequence := client.signer.Account(signer).Sequence()
client.txTracker[txHash] = txInfo{
sequence: sequence,
signer: signer,
timestamp: time.Now(),
txBytes: txBytes,
}
}

// GetTxFromTxTracker gets transaction info from the tx client's local tx tracker by its hash
func (client *TxClient) GetTxFromTxTracker(hash string) (sequence uint64, signer string, txBytes []byte, exists bool) {
client.mtx.Lock()
defer client.mtx.Unlock()
txInfo, exists := client.txTracker[hash]
return txInfo.sequence, txInfo.signer, txInfo.txBytes, exists
}

// Signer exposes the tx clients underlying signer
func (client *TxClient) Signer() *Signer {
return client.signer
Expand All @@ -1073,51 +1030,51 @@ func (client *TxClient) Signer() *Signer {
// StartTxQueueForTest starts the tx queue for testing purposes.
// This function is only intended for use in tests.
func (client *TxClient) StartTxQueueForTest(ctx context.Context) error {
if client.txQueue == nil {
if client.parallelTxQueue == nil {
return nil
}
return client.txQueue.start(ctx)
return client.parallelTxQueue.start(ctx)
}

// StopTxQueueForTest stops the tx queue for testing purposes.
// This function is only intended for use in tests.
func (client *TxClient) StopTxQueueForTest() {
if client.txQueue != nil {
client.txQueue.stop()
if client.parallelTxQueue != nil {
client.parallelTxQueue.stop()
}
}

// IsTxQueueStartedForTest returns whether the tx queue is started, for testing purposes.
// This function is only intended for use in tests.
func (client *TxClient) IsTxQueueStartedForTest() bool {
if client.txQueue == nil {
if client.parallelTxQueue == nil {
return false
}
return client.txQueue.isStarted()
return client.parallelTxQueue.isStarted()
}

// TxQueueWorkerCount returns the number of workers in the tx queue
func (client *TxClient) TxQueueWorkerCount() int {
if client.txQueue == nil {
if client.parallelTxQueue == nil {
return 0
}
return len(client.txQueue.workers)
return len(client.parallelTxQueue.workers)
}

// TxQueueWorkerAddress returns the address for the worker at the given index
func (client *TxClient) TxQueueWorkerAddress(index int) string {
if client.txQueue == nil || index < 0 || index >= len(client.txQueue.workers) {
if client.parallelTxQueue == nil || index < 0 || index >= len(client.parallelTxQueue.workers) {
return ""
}
return client.txQueue.workers[index].address
return client.parallelTxQueue.workers[index].address
}

// TxQueueWorkerAccountName returns the account name for the worker at the given index
func (client *TxClient) TxQueueWorkerAccountName(index int) string {
if client.txQueue == nil || index < 0 || index >= len(client.txQueue.workers) {
if client.parallelTxQueue == nil || index < 0 || index >= len(client.parallelTxQueue.workers) {
return ""
}
return client.txQueue.workers[index].accountName
return client.parallelTxQueue.workers[index].accountName
}

// QueryMinimumGasPrice queries both the nodes local and network wide
Expand Down
Loading