diff --git a/cdk/lib/__snapshots__/registration.test.ts.snap b/cdk/lib/__snapshots__/registration.test.ts.snap index 5834e6dc5..467313976 100644 --- a/cdk/lib/__snapshots__/registration.test.ts.snap +++ b/cdk/lib/__snapshots__/registration.test.ts.snap @@ -64,6 +64,14 @@ exports[`The Registration stack matches the snapshot for CODE 1`] = ` "Description": "ACM Certificate for app use", "Type": "String", }, + "DatabaseAccessSecurityGroup": { + "AllowedValues": [ + "/CODE/mobile-notifications/registrations-db/postgres-access-security-group", + "/PROD/mobile-notifications/registrations-db/postgres-access-security-group", + ], + "Description": "The security group that allows access to the database", + "Type": "AWS::SSM::Parameter::Value", + }, "DistBucket": { "Description": "The name of the s3 bucket containing the server artifact", "Type": "String", @@ -112,10 +120,6 @@ exports[`The Registration stack matches the snapshot for CODE 1`] = ` "Description": "Environment name", "Type": "String", }, - "VPCSecurityGroup": { - "Description": "The default security group of the VPC", - "Type": "AWS::EC2::SecurityGroup::Id", - }, "VpcId": { "Description": "The VPC", "Type": "AWS::EC2::VPC::Id", @@ -795,7 +799,7 @@ exports[`The Registration stack matches the snapshot for CODE 1`] = ` "Ref": "InstanceSecurityGroup", }, { - "Ref": "VPCSecurityGroup", + "Ref": "DatabaseAccessSecurityGroup", }, ], "UserData": { @@ -917,6 +921,14 @@ exports[`The Registration stack matches the snapshot for PROD 1`] = ` "Description": "ACM Certificate for app use", "Type": "String", }, + "DatabaseAccessSecurityGroup": { + "AllowedValues": [ + "/CODE/mobile-notifications/registrations-db/postgres-access-security-group", + "/PROD/mobile-notifications/registrations-db/postgres-access-security-group", + ], + "Description": "The security group that allows access to the database", + "Type": "AWS::SSM::Parameter::Value", + }, "DistBucket": { "Description": "The name of the s3 bucket containing the server artifact", "Type": "String", @@ -965,10 +977,6 @@ exports[`The Registration stack matches the snapshot for PROD 1`] = ` "Description": "Environment name", "Type": "String", }, - "VPCSecurityGroup": { - "Description": "The default security group of the VPC", - "Type": "AWS::EC2::SecurityGroup::Id", - }, "VpcId": { "Description": "The VPC", "Type": "AWS::EC2::VPC::Id", @@ -1648,7 +1656,7 @@ exports[`The Registration stack matches the snapshot for PROD 1`] = ` "Ref": "InstanceSecurityGroup", }, { - "Ref": "VPCSecurityGroup", + "Ref": "DatabaseAccessSecurityGroup", }, ], "UserData": { diff --git a/common/src/main/scala/db/RegistrationRepository.scala b/common/src/main/scala/db/RegistrationRepository.scala index a0ded8206..30f6fcf47 100644 --- a/common/src/main/scala/db/RegistrationRepository.scala +++ b/common/src/main/scala/db/RegistrationRepository.scala @@ -11,5 +11,6 @@ trait RegistrationRepository[F[_], S[_[_], _]] { def delete(sub: Registration): ConnectionIO[Int] def deleteByToken(token: String): ConnectionIO[Int] def deleteByDate(olderThanDays: Int): ConnectionIO[Int] + def simpleSelectForHealthCheck(): S[F, TopicCount] def topicCounts(countsThreshold: Int): S[F, TopicCount] } diff --git a/common/src/main/scala/db/RegistrationService.scala b/common/src/main/scala/db/RegistrationService.scala index a52cc0057..4a2507e52 100644 --- a/common/src/main/scala/db/RegistrationService.scala +++ b/common/src/main/scala/db/RegistrationService.scala @@ -44,6 +44,8 @@ class RegistrationService[F[_]: Async, S[_[_], _]](repository: RegistrationRepos } def topicCounts(countThreshold: Int): S[F, TopicCount] = repository.topicCounts(countThreshold) + + def simpleSelectForHealthCheck(): S[F, TopicCount] = repository.simpleSelectForHealthCheck() } diff --git a/common/src/main/scala/db/SqlRegistrationRepository.scala b/common/src/main/scala/db/SqlRegistrationRepository.scala index 439f7296f..9fb099b40 100644 --- a/common/src/main/scala/db/SqlRegistrationRepository.scala +++ b/common/src/main/scala/db/SqlRegistrationRepository.scala @@ -77,6 +77,23 @@ class SqlRegistrationRepository[F[_]: Async](xa: Transactor[F]) .transact(xa) } + /** + * Used to verify that the DB connection is healthy. + * We just select one topic and a constant value as this is sufficient to check connectivity (we don't care about what data is returned). + */ + override def simpleSelectForHealthCheck(): Stream[F, TopicCount] = { + logger.info("Performing a query to check DB connectivity") + sql""" + SELECT topic + , 1 + FROM registrations + LIMIT 1 + """ + .query[TopicCount] + .stream + .transact(xa) + } + override def findTokens(topics: NonEmptyList[String], shardRange: Option[Range]): Stream[F, HarvestedToken] = { val queryStatement = (sql""" SELECT token, platform, buildTier diff --git a/registration-db.yaml b/registration-db.yaml index 06a0d8494..a276fe23d 100644 --- a/registration-db.yaml +++ b/registration-db.yaml @@ -49,6 +49,13 @@ Resources: GroupName: !Sub registrations-db-${Stage} GroupDescription: !Sub Security group allowing VPC only traffic SecurityGroupIngress: + # Join PostgresAccessSecurityGroup to allow access to postgres to the registration db + - SourceSecurityGroupId: !Ref PostgresAccessSecurityGroup + FromPort: 5432 + IpProtocol: tcp + ToPort: 5432 + + # TODO Remove this rule once all applications are using the PostgresAccessSecurityGroup - SourceSecurityGroupId: !Ref VPCSecurityGroup FromPort: 5432 IpProtocol: tcp @@ -66,6 +73,27 @@ Resources: Value: registrations-db VpcId: !Ref VpcId + PostgresAccessSecurityGroup: + Type: AWS::EC2::SecurityGroup + Properties: + GroupName: !Sub registrations-db-${Stage}-access + GroupDescription: !Sub Security group allowing access to the registrations db + Tags: + - Key: Stage + Value: !Ref Stage + - Key: Stack + Value: mobile-notifications + - Key: App + Value: registrations-db + VpcId: !Ref VpcId + + PostgresAccessSecurityGroupName: + Type: AWS::SSM::Parameter + Properties: + Name: !Sub /${Stage}/mobile-notifications/registrations-db/postgres-access-security-group + Type: String + Value: !Ref PostgresAccessSecurityGroup + PrivateRegistrationPostgres13DB: Type: AWS::RDS::DBInstance DependsOn: PrivateRegistrationDBSubnetGroup @@ -179,3 +207,4 @@ Resources: Outputs: DBUrl: Value: !GetAtt PrivateRegistrationPostgres13DB.Endpoint.Address + diff --git a/registration/app/registration/controllers/Main.scala b/registration/app/registration/controllers/Main.scala index 990e76d89..da1030edd 100644 --- a/registration/app/registration/controllers/Main.scala +++ b/registration/app/registration/controllers/Main.scala @@ -33,16 +33,26 @@ final class Main( private val logger: Logger = LoggerFactory.getLogger(this.getClass) - def healthCheck: Action[AnyContent] = Action { - // This forces Play to close the connection rather than allowing - // keep-alive (because the content length is unknown) - Ok.sendEntity( - HttpEntity.Streamed( - data = Source(Array(ByteString("Good")).toVector), - contentLength = None, - contentType = Some("text/plain") - ) - ) + // Check if we can talk to the registration database + private lazy val dbConnectivityCheck = registrar.dbHealthCheck() + + def healthCheck: Action[AnyContent] = Action.async { + dbConnectivityCheck + .map(_ => { + // This forces Play to close the connection rather than allowing + // keep-alive (because the content length is unknown) + Ok.sendEntity( + HttpEntity.Streamed( + data = Source(Array(ByteString("Good")).toVector), + contentLength = None, + contentType = Some("text/plain") + ) + ) + }) + .recover { _ => { + logger.error("Failing to connect to database") + InternalServerError + } } } def newsstandRegister: Action[LegacyNewsstandRegistration] = diff --git a/registration/app/registration/services/DatabaseRegistrar.scala b/registration/app/registration/services/DatabaseRegistrar.scala index 7baa46061..b08ba7821 100644 --- a/registration/app/registration/services/DatabaseRegistrar.scala +++ b/registration/app/registration/services/DatabaseRegistrar.scala @@ -8,7 +8,6 @@ import fs2.Stream import com.amazonaws.services.cloudwatch.model.StandardUnit import metrics.{MetricDataPoint, Metrics} - import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} @@ -16,6 +15,12 @@ class DatabaseRegistrar( registrationService: RegistrationService[IO, Stream], metrics: Metrics )(implicit ec: ExecutionContext) extends NotificationRegistrar { + def dbHealthCheck(): Future[List[TopicCount]] = { + val simpleSelect = registrationService.simpleSelectForHealthCheck() + simpleSelect.compile.toList.unsafeToFuture() + } + + override val providerIdentifier: String = "DatabaseRegistrar" override def register(deviceToken: DeviceToken, registration: Registration): RegistrarResponse[RegistrationResponse] = { diff --git a/registration/app/registration/services/NotificationRegistrar.scala b/registration/app/registration/services/NotificationRegistrar.scala index c7e1db74c..d6d713e6b 100644 --- a/registration/app/registration/services/NotificationRegistrar.scala +++ b/registration/app/registration/services/NotificationRegistrar.scala @@ -22,6 +22,7 @@ trait NotificationRegistrar { import NotificationRegistrar.RegistrarResponse val providerIdentifier: String def register(deviceToken: DeviceToken, registration: Registration): RegistrarResponse[RegistrationResponse] + def dbHealthCheck(): Future[List[TopicCount]] } object NotificationRegistrar { diff --git a/registration/conf/registration.yaml b/registration/conf/registration.yaml index 360a5e0ac..df3df7393 100644 --- a/registration/conf/registration.yaml +++ b/registration/conf/registration.yaml @@ -38,9 +38,12 @@ Parameters: - CODE - PROD Description: Environment name - VPCSecurityGroup: - Type: AWS::EC2::SecurityGroup::Id - Description: The default security group of the VPC + DatabaseAccessSecurityGroup: + Type: AWS::SSM::Parameter::Value + Description: The security group that allows access to the database + AllowedValues: + - /CODE/mobile-notifications/registrations-db/postgres-access-security-group + - /PROD/mobile-notifications/registrations-db/postgres-access-security-group AlarmTopic: Type: String Description: The ARN of the SNS topic to send all the cloudwatch alarms to @@ -291,7 +294,7 @@ Resources: InstanceType: !FindInMap [StageVariables, !Ref Stage, InstanceType] SecurityGroups: - !Ref InstanceSecurityGroup - - !Ref VPCSecurityGroup + - !Ref DatabaseAccessSecurityGroup MetadataOptions: HttpTokens: required UserData: diff --git a/registration/conf/riff-raff.yaml b/registration/conf/riff-raff.yaml index 46d41a04a..71e6e4770 100644 --- a/registration/conf/riff-raff.yaml +++ b/registration/conf/riff-raff.yaml @@ -16,8 +16,10 @@ deployments: templateStageParameters: CODE: LoggingStreamName: /account/services/logging.stream.name.code + DatabaseAccessSecurityGroup: /CODE/mobile-notifications/registrations-db/postgres-access-security-group PROD: LoggingStreamName: /account/services/logging.stream.name + DatabaseAccessSecurityGroup: /PROD/mobile-notifications/registrations-db/postgres-access-security-group registration: type: autoscaling parameters: diff --git a/registration/test/registration/controllers/RegistrationsFixtures.scala b/registration/test/registration/controllers/RegistrationsFixtures.scala index c5bdf3f3e..7852d67ed 100644 --- a/registration/test/registration/controllers/RegistrationsFixtures.scala +++ b/registration/test/registration/controllers/RegistrationsFixtures.scala @@ -35,6 +35,9 @@ trait DelayedRegistrationsBase extends RegistrationsBase { provider = Unknown )) } + + // Not needed for tests + override def dbHealthCheck(): Future[List[TopicCount]] = Future.successful(List.empty) } } @@ -71,6 +74,8 @@ trait RegistrationsBase extends WithPlayApp with RegistrationsJson { )) } + // Not needed for tests + override def dbHealthCheck(): Future[List[TopicCount]] = Future.successful(List.empty) }