Skip to content
Draft
20 changes: 20 additions & 0 deletions wallet/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ type runtimeCache interface {
GetAccount(ctx context.Context,
query db.GetAccountQuery) (*db.AccountInfo, error)

// GetAccountSecret returns encrypted account-level signing material.
// The result mirrors the underlying db.AccountStore.GetAccountSecret
// contract and never contains plaintext key material.
GetAccountSecret(ctx context.Context,
query db.GetAccountSecretQuery) (*db.AccountSecret, error)

// ListAccounts returns accounts matching the given query. The result
// mirrors the underlying db.AccountStore.ListAccounts contract.
ListAccounts(ctx context.Context,
Expand Down Expand Up @@ -74,6 +80,20 @@ func (c *storeRuntimeCache) GetAccount(ctx context.Context,
return c.store.GetAccount(ctx, query)
}

// GetAccountSecret delegates to the underlying db.Store.
//
// NOTE: pass-through today. See storeRuntimeCache's TODO(yy).
//
// TODO(yy): drop the wrapcheck exemption once the cache layer wraps
// store errors with its own typed errors.
//
//nolint:wrapcheck
func (c *storeRuntimeCache) GetAccountSecret(ctx context.Context,
query db.GetAccountSecretQuery) (*db.AccountSecret, error) {

return c.store.GetAccountSecret(ctx, query)
}

// ListAccounts delegates to the underlying db.Store.
//
// NOTE: pass-through today. See storeRuntimeCache's TODO(yy).
Expand Down
1 change: 0 additions & 1 deletion wallet/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ var (
errMock = errors.New("mock error")
errChainMock = errors.New("chain error")
errPutMock = errors.New("put error")
errLockMock = errors.New("lock fail")
errDBFail = errors.New("db fail")
errDeriveFail = errors.New("derive fail")
errLoadStateFail = errors.New("load state fail")
Expand Down
39 changes: 21 additions & 18 deletions wallet/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -700,8 +700,10 @@ func (w *Wallet) handleUnlockReq(req unlockReq) {
return
}

// Attempt to unlock the underlying address manager.
err = w.DBUnlock(w.lifetimeCtx, req.req.Passphrase)
// Attempt to unlock the key vault. We pass a negative timeout to
// disable the vault's own auto-lock: the controller keeps owning the
// auto-lock schedule through its lockTimer below.
err = w.keyVault.Unlock(w.lifetimeCtx, req.req.Passphrase, -1)
if err != nil {
req.resp <- err
return
Expand Down Expand Up @@ -752,22 +754,12 @@ func (w *Wallet) handleLockReq(req lockReq) {
}
}

// Signal the address manager to lock, clearing sensitive data.
err = w.addrStore.Lock()
if err != nil {
log.Errorf("Could not lock wallet: %v", err)

// If the wallet is already locked, we consider this a success
// (idempotency) and proceed to ensure our state is consistent.
if !waddrmgr.IsError(err, waddrmgr.ErrLocked) {
req.resp <- err

return
}
}
// Signal the key vault to lock, clearing sensitive data. Lock is void
// and idempotent: the vault swallows an already-locked condition and
// logs any other failure internally.
w.keyVault.Lock()

// Even if an error occurred (e.g. already locked), we ensure the
// wallet's high-level state is synchronized to 'locked'.
// Synchronize the wallet's high-level state to 'locked'.
w.state.toLocked()

// Report the result back to the caller.
Expand All @@ -786,8 +778,19 @@ func (w *Wallet) handleChangePassphraseReq(req changePassphraseReq) {
return
}

// Delegate the cryptographic rotation to the database layer.
// Persist the passphrase rotation to the database.
err = w.DBPutPassphrase(w.lifetimeCtx, req.req)
if err != nil {
req.resp <- err
return
}

// A private passphrase change rotates the secret crypto keys, so let
// the key vault refresh any runtime state it caches under the new
// passphrase.
if req.req.ChangePrivate {
err = w.keyVault.RefreshPrivatePassphrase(req.req.PrivateNew)
}

// Report the result back to the caller.
req.resp <- err
Expand Down
115 changes: 37 additions & 78 deletions wallet/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ func TestHandleUnlockReq(t *testing.T) {
pass := []byte("password")
req := newUnlockReq(UnlockRequest{Passphrase: pass})

// Setup the expected call to the address manager's Unlock method.
deps.addrStore.On("Unlock", mock.Anything, pass).Return(nil).Once()
// Setup the expected call to the key vault's Unlock method.
deps.vault.On(
"Unlock", mock.Anything, pass, mock.Anything,
).Return(nil).Once()

// Act: Dispatch the unlock request to the handler.
w.handleUnlockReq(req)
Expand Down Expand Up @@ -107,7 +109,9 @@ func TestHandleUnlockReq_Errors(t *testing.T) {

pass := []byte("password")
req := newUnlockReq(UnlockRequest{Passphrase: pass})
deps.addrStore.On("Unlock", mock.Anything, pass).Return(
deps.vault.On(
"Unlock", mock.Anything, pass, mock.Anything,
).Return(
errDBMock,
).Once()

Expand Down Expand Up @@ -135,8 +139,8 @@ func TestHandleLockReq(t *testing.T) {

req := newLockReq()

// Setup the expected call to the address manager's Lock method.
deps.addrStore.On("Lock").Return(nil).Once()
// Setup the expected call to the key vault's Lock method.
deps.vault.On("Lock").Return().Once()

// Act: Dispatch the lock request to the handler.
w.handleLockReq(req)
Expand All @@ -148,42 +152,6 @@ func TestHandleLockReq(t *testing.T) {
require.False(t, w.state.isUnlocked())
}

// TestHandleLockReq_Idempotency verifies that if the wallet is already locked
// (indicated by waddrmgr.ErrLocked), the lock request treats it as a success
// and ensures the state is consistent.
func TestHandleLockReq_Idempotency(t *testing.T) {
t.Parallel()

// Arrange: Create a test wallet and transition it to 'Started'.
w, deps := createTestWalletWithMocks(t)
require.NoError(t, w.state.toStarting())
require.NoError(t, w.state.toStarted())

// Transition the wallet to the 'Unlocked' state for testing.
w.state.toUnlocked()

req := newLockReq()

// Setup the expected call to the address manager's Lock method
// returning ErrLocked.
errLocked := waddrmgr.ManagerError{
ErrorCode: waddrmgr.ErrLocked,
Description: "address manager is locked",
}
deps.addrStore.On("Lock").Return(errLocked).Once()

// Act: Dispatch the lock request to the handler.
w.handleLockReq(req)

// Assert: Verify that the response indicates success and the wallet
// state is 'Locked'.
resp := <-req.resp
require.NoError(t, resp)
require.False(t, w.state.isUnlocked())
}

// TestHandleLockReq_Errors verifies that handleLockReq correctly handles error
// conditions, such as attempting to lock a stopped wallet.
func TestHandleLockReq_Errors(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -249,12 +217,15 @@ func TestHandleChangePassphraseReq(t *testing.T) {
}
req := newChangePassphraseReq(reqStruct)

// Setup the expected call to the address manager's ChangePassphrase
// method.
// DBPutPassphrase drives the legacy address manager for the private
// rotation, then the controller refreshes the vault's runtime state.
deps.addrStore.On(
"ChangePassphrase", mock.Anything, []byte("old"),
[]byte("new"), true, mock.Anything,
).Return(nil).Once()
deps.vault.On(
"RefreshPrivatePassphrase", []byte("new"),
).Return(nil).Once()

// Act: Call the handler.
w.handleChangePassphraseReq(req)
Expand Down Expand Up @@ -622,8 +593,8 @@ func TestControllerLock(t *testing.T) {
w.state.toUnlocked()
require.True(t, w.state.isUnlocked())

// Expect a call to the address manager's Lock method.
deps.addrStore.On("Lock").Return(nil).Once()
// Expect a call to the key vault's Lock method.
deps.vault.On("Lock").Return().Once()

// Act: Call the Lock method.
err := w.Lock(t.Context())
Expand Down Expand Up @@ -661,8 +632,10 @@ func TestControllerUnlock(t *testing.T) {

pass := []byte("password")

// Expect a call to the address manager's Unlock method.
deps.addrStore.On("Unlock", mock.Anything, pass).Return(nil).Once()
// Expect a call to the key vault's Unlock method.
deps.vault.On(
"Unlock", mock.Anything, pass, mock.Anything,
).Return(nil).Once()

// Act: Call the Unlock method.
err := w.Unlock(t.Context(), UnlockRequest{Passphrase: pass})
Expand Down Expand Up @@ -704,10 +677,14 @@ func TestControllerChangePassphrase(t *testing.T) {
PrivateNew: []byte("new"),
}

// Expect a call to ChangePassphrase in the address store.
// DBPutPassphrase drives the legacy address manager for the private
// rotation, then the controller refreshes the vault's runtime state.
deps.addrStore.On(
"ChangePassphrase", mock.Anything, []byte("old"), []byte("new"),
true, mock.Anything,
"ChangePassphrase", mock.Anything, []byte("old"),
[]byte("new"), true, mock.Anything,
).Return(nil).Once()
deps.vault.On(
"RefreshPrivatePassphrase", []byte("new"),
).Return(nil).Once()

// Act: Call ChangePassphrase.
Expand Down Expand Up @@ -861,9 +838,9 @@ func TestMainLoop_AutoLock(t *testing.T) {
w.lockTimer = time.NewTimer(time.Millisecond * 10)

lockCalled := make(chan struct{})
deps.addrStore.On("Lock").Run(func(args mock.Arguments) {
deps.vault.On("Lock").Run(func(args mock.Arguments) {
close(lockCalled)
}).Return(nil).Once()
}).Return().Once()

// Act: Start main loop.
w.wg.Add(1)
Expand Down Expand Up @@ -1437,10 +1414,12 @@ func TestControllerUnlock_DefaultTimeout(t *testing.T) {

pass := []byte("pass")
req := UnlockRequest{Passphrase: pass}
deps.addrStore.On("Unlock", mock.Anything, pass).Return(nil).Once()
deps.vault.On(
"Unlock", mock.Anything, pass, mock.Anything,
).Return(nil).Once()
// Auto-lock might trigger if the test runs slowly, but it's not
// guaranteed.
deps.addrStore.On("Lock").Return(nil).Maybe()
deps.vault.On("Lock").Return().Maybe()

// Act: Perform Unlock with default timeout.
err := w.Unlock(t.Context(), req)
Expand Down Expand Up @@ -1497,7 +1476,9 @@ func TestControllerUnlock_NegativeTimeout(t *testing.T) {

pass := []byte("pass")
req := UnlockRequest{Passphrase: pass, Timeout: -1}
deps.addrStore.On("Unlock", mock.Anything, pass).Return(nil).Once()
deps.vault.On(
"Unlock", mock.Anything, pass, mock.Anything,
).Return(nil).Once()

// Act: Perform Unlock with negative timeout (no auto-lock).
err := w.Unlock(t.Context(), req)
Expand Down Expand Up @@ -1526,7 +1507,7 @@ func TestControllerUnlock_DBUnlockFail(t *testing.T) {
go w.mainLoop()

pass := []byte("pass")
deps.addrStore.On("Unlock", mock.Anything, pass).Return(
deps.vault.On("Unlock", mock.Anything, pass, mock.Anything).Return(
errDBMock).Once()

// Act: Attempt Unlock.
Expand All @@ -1540,28 +1521,6 @@ func TestControllerUnlock_DBUnlockFail(t *testing.T) {
w.wg.Wait()
}

// TestHandleLockReq_LockError verifies error handling when Lock fails.
func TestHandleLockReq_LockError(t *testing.T) {
t.Parallel()

// Arrange: Setup mock expectations where internal lock fails.
w, deps := createTestWalletWithMocks(t)

require.NoError(t, w.state.toStarting())
require.NoError(t, w.state.toStarted())

req := lockReq{resp: make(chan error, 1)}

deps.addrStore.On("Lock").Return(errLockMock).Once()

// Act: Handle lock request.
w.handleLockReq(req)
err := <-req.resp

// Assert: Verify error.
require.ErrorContains(t, err, "lock fail")
}

// TestSubmitRescanRequest_HeightOverflow verifies large start height rejection.
func TestSubmitRescanRequest_HeightOverflow(t *testing.T) {
t.Parallel()
Expand Down
17 changes: 0 additions & 17 deletions wallet/db_ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,23 +198,6 @@ func (w *Wallet) DBDeleteExpiredLockedOutputs(_ context.Context) error {
return nil
}

// DBUnlock attempts to unlock the wallet's address manager with the provided
// passphrase.
//
// TODO(yy): Refactor this in the `Store` implementation - the only db
// operation needed is to load the account info and derive the private keys.
func (w *Wallet) DBUnlock(_ context.Context, passphrase []byte) error {
err := walletdb.View(w.cfg.DB, func(tx walletdb.ReadTx) error {
addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
return w.addrStore.Unlock(addrmgrNs, passphrase)
})
if err != nil {
return fmt.Errorf("view: %w", err)
}

return nil
}

// DBPutPassphrase updates the wallet's public or private passphrases.
//
// TODO(yy): Refactor this in the `Store` implementation - we can call
Expand Down
19 changes: 0 additions & 19 deletions wallet/db_ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,25 +142,6 @@ func TestDBBirthdayBlock(t *testing.T) {
require.Equal(t, block, retBlock)
}

// TestDBUnlock verifies that the wallet can successfully unlock its address
// manager using the provided passphrase.
func TestDBUnlock(t *testing.T) {
t.Parallel()

// Arrange: Create a test wallet and setup the expected mock call for
// unlocking the address manager.
w, mocks := createTestWalletWithMocks(t)
pass := []byte("password")

mocks.addrStore.On("Unlock", mock.Anything, pass).Return(nil).Once()

// Act: Attempt to unlock the wallet with the passphrase.
err := w.DBUnlock(t.Context(), pass)

// Assert: Verify that the unlock operation succeeded.
require.NoError(t, err)
}

// TestDBDeleteExpiredLockedOutputs verifies that the wallet successfully
// invokes the transaction store to remove any expired output locks.
func TestDBDeleteExpiredLockedOutputs(t *testing.T) {
Expand Down
3 changes: 1 addition & 2 deletions wallet/deprecated.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import (
"github.com/btcsuite/btcwallet/internal/prompt"
"github.com/btcsuite/btcwallet/waddrmgr"
kvdb "github.com/btcsuite/btcwallet/wallet/internal/db/kvdb"
"github.com/btcsuite/btcwallet/wallet/internal/keyvault"
"github.com/btcsuite/btcwallet/wallet/txauthor"
"github.com/btcsuite/btcwallet/wallet/txrules"
"github.com/btcsuite/btcwallet/walletdb"
Expand Down Expand Up @@ -7215,7 +7214,7 @@ func OpenWithRetry(db walletdb.DB, pubPass []byte, cbs *waddrmgr.OpenCallbacks,
id: walletID,
addrStore: addrMgr,
store: store,
keyVault: keyvault.NewDBVault(store, walletID),
keyVault: kvdb.NewLegacyManagerVault(db, addrMgr),
txStore: txMgr,
walletDeprecated: deprecated,
}
Expand Down
Loading
Loading