@@ -4,6 +4,7 @@ import zio.{UIO, Chunk}
44import zio .prelude .State
55import zio .raft .{Command , HMap , StateMachine , Index }
66import zio .raft .protocol .{SessionId , RequestId }
7+ import zio .raft .protocol .RequestId .RequestIdSyntax
78import zio .stream .Stream
89import java .time .Instant
910
@@ -267,14 +268,15 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
267268 *
268269 * Implementation of Raft dissertation Chapter 6.3 session management protocol:
269270 * 1. Check cache for (sessionId, requestId) 2. If cache hit, return cached response 3. If cache miss, check if
270- * requestId < highestLowestRequestIdSeen → response was evicted, return error 4. If cache miss + requestId >=
271+ * requestId <= highestLowestRequestIdSeen → response was evicted, return error 4. If cache miss + requestId >
271272 * highestLowestRequestIdSeen, execute command and update highestLowestRequestIdSeen
272273 *
273274 * This correctly handles out-of-order requests. The lowestRequestId from the client tells us which responses have
274- * been acknowledged and can be evicted. We only update highestLowestRequestIdSeen for requests we actually process.
275+ * been acknowledged (inclusive) and can be evicted. We only update highestLowestRequestIdSeen for requests we
276+ * actually process.
275277 */
276278 private def handleClientRequest (cmd : SessionCommand .ClientRequest [UC , SR ])
277- : State [HMap [Schema ], Either [RequestError , (cmd.command.Response , List [ServerRequestWithContext [SR ]])]] =
279+ : State [HMap [Schema ], Either [RequestError , (cmd.command.Response , List [ServerRequestEnvelope [SR ]])]] =
278280 for
279281 highestLowestSeen <- getHighestLowestRequestIdSeen(cmd.sessionId)
280282 cachedOpt <- getCachedResponse((cmd.sessionId, cmd.requestId))
@@ -285,9 +287,9 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
285287
286288 case None =>
287289 // Cache miss - check if response was evicted
288- // If requestId < highestLowestRequestIdSeen, client has acknowledged receiving this response
289- if cmd.requestId.isLowerThan (highestLowestSeen) then
290- // Client said "I have responses for all requestIds < highestLowest", so this was evicted
290+ // If requestId <= highestLowestRequestIdSeen, client has acknowledged receiving this response
291+ if cmd.requestId.isLowerOrEqual (highestLowestSeen) then
292+ // Client said "I have responses for all requestIds <= highestLowest", so this was evicted
291293 State .succeed(Left (RequestError .ResponseEvicted (cmd.sessionId, cmd.requestId)))
292294 else
293295 // requestId >= highestLowestRequestIdSeen
@@ -327,8 +329,8 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
327329
328330 /** Update the highest lowestRequestId seen from the client (only if lowestRequestId > current highest).
329331 *
330- * The lowestRequestId from the client indicates "I have received all responses for requestIds < this value". We
331- * track the highest such value to detect evicted responses.
332+ * The lowestRequestId from the client indicates "I have received all responses for requestIds <= this value
333+ * (inclusive)". We track the highest such value to detect evicted responses.
332334 */
333335 private def updateHighestLowestRequestIdSeen (
334336 sessionId : SessionId ,
@@ -354,7 +356,7 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
354356 val keysToRemove = state.range[" serverRequests" ](
355357 (cmd.sessionId, RequestId .zero),
356358 (cmd.sessionId, upperBoundExclusive)
357- ).map(_._1 )
359+ ).map((key, _) => key )
358360
359361 // Remove all acknowledged requests in one efficient operation
360362 state.removedAll[" serverRequests" ](keysToRemove)
@@ -363,7 +365,7 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
363365 /** Handle CreateSession command.
364366 */
365367 private def handleCreateSession (cmd : SessionCommand .CreateSession [SR ])
366- : State [HMap [Schema ], List [ServerRequestWithContext [SR ]]] =
368+ : State [HMap [Schema ], List [ServerRequestEnvelope [SR ]]] =
367369 for
368370 _ <- createSessionMetadata(cmd.sessionId, cmd.capabilities, cmd.createdAt)
369371 (serverRequestsLog, _) <- handleSessionCreated(cmd.sessionId, cmd.capabilities, cmd.createdAt).withLog
@@ -382,7 +384,7 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
382384 /** Handle SessionExpired command (internal name to avoid conflict with abstract method).
383385 */
384386 private def handleSessionExpired_internal (cmd : SessionCommand .SessionExpired [SR ])
385- : State [HMap [Schema ], List [ServerRequestWithContext [SR ]]] =
387+ : State [HMap [Schema ], List [ServerRequestEnvelope [SR ]]] =
386388 for
387389 capabilities <- getSessionCapabilities(cmd.sessionId)
388390 (serverRequestsLog, _) <- handleSessionExpired(cmd.sessionId, capabilities, cmd.createdAt).withLog
@@ -408,9 +410,7 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
408410 // Use foldRight to collect updated requests and build new state
409411 // This avoids vars and the need to reverse at the end
410412 state.iterator[" serverRequests" ].foldRight((List .empty[PendingServerRequest [SR ]], state)) {
411- case (((sessionId, requestId), pendingAny), (accumulated, currentState)) =>
412- val pending = pendingAny.asInstanceOf [PendingServerRequest [SR ]]
413-
413+ case (((sessionId, requestId), pending), (accumulated, currentState)) =>
414414 // Check if this request needs retry (lastSentAt before threshold)
415415 if pending.lastSentAt.isBefore(cmd.lastSentBefore) then
416416 // Update lastSentAt and add to accumulated list
@@ -439,7 +439,7 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
439439 private def addServerRequests (
440440 createdAt : Instant ,
441441 serverRequests : Chunk [ServerRequestForSession [SR ]]
442- ): State [HMap [Schema ], List [ServerRequestWithContext [SR ]]] =
442+ ): State [HMap [Schema ], List [ServerRequestEnvelope [SR ]]] =
443443 if serverRequests.isEmpty then
444444 State .succeed(Nil )
445445 else
@@ -448,52 +448,45 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
448448 val groupedBySession = serverRequests.groupBy(_.sessionId)
449449
450450 // Process each session's requests
451- val (finalState, allRequestsWithContext ) =
452- groupedBySession.foldLeft((state, List .empty[ServerRequestWithContext [SR ]])) {
451+ val (finalState, allEnvelopes ) =
452+ groupedBySession.foldLeft((state, List .empty[ServerRequestEnvelope [SR ]])) {
453453 case ((currentState, accumulated), (sessionId, sessionRequests)) =>
454454 // Get last assigned ID for this session
455455 val lastId = currentState.get[" lastServerRequestId" ](sessionId)
456456 .getOrElse(RequestId .zero)
457457
458- // Assign IDs starting from lastId + 1
459- val requestsWithIds = sessionRequests.zipWithIndex.map { case (reqWithSession, index) =>
460- val newId = RequestId (RequestId .unwrap(lastId) + index + 1 )
461- (
462- sessionId,
463- newId,
464- PendingServerRequest (
465- payload = reqWithSession.payload,
466- lastSentAt = createdAt
467- )
468- )
469- }
470-
471- // Update lastServerRequestId for this session
472- val newLastId = requestsWithIds.last._2
458+ // Calculate new last ID: lastId + number of requests
459+ val newLastId = lastId.increaseBy(sessionRequests.length)
473460 val stateWithNewId = currentState.updated[" lastServerRequestId" ](
474461 sessionId,
475462 newLastId
476463 )
477464
478- // Add all requests with composite keys
479- val stateWithRequests = requestsWithIds.foldLeft(stateWithNewId) {
480- case (s, (sid, requestId, pending)) =>
481- s.updated[" serverRequests" ]((sid, requestId), pending)
482- }
483-
484- // Create requests with context
485- val requestsWithContext = requestsWithIds.map { case (sid, requestId, pending) =>
486- ServerRequestWithContext (
487- sessionId = sid,
488- requestId = requestId,
489- payload = pending.payload
490- )
491- }.toList
492-
493- (stateWithRequests, accumulated ++ requestsWithContext)
465+ // Add all requests with composite keys and create envelopes in one pass
466+ val (stateWithRequests, envelopes) =
467+ sessionRequests.zipWithIndex.foldLeft((stateWithNewId, List .empty[ServerRequestEnvelope [SR ]])) {
468+ case ((s, envs), (reqWithSession, index)) =>
469+ val requestId = lastId.increaseBy(index + 1 )
470+ val pending = PendingServerRequest (
471+ payload = reqWithSession.payload,
472+ lastSentAt = createdAt
473+ )
474+ val updatedState = s.updated[" serverRequests" ]((sessionId, requestId), pending)
475+ val envelope = ServerRequestEnvelope (
476+ sessionId = sessionId,
477+ requestId = requestId,
478+ payload = reqWithSession.payload
479+ )
480+ (updatedState, envelope :: envs)
481+ }
482+
483+ // Reverse to maintain order (since we prepended)
484+ val orderedEnvelopes = envelopes.reverse
485+
486+ (stateWithRequests, accumulated ++ orderedEnvelopes)
494487 }
495488
496- (allRequestsWithContext , finalState)
489+ (allEnvelopes , finalState)
497490 }
498491
499492 /** Remove all session data when session expires.
@@ -504,13 +497,13 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
504497 val cacheKeysToRemove = state.range[" cache" ](
505498 (sessionId, RequestId .zero),
506499 (sessionId, RequestId .max)
507- ).map(_._1 )
500+ ).map((key, _) => key )
508501
509502 // Remove all server requests for this session using range query
510503 val serverRequestKeysToRemove = state.range[" serverRequests" ](
511504 (sessionId, RequestId .zero),
512505 (sessionId, RequestId .max)
513- ).map(_._1 )
506+ ).map((key, _) => key )
514507
515508 // Remove all session data in batch
516509 state
@@ -527,28 +520,30 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
527520
528521 /** Clean up cached responses based on lowestRequestId (Lowest Sequence Number Protocol).
529522 *
530- * Removes all cached responses for the session with requestId < lowestRequestId. This allows the client to control
523+ * Removes all cached responses for the session with requestId <= lowestRequestId. This allows the client to control
531524 * cache cleanup by telling the server which responses it no longer needs (Chapter 6.3 of Raft dissertation).
532525 *
526+ * The client sends lowestRequestId to indicate "I have received all responses up to and including this ID".
527+ *
533528 * Uses range queries to efficiently find and remove old cache entries.
534529 */
535530 private def cleanupCache (
536531 sessionId : SessionId ,
537532 lowestRequestId : RequestId
538533 ): State [HMap [Schema ], Unit ] =
539534 State .update { state =>
540- // Use range to find all cache entries for this session with requestId < lowestRequestId
541- // Range is [from, until), so we use (sessionId, RequestId.zero) to (sessionId, lowestRequestId)
535+ // Use range to find all cache entries for this session with requestId <= lowestRequestId
536+ // Range is [from, until), so to include lowestRequestId, we use lowestRequestId.next as upper bound
542537 val keysToRemove = state.range[" cache" ](
543538 (sessionId, RequestId .zero),
544- (sessionId, lowestRequestId)
545- ).map(_._1 )
539+ (sessionId, lowestRequestId.next )
540+ ).map((key, _) => key )
546541
547542 // Remove all old cache entries in one efficient operation
548543 state.removedAll[" cache" ](keysToRemove)
549544 }
550545
551- /** Dirty read helper (FR-027) - check if ANY session has pending requests needing retry.
546+ /** Dirty read helper - check if ANY session has pending requests needing retry.
552547 *
553548 * This method can be called directly (outside Raft consensus) to optimize the retry process. The retry process
554549 * performs a dirty read, applies policy locally, and only sends GetRequestsForRetry command if retries are needed.
0 commit comments