Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions cdk/lib/__snapshots__/registration.test.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ exports[`The Registration stack matches the snapshot for CODE 1`] = `
"Description": "ACM Certificate for app use",
"Type": "String",
},
"DatabaseAccessSecurityGroup": {
"AllowedValues": [
"/CODE/mobile-notifications/registrations-db/postgres-access-security-group",
"/PROD/mobile-notifications/registrations-db/postgres-access-security-group",
],
"Description": "The security group that allows access to the database",
"Type": "AWS::SSM::Parameter::Value<AWS::EC2::SecurityGroup::Id>",
},
"DistBucket": {
"Description": "The name of the s3 bucket containing the server artifact",
"Type": "String",
Expand Down Expand Up @@ -112,10 +120,6 @@ exports[`The Registration stack matches the snapshot for CODE 1`] = `
"Description": "Environment name",
"Type": "String",
},
"VPCSecurityGroup": {
"Description": "The default security group of the VPC",
"Type": "AWS::EC2::SecurityGroup::Id",
},
"VpcId": {
"Description": "The VPC",
"Type": "AWS::EC2::VPC::Id",
Expand Down Expand Up @@ -795,7 +799,7 @@ exports[`The Registration stack matches the snapshot for CODE 1`] = `
"Ref": "InstanceSecurityGroup",
},
{
"Ref": "VPCSecurityGroup",
"Ref": "DatabaseAccessSecurityGroup",
},
],
"UserData": {
Expand Down Expand Up @@ -917,6 +921,14 @@ exports[`The Registration stack matches the snapshot for PROD 1`] = `
"Description": "ACM Certificate for app use",
"Type": "String",
},
"DatabaseAccessSecurityGroup": {
"AllowedValues": [
"/CODE/mobile-notifications/registrations-db/postgres-access-security-group",
"/PROD/mobile-notifications/registrations-db/postgres-access-security-group",
],
"Description": "The security group that allows access to the database",
"Type": "AWS::SSM::Parameter::Value<AWS::EC2::SecurityGroup::Id>",
},
"DistBucket": {
"Description": "The name of the s3 bucket containing the server artifact",
"Type": "String",
Expand Down Expand Up @@ -965,10 +977,6 @@ exports[`The Registration stack matches the snapshot for PROD 1`] = `
"Description": "Environment name",
"Type": "String",
},
"VPCSecurityGroup": {
"Description": "The default security group of the VPC",
"Type": "AWS::EC2::SecurityGroup::Id",
},
"VpcId": {
"Description": "The VPC",
"Type": "AWS::EC2::VPC::Id",
Expand Down Expand Up @@ -1648,7 +1656,7 @@ exports[`The Registration stack matches the snapshot for PROD 1`] = `
"Ref": "InstanceSecurityGroup",
},
{
"Ref": "VPCSecurityGroup",
"Ref": "DatabaseAccessSecurityGroup",
},
],
"UserData": {
Expand Down
1 change: 1 addition & 0 deletions common/src/main/scala/db/RegistrationRepository.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ trait RegistrationRepository[F[_], S[_[_], _]] {
def delete(sub: Registration): ConnectionIO[Int]
def deleteByToken(token: String): ConnectionIO[Int]
def deleteByDate(olderThanDays: Int): ConnectionIO[Int]
def simpleSelectForHealthCheck(): S[F, TopicCount]
def topicCounts(countsThreshold: Int): S[F, TopicCount]
}
2 changes: 2 additions & 0 deletions common/src/main/scala/db/RegistrationService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class RegistrationService[F[_]: Async, S[_[_], _]](repository: RegistrationRepos
}

def topicCounts(countThreshold: Int): S[F, TopicCount] = repository.topicCounts(countThreshold)

def simpleSelectForHealthCheck(): S[F, TopicCount] = repository.simpleSelectForHealthCheck()
}


Expand Down
17 changes: 17 additions & 0 deletions common/src/main/scala/db/SqlRegistrationRepository.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,23 @@ class SqlRegistrationRepository[F[_]: Async](xa: Transactor[F])
.transact(xa)
}

/**
* Used to verify that the DB connection is healthy.
* We just select one topic and a constant value as this is sufficient to check connectivity (we don't care about what data is returned).
*/
override def simpleSelectForHealthCheck(): Stream[F, TopicCount] = {
logger.info("Performing a query to check DB connectivity")
sql"""
SELECT topic
, 1
FROM registrations
LIMIT 1
"""
.query[TopicCount]
.stream
.transact(xa)
}

override def findTokens(topics: NonEmptyList[String], shardRange: Option[Range]): Stream[F, HarvestedToken] = {
val queryStatement = (sql"""
SELECT token, platform, buildTier
Expand Down
29 changes: 29 additions & 0 deletions registration-db.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ Resources:
GroupName: !Sub registrations-db-${Stage}
GroupDescription: !Sub Security group allowing VPC only traffic
SecurityGroupIngress:
# Join PostgresAccessSecurityGroup to allow access to postgres to the registration db
- SourceSecurityGroupId: !Ref PostgresAccessSecurityGroup
FromPort: 5432
IpProtocol: tcp
ToPort: 5432

# TODO Remove this rule once all applications are using the PostgresAccessSecurityGroup
- SourceSecurityGroupId: !Ref VPCSecurityGroup
FromPort: 5432
IpProtocol: tcp
Expand All @@ -66,6 +73,27 @@ Resources:
Value: registrations-db
VpcId: !Ref VpcId

PostgresAccessSecurityGroup:
Type: AWS::EC2::SecurityGroup
Properties:
GroupName: !Sub registrations-db-${Stage}-access
GroupDescription: !Sub Security group allowing access to the registrations db
Tags:
- Key: Stage
Value: !Ref Stage
- Key: Stack
Value: mobile-notifications
- Key: App
Value: registrations-db
VpcId: !Ref VpcId

PostgresAccessSecurityGroupName:
Type: AWS::SSM::Parameter
Properties:
Name: !Sub /${Stage}/mobile-notifications/registrations-db/postgres-access-security-group
Type: String
Value: !Ref PostgresAccessSecurityGroup

PrivateRegistrationPostgres13DB:
Type: AWS::RDS::DBInstance
DependsOn: PrivateRegistrationDBSubnetGroup
Expand Down Expand Up @@ -179,3 +207,4 @@ Resources:
Outputs:
DBUrl:
Value: !GetAtt PrivateRegistrationPostgres13DB.Endpoint.Address

30 changes: 20 additions & 10 deletions registration/app/registration/controllers/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,26 @@ final class Main(

private val logger: Logger = LoggerFactory.getLogger(this.getClass)

def healthCheck: Action[AnyContent] = Action {
// This forces Play to close the connection rather than allowing
// keep-alive (because the content length is unknown)
Ok.sendEntity(
HttpEntity.Streamed(
data = Source(Array(ByteString("Good")).toVector),
contentLength = None,
contentType = Some("text/plain")
)
)
// Check if we can talk to the registration database
private lazy val dbConnectivityCheck = registrar.dbHealthCheck()

def healthCheck: Action[AnyContent] = Action.async {
dbConnectivityCheck
.map(_ => {
// This forces Play to close the connection rather than allowing
// keep-alive (because the content length is unknown)
Ok.sendEntity(
HttpEntity.Streamed(
data = Source(Array(ByteString("Good")).toVector),
contentLength = None,
contentType = Some("text/plain")
)
)
})
.recover { _ => {
logger.error("Failing to connect to database")
InternalServerError
} }
}

def newsstandRegister: Action[LegacyNewsstandRegistration] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@ import fs2.Stream
import com.amazonaws.services.cloudwatch.model.StandardUnit
import metrics.{MetricDataPoint, Metrics}


import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Success}

class DatabaseRegistrar(
registrationService: RegistrationService[IO, Stream],
metrics: Metrics
)(implicit ec: ExecutionContext) extends NotificationRegistrar {
def dbHealthCheck(): Future[List[TopicCount]] = {
val simpleSelect = registrationService.simpleSelectForHealthCheck()
simpleSelect.compile.toList.unsafeToFuture()
}


override val providerIdentifier: String = "DatabaseRegistrar"

override def register(deviceToken: DeviceToken, registration: Registration): RegistrarResponse[RegistrationResponse] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ trait NotificationRegistrar {
import NotificationRegistrar.RegistrarResponse
val providerIdentifier: String
def register(deviceToken: DeviceToken, registration: Registration): RegistrarResponse[RegistrationResponse]
def dbHealthCheck(): Future[List[TopicCount]]
}

object NotificationRegistrar {
Expand Down
11 changes: 7 additions & 4 deletions registration/conf/registration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,12 @@ Parameters:
- CODE
- PROD
Description: Environment name
VPCSecurityGroup:
Type: AWS::EC2::SecurityGroup::Id
Description: The default security group of the VPC
DatabaseAccessSecurityGroup:
Type: AWS::SSM::Parameter::Value<AWS::EC2::SecurityGroup::Id>
Description: The security group that allows access to the database
AllowedValues:
- /CODE/mobile-notifications/registrations-db/postgres-access-security-group
- /PROD/mobile-notifications/registrations-db/postgres-access-security-group
AlarmTopic:
Type: String
Description: The ARN of the SNS topic to send all the cloudwatch alarms to
Expand Down Expand Up @@ -291,7 +294,7 @@ Resources:
InstanceType: !FindInMap [StageVariables, !Ref Stage, InstanceType]
SecurityGroups:
- !Ref InstanceSecurityGroup
- !Ref VPCSecurityGroup
- !Ref DatabaseAccessSecurityGroup
MetadataOptions:
HttpTokens: required
UserData:
Expand Down
2 changes: 2 additions & 0 deletions registration/conf/riff-raff.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ deployments:
templateStageParameters:
CODE:
LoggingStreamName: /account/services/logging.stream.name.code
DatabaseAccessSecurityGroup: /CODE/mobile-notifications/registrations-db/postgres-access-security-group
PROD:
LoggingStreamName: /account/services/logging.stream.name
DatabaseAccessSecurityGroup: /PROD/mobile-notifications/registrations-db/postgres-access-security-group
registration:
type: autoscaling
parameters:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ trait DelayedRegistrationsBase extends RegistrationsBase {
provider = Unknown
))
}

// Not needed for tests
override def dbHealthCheck(): Future[List[TopicCount]] = Future.successful(List.empty)
}
}

Expand Down Expand Up @@ -71,6 +74,8 @@ trait RegistrationsBase extends WithPlayApp with RegistrationsJson {
))
}

// Not needed for tests
override def dbHealthCheck(): Future[List[TopicCount]] = Future.successful(List.empty)
}


Expand Down