Skip to content

Commit a2cc53c

Browse files
committed
feat(tx): FP-friendly resource-block API [closes #99]
1 parent bd3dff8 commit a2cc53c

File tree

3 files changed

+136
-59
lines changed

3 files changed

+136
-59
lines changed

scalasql/core/src/DbApi.scala

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,12 @@ object DbApi {
194194
trait Txn extends DbApi {
195195

196196
/**
197-
* Creates a SQL Savepoint that is active within the given block; automatically
197+
* Returns a ResourceBlock that creates a SQL Savepoint that is active within the given block; automatically
198198
* releases the savepoint if the block completes successfully and rolls it back
199199
* if the block terminates with an exception, and allows you to roll back the
200200
* savepoint manually via the [[DbApi.Savepoint]] parameter passed to that block
201201
*/
202-
def savepoint[T](block: DbApi.Savepoint => T): T
202+
def savepoint: UseBlock[DbApi.Savepoint]
203203

204204
/**
205205
* Rolls back any active Savepoints and then rolls back this Transaction
@@ -575,31 +575,28 @@ object DbApi {
575575
DbApi.renderSql(query, config, dialect.withCastParams(castParams))
576576
}
577577

578-
val savepointStack = collection.mutable.ArrayDeque.empty[java.sql.Savepoint]
579-
580-
def savepoint[T](block: DbApi.Savepoint => T): T = {
581-
val savepoint = connection.setSavepoint()
582-
savepointStack.append(savepoint)
583-
584-
try {
585-
val res = block(new DbApi.SavepointImpl(savepoint, () => rollbackSavepoint(savepoint)))
586-
if (dialect.supportSavepointRelease && savepointStack.lastOption.exists(_ eq savepoint)) {
587-
// Only release if this savepoint has not been rolled back,
588-
// directly or indirectly
589-
connection.releaseSavepoint(savepoint)
590-
}
591-
res
592-
} catch {
593-
case e: Throwable =>
594-
rollbackSavepoint(savepoint)
595-
throw e
578+
private val savepointStack = collection.mutable.ArrayDeque.empty[java.sql.Savepoint]
579+
580+
lazy val savepoint: UseBlock[DbApi.Savepoint] = UseBlockImpl {
581+
val jSavepoint = connection.setSavepoint()
582+
savepointStack.append(jSavepoint)
583+
jSavepoint
584+
}((jSp, error) => {
585+
error match {
586+
case None =>
587+
if (dialect.supportSavepointRelease && savepointStack.lastOption.exists(_ eq jSp)) {
588+
// Only release if this savepoint has not been rolled back,
589+
// directly or indirectly
590+
connection.releaseSavepoint(jSp)
591+
}
592+
case Some(_) => rollbackSavepoint(jSp)
596593
}
597-
}
594+
}).map(jSp => new DbApi.SavepointImpl(jSp, () => rollbackSavepoint(jSp)))
598595

599596
// Make sure we keep track of what savepoints are active on the stack, so we do
600597
// not release or rollback the same savepoint multiple times even in the case of
601598
// exceptions or explicit rollbacks
602-
def rollbackSavepoint(savepoint: java.sql.Savepoint) = {
599+
private def rollbackSavepoint(savepoint: java.sql.Savepoint) = {
603600
savepointStack.indexOf(savepoint) match {
604601
case -1 => // do nothing
605602
case savepointIndex =>
@@ -608,7 +605,7 @@ object DbApi {
608605
}
609606
}
610607

611-
def rollback() = {
608+
def rollback(): Unit = {
612609
try {
613610
notifyListeners(listeners)(_.beforeRollback())
614611
} finally {
@@ -618,6 +615,11 @@ object DbApi {
618615
}
619616
}
620617

618+
/** Attempts rollback, adding any exceptions as suppressed to the cause */
619+
def rollbackCause(cause: Throwable): Unit =
620+
try rollback()
621+
catch { case e: Throwable => cause.addSuppressed(e) }
622+
621623
private def cast[T](t: Any): T = t.asInstanceOf[T]
622624

623625
private def flattenParamPuts[T](flattened: SqlStr.Flattened) = {

scalasql/core/src/DbClient.scala

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
package scalasql.core
22

3-
import scalasql.core.DialectConfig
4-
53
/**
64
* A database client. Primarily allows you to access the database within a [[transaction]]
75
* block or via [[getAutoCommitClientConnection]]
@@ -14,12 +12,12 @@ trait DbClient {
1412
def renderSql[Q, R](query: Q, castParams: Boolean = false)(implicit qr: Queryable[Q, R]): String
1513

1614
/**
17-
* Opens a database transaction within the given [[block]], automatically committing it
15+
* Returns a [[UseBlock]] for database transaction, automatically committing it
1816
* if the block returns successfully and rolling it back if the blow fails with an uncaught
1917
* exception. Within the block, you provides a [[DbApi.Txn]] you can use to run queries, create
2018
* savepoints, or roll back the transaction.
2119
*/
22-
def transaction[T](block: DbApi.Txn => T): T
20+
def transaction: UseBlock[DbApi.Txn]
2321

2422
/**
2523
* Provides a [[DbApi]] that you can use to run queries in "auto-commit" mode, such
@@ -71,39 +69,29 @@ object DbClient {
7169
DbApi.renderSql(query, config, dialect.withCastParams(castParams))
7270
}
7371

74-
def transaction[T](block: DbApi.Txn => T): T = {
72+
lazy val transaction: UseBlock[DbApi.Txn] = UseBlockImpl[DbApi.Impl] {
7573
connection.setAutoCommit(false)
7674
val txn = new DbApi.Impl(connection, config, dialect, listeners, autoCommit = false)
77-
var rolledBack = false
78-
try {
79-
notifyListeners(txn.listeners)(_.begin())
80-
val result = block(txn)
81-
notifyListeners(txn.listeners)(_.beforeCommit())
82-
result
83-
} catch {
84-
case e: Throwable =>
85-
rolledBack = true
75+
notifyListeners(txn.listeners)(_.begin())
76+
txn
77+
}((txn, error) => {
78+
error match {
79+
case None =>
8680
try {
87-
notifyListeners(txn.listeners)(_.beforeRollback())
81+
notifyListeners(txn.listeners)(_.beforeCommit())
8882
} catch {
89-
case e2: Throwable => e.addSuppressed(e2)
90-
} finally {
91-
connection.rollback()
92-
try {
93-
notifyListeners(txn.listeners)(_.afterRollback())
94-
} catch {
95-
case e3: Throwable => e.addSuppressed(e3)
96-
}
83+
case beforeCommitHookErr: Throwable =>
84+
txn.rollbackCause(beforeCommitHookErr)
85+
throw beforeCommitHookErr
9786
}
98-
throw e
99-
} finally {
100-
// this commits uncommitted operations, if any
101-
connection.setAutoCommit(true)
102-
if (!rolledBack) {
87+
// this commits uncommitted operations, if any
88+
connection.setAutoCommit(true)
89+
// afterCommit exceptions just propagate - commit already done
10390
notifyListeners(txn.listeners)(_.afterCommit())
104-
}
91+
92+
case Some(useError) => txn.rollbackCause(useError)
10593
}
106-
}
94+
})
10795

10896
def getAutoCommitClientConnection: DbApi = {
10997
connection.setAutoCommit(true)
@@ -130,13 +118,13 @@ object DbClient {
130118
DbApi.renderSql(query, config, dialect.withCastParams(castParams))
131119
}
132120

133-
private def withConnection[T](f: DbClient.Connection => T): T = {
134-
val connection = dataSource.getConnection
135-
try f(new DbClient.Connection(connection, config, listeners))
136-
finally connection.close()
137-
}
121+
private lazy val withConnectionImpl: UseBlockImpl[Connection] = UseBlockImpl
122+
.autoCloseable(dataSource.getConnection)
123+
.map(new Connection(_, config, listeners))
124+
125+
lazy val withConnection: UseBlock[Connection] = withConnectionImpl
138126

139-
def transaction[T](block: DbApi.Txn => T): T = withConnection(_.transaction(block))
127+
lazy val transaction: UseBlock[DbApi.Txn] = withConnectionImpl.flatMap(_.transaction)
140128

141129
def getAutoCommitClientConnection: DbApi = {
142130
val connection = dataSource.getConnection

scalasql/core/src/UseBlock.scala

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package scalasql.core
2+
3+
/**
4+
* A block abstraction that wraps resource usage with proper acquiring and close/release.
5+
* Exposes lifecycle operations separately, which is useful for integration with FP effect libraries.
6+
* Should not be implemented outside of library code, so sealed.
7+
*
8+
* @tparam A The resource type provided during the `use` phase
9+
*/
10+
sealed trait UseBlock[+A] {
11+
12+
/**
13+
* Acquires the resource. Called once at the start of the block.
14+
* @return The acquired resource and release function.
15+
*/
16+
def allocate(): (A, Option[Throwable] => Unit)
17+
18+
/**
19+
* Combines acquire, use, and release into a single operation.
20+
* Makes it friendly to traditional block-style API.
21+
*/
22+
final def apply[T](use: A => T): T = {
23+
val (resource, release) = allocate()
24+
var usedOk = false
25+
try {
26+
val result = use(resource)
27+
usedOk = true
28+
release(None)
29+
result
30+
} catch {
31+
case e: Throwable =>
32+
if (!usedOk) {
33+
release(Some(e))
34+
} // else - we had an error in `release(None)`, so just propagating
35+
throw e
36+
}
37+
}
38+
}
39+
40+
/** Package-private implementation for internal composition */
41+
private[core] class UseBlockImpl[+A](alloc: () => (A, Option[Throwable] => Unit))
42+
extends UseBlock[A] {
43+
44+
def allocate(): (A, Option[Throwable] => Unit) = alloc()
45+
46+
def map[B](f: A => B): UseBlockImpl[B] = new UseBlockImpl[B](() => {
47+
val (a, release) = alloc()
48+
(f(a), release)
49+
})
50+
51+
def flatMap[B](f: A => UseBlock[B]): UseBlockImpl[B] = new UseBlockImpl[B](() => {
52+
val (outerResource, outerRelease) = alloc()
53+
val (innerResource, innerRelease) =
54+
try {
55+
f(outerResource).allocate()
56+
} catch {
57+
case e: Throwable =>
58+
outerRelease(Some(e))
59+
throw e
60+
}
61+
def combinedRelease(errOpt: Option[Throwable]): Unit = {
62+
var errorForOuter = errOpt
63+
try {
64+
innerRelease(errOpt)
65+
} catch {
66+
case e: Throwable =>
67+
errorForOuter = Some(e)
68+
throw e
69+
} finally {
70+
outerRelease(errorForOuter)
71+
}
72+
}
73+
(innerResource, combinedRelease)
74+
})
75+
}
76+
77+
private[core] object UseBlockImpl {
78+
79+
def apply[A](acquire: => A)(release: (A, Option[Throwable]) => Unit): UseBlockImpl[A] =
80+
new UseBlockImpl[A](() => {
81+
val a = acquire
82+
(a, errOpt => release(a, errOpt))
83+
})
84+
85+
def autoCloseable[A <: AutoCloseable](acquire: => A): UseBlockImpl[A] =
86+
apply(acquire)((a, _) => a.close())
87+
}

0 commit comments

Comments
 (0)