Skip to content

Commit 372d6ac

Browse files
committed
Replace context.TODO() with appropriate context
Signed-off-by: Takeshi Arabiki <takeshi.arabiki@datachain.jp>
1 parent b4a60cf commit 372d6ac

File tree

6 files changed

+47
-45
lines changed

6 files changed

+47
-45
lines changed

module/facade.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ func GetCurrentEpoch(v uint64) uint64 {
1616
return getCurrentEpoch(v)
1717
}
1818

19-
func QueryFinalizedHeader(fn getHeaderFn, height uint64, limitHeight uint64) ([]*ETHHeader, error) {
20-
return queryFinalizedHeader(context.TODO(), fn, height, limitHeight)
19+
func QueryFinalizedHeader(ctx context.Context, fn getHeaderFn, height uint64, limitHeight uint64) ([]*ETHHeader, error) {
20+
return queryFinalizedHeader(ctx, fn, height, limitHeight)
2121
}
2222

23-
func QueryValidatorSetAndTurnLength(fn getHeaderFn, height uint64) (Validators, uint8, error) {
24-
return queryValidatorSetAndTurnLength(context.TODO(), fn, height)
23+
func QueryValidatorSetAndTurnLength(ctx context.Context, fn getHeaderFn, height uint64) (Validators, uint8, error) {
24+
return queryValidatorSetAndTurnLength(ctx, fn, height)
2525
}
2626

2727
func ExtractValidatorSetAndTurnLength(h *types.Header) (Validators, uint8, error) {

module/prover.go

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,15 @@ func (pr *Prover) SetupForRelay(ctx context.Context) error {
5252
// These states will be submitted to the counterparty chain as MsgCreateClient.
5353
// If `height` is nil, the latest finalized height is selected automatically.
5454
func (pr *Prover) CreateInitialLightClientState(ctx context.Context, height exported.Height) (exported.ClientState, exported.ConsensusState, error) {
55-
latestHeight, err := pr.chain.LatestHeight(context.TODO())
55+
latestHeight, err := pr.chain.LatestHeight(ctx)
5656
if err != nil {
5757
return nil, nil, err
5858
}
5959
var finalizedHeader []*ETHHeader
6060
if height == nil {
61-
_, finalizedHeader, err = queryLatestFinalizedHeader(context.TODO(), pr.chain.Header, latestHeight.GetRevisionHeight())
61+
_, finalizedHeader, err = queryLatestFinalizedHeader(ctx, pr.chain.Header, latestHeight.GetRevisionHeight())
6262
} else {
63-
finalizedHeader, err = queryFinalizedHeader(context.TODO(), pr.chain.Header, height.GetRevisionHeight(), latestHeight.GetRevisionHeight())
63+
finalizedHeader, err = queryFinalizedHeader(ctx, pr.chain.Header, height.GetRevisionHeight(), latestHeight.GetRevisionHeight())
6464
}
6565
if err != nil {
6666
return nil, nil, err
@@ -69,18 +69,18 @@ func (pr *Prover) CreateInitialLightClientState(ctx context.Context, height expo
6969
return nil, nil, fmt.Errorf("no finalized headers were found up to %d", latestHeight.GetRevisionHeight())
7070
}
7171
//Header should be Finalized, not necessarily Verifiable.
72-
return pr.buildInitialState(&Header{
72+
return pr.buildInitialState(ctx, &Header{
7373
Headers: finalizedHeader,
7474
})
7575
}
7676

7777
// GetLatestFinalizedHeader returns the latest finalized header from the chain
7878
func (pr *Prover) GetLatestFinalizedHeader(ctx context.Context) (out core.Header, err error) {
79-
latestHeight, err := pr.chain.LatestHeight(context.TODO())
79+
latestHeight, err := pr.chain.LatestHeight(ctx)
8080
if err != nil {
8181
return nil, err
8282
}
83-
header, err := pr.GetLatestFinalizedHeaderByLatestHeight(context.TODO(), latestHeight.GetRevisionHeight())
83+
header, err := pr.GetLatestFinalizedHeaderByLatestHeight(ctx, latestHeight.GetRevisionHeight())
8484
if err != nil {
8585
return nil, err
8686
}
@@ -90,31 +90,31 @@ func (pr *Prover) GetLatestFinalizedHeader(ctx context.Context) (out core.Header
9090

9191
// GetLatestFinalizedHeaderByLatestHeight returns the latest finalized verifiable header from the chain
9292
func (pr *Prover) GetLatestFinalizedHeaderByLatestHeight(ctx context.Context, latestBlockNumber uint64) (core.Header, error) {
93-
height, finalizedHeader, err := queryLatestFinalizedHeader(context.TODO(), pr.chain.Header, latestBlockNumber)
93+
height, finalizedHeader, err := queryLatestFinalizedHeader(ctx, pr.chain.Header, latestBlockNumber)
9494
if err != nil {
9595
return nil, err
9696
}
9797
// Make headers verifiable
98-
return pr.withValidators(height, finalizedHeader)
98+
return pr.withValidators(ctx, height, finalizedHeader)
9999
}
100100

101101
// SetupHeadersForUpdate creates a new header based on a given header
102102
func (pr *Prover) SetupHeadersForUpdate(ctx context.Context, counterparty core.FinalityAwareChain, latestFinalizedHeader core.Header) ([]core.Header, error) {
103103
header := latestFinalizedHeader.(*Header)
104104
// LCP doesn't need height / EVM needs latest height
105-
latestHeightOnDstChain, err := counterparty.LatestHeight(context.TODO())
105+
latestHeightOnDstChain, err := counterparty.LatestHeight(ctx)
106106
if err != nil {
107107
return nil, err
108108
}
109-
csRes, err := counterparty.QueryClientState(core.NewQueryContext(context.TODO(), latestHeightOnDstChain))
109+
csRes, err := counterparty.QueryClientState(core.NewQueryContext(ctx, latestHeightOnDstChain))
110110
if err != nil {
111111
return nil, fmt.Errorf("no client state found : SetupHeadersForUpdate: height = %d, %+v", latestHeightOnDstChain.GetRevisionHeight(), err)
112112
}
113113
var cs exported.ClientState
114114
if err = pr.chain.Codec().UnpackAny(csRes.ClientState, &cs); err != nil {
115115
return nil, err
116116
}
117-
return pr.SetupHeadersForUpdateByLatestHeight(context.TODO(), cs.GetLatestHeight(), header)
117+
return pr.SetupHeadersForUpdateByLatestHeight(ctx, cs.GetLatestHeight(), header)
118118
}
119119

120120
func (pr *Prover) SetupHeadersForUpdateByLatestHeight(ctx context.Context, clientStateLatestHeight exported.Height, latestFinalizedHeader *Header) ([]core.Header, error) {
@@ -127,14 +127,14 @@ func (pr *Prover) SetupHeadersForUpdateByLatestHeight(ctx context.Context, clien
127127
if ethHeaders == nil {
128128
return nil, nil
129129
}
130-
return pr.withValidators(height, ethHeaders)
130+
return pr.withValidators(ctx, height, ethHeaders)
131131
}
132-
latestHeight, err := pr.chain.LatestHeight(context.TODO())
132+
latestHeight, err := pr.chain.LatestHeight(ctx)
133133
if err != nil {
134134
return nil, err
135135
}
136136
return setupHeadersForUpdate(
137-
context.TODO(),
137+
ctx,
138138
queryVerifiableNeighboringEpochHeader,
139139
pr.chain.Header,
140140
clientStateLatestHeight,
@@ -144,7 +144,7 @@ func (pr *Prover) SetupHeadersForUpdateByLatestHeight(ctx context.Context, clien
144144

145145
func (pr *Prover) ProveState(ctx core.QueryContext, path string, value []byte) ([]byte, clienttypes.Height, error) {
146146
proofHeight := toHeight(ctx.Height())
147-
accountProof, commitmentProof, err := pr.getStateCommitmentProof(context.TODO(), []byte(path), proofHeight)
147+
accountProof, commitmentProof, err := pr.getStateCommitmentProof(ctx.Context(), []byte(path), proofHeight)
148148
if err != nil {
149149
return nil, proofHeight, err
150150
}
@@ -161,11 +161,11 @@ func (pr *Prover) ProveState(ctx core.QueryContext, path string, value []byte) (
161161
}
162162

163163
func (pr *Prover) CheckRefreshRequired(ctx context.Context, counterparty core.ChainInfoICS02Querier) (bool, error) {
164-
cpQueryHeight, err := counterparty.LatestHeight(context.TODO())
164+
cpQueryHeight, err := counterparty.LatestHeight(ctx)
165165
if err != nil {
166166
return false, fmt.Errorf("failed to get the latest height of the counterparty chain: %+v", err)
167167
}
168-
cpQueryCtx := core.NewQueryContext(context.TODO(), cpQueryHeight)
168+
cpQueryCtx := core.NewQueryContext(ctx, cpQueryHeight)
169169

170170
resCs, err := counterparty.QueryClientState(cpQueryCtx)
171171
if err != nil {
@@ -188,12 +188,12 @@ func (pr *Prover) CheckRefreshRequired(ctx context.Context, counterparty core.Ch
188188
}
189189
lcLastTimestamp := time.Unix(0, int64(cons.GetTimestamp()))
190190

191-
selfQueryHeight, err := pr.chain.LatestHeight(context.TODO())
191+
selfQueryHeight, err := pr.chain.LatestHeight(ctx)
192192
if err != nil {
193193
return false, fmt.Errorf("failed to get the latest height of the self chain: %+v", err)
194194
}
195195

196-
selfTimestamp, err := pr.chain.Timestamp(context.TODO(), selfQueryHeight)
196+
selfTimestamp, err := pr.chain.Timestamp(ctx, selfQueryHeight)
197197
if err != nil {
198198
return false, fmt.Errorf("failed to get timestamp of the self chain: %+v", err)
199199
}
@@ -227,19 +227,19 @@ func (pr *Prover) CheckRefreshRequired(ctx context.Context, counterparty core.Ch
227227

228228
}
229229

230-
func (pr *Prover) withValidators(height uint64, ethHeaders []*ETHHeader) (core.Header, error) {
231-
return withValidators(context.TODO(), pr.chain.Header, height, ethHeaders)
230+
func (pr *Prover) withValidators(ctx context.Context, height uint64, ethHeaders []*ETHHeader) (core.Header, error) {
231+
return withValidators(ctx, pr.chain.Header, height, ethHeaders)
232232
}
233233

234-
func (pr *Prover) buildInitialState(dstHeader core.Header) (exported.ClientState, exported.ConsensusState, error) {
234+
func (pr *Prover) buildInitialState(ctx context.Context, dstHeader core.Header) (exported.ClientState, exported.ConsensusState, error) {
235235
currentEpoch := getCurrentEpoch(dstHeader.GetHeight().GetRevisionHeight())
236-
currentValidators, currentTurnLength, err := queryValidatorSetAndTurnLength(context.TODO(), pr.chain.Header, currentEpoch)
236+
currentValidators, currentTurnLength, err := queryValidatorSetAndTurnLength(ctx, pr.chain.Header, currentEpoch)
237237
if err != nil {
238238
return nil, nil, err
239239
}
240240

241241
previousEpoch := getPreviousEpoch(dstHeader.GetHeight().GetRevisionHeight())
242-
previousValidators, previousTurnLength, err := queryValidatorSetAndTurnLength(context.TODO(), pr.chain.Header, previousEpoch)
242+
previousValidators, previousTurnLength, err := queryValidatorSetAndTurnLength(ctx, pr.chain.Header, previousEpoch)
243243
if err != nil {
244244
return nil, nil, err
245245
}
@@ -248,7 +248,7 @@ func (pr *Prover) buildInitialState(dstHeader core.Header) (exported.ClientState
248248
return nil, nil, err
249249
}
250250

251-
chainID, err := pr.chain.CanonicalChainID(context.TODO())
251+
chainID, err := pr.chain.CanonicalChainID(ctx)
252252
if err != nil {
253253
return nil, nil, err
254254
}

module/prover_test.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ func (ts *ProverTestSuite) TestQueryClientStateWithProof() {
156156
bzCs, err := ts.prover.chain.Codec().Marshal(cs)
157157
ts.Require().NoError(err)
158158

159-
ctx := core.NewQueryContext(context.TODO(), clienttypes.NewHeight(0, 21400))
159+
ctx := core.NewQueryContext(context.Background(), clienttypes.NewHeight(0, 21400))
160160
proof, proofHeight, err := ts.prover.ProveState(ctx, host.FullClientStatePath(ts.prover.chain.Path().ClientID), bzCs)
161161
ts.Require().NoError(err)
162162

@@ -248,41 +248,42 @@ func (ts *ProverTestSuite) TestCheckRefreshRequired() {
248248
ts.chain.trustedHeight = 0
249249
}()
250250

251+
ctx := context.Background()
251252
now := time.Now()
252253
chainHeight := clienttypes.NewHeight(0, 0)
253254
csHeight := clienttypes.NewHeight(0, 0)
254255
ts.chain.chainTimestamp[chainHeight] = uint64(now.Unix())
255256

256257
// should refresh by trusting_period
257258
ts.chain.consensusStateTimestamp[csHeight] = uint64(now.Add(-51 * time.Second).UnixNano())
258-
required, err := ts.prover.CheckRefreshRequired(context.TODO(), dst)
259+
required, err := ts.prover.CheckRefreshRequired(ctx, dst)
259260
ts.Require().NoError(err)
260261
ts.Require().True(required)
261262

262263
// needless by trusting_period
263264
ts.chain.consensusStateTimestamp[csHeight] = uint64(now.Add(-50 * time.Second).UnixNano())
264-
required, err = ts.prover.CheckRefreshRequired(context.TODO(), dst)
265+
required, err = ts.prover.CheckRefreshRequired(ctx, dst)
265266
ts.Require().NoError(err)
266267
ts.Require().False(required)
267268

268269
// should refresh by block difference
269270
ts.chain.latestHeight = 2
270271
ts.prover.config.RefreshBlockDifferenceThreshold = 1
271-
required, err = ts.prover.CheckRefreshRequired(context.TODO(), dst)
272+
required, err = ts.prover.CheckRefreshRequired(ctx, dst)
272273
ts.Require().NoError(err)
273274
ts.Require().True(required)
274275

275276
// needless by block difference
276277
ts.prover.config.RefreshBlockDifferenceThreshold = 2
277-
required, err = ts.prover.CheckRefreshRequired(context.TODO(), dst)
278+
required, err = ts.prover.CheckRefreshRequired(ctx, dst)
278279
ts.Require().NoError(err)
279280
ts.Require().False(required)
280281

281282
// needless by invalid block difference
282283
ts.chain.latestHeight = 1
283284
ts.chain.trustedHeight = 3
284285
ts.prover.config.RefreshBlockDifferenceThreshold = 1
285-
required, err = ts.prover.CheckRefreshRequired(context.TODO(), dst)
286+
required, err = ts.prover.CheckRefreshRequired(ctx, dst)
286287
ts.Require().NoError(err)
287288
ts.Require().False(required)
288289
}
@@ -299,7 +300,7 @@ func (ts *ProverTestSuite) TestProveHostConsensusState() {
299300
(*exported.ConsensusState)(nil),
300301
&ConsensusState{},
301302
)
302-
ctx := core.NewQueryContext(context.TODO(), clienttypes.NewHeight(0, 0))
303+
ctx := core.NewQueryContext(context.Background(), clienttypes.NewHeight(0, 0))
303304
prove, err := ts.prover.ProveHostConsensusState(ctx, clienttypes.NewHeight(0, 0), &cs)
304305
ts.Require().NoError(err)
305306
ts.Require().Len(prove, 150)

module/validator_set_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@ package module
33
import (
44
"context"
55
"errors"
6-
"github.com/ethereum/go-ethereum/core/types"
7-
"github.com/stretchr/testify/suite"
86
"math/big"
97
"testing"
8+
9+
"github.com/ethereum/go-ethereum/core/types"
10+
"github.com/stretchr/testify/suite"
1011
)
1112

1213
type ValidatorSetTestSuite struct {
@@ -54,7 +55,7 @@ func (ts *ValidatorSetTestSuite) TestSuccessQueryValidatorSet() {
5455
fn := func(ctx context.Context, height uint64) (*types.Header, error) {
5556
return epochHeader(), nil
5657
}
57-
validators, turnLength, err := QueryValidatorSetAndTurnLength(fn, 400)
58+
validators, turnLength, err := QueryValidatorSetAndTurnLength(context.Background(), fn, 400)
5859
ts.Require().NoError(err)
5960
ts.Require().Len(validators, 4)
6061
ts.Require().Equal(turnLength, uint8(1))
@@ -64,7 +65,7 @@ func (ts *ValidatorSetTestSuite) TestErrorQueryValidatorSet() {
6465
fn := func(ctx context.Context, height uint64) (*types.Header, error) {
6566
return nil, errors.New("error")
6667
}
67-
_, _, err := QueryValidatorSetAndTurnLength(fn, 200)
68+
_, _, err := QueryValidatorSetAndTurnLength(context.Background(), fn, 200)
6869
ts.Require().Equal(err.Error(), "error")
6970
}
7071

tests/prover_network_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func (ts *ProverNetworkTestSuite) SetupTest() {
4747
}
4848

4949
func (ts *ProverNetworkTestSuite) TestQueryLatestFinalizedHeader() {
50-
header, err := ts.prover.GetLatestFinalizedHeader(context.TODO())
50+
header, err := ts.prover.GetLatestFinalizedHeader(context.Background())
5151
ts.Require().NoError(err)
5252
ts.Require().NoError(header.ValidateBasic())
5353
ts.Require().Len(header.(*module.Header).Headers, 3)
@@ -65,9 +65,9 @@ func (ts *ProverNetworkTestSuite) TestSetupHeadersForUpdate() {
6565
dst := dstChain{
6666
Chain: ts.makeChain("http://localhost:8645", "ibc0"),
6767
}
68-
header, err := ts.prover.GetLatestFinalizedHeader(context.TODO())
68+
header, err := ts.prover.GetLatestFinalizedHeader(context.Background())
6969
ts.Require().NoError(err)
70-
setupDone, err := ts.prover.SetupHeadersForUpdate(context.TODO(), dst, header)
70+
setupDone, err := ts.prover.SetupHeadersForUpdate(context.Background(), dst, header)
7171
ts.Require().NoError(err)
7272
ts.Require().True(len(setupDone) > 0)
7373
for _, h := range setupDone {
@@ -76,7 +76,7 @@ func (ts *ProverNetworkTestSuite) TestSetupHeadersForUpdate() {
7676
}
7777

7878
func (ts *ProverNetworkTestSuite) TestSuccessCreateInitialLightClientState() {
79-
s1, s2, err := ts.prover.CreateInitialLightClientState(context.TODO(), nil)
79+
s1, s2, err := ts.prover.CreateInitialLightClientState(context.Background(), nil)
8080
ts.Require().NoError(err)
8181

8282
cs := s1.(*module.ClientState)

tool/testdata/internal/membership/verify_membership.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ func (m *verifyMembershipModule) proveState(chainID uint64, path string, value [
9090
}
9191

9292
proof, proofHeight, err := prover.ProveState(ctx, path, value)
93-
storageRoot, err := prover.GetStorageRoot(context.TODO(), header)
93+
storageRoot, err := prover.GetStorageRoot(ctx.Context(), header)
9494
if err != nil {
9595
return common.Hash{}, nil, types.Height{}, err
9696
}

0 commit comments

Comments
 (0)