Skip to content

Commit 47bd2fd

Browse files
Healthcheck now checks if registration db can be reached.
Co-authored-by: Akash <akash1810@users.noreply.github.com>
1 parent a8d0a26 commit 47bd2fd

File tree

7 files changed

+54
-13
lines changed

7 files changed

+54
-13
lines changed

common/src/main/scala/db/RegistrationRepository.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ trait RegistrationRepository[F[_], S[_[_], _]] {
1111
def delete(sub: Registration): ConnectionIO[Int]
1212
def deleteByToken(token: String): ConnectionIO[Int]
1313
def deleteByDate(olderThanDays: Int): ConnectionIO[Int]
14+
def simpleSelectForHealthCheck(): S[F, TopicCount]
1415
def topicCounts(countsThreshold: Int): S[F, TopicCount]
1516
}

common/src/main/scala/db/RegistrationService.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class RegistrationService[F[_]: Async, S[_[_], _]](repository: RegistrationRepos
4444
}
4545

4646
def topicCounts(countThreshold: Int): S[F, TopicCount] = repository.topicCounts(countThreshold)
47+
48+
def simpleSelectForHealthCheck(): S[F, TopicCount] = repository.simpleSelectForHealthCheck()
4749
}
4850

4951

common/src/main/scala/db/SqlRegistrationRepository.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,22 @@ class SqlRegistrationRepository[F[_]: Async](xa: Transactor[F])
7777
.transact(xa)
7878
}
7979

80+
/**
81+
* Used to verify that the DB connection is healthy.
82+
* We just select one topic and a constant value as this is sufficient to check connectivity (we don't care about what data is returned).
83+
*/
84+
override def simpleSelectForHealthCheck(): Stream[F, TopicCount] = {
85+
sql"""
86+
SELECT topic
87+
, 1
88+
FROM registrations
89+
LIMIT 1
90+
"""
91+
.query[TopicCount]
92+
.stream
93+
.transact(xa)
94+
}
95+
8096
override def findTokens(topics: NonEmptyList[String], shardRange: Option[Range]): Stream[F, HarvestedToken] = {
8197
val queryStatement = (sql"""
8298
SELECT token, platform, buildTier

registration/app/registration/controllers/Main.scala

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@ import registration.models.{LegacyNewsstandRegistration, LegacyRegistration}
1313
import registration.services._
1414
import registration.services.topic.TopicValidator
1515

16-
import scala.concurrent.{ExecutionContext, Future}
16+
import scala.concurrent.{Await, ExecutionContext, Future}
1717
import org.slf4j.{Logger, LoggerFactory}
1818
import play.api.http.HttpEntity
1919
import providers.ProviderError
2020

21+
import scala.concurrent.duration.Duration
2122
import scala.util.{Success, Try}
2223

2324
final class Main(
@@ -33,16 +34,25 @@ final class Main(
3334

3435
private val logger: Logger = LoggerFactory.getLogger(this.getClass)
3536

36-
def healthCheck: Action[AnyContent] = Action {
37-
// This forces Play to close the connection rather than allowing
38-
// keep-alive (because the content length is unknown)
39-
Ok.sendEntity(
40-
HttpEntity.Streamed(
41-
data = Source(Array(ByteString("Good")).toVector),
42-
contentLength = None,
43-
contentType = Some("text/plain")
44-
)
45-
)
37+
38+
def healthCheck: Action[AnyContent] = Action.async {
39+
// Check if we can talk to the registration database
40+
registrar.dbHealthCheck()
41+
.map(_ => {
42+
// This forces Play to close the connection rather than allowing
43+
// keep-alive (because the content length is unknown)
44+
Ok.sendEntity(
45+
HttpEntity.Streamed(
46+
data = Source(Array(ByteString("Good")).toVector),
47+
contentLength = None,
48+
contentType = Some("text/plain")
49+
)
50+
)
51+
})
52+
.recover { _ => {
53+
logger.error("Failing to connect to database")
54+
InternalServerError
55+
} }
4656
}
4757

4858
def newsstandRegister: Action[LegacyNewsstandRegistration] =

registration/app/registration/services/DatabaseRegistrar.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,20 @@ import fs2.Stream
88
import com.amazonaws.services.cloudwatch.model.StandardUnit
99
import metrics.{MetricDataPoint, Metrics}
1010

11-
12-
import scala.concurrent.{ExecutionContext, Future}
11+
import scala.concurrent.duration.Duration
12+
import scala.concurrent.{Await, ExecutionContext, Future}
1313
import scala.util.{Failure, Success}
1414

1515
class DatabaseRegistrar(
1616
registrationService: RegistrationService[IO, Stream],
1717
metrics: Metrics
1818
)(implicit ec: ExecutionContext) extends NotificationRegistrar {
19+
def dbHealthCheck(): Future[List[TopicCount]] = {
20+
val simpleSelect = registrationService.simpleSelectForHealthCheck()
21+
simpleSelect.compile.toList.unsafeToFuture()
22+
}
23+
24+
1925
override val providerIdentifier: String = "DatabaseRegistrar"
2026

2127
override def register(deviceToken: DeviceToken, registration: Registration): RegistrarResponse[RegistrationResponse] = {

registration/app/registration/services/NotificationRegistrar.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ trait NotificationRegistrar {
2222
import NotificationRegistrar.RegistrarResponse
2323
val providerIdentifier: String
2424
def register(deviceToken: DeviceToken, registration: Registration): RegistrarResponse[RegistrationResponse]
25+
def dbHealthCheck(): Future[List[TopicCount]]
2526
}
2627

2728
object NotificationRegistrar {

registration/test/registration/controllers/RegistrationsFixtures.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ trait DelayedRegistrationsBase extends RegistrationsBase {
3535
provider = Unknown
3636
))
3737
}
38+
39+
// Not needed for tests
40+
override def dbHealthCheck(): Future[List[TopicCount]] = Future.successful(List.empty)
3841
}
3942
}
4043

@@ -71,6 +74,8 @@ trait RegistrationsBase extends WithPlayApp with RegistrationsJson {
7174
))
7275
}
7376

77+
// Not needed for tests
78+
override def dbHealthCheck(): Future[List[TopicCount]] = Future.successful(List.empty)
7479
}
7580

7681

0 commit comments

Comments
 (0)