@@ -80,7 +80,6 @@ public class FanOutKinesisShardSubscription {
8080 // Queue is meant for eager retrieval of records from the Kinesis stream. We will always have 2
8181 // record batches available on next read.
8282 private final BlockingQueue <SubscribeToShardEvent > eventQueue = new LinkedBlockingQueue <>(2 );
83- private final AtomicBoolean subscriptionActive = new AtomicBoolean (false );
8483 private final AtomicReference <Throwable > subscriptionException = new AtomicReference <>();
8584
8685 // Store the current starting position for this subscription. Will be updated each time new
@@ -108,7 +107,8 @@ public void activateSubscription() {
108107 shardId ,
109108 startingPosition ,
110109 consumerArn );
111- if (subscriptionActive .get ()) {
110+ if (shardSubscriber != null
111+ && shardSubscriber .getSubscriptionState () == SubscriptionState .SUBSCRIBED ) {
112112 LOG .warn ("Skipping activation of subscription since it is already active." );
113113 return ;
114114 }
@@ -166,9 +166,9 @@ public void activateSubscription() {
166166 shardId ,
167167 startingPosition ,
168168 consumerArn );
169- subscriptionActive .set (true );
170169 // Request first batch of records.
171170 shardSubscriber .requestRecords ();
171+
172172 } else {
173173 String errorMessage =
174174 "Timeout when subscribing to shard "
@@ -236,16 +236,37 @@ public SubscribeToShardEvent nextEvent() {
236236 throw new KinesisStreamsSourceException (
237237 "Subscription encountered unrecoverable exception." , throwable );
238238 }
239+ final SubscriptionState state =
240+ Optional .ofNullable (shardSubscriber )
241+ .map (FanOutShardSubscriber ::getSubscriptionState )
242+ .orElse (SubscriptionState .NOT_STARTED );
239243
240- if (!subscriptionActive .get ()) {
241- LOG .debug (
242- "Subscription to shard {} for consumer {} is not yet active. Skipping." ,
243- shardId ,
244- consumerArn );
245- return null ;
244+ switch (state ) {
245+ case NOT_STARTED :
246+ LOG .debug (
247+ "Subscription to shard {} for consumer {} is not yet active. Skipping." ,
248+ shardId ,
249+ consumerArn );
250+ return null ;
251+ case COMPLETED :
252+ if (shardSubscriber .isShardEndReached ()) {
253+ LOG .info (
254+ "Subscription reached SHARD_END for shard {} for consumer {}." ,
255+ shardId ,
256+ consumerArn );
257+ return null ;
258+ }
259+ LOG .info (
260+ "Subscription expired to shard {} for consumer {}. Restarting." ,
261+ shardId ,
262+ consumerArn );
263+ activateSubscription ();
264+ return null ;
265+ case SUBSCRIBED :
266+ return eventQueue .poll ();
267+ default :
268+ throw new IllegalStateException ("Unknown subscription state: " + state );
246269 }
247-
248- return eventQueue .poll ();
249270 }
250271
251272 /**
@@ -254,26 +275,48 @@ public SubscribeToShardEvent nextEvent() {
254275 */
255276 private class FanOutShardSubscriber implements Subscriber <SubscribeToShardEventStream > {
256277 private final CountDownLatch subscriptionLatch ;
257-
258278 private Subscription subscription ;
259279
280+ private final AtomicReference <SubscriptionState > subscriptionState =
281+ new AtomicReference <>(SubscriptionState .NOT_STARTED );
282+ private final AtomicBoolean isShardEnd = new AtomicBoolean (false );
283+
260284 private FanOutShardSubscriber (CountDownLatch subscriptionLatch ) {
261285 this .subscriptionLatch = subscriptionLatch ;
262286 }
263287
288+ /**
289+ * Fetch the state that the subscriber is in.
290+ *
291+ * @return Subscription state for the subscriber.
292+ */
293+ public SubscriptionState getSubscriptionState () {
294+ return subscriptionState .get ();
295+ }
296+
297+ /**
298+ * Boolean whether this subscriber has reached the end of a shard.
299+ *
300+ * @return True if ShardEnd. false otherwise.
301+ */
302+ public boolean isShardEndReached () {
303+ return isShardEnd .get ();
304+ }
305+
264306 public void requestRecords () {
265307 subscription .request (1 );
266308 }
267309
268310 public void cancel () {
269- if (! subscriptionActive . get ()) {
311+ if (this . subscriptionState . get () == SubscriptionState . COMPLETED ) {
270312 LOG .warn ("Trying to cancel inactive subscription. Ignoring." );
271313 return ;
272314 }
273- subscriptionActive . set ( false );
315+
274316 if (subscription != null ) {
275317 subscription .cancel ();
276318 }
319+ this .subscriptionState .set (SubscriptionState .COMPLETED );
277320 }
278321
279322 @ Override
@@ -284,6 +327,7 @@ public void onSubscribe(Subscription subscription) {
284327 startingPosition ,
285328 consumerArn );
286329 this .subscription = subscription ;
330+ this .subscriptionState .set (SubscriptionState .SUBSCRIBED );
287331 subscriptionLatch .countDown ();
288332 }
289333
@@ -300,6 +344,11 @@ public void visit(SubscribeToShardEvent event) {
300344 event );
301345 eventQueue .put (event );
302346
347+ if (event .continuationSequenceNumber () == null ) {
348+ isShardEnd .set (true );
349+ return ;
350+ }
351+
303352 // Update the starting position in case we have to recreate the
304353 // subscription
305354 startingPosition =
@@ -330,8 +379,14 @@ public void onError(Throwable throwable) {
330379 @ Override
331380 public void onComplete () {
332381 LOG .info ("Subscription complete - {} ({})" , shardId , consumerArn );
333- cancel ();
334- activateSubscription ();
382+ this .subscriptionState .set (SubscriptionState .COMPLETED );
335383 }
336384 }
385+
386+ /** States that the {@code FanOutShardSubscriber} may be in. */
387+ private enum SubscriptionState {
388+ NOT_STARTED ,
389+ SUBSCRIBED ,
390+ COMPLETED
391+ }
337392}
0 commit comments