Skip to content

Commit 2ff19d1

Browse files
committed
address pr comments
1 parent 4a039c4 commit 2ff19d1

File tree

5 files changed

+121
-81
lines changed

5 files changed

+121
-81
lines changed

client-server-protocol/src/main/scala/zio/raft/protocol/package.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ package object protocol {
8686
def isLowerThan(other: RequestId): Boolean = value < other.value
8787

8888
def isGreaterThan(other: RequestId): Boolean = value > other.value
89+
90+
def increaseBy(x: Int) = RequestId(value + x)
8991
}
9092
}
9193

session-state-machine/src/main/scala/zio/raft/sessionstatemachine/SessionCommand.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ object SessionCommand:
2121
/** A client request containing a user command to execute.
2222
*
2323
* This is the primary command type. The SessionStateMachine base class:
24-
* 1. Updates highestLowestRequestIdSeen if lowestRequestId is higher 2. Checks if requestId <
25-
* highestLowestRequestIdSeen AND not in cache → return Left(SessionError.ResponseEvicted) 3. Checks if
24+
* 1. Updates highestLowestRequestIdSeen if lowestRequestId is higher 2. Checks if requestId <=
25+
* highestLowestRequestIdSeen AND not in cache → return Left(RequestError.ResponseEvicted) 3. Checks if
2626
* (sessionId, requestId) is already in the cache (idempotency) 4. If cached, returns the cached response 5. If
2727
* not cached, narrows state to UserSchema and calls user's applyCommand method 6. Caches the response and
2828
* returns it
@@ -33,13 +33,13 @@ object SessionCommand:
3333
* The request ID (for idempotency checking)
3434
* @param lowestRequestId
3535
* The lowest request ID for which client hasn't received response (for cache cleanup). Client is saying "I have
36-
* received all responses for requestIds < this value"
36+
* received all responses for requestIds <= this value (inclusive)"
3737
* @param command
3838
* The user's command to execute
3939
*
4040
* @note
41-
* Response type is Either[RequestError, (command.Response, List[ServerRequestWithContext[SR]])] Left: error (e.g.,
42-
* response evicted) Right: (user command response, server requests with context)
41+
* Response type is Either[RequestError, (command.Response, List[ServerRequestEnvelope[SR]])] Left: error (e.g.,
42+
* response evicted) Right: (user command response, server request envelopes)
4343
* @note
4444
* lowestRequestId enables the "Lowest Sequence Number Protocol" from Raft dissertation Ch. 6.3
4545
* @note
@@ -54,8 +54,8 @@ object SessionCommand:
5454
lowestRequestId: RequestId,
5555
command: UC
5656
) extends SessionCommand[UC, SR]:
57-
// Response type can be an error or the user command's response with server requests
58-
type Response = Either[RequestError, (command.Response, List[ServerRequestWithContext[SR]])]
57+
// Response type can be an error or the user command's response with server request envelopes
58+
type Response = Either[RequestError, (command.Response, List[ServerRequestEnvelope[SR]])]
5959

6060
/** Acknowledgment from a client for a server-initiated request.
6161
*
@@ -90,7 +90,7 @@ object SessionCommand:
9090
sessionId: SessionId,
9191
capabilities: Map[String, String]
9292
) extends SessionCommand[Nothing, SR]:
93-
type Response = List[ServerRequestWithContext[SR]] // server requests with context
93+
type Response = List[ServerRequestEnvelope[SR]] // server request envelopes
9494

9595
/** Notification that a session has expired.
9696
*
@@ -105,7 +105,7 @@ object SessionCommand:
105105
createdAt: Instant,
106106
sessionId: SessionId
107107
) extends SessionCommand[Nothing, SR]:
108-
type Response = List[ServerRequestWithContext[SR]] // server requests with context
108+
type Response = List[ServerRequestEnvelope[SR]] // server request envelopes
109109

110110
/** Command to atomically retrieve requests needing retry and update lastSentAt.
111111
*

session-state-machine/src/main/scala/zio/raft/sessionstatemachine/SessionStateMachine.scala

Lines changed: 52 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import zio.{UIO, Chunk}
44
import zio.prelude.State
55
import zio.raft.{Command, HMap, StateMachine, Index}
66
import zio.raft.protocol.{SessionId, RequestId}
7+
import zio.raft.protocol.RequestId.RequestIdSyntax
78
import zio.stream.Stream
89
import java.time.Instant
910

@@ -267,14 +268,15 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
267268
*
268269
* Implementation of Raft dissertation Chapter 6.3 session management protocol:
269270
* 1. Check cache for (sessionId, requestId) 2. If cache hit, return cached response 3. If cache miss, check if
270-
* requestId < highestLowestRequestIdSeen → response was evicted, return error 4. If cache miss + requestId >=
271+
* requestId <= highestLowestRequestIdSeen → response was evicted, return error 4. If cache miss + requestId >
271272
* highestLowestRequestIdSeen, execute command and update highestLowestRequestIdSeen
272273
*
273274
* This correctly handles out-of-order requests. The lowestRequestId from the client tells us which responses have
274-
* been acknowledged and can be evicted. We only update highestLowestRequestIdSeen for requests we actually process.
275+
* been acknowledged (inclusive) and can be evicted. We only update highestLowestRequestIdSeen for requests we
276+
* actually process.
275277
*/
276278
private def handleClientRequest(cmd: SessionCommand.ClientRequest[UC, SR])
277-
: State[HMap[Schema], Either[RequestError, (cmd.command.Response, List[ServerRequestWithContext[SR]])]] =
279+
: State[HMap[Schema], Either[RequestError, (cmd.command.Response, List[ServerRequestEnvelope[SR]])]] =
278280
for
279281
highestLowestSeen <- getHighestLowestRequestIdSeen(cmd.sessionId)
280282
cachedOpt <- getCachedResponse((cmd.sessionId, cmd.requestId))
@@ -285,9 +287,9 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
285287

286288
case None =>
287289
// Cache miss - check if response was evicted
288-
// If requestId < highestLowestRequestIdSeen, client has acknowledged receiving this response
289-
if cmd.requestId.isLowerThan(highestLowestSeen) then
290-
// Client said "I have responses for all requestIds < highestLowest", so this was evicted
290+
// If requestId <= highestLowestRequestIdSeen, client has acknowledged receiving this response
291+
if cmd.requestId.isLowerOrEqual(highestLowestSeen) then
292+
// Client said "I have responses for all requestIds <= highestLowest", so this was evicted
291293
State.succeed(Left(RequestError.ResponseEvicted(cmd.sessionId, cmd.requestId)))
292294
else
293295
// requestId >= highestLowestRequestIdSeen
@@ -327,8 +329,8 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
327329

328330
/** Update the highest lowestRequestId seen from the client (only if lowestRequestId > current highest).
329331
*
330-
* The lowestRequestId from the client indicates "I have received all responses for requestIds < this value". We
331-
* track the highest such value to detect evicted responses.
332+
* The lowestRequestId from the client indicates "I have received all responses for requestIds <= this value
333+
* (inclusive)". We track the highest such value to detect evicted responses.
332334
*/
333335
private def updateHighestLowestRequestIdSeen(
334336
sessionId: SessionId,
@@ -354,7 +356,7 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
354356
val keysToRemove = state.range["serverRequests"](
355357
(cmd.sessionId, RequestId.zero),
356358
(cmd.sessionId, upperBoundExclusive)
357-
).map(_._1)
359+
).map((key, _) => key)
358360

359361
// Remove all acknowledged requests in one efficient operation
360362
state.removedAll["serverRequests"](keysToRemove)
@@ -363,7 +365,7 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
363365
/** Handle CreateSession command.
364366
*/
365367
private def handleCreateSession(cmd: SessionCommand.CreateSession[SR])
366-
: State[HMap[Schema], List[ServerRequestWithContext[SR]]] =
368+
: State[HMap[Schema], List[ServerRequestEnvelope[SR]]] =
367369
for
368370
_ <- createSessionMetadata(cmd.sessionId, cmd.capabilities, cmd.createdAt)
369371
(serverRequestsLog, _) <- handleSessionCreated(cmd.sessionId, cmd.capabilities, cmd.createdAt).withLog
@@ -382,7 +384,7 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
382384
/** Handle SessionExpired command (internal name to avoid conflict with abstract method).
383385
*/
384386
private def handleSessionExpired_internal(cmd: SessionCommand.SessionExpired[SR])
385-
: State[HMap[Schema], List[ServerRequestWithContext[SR]]] =
387+
: State[HMap[Schema], List[ServerRequestEnvelope[SR]]] =
386388
for
387389
capabilities <- getSessionCapabilities(cmd.sessionId)
388390
(serverRequestsLog, _) <- handleSessionExpired(cmd.sessionId, capabilities, cmd.createdAt).withLog
@@ -408,9 +410,7 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
408410
// Use foldRight to collect updated requests and build new state
409411
// This avoids vars and the need to reverse at the end
410412
state.iterator["serverRequests"].foldRight((List.empty[PendingServerRequest[SR]], state)) {
411-
case (((sessionId, requestId), pendingAny), (accumulated, currentState)) =>
412-
val pending = pendingAny.asInstanceOf[PendingServerRequest[SR]]
413-
413+
case (((sessionId, requestId), pending), (accumulated, currentState)) =>
414414
// Check if this request needs retry (lastSentAt before threshold)
415415
if pending.lastSentAt.isBefore(cmd.lastSentBefore) then
416416
// Update lastSentAt and add to accumulated list
@@ -439,7 +439,7 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
439439
private def addServerRequests(
440440
createdAt: Instant,
441441
serverRequests: Chunk[ServerRequestForSession[SR]]
442-
): State[HMap[Schema], List[ServerRequestWithContext[SR]]] =
442+
): State[HMap[Schema], List[ServerRequestEnvelope[SR]]] =
443443
if serverRequests.isEmpty then
444444
State.succeed(Nil)
445445
else
@@ -448,52 +448,45 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
448448
val groupedBySession = serverRequests.groupBy(_.sessionId)
449449

450450
// Process each session's requests
451-
val (finalState, allRequestsWithContext) =
452-
groupedBySession.foldLeft((state, List.empty[ServerRequestWithContext[SR]])) {
451+
val (finalState, allEnvelopes) =
452+
groupedBySession.foldLeft((state, List.empty[ServerRequestEnvelope[SR]])) {
453453
case ((currentState, accumulated), (sessionId, sessionRequests)) =>
454454
// Get last assigned ID for this session
455455
val lastId = currentState.get["lastServerRequestId"](sessionId)
456456
.getOrElse(RequestId.zero)
457457

458-
// Assign IDs starting from lastId + 1
459-
val requestsWithIds = sessionRequests.zipWithIndex.map { case (reqWithSession, index) =>
460-
val newId = RequestId(RequestId.unwrap(lastId) + index + 1)
461-
(
462-
sessionId,
463-
newId,
464-
PendingServerRequest(
465-
payload = reqWithSession.payload,
466-
lastSentAt = createdAt
467-
)
468-
)
469-
}
470-
471-
// Update lastServerRequestId for this session
472-
val newLastId = requestsWithIds.last._2
458+
// Calculate new last ID: lastId + number of requests
459+
val newLastId = lastId.increaseBy(sessionRequests.length)
473460
val stateWithNewId = currentState.updated["lastServerRequestId"](
474461
sessionId,
475462
newLastId
476463
)
477464

478-
// Add all requests with composite keys
479-
val stateWithRequests = requestsWithIds.foldLeft(stateWithNewId) {
480-
case (s, (sid, requestId, pending)) =>
481-
s.updated["serverRequests"]((sid, requestId), pending)
482-
}
483-
484-
// Create requests with context
485-
val requestsWithContext = requestsWithIds.map { case (sid, requestId, pending) =>
486-
ServerRequestWithContext(
487-
sessionId = sid,
488-
requestId = requestId,
489-
payload = pending.payload
490-
)
491-
}.toList
492-
493-
(stateWithRequests, accumulated ++ requestsWithContext)
465+
// Add all requests with composite keys and create envelopes in one pass
466+
val (stateWithRequests, envelopes) =
467+
sessionRequests.zipWithIndex.foldLeft((stateWithNewId, List.empty[ServerRequestEnvelope[SR]])) {
468+
case ((s, envs), (reqWithSession, index)) =>
469+
val requestId = lastId.increaseBy(index + 1)
470+
val pending = PendingServerRequest(
471+
payload = reqWithSession.payload,
472+
lastSentAt = createdAt
473+
)
474+
val updatedState = s.updated["serverRequests"]((sessionId, requestId), pending)
475+
val envelope = ServerRequestEnvelope(
476+
sessionId = sessionId,
477+
requestId = requestId,
478+
payload = reqWithSession.payload
479+
)
480+
(updatedState, envelope :: envs)
481+
}
482+
483+
// Reverse to maintain order (since we prepended)
484+
val orderedEnvelopes = envelopes.reverse
485+
486+
(stateWithRequests, accumulated ++ orderedEnvelopes)
494487
}
495488

496-
(allRequestsWithContext, finalState)
489+
(allEnvelopes, finalState)
497490
}
498491

499492
/** Remove all session data when session expires.
@@ -504,13 +497,13 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
504497
val cacheKeysToRemove = state.range["cache"](
505498
(sessionId, RequestId.zero),
506499
(sessionId, RequestId.max)
507-
).map(_._1)
500+
).map((key, _) => key)
508501

509502
// Remove all server requests for this session using range query
510503
val serverRequestKeysToRemove = state.range["serverRequests"](
511504
(sessionId, RequestId.zero),
512505
(sessionId, RequestId.max)
513-
).map(_._1)
506+
).map((key, _) => key)
514507

515508
// Remove all session data in batch
516509
state
@@ -527,28 +520,30 @@ trait SessionStateMachine[UC <: Command, R, SR, UserSchema <: Tuple]
527520

528521
/** Clean up cached responses based on lowestRequestId (Lowest Sequence Number Protocol).
529522
*
530-
* Removes all cached responses for the session with requestId < lowestRequestId. This allows the client to control
523+
* Removes all cached responses for the session with requestId <= lowestRequestId. This allows the client to control
531524
* cache cleanup by telling the server which responses it no longer needs (Chapter 6.3 of Raft dissertation).
532525
*
526+
* The client sends lowestRequestId to indicate "I have received all responses up to and including this ID".
527+
*
533528
* Uses range queries to efficiently find and remove old cache entries.
534529
*/
535530
private def cleanupCache(
536531
sessionId: SessionId,
537532
lowestRequestId: RequestId
538533
): State[HMap[Schema], Unit] =
539534
State.update { state =>
540-
// Use range to find all cache entries for this session with requestId < lowestRequestId
541-
// Range is [from, until), so we use (sessionId, RequestId.zero) to (sessionId, lowestRequestId)
535+
// Use range to find all cache entries for this session with requestId <= lowestRequestId
536+
// Range is [from, until), so to include lowestRequestId, we use lowestRequestId.next as upper bound
542537
val keysToRemove = state.range["cache"](
543538
(sessionId, RequestId.zero),
544-
(sessionId, lowestRequestId)
545-
).map(_._1)
539+
(sessionId, lowestRequestId.next)
540+
).map((key, _) => key)
546541

547542
// Remove all old cache entries in one efficient operation
548543
state.removedAll["cache"](keysToRemove)
549544
}
550545

551-
/** Dirty read helper (FR-027) - check if ANY session has pending requests needing retry.
546+
/** Dirty read helper - check if ANY session has pending requests needing retry.
552547
*
553548
* This method can be called directly (outside Raft consensus) to optimize the retry process. The retry process
554549
* performs a dirty read, applies policy locally, and only sends GetRequestsForRetry command if retries are needed.

session-state-machine/src/main/scala/zio/raft/sessionstatemachine/package.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ package object sessionstatemachine:
5555

5656
/** Server request with assigned ID after being added to state.
5757
*/
58-
case class ServerRequestWithContext[SR](
58+
case class ServerRequestEnvelope[SR](
5959
sessionId: SessionId,
6060
requestId: RequestId,
6161
payload: SR
@@ -140,8 +140,8 @@ package object sessionstatemachine:
140140
/** Response was cached but has been evicted. Client must create a new session.
141141
*
142142
* This error occurs when:
143-
* 1. Client retries a request with requestId < highestLowestRequestIdSeen for the session 2. The response is not
144-
* in the cache (was evicted via cleanupCache)
143+
* 1. Client retries a request with requestId <= highestLowestRequestIdSeen for the session 2. The response is
144+
* not in the cache (was evicted via cleanupCache)
145145
*
146146
* Per Raft dissertation Chapter 6.3, the client should create a new session and retry the operation.
147147
*/
@@ -168,9 +168,9 @@ package object sessionstatemachine:
168168
* - No data duplication: sessionId and requestId are not stored in the value, only in the key
169169
*
170170
* The highestLowestRequestIdSeen prefix enables detection of evicted responses:
171-
* - Client sends lowestRequestId indicating "I have received all responses < this ID"
171+
* - Client sends lowestRequestId indicating "I have received all responses <= this ID (inclusive)"
172172
* - We track the highest lowestRequestId value we've seen from the client
173-
* - When a ClientRequest arrives, we check if requestId < highestLowestRequestIdSeen
173+
* - When a ClientRequest arrives, we check if requestId <= highestLowestRequestIdSeen
174174
* - If yes and response is not in cache, we know it was evicted (client already acknowledged it)
175175
* - This correctly handles out-of-order requests while preventing re-execution of acknowledged commands
176176
*/

0 commit comments

Comments
 (0)