Skip to content

Commit c731171

Browse files
Fix TxnVarMap new-key concurrency bugs (#51)
Three related fixes for when a transaction adds a new key to a TxnVarMap and the underlying TxnVar doesn't exist yet: - Commit-time lock fallback: TxnLogUpdateVarMapEntry.lock now falls back to the map's structural commitLock when getTxnVar returns None, ensuring new-key commits serialise with structural changes. - addOrUpdate TOCTOU fix: moved the value.get + match inside the internalStructureLock to eliminate the race between checking for key existence and performing the add/update. - delete TOCTOU fix: same pattern — moved the check-then-act inside the lock.
1 parent af6cfcb commit c731171

File tree

4 files changed

+68
-29
lines changed

4 files changed

+68
-29
lines changed

src/main/scala/bengal/stm/model/TxnVarMap.scala

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -82,38 +82,34 @@ case class TxnVarMap[F[_]: STM: Async, K, V](
8282
): F[TxnVarRuntimeId] =
8383
Async[F].delay(getRuntimeExistentialId(key).addParent(runtimeId))
8484

85-
// Only called when key is known to not exist
86-
private def add(newKey: K, newValue: V): F[Unit] =
87-
for {
88-
newTxnVar <- TxnVar.of(newValue)
89-
_ <- withLock(internalStructureLock)(
90-
value.update(_ += (newKey -> newTxnVar))
91-
)
92-
} yield ()
93-
9485
private[stm] def addOrUpdate(key: K, newValue: V): F[Unit] =
95-
for {
96-
txnVarMap <- value.get
97-
_ <- txnVarMap.get(key) match {
98-
case Some(tVar) =>
99-
withLock(internalStructureLock)(
86+
withLock(internalStructureLock) {
87+
for {
88+
txnVarMap <- value.get
89+
_ <- txnVarMap.get(key) match {
90+
case Some(tVar) =>
10091
tVar.set(newValue)
101-
)
102-
case None =>
103-
add(key, newValue)
104-
}
105-
} yield ()
92+
case None =>
93+
for {
94+
newTxnVar <- TxnVar.of(newValue)
95+
_ <- value.update(_ += (key -> newTxnVar))
96+
} yield ()
97+
}
98+
} yield ()
99+
}
106100

107101
private[stm] def delete(key: K): F[Unit] =
108-
for {
109-
txnVarMap <- value.get
110-
_ <- txnVarMap.get(key) match {
111-
case Some(_) =>
112-
withLock(internalStructureLock)(value.update(_ -= key))
113-
case None =>
114-
Async[F].unit
115-
}
116-
} yield ()
102+
withLock(internalStructureLock) {
103+
for {
104+
txnVarMap <- value.get
105+
_ <- txnVarMap.get(key) match {
106+
case Some(_) =>
107+
value.update(_ -= key)
108+
case None =>
109+
Async[F].unit
110+
}
111+
} yield ()
112+
}
117113
}
118114

119115
object TxnVarMap {

src/main/scala/bengal/stm/runtime/TxnLogContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ private[stm] trait TxnLogContext[F[_]] {
276276
override private[stm] lazy val lock: F[Option[Semaphore[F]]] =
277277
for {
278278
oTxnVar <- txnVarMap.getTxnVar(key)
279-
} yield oTxnVar.map(_.commitLock)
279+
} yield Some(oTxnVar.map(_.commitLock).getOrElse(txnVarMap.commitLock))
280280

281281
override private[stm] lazy val idFootprint: F[IdFootprint] =
282282
txnVarMap.getRuntimeId(key).map { rid =>

src/test/scala/bengal/stm/runtime/TxnLogEntrySpec.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,5 +485,13 @@ class TxnLogEntrySpec extends AsyncFreeSpec with AsyncIOSpec with Matchers {
485485
footprint.updatedIds shouldBe Set(rid)
486486
}
487487
}
488+
489+
"lock returns Some for new key (falls back to map commitLock)" in withRuntime { implicit stm =>
490+
for {
491+
tvarMap <- TxnVarMap.of(Map("a" -> 1))
492+
entry = stm.TxnLogUpdateVarMapEntry[String, Int]("newkey", None, Some(5), tvarMap)
493+
lock <- entry.lock
494+
} yield lock shouldBe Some(tvarMap.commitLock)
495+
}
488496
}
489497
}

src/test/scala/model/TxnVarMapSpec.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package ai.entrolution
1818
package model
1919

2020
import cats.effect.IO
21+
import cats.effect.implicits._
2122
import cats.effect.testing.scalatest.AsyncIOSpec
2223
import org.scalatest.EitherValues
2324
import org.scalatest.freespec.AsyncFreeSpec
@@ -235,4 +236,38 @@ class TxnVarMapSpec extends AsyncFreeSpec with AsyncIOSpec with Matchers with Ei
235236
.asserting(_ shouldBe None)
236237
}
237238
}
239+
240+
"concurrent new-key operations" - {
241+
"concurrent set of same new key" in {
242+
STM
243+
.runtime[IO]
244+
.flatMap { implicit stm =>
245+
for {
246+
tVarMap <- TxnVarMap.of(Map.empty[String, Int])
247+
_ <- (
248+
tVarMap.set("x", 1).commit,
249+
tVarMap.set("x", 2).commit
250+
).parTupled
251+
result <- tVarMap.get("x").commit
252+
} yield result
253+
}
254+
.asserting(_ shouldBe defined)
255+
}
256+
257+
"concurrent delete of different keys" in {
258+
STM
259+
.runtime[IO]
260+
.flatMap { implicit stm =>
261+
for {
262+
tVarMap <- TxnVarMap.of(Map("a" -> 1, "b" -> 2))
263+
_ <- (
264+
tVarMap.remove("a").commit,
265+
tVarMap.remove("b").commit
266+
).parTupled
267+
result <- tVarMap.get.commit
268+
} yield result
269+
}
270+
.asserting(_ shouldBe empty)
271+
}
272+
}
238273
}

0 commit comments

Comments
 (0)