Skip to content

Commit b70bd9e

Browse files
committed
db: add scan horizon extension helpers
Add ExtendScanHorizon and its validateHorizon, deriveHorizonRange, and deriveNextValidChild helpers, mirroring the legacy ScopedKeyManager extend-addresses semantics: derive and persist every valid child through the horizon index, skip HD-invalid children, and advance the branch next-index monotonically.
1 parent 4fe5f14 commit b70bd9e

2 files changed

Lines changed: 332 additions & 0 deletions

File tree

wallet/internal/db/scan_batch_common.go

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ package db
22

33
import (
44
"context"
5+
"errors"
6+
"fmt"
7+
8+
"github.com/btcsuite/btcd/btcutil/hdkeychain"
59
)
610

711
// externalBranch and internalBranch are the BIP44 branch numbers the horizon
@@ -73,3 +77,150 @@ type ScanHorizonOps interface {
7377
AdvanceNextIndex(ctx context.Context, accountID int64, branch uint32,
7478
nextIndex uint32) error
7579
}
80+
81+
// ExtendScanHorizon ensures every valid child through horizon.Index is derived
82+
// and persisted on the requested branch, mirroring the legacy
83+
// ScopedKeyManager.ExtendAddresses semantics used by the kvdb backend:
84+
//
85+
// - Derivation starts from the branch's current next-index and runs through
86+
// horizon.Index inclusive.
87+
// - When horizon.Index is already below the current next-index the call is a
88+
// no-op, so replaying a horizon never inserts duplicate rows.
89+
// - HD-invalid child indices are skipped (the next-index simply advances)
90+
// instead of failing, so the SQL path matches the kvdb invalid-child skip.
91+
// - After derivation the branch next-index advances to one past the last
92+
// derived child, leaving the same terminal counter the legacy path would.
93+
func ExtendScanHorizon(ctx context.Context, ops ScanHorizonOps,
94+
deriveFn AddressDerivationFunc, horizon ScanHorizon) error {
95+
96+
branch, err := validateHorizon(deriveFn, horizon)
97+
if err != nil {
98+
return err
99+
}
100+
101+
account, err := ops.GetHorizonAccount(ctx, horizon)
102+
if err != nil {
103+
return fmt.Errorf("extend horizon: %w", err)
104+
}
105+
106+
addrType, nextIndex := account.branchState(branch)
107+
108+
// Nothing to do when the scan did not advance past the persisted tip.
109+
if horizon.Index < nextIndex {
110+
return nil
111+
}
112+
113+
nextIndex, err = deriveHorizonRange(
114+
ctx, ops, deriveFn, account, horizon, branch, addrType, nextIndex,
115+
)
116+
if err != nil {
117+
return err
118+
}
119+
120+
// Persist the advanced next-index so subsequent address allocation and
121+
// horizon replays resume past the derived range.
122+
err = ops.AdvanceNextIndex(ctx, account.AccountID, branch, nextIndex)
123+
if err != nil {
124+
return fmt.Errorf("extend horizon: advance next index: %w", err)
125+
}
126+
127+
return nil
128+
}
129+
130+
// validateHorizon checks the derivation callback and horizon bounds, returning
131+
// the validated branch number.
132+
func validateHorizon(deriveFn AddressDerivationFunc,
133+
horizon ScanHorizon) (uint32, error) {
134+
135+
if deriveFn == nil {
136+
return 0, fmt.Errorf("extend horizon: %w",
137+
errNilAddressDerivationFunc)
138+
}
139+
140+
// Recovery only ever reports the external or internal branch; reject
141+
// anything else up front so an unexpected branch cannot silently derive
142+
// against the wrong next-index column.
143+
branch := horizon.Branch
144+
if branch != externalBranch && branch != internalBranch {
145+
return 0, fmt.Errorf("extend horizon: %w: branch %d",
146+
ErrInvalidParam, branch)
147+
}
148+
149+
// kvdb caps a single extension at MaxAddressesPerAccount; mirror the bound
150+
// so both backends reject the same out-of-range horizon.
151+
if horizon.Index > MaxAddressIndex {
152+
return 0, fmt.Errorf("extend horizon: %w",
153+
ErrMaxAddressIndexReached)
154+
}
155+
156+
return branch, nil
157+
}
158+
159+
// deriveHorizonRange derives and persists one valid child per index from
160+
// nextIndex through horizon.Index inclusive and returns the advanced next
161+
// index. It mirrors the nested loop in the legacy extendAddresses: each outer
162+
// step finds the next valid child, skipping HD-invalid indices even past
163+
// horizon.Index, so the terminal next index matches the address manager.
164+
func deriveHorizonRange(ctx context.Context, ops ScanHorizonOps,
165+
deriveFn AddressDerivationFunc, account *HorizonAccount,
166+
horizon ScanHorizon, branch uint32, addrType AddressType,
167+
nextIndex uint32) (uint32, error) {
168+
169+
for nextIndex <= horizon.Index {
170+
next, err := deriveNextValidChild(
171+
ctx, ops, deriveFn, account, horizon.Scope, branch, addrType,
172+
nextIndex,
173+
)
174+
if err != nil {
175+
return 0, err
176+
}
177+
178+
nextIndex = next
179+
}
180+
181+
return nextIndex, nil
182+
}
183+
184+
// deriveNextValidChild derives and persists the first valid child at or after
185+
// startIndex, returning the index immediately past the persisted child.
186+
// HD-invalid children are skipped without persisting a row, exactly like the
187+
// inner loop of the legacy extendAddresses.
188+
func deriveNextValidChild(ctx context.Context, ops ScanHorizonOps,
189+
deriveFn AddressDerivationFunc, account *HorizonAccount, scope KeyScope,
190+
branch uint32, addrType AddressType, startIndex uint32) (uint32, error) {
191+
192+
for index := startIndex; ; index++ {
193+
derived, err := deriveFn(ctx, AddressDerivationParams{
194+
Scope: scope,
195+
DerivedAccountNumber: &account.AccountNumber,
196+
Branch: branch,
197+
Index: index,
198+
AddrType: addrType,
199+
AccountPubKey: account.AccountPubKey,
200+
})
201+
if errors.Is(err, hdkeychain.ErrInvalidChild) {
202+
continue
203+
}
204+
205+
if err != nil {
206+
return 0, fmt.Errorf("extend horizon: derive index %d: %w",
207+
index, err)
208+
}
209+
210+
if derived == nil {
211+
return 0, fmt.Errorf("extend horizon: derive index %d: %w",
212+
index, errNilDerivedAddressData)
213+
}
214+
215+
err = ops.InsertDerivedAddress(
216+
ctx, account.AccountID, addrType, branch, index,
217+
derived.ScriptPubKey, derived.PubKey,
218+
)
219+
if err != nil {
220+
return 0, fmt.Errorf("extend horizon: insert index %d: %w",
221+
index, err)
222+
}
223+
224+
return index + 1, nil
225+
}
226+
}
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
package db
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/btcsuite/btcd/btcutil/hdkeychain"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
// fakeHorizonOps is a minimal ScanHorizonOps used by the ExtendScanHorizon
12+
// tests. It records the inserted child indices and the final advanced
13+
// next-index so each scenario can assert the exact derivation work performed.
14+
type fakeHorizonOps struct {
15+
account *HorizonAccount
16+
17+
inserted []uint32
18+
advancedTo uint32
19+
advanceCall bool
20+
}
21+
22+
// GetHorizonAccount returns the preconfigured account.
23+
func (o *fakeHorizonOps) GetHorizonAccount(_ context.Context,
24+
_ ScanHorizon) (*HorizonAccount, error) {
25+
26+
return o.account, nil
27+
}
28+
29+
// InsertDerivedAddress records the child index that was derived and persisted.
30+
func (o *fakeHorizonOps) InsertDerivedAddress(_ context.Context, _ int64,
31+
_ AddressType, _ uint32, index uint32, _ []byte, _ []byte) error {
32+
33+
o.inserted = append(o.inserted, index)
34+
35+
return nil
36+
}
37+
38+
// AdvanceNextIndex records the advanced next-index value.
39+
func (o *fakeHorizonOps) AdvanceNextIndex(_ context.Context, _ int64,
40+
_ uint32, nextIndex uint32) error {
41+
42+
o.advanceCall = true
43+
o.advancedTo = nextIndex
44+
45+
return nil
46+
}
47+
48+
// validChildDeriveFunc returns a derivation callback that always derives a
49+
// valid address, recording nothing of the input.
50+
func validChildDeriveFunc() AddressDerivationFunc {
51+
return func(_ context.Context,
52+
_ AddressDerivationParams) (*DerivedAddressData, error) {
53+
54+
return &DerivedAddressData{
55+
ScriptPubKey: []byte{0x00, 0x14},
56+
PubKey: []byte{0x02},
57+
}, nil
58+
}
59+
}
60+
61+
// TestExtendScanHorizonNoOpBelowNextIndex verifies that ExtendScanHorizon does
62+
// nothing when the discovered index is below the branch's current next index,
63+
// so replaying an already-covered horizon never re-derives or advances.
64+
func TestExtendScanHorizonNoOpBelowNextIndex(t *testing.T) {
65+
t.Parallel()
66+
67+
ops := &fakeHorizonOps{
68+
account: &HorizonAccount{
69+
AccountID: 7,
70+
NextExternalIndex: 5,
71+
},
72+
}
73+
74+
// The horizon index is below the persisted next-external index, so the
75+
// call must be a no-op.
76+
err := ExtendScanHorizon(t.Context(), ops, validChildDeriveFunc(),
77+
ScanHorizon{Branch: externalBranch, Index: 3})
78+
require.NoError(t, err)
79+
80+
require.Empty(t, ops.inserted)
81+
require.False(t, ops.advanceCall)
82+
}
83+
84+
// TestExtendScanHorizonRejectsInvalidBranch verifies that a branch other than
85+
// external or internal is rejected with ErrInvalidParam before any account
86+
// load or derivation.
87+
func TestExtendScanHorizonRejectsInvalidBranch(t *testing.T) {
88+
t.Parallel()
89+
90+
ops := &fakeHorizonOps{account: &HorizonAccount{AccountID: 7}}
91+
92+
err := ExtendScanHorizon(t.Context(), ops, validChildDeriveFunc(),
93+
ScanHorizon{Branch: 2, Index: 1})
94+
require.ErrorIs(t, err, ErrInvalidParam)
95+
96+
require.Empty(t, ops.inserted)
97+
require.False(t, ops.advanceCall)
98+
}
99+
100+
// TestExtendScanHorizonRejectsMaxIndex verifies that a horizon index above
101+
// MaxAddressIndex is rejected with ErrMaxAddressIndexReached, matching the
102+
// legacy address manager's per-account child bound.
103+
func TestExtendScanHorizonRejectsMaxIndex(t *testing.T) {
104+
t.Parallel()
105+
106+
ops := &fakeHorizonOps{account: &HorizonAccount{AccountID: 7}}
107+
108+
err := ExtendScanHorizon(t.Context(), ops, validChildDeriveFunc(),
109+
ScanHorizon{Branch: externalBranch, Index: MaxAddressIndex + 1})
110+
require.ErrorIs(t, err, ErrMaxAddressIndexReached)
111+
112+
require.Empty(t, ops.inserted)
113+
require.False(t, ops.advanceCall)
114+
}
115+
116+
// TestExtendScanHorizonSkipsInvalidChild verifies that an HD-invalid child
117+
// index is skipped without persisting a row while derivation advances to the
118+
// next valid child, mirroring the legacy extendAddresses inner loop.
119+
func TestExtendScanHorizonSkipsInvalidChild(t *testing.T) {
120+
t.Parallel()
121+
122+
ops := &fakeHorizonOps{
123+
account: &HorizonAccount{
124+
AccountID: 7,
125+
NextExternalIndex: 0,
126+
},
127+
}
128+
129+
// Index 1 derives an ErrInvalidChild, so it must be skipped: only indices
130+
// 0 and 2 are persisted while the loop still reaches horizon index 2.
131+
deriveFn := func(_ context.Context,
132+
params AddressDerivationParams) (*DerivedAddressData, error) {
133+
134+
if params.Index == 1 {
135+
return nil, hdkeychain.ErrInvalidChild
136+
}
137+
138+
return &DerivedAddressData{
139+
ScriptPubKey: []byte{0x00, 0x14},
140+
PubKey: []byte{0x02},
141+
}, nil
142+
}
143+
144+
err := ExtendScanHorizon(t.Context(), ops, deriveFn,
145+
ScanHorizon{Branch: externalBranch, Index: 2})
146+
require.NoError(t, err)
147+
148+
// Index 1 was skipped; indices 0 and 2 were persisted.
149+
require.Equal(t, []uint32{0, 2}, ops.inserted)
150+
151+
// The next index advanced past the last derived child.
152+
require.True(t, ops.advanceCall)
153+
require.Equal(t, uint32(3), ops.advancedTo)
154+
}
155+
156+
// TestExtendScanHorizonAdvancesAfterInserts verifies that, after deriving the
157+
// full range, ExtendScanHorizon persists every child and advances the branch
158+
// next-index to one past the last derived child.
159+
func TestExtendScanHorizonAdvancesAfterInserts(t *testing.T) {
160+
t.Parallel()
161+
162+
ops := &fakeHorizonOps{
163+
account: &HorizonAccount{
164+
AccountID: 7,
165+
NextInternalIndex: 1,
166+
AddrSchema: ScopeAddrSchema{
167+
InternalAddrType: WitnessPubKey,
168+
ExternalAddrType: WitnessPubKey,
169+
},
170+
},
171+
}
172+
173+
// Extend the internal branch from its next index 1 through index 3.
174+
err := ExtendScanHorizon(t.Context(), ops, validChildDeriveFunc(),
175+
ScanHorizon{Branch: internalBranch, Index: 3})
176+
require.NoError(t, err)
177+
178+
require.Equal(t, []uint32{1, 2, 3}, ops.inserted)
179+
require.True(t, ops.advanceCall)
180+
require.Equal(t, uint32(4), ops.advancedTo)
181+
}

0 commit comments

Comments
 (0)