|
1 | 1 | package base |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "container/list" |
4 | 5 | "context" |
| 6 | + "fmt" |
5 | 7 |
|
6 | 8 | sdk "github.com/cosmos/cosmos-sdk/types" |
7 | 9 | sdkmempool "github.com/cosmos/cosmos-sdk/types/mempool" |
| 10 | + |
| 11 | + signer_extraction "github.com/skip-mev/block-sdk/v2/adapters/signer_extraction_adapter" |
8 | 12 | ) |
9 | 13 |
|
10 | 14 | var ( |
11 | 15 | _ MempoolInterface = (*DefaultMempool[int64])(nil) |
12 | 16 | _ sdkmempool.Iterator = (*DefaultIterator)(nil) |
13 | 17 | ) |
14 | 18 |
|
15 | | -// DefaultMempool implements a simple mempool that stores all transactions |
| 19 | +// DefaultMempool implements a FIFO mempool with duplicate detection based on sender:nonce. |
| 20 | +// It uses a linked list for FIFO ordering and a map for O(1) duplicate detection. |
16 | 21 | type DefaultMempool[C comparable] struct { |
17 | | - txs []sdk.Tx |
18 | | - MaxTx int |
| 22 | + txs *list.List // Linked list for FIFO ordering |
| 23 | + seen map[string]*list.Element // Map from sender:nonce to list element |
| 24 | + MaxTx int // Maximum number of transactions |
| 25 | + signerExtractor signer_extraction.Adapter // For extracting signer info |
19 | 26 | } |
20 | 27 |
|
21 | | -// NewDefaultMempool creates a new DefaultMempool |
22 | | -func NewDefaultMempool[C comparable](maxTxs int) *DefaultMempool[C] { |
| 28 | +// NewDefaultMempool creates a new FIFO mempool with duplicate detection |
| 29 | +func NewDefaultMempool[C comparable](maxTxs int, signerExtractor signer_extraction.Adapter) *DefaultMempool[C] { |
23 | 30 | return &DefaultMempool[C]{ |
24 | | - txs: make([]sdk.Tx, 0), |
25 | | - MaxTx: maxTxs, |
| 31 | + txs: list.New(), |
| 32 | + seen: make(map[string]*list.Element), |
| 33 | + MaxTx: maxTxs, |
| 34 | + signerExtractor: signerExtractor, |
| 35 | + } |
| 36 | +} |
| 37 | + |
| 38 | +// getTxKey creates a unique key from sender:nonce combination |
| 39 | +func (mp *DefaultMempool[C]) getTxKey(tx sdk.Tx) (string, error) { |
| 40 | + signers, err := mp.signerExtractor.GetSigners(tx) |
| 41 | + if err != nil { |
| 42 | + return "", err |
26 | 43 | } |
| 44 | + if len(signers) == 0 { |
| 45 | + return "", fmt.Errorf("tx must have at least one signer") |
| 46 | + } |
| 47 | + |
| 48 | + sig := signers[0] |
| 49 | + nonce := sig.Sequence |
| 50 | + sender := sig.Signer.String() |
| 51 | + |
| 52 | + return fmt.Sprintf("%s:%d", sender, nonce), nil |
27 | 53 | } |
28 | 54 |
|
29 | 55 | // Insert implements MempoolInterface. |
30 | 56 | func (mp *DefaultMempool[C]) Insert(_ context.Context, tx sdk.Tx) error { |
| 57 | + if mp.MaxTx < 0 { |
| 58 | + return nil // No-op if MaxTx is negative |
| 59 | + } |
| 60 | + |
| 61 | + key, err := mp.getTxKey(tx) |
| 62 | + if err != nil { |
| 63 | + return fmt.Errorf("failed to get tx key for insertion: %w", err) |
| 64 | + } |
| 65 | + |
| 66 | + // Check if this tx already exists |
| 67 | + if _, exists := mp.seen[key]; exists { |
| 68 | + // No-op: transaction with same sender:nonce already exists |
| 69 | + return nil |
| 70 | + } |
| 71 | + |
| 72 | + // Check capacity before adding new transaction |
31 | 73 | if mp.MaxTx > 0 && mp.CountTx() >= mp.MaxTx { |
32 | 74 | return sdkmempool.ErrMempoolTxMaxCapacity |
33 | | - } else if mp.MaxTx < 0 { |
34 | | - return nil |
35 | 75 | } |
36 | | - mp.txs = append(mp.txs, tx) |
| 76 | + |
| 77 | + // Add new transaction |
| 78 | + element := mp.txs.PushBack(tx) |
| 79 | + mp.seen[key] = element |
| 80 | + |
37 | 81 | return nil |
38 | 82 | } |
39 | 83 |
|
40 | 84 | // Remove implements MempoolInterface. |
41 | 85 | func (mp *DefaultMempool[C]) Remove(tx sdk.Tx) error { |
42 | | - for i, t := range mp.txs { |
43 | | - if t == tx { |
44 | | - mp.txs = append(mp.txs[:i], mp.txs[i+1:]...) |
45 | | - return nil |
46 | | - } |
| 86 | + key, err := mp.getTxKey(tx) |
| 87 | + if err != nil { |
| 88 | + return fmt.Errorf("failed to get tx key for removal: %w", err) |
| 89 | + } |
| 90 | + |
| 91 | + // Remove by key |
| 92 | + if element, exists := mp.seen[key]; exists { |
| 93 | + mp.txs.Remove(element) |
| 94 | + delete(mp.seen, key) |
47 | 95 | } |
| 96 | + |
48 | 97 | return nil |
49 | 98 | } |
50 | 99 |
|
51 | 100 | // Select implements MempoolInterface. |
52 | 101 | func (mp *DefaultMempool[C]) Select(_ context.Context, _ [][]byte) sdkmempool.Iterator { |
53 | | - if len(mp.txs) == 0 { |
| 102 | + if mp.txs.Len() == 0 { |
54 | 103 | return nil |
55 | 104 | } |
56 | 105 |
|
57 | 106 | return &DefaultIterator{ |
58 | | - txs: mp.txs, |
59 | | - curr: 0, |
| 107 | + current: mp.txs.Front(), |
60 | 108 | } |
61 | 109 | } |
62 | 110 |
|
63 | 111 | // CountTx implements MempoolInterface. |
64 | 112 | func (mp *DefaultMempool[C]) CountTx() int { |
65 | | - return len(mp.txs) |
| 113 | + return mp.txs.Len() |
66 | 114 | } |
67 | 115 |
|
68 | 116 | // Contains implements MempoolInterface. |
69 | 117 | func (mp *DefaultMempool[C]) Contains(tx sdk.Tx) bool { |
70 | | - for _, t := range mp.txs { |
71 | | - if t == tx { |
72 | | - return true |
73 | | - } |
| 118 | + key, err := mp.getTxKey(tx) |
| 119 | + if err != nil { |
| 120 | + return false |
| 121 | + } |
| 122 | + |
| 123 | + // Check if we have this sender:nonce combination |
| 124 | + if element, exists := mp.seen[key]; exists { |
| 125 | + // Return true only if it's the exact same transaction object |
| 126 | + return element.Value.(sdk.Tx) == tx |
74 | 127 | } |
| 128 | + |
75 | 129 | return false |
76 | 130 | } |
77 | 131 |
|
78 | | -// DefaultIterator implements sdkmempool.Iterator |
| 132 | +// DefaultIterator implements sdkmempool.Iterator for FIFO mempool |
79 | 133 | type DefaultIterator struct { |
80 | | - txs []sdk.Tx |
81 | | - curr int |
| 134 | + current *list.Element |
82 | 135 | } |
83 | 136 |
|
84 | 137 | // Next implements sdkmempool.Iterator |
85 | 138 | func (i *DefaultIterator) Next() sdkmempool.Iterator { |
86 | | - if i.curr >= len(i.txs)-1 { |
| 139 | + if i.current == nil { |
87 | 140 | return nil |
88 | 141 | } |
89 | | - i.curr++ |
| 142 | + |
| 143 | + i.current = i.current.Next() |
| 144 | + if i.current == nil { |
| 145 | + return nil |
| 146 | + } |
| 147 | + |
90 | 148 | return i |
91 | 149 | } |
92 | 150 |
|
93 | 151 | // Tx implements sdkmempool.Iterator |
94 | 152 | func (i *DefaultIterator) Tx() sdk.Tx { |
95 | | - return i.txs[i.curr] |
| 153 | + if i.current == nil { |
| 154 | + return nil |
| 155 | + } |
| 156 | + |
| 157 | + return i.current.Value.(sdk.Tx) |
96 | 158 | } |
0 commit comments