@@ -12,7 +12,10 @@ import (
1212
1313const stateContentsAttributeKey string = "state"
1414
15- var errNoFreeShardGroups = errors .New ("No free shard groups" )
15+ var (
16+ errNoFreeShardGroups = errors .New ("No free shard groups" )
17+ errShardRetention = errors .New ("Could not retain shard group" )
18+ )
1619
1720type stateHandler struct {
1821 logger logger.Logger
@@ -33,8 +36,16 @@ func newStateHandler(member *member) (*stateHandler, error) {
3336func (sh * stateHandler ) start () error {
3437
3538 // stops on stop()
36- go sh .refreshStatePeriodically ()
39+ go func () {
40+ if err := sh .refreshStatePeriodically (); err != nil {
41+ if errors .RootCause (err ) == errShardRetention {
3742
43+ // signal that the Handler needs to be restarted
44+ sh .logger .ErrorWith ("Aborting member" , "memberID" , sh .member .id )
45+ sh .member .handler .Abort (sh .member .session ) // nolint: errcheck
46+ }
47+ }
48+ }()
3849 return nil
3950}
4051
@@ -77,7 +88,7 @@ func (sh *stateHandler) getSessionState(state *State, memberID string) (*Session
7788 return nil , errors .Errorf ("Member state not found: %s" , memberID )
7889}
7990
80- func (sh * stateHandler ) refreshStatePeriodically () {
91+ func (sh * stateHandler ) refreshStatePeriodically () error {
8192 var err error
8293
8394 // guaranteed to only be REPLACED by a new instance - not edited. as such, once this is initialized
@@ -94,6 +105,13 @@ func (sh *stateHandler) refreshStatePeriodically() {
94105 } else {
95106 lastState , err = sh .refreshState ()
96107 if err != nil {
108+
109+ // in case of shard retention error we want to signal the member to restart
110+ if errors .RootCause (err ) == errShardRetention {
111+ sh .logger .WarnWith ("Failed getting state on shard retention (requested by member)" ,
112+ "err" , errors .GetErrorStackString (err , 10 ))
113+ return errors .Wrap (err , "Failed refreshing state by demand" )
114+ }
97115 sh .logger .WarnWith ("Failed getting state" , "err" , errors .GetErrorStackString (err , 10 ))
98116 }
99117
@@ -105,14 +123,21 @@ func (sh *stateHandler) refreshStatePeriodically() {
105123 case <- time .After (sh .member .streamConsumerGroup .config .Session .HeartbeatInterval ):
106124 lastState , err = sh .refreshState ()
107125 if err != nil {
126+
127+ // in case of shard retention error we want to signal the member to restart
128+ if errors .RootCause (err ) == errShardRetention {
129+ sh .logger .WarnWith ("Failed getting state on shard retention (periodic refresh)" ,
130+ "err" , errors .GetErrorStackString (err , 10 ))
131+ return errors .Wrap (err , "Failed refreshing state periodically" )
132+ }
108133 sh .logger .WarnWith ("Failed refreshing state" , "err" , errors .GetErrorStackString (err , 10 ))
109134 continue
110135 }
111136
112137 // if we're told to stop, exit the loop
113138 case <- sh .stopChan :
114139 sh .logger .Debug ("Stopping" )
115- return
140+ return nil
116141 }
117142 }
118143}
@@ -142,6 +167,14 @@ func (sh *stateHandler) refreshState() (*State, error) {
142167 }
143168
144169 return state , nil
170+
171+ }, func () error {
172+
173+ // set retainShards flag to true only after the new state has been saved in persistency
174+ // (meaning the shards have been assigned successfully)
175+ sh .member .retainShards = true
176+
177+ return nil
145178 })
146179}
147180
@@ -150,16 +183,42 @@ func (sh *stateHandler) createSessionState(state *State) error {
150183 state .SessionStates = []* SessionState {}
151184 }
152185
153- // assign shards
154- shards , err := sh .assignShards (sh .member .streamConsumerGroup .maxReplicas , sh .member .streamConsumerGroup .totalNumShards , state )
155- if err != nil {
156- return errors .Wrap (err , "Failed resolving shards for session" )
186+ var shards []int
187+ var err error
188+
189+ if sh .member .retainShards {
190+
191+ // try to retain the originally assigned shard group
192+ shards , err = sh .retainShards (sh .member .shardGroupToRetain , sh .member .id , state )
193+
194+ // shards were "stolen" - set retainShards flag to false and commit suicide
195+ if err != nil {
196+ sh .logger .ErrorWith ("Failed to retain shards" ,
197+ "memberID" , sh .member .id ,
198+ "shardsToRetain" , sh .member .shardGroupToRetain ,
199+ "state" , state ,
200+ "error" , err .Error ())
201+ sh .member .retainShards = false
202+ return err
203+ }
204+ } else {
205+
206+ // assign shards
207+ shards , err = sh .assignShards (sh .member .streamConsumerGroup .maxReplicas ,
208+ sh .member .streamConsumerGroup .totalNumShards ,
209+ state )
210+ if err != nil {
211+ return errors .Wrap (err , "Failed resolving shards for session" )
212+ }
157213 }
158214
159215 sh .logger .DebugWith ("Assigned shards" ,
160216 "shards" , shards ,
161217 "state" , state )
162218
219+ // save shards to retain on the member itself
220+ sh .member .shardGroupToRetain = shards
221+
163222 state .SessionStates = append (state .SessionStates , & SessionState {
164223 MemberID : sh .member .id ,
165224 LastHeartbeat : time .Now (),
@@ -209,6 +268,23 @@ func (sh *stateHandler) assignShards(maxReplicas int, numShards int, state *Stat
209268 return nil , errNoFreeShardGroups
210269}
211270
271+ func (sh * stateHandler ) retainShards (memberShardGroup []int , memberID string , state * State ) ([]int , error ) {
272+
273+ for _ , sessionState := range state .SessionStates {
274+ if common .IntSlicesEqual (memberShardGroup , sessionState .Shards ) {
275+ if sessionState .MemberID == memberID {
276+ return memberShardGroup , nil
277+ }
278+
279+ // original shard group was taken
280+ return nil , errShardRetention
281+ }
282+ }
283+
284+ // shard group to retain is not taken by any member - original member can retain it
285+ return memberShardGroup , nil
286+ }
287+
212288func (sh * stateHandler ) getReplicaShardGroups (maxReplicas int , numShards int ) ([][]int , error ) {
213289 var replicaShardGroups [][]int
214290 shards := common .MakeRange (0 , numShards )
0 commit comments