Skip to content

Commit 82c04d7

Browse files
committed
fix(storage): added context state interface
1 parent b7e0426 commit 82c04d7

1 file changed

Lines changed: 31 additions & 30 deletions

File tree

pkg/storage/state.go

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package storage
22

33
import (
4+
"context"
45
"encoding/json"
56
"errors"
67
"github.com/gotd/td/telegram/updates"
@@ -16,7 +17,7 @@ func NewState(kv kv.KV) updates.StateStorage {
1617
return &State{kv: kv}
1718
}
1819

19-
func (s *State) Get(key string, v interface{}) error {
20+
func (s *State) Get(_ context.Context, key string, v interface{}) error {
2021
data, err := s.kv.Get(key)
2122
if err != nil {
2223
return err
@@ -25,7 +26,7 @@ func (s *State) Get(key string, v interface{}) error {
2526
return json.Unmarshal(data, v)
2627
}
2728

28-
func (s *State) Set(key string, v interface{}) error {
29+
func (s *State) Set(_ context.Context, key string, v interface{}) error {
2930
data, err := json.Marshal(v)
3031
if err != nil {
3132
return err
@@ -34,10 +35,10 @@ func (s *State) Set(key string, v interface{}) error {
3435
return s.kv.Set(key, data)
3536
}
3637

37-
func (s *State) GetState(userID int64) (updates.State, bool, error) {
38+
func (s *State) GetState(ctx context.Context, userID int64) (updates.State, bool, error) {
3839
state := updates.State{}
3940

40-
if err := s.Get(key.State(userID), &state); err != nil {
41+
if err := s.Get(ctx, key.State(userID), &state); err != nil {
4142
if errors.Is(err, kv.ErrNotFound) {
4243
return state, false, nil
4344
}
@@ -47,69 +48,69 @@ func (s *State) GetState(userID int64) (updates.State, bool, error) {
4748
return state, true, nil
4849
}
4950

50-
func (s *State) SetState(userID int64, state updates.State) error {
51-
if err := s.Set(key.State(userID), state); err != nil {
51+
func (s *State) SetState(ctx context.Context, userID int64, state updates.State) error {
52+
if err := s.Set(ctx, key.State(userID), state); err != nil {
5253
return err
5354
}
5455

55-
return s.Set(key.StateChannel(userID), struct{}{})
56+
return s.Set(ctx, key.StateChannel(userID), struct{}{})
5657
}
5758

58-
func (s *State) SetPts(userID int64, pts int) error {
59+
func (s *State) SetPts(ctx context.Context, userID int64, pts int) error {
5960
state, k := updates.State{}, key.State(userID)
6061

61-
if err := s.Get(k, &state); err != nil {
62+
if err := s.Get(ctx, k, &state); err != nil {
6263
return err
6364
}
6465
state.Pts = pts
65-
return s.Set(k, state)
66+
return s.Set(ctx, k, state)
6667
}
6768

68-
func (s *State) SetQts(userID int64, qts int) error {
69+
func (s *State) SetQts(ctx context.Context, userID int64, qts int) error {
6970
state, k := updates.State{}, key.State(userID)
7071

71-
if err := s.Get(k, &state); err != nil {
72+
if err := s.Get(ctx, k, &state); err != nil {
7273
return err
7374
}
7475
state.Qts = qts
75-
return s.Set(k, state)
76+
return s.Set(ctx, k, state)
7677
}
7778

78-
func (s *State) SetDate(userID int64, date int) error {
79+
func (s *State) SetDate(ctx context.Context, userID int64, date int) error {
7980
state, k := updates.State{}, key.State(userID)
8081

81-
if err := s.Get(k, &state); err != nil {
82+
if err := s.Get(ctx, k, &state); err != nil {
8283
return err
8384
}
8485
state.Date = date
85-
return s.Set(k, state)
86+
return s.Set(ctx, k, state)
8687
}
8788

88-
func (s *State) SetSeq(userID int64, seq int) error {
89+
func (s *State) SetSeq(ctx context.Context, userID int64, seq int) error {
8990
state, k := updates.State{}, key.State(userID)
9091

91-
if err := s.Get(k, &state); err != nil {
92+
if err := s.Get(ctx, k, &state); err != nil {
9293
return err
9394
}
9495
state.Seq = seq
95-
return s.Set(k, state)
96+
return s.Set(ctx, k, state)
9697
}
9798

98-
func (s *State) SetDateSeq(userID int64, date, seq int) error {
99+
func (s *State) SetDateSeq(ctx context.Context, userID int64, date, seq int) error {
99100
state, k := updates.State{}, key.State(userID)
100101

101-
if err := s.Get(k, &state); err != nil {
102+
if err := s.Get(ctx, k, &state); err != nil {
102103
return err
103104
}
104105
state.Date = date
105106
state.Seq = seq
106-
return s.Set(k, state)
107+
return s.Set(ctx, k, state)
107108
}
108109

109-
func (s *State) GetChannelPts(userID, channelID int64) (int, bool, error) {
110+
func (s *State) GetChannelPts(ctx context.Context, userID, channelID int64) (int, bool, error) {
110111
c := make(map[int64]int)
111112

112-
if err := s.Get(key.StateChannel(userID), &c); err != nil {
113+
if err := s.Get(ctx, key.StateChannel(userID), &c); err != nil {
113114
if errors.Is(err, kv.ErrNotFound) {
114115
return 0, false, nil
115116
}
@@ -124,25 +125,25 @@ func (s *State) GetChannelPts(userID, channelID int64) (int, bool, error) {
124125
return pts, true, nil
125126
}
126127

127-
func (s *State) SetChannelPts(userID, channelID int64, pts int) error {
128+
func (s *State) SetChannelPts(ctx context.Context, userID, channelID int64, pts int) error {
128129
c, k := make(map[int64]int), key.StateChannel(userID)
129130

130-
if err := s.Get(k, &c); err != nil {
131+
if err := s.Get(ctx, k, &c); err != nil {
131132
return err
132133
}
133134
c[channelID] = pts
134-
return s.Set(k, c)
135+
return s.Set(ctx, k, c)
135136
}
136137

137-
func (s *State) ForEachChannels(userID int64, f func(channelID int64, pts int) error) error {
138+
func (s *State) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error {
138139
c := make(map[int64]int)
139140

140-
if err := s.Get(key.StateChannel(userID), &c); err != nil {
141+
if err := s.Get(ctx, key.StateChannel(userID), &c); err != nil {
141142
return err
142143
}
143144

144145
for channelID, pts := range c {
145-
if err := f(channelID, pts); err != nil {
146+
if err := f(ctx, channelID, pts); err != nil {
146147
return err
147148
}
148149
}

0 commit comments

Comments
 (0)