Skip to content

Avoid returning expired tokens #217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions server/src/main/resources/application.conf
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,10 @@ notificationsTask {
# the NotificationsTask will run every time the period is fullfilled
interval = 1 minutes
interval = ${?NOTIFICATIONS_TASK_INTERVAL}
}

expiredTokensTask {
# the expiredTokensTask will run every time the period is fullfilled
interval = 5 minutes
interval = ${?EXPIRED_TOKENS_TASK_INTERVAL}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import net.wiringbits.common.models.Password
import net.wiringbits.config.{JwtConfig, UserTokensConfig}
import net.wiringbits.repositories.{UserTokensRepository, UsersRepository}
import net.wiringbits.util.{EmailMessage, JwtUtils, TokensHelper}
import net.wiringbits.validations.ValidateUserToken
import org.mindrot.jbcrypt.BCrypt

import java.time.Clock
Expand All @@ -30,7 +29,6 @@ class ResetPasswordAction @Inject() (
// When the token valid
tokenMaybe <- userTokensRepository.find(userId, hmacToken)
token = tokenMaybe.getOrElse(throw new RuntimeException(s"Token for user $userId wasn't found"))
_ = ValidateUserToken(token)

// We trigger the reset password flow
userMaybe <- usersRepository.find(userId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ import net.wiringbits.api.models.VerifyEmail
import net.wiringbits.config.UserTokensConfig
import net.wiringbits.repositories.{UserTokensRepository, UsersRepository}
import net.wiringbits.util.{EmailMessage, TokensHelper}
import net.wiringbits.validations.{ValidateUserIsNotVerified, ValidateUserToken}
import net.wiringbits.validations.ValidateUserIsNotVerified

import java.time.Clock
import java.util.UUID
import javax.inject.Inject
import scala.concurrent.{ExecutionContext, Future}
Expand All @@ -16,8 +15,7 @@ class VerifyUserEmailAction @Inject() (
userTokensRepository: UserTokensRepository,
userTokensConfig: UserTokensConfig
)(implicit
ec: ExecutionContext,
clock: Clock
ec: ExecutionContext
) {
def apply(userId: UUID, token: UUID): Future[VerifyEmail.Response] = for {
// when the user is not verified
Expand All @@ -29,7 +27,6 @@ class VerifyUserEmailAction @Inject() (
hmacToken = TokensHelper.doHMACSHA1(token.toString.getBytes, userTokensConfig.hmacSecret)
tokenMaybe <- userTokensRepository.find(userId, hmacToken)
userToken = tokenMaybe.getOrElse(throw new RuntimeException(s"Token for user $userId wasn't found"))
_ = ValidateUserToken(userToken)

// then, the user is marked as verified
emailMessage = EmailMessage.confirm(user.name)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package net.wiringbits.actions.internal

import net.wiringbits.repositories.UserTokensRepository
import net.wiringbits.repositories.models.UserToken

import javax.inject.Inject
import scala.concurrent.Future

class DeleteExpiredTokenAction @Inject() (userTokensRepository: UserTokensRepository) {
def apply(token: UserToken): Future[Unit] = {
userTokensRepository.delete(tokenId = token.id, userId = token.userId)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package net.wiringbits.actions.internal

import net.wiringbits.repositories.UserTokensRepository
import net.wiringbits.repositories.models.UserToken

import javax.inject.Inject
import scala.concurrent.Future

class GetExpiredTokensAction @Inject() (
userTokensRepository: UserTokensRepository
) {
def apply(): Future[List[UserToken]] = {
userTokensRepository.getExpiredTokens
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package net.wiringbits.config

import play.api.Configuration

import scala.concurrent.duration.FiniteDuration

case class ExpiredTokensConfig(interval: FiniteDuration) {
override def toString: String = {
s"ExpiredTokensConfig(interval = $interval)"
}
}

object ExpiredTokensConfig {
def apply(config: Configuration): ExpiredTokensConfig = {
val interval = config.get[FiniteDuration]("interval")
ExpiredTokensConfig(interval)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,12 @@ class ConfigModule extends AbstractModule {
logger.info(s"Config loaded: $config")
config
}

@Provides
@Singleton
def expiredTokensConfig(global: Configuration): ExpiredTokensConfig = {
val config = ExpiredTokensConfig(global.get[Configuration]("expiredTokensTask"))
logger.info(s"Config loaded: $config")
config
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package net.wiringbits.modules

import com.google.inject.AbstractModule
import net.wiringbits.tasks.NotificationsTask
import net.wiringbits.tasks.{ExpiredTokensTask, NotificationsTask}

class TasksModule extends AbstractModule {

override def configure(): Unit = {
bind(classOf[NotificationsTask]).asEagerSingleton()
bind(classOf[ExpiredTokensTask]).asEagerSingleton()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@ import net.wiringbits.repositories.daos.UserTokensDAO
import net.wiringbits.repositories.models.UserToken
import play.api.db.Database

import java.time.Clock
import java.util.UUID
import javax.inject.Inject
import scala.concurrent.Future

class UserTokensRepository @Inject() (
database: Database
)(implicit
ec: DatabaseExecutionContext
ec: DatabaseExecutionContext,
clock: Clock
) {

def create(request: UserToken.Create): Future[Unit] = Future {
Expand All @@ -33,6 +35,12 @@ class UserTokensRepository @Inject() (
}
}

def getExpiredTokens: Future[List[UserToken]] = Future {
database.withConnection { implicit conn =>
UserTokensDAO.getExpiredTokens()
}
}

def delete(tokenId: UUID, userId: UUID): Future[Unit] = Future {
database.withConnection { implicit conn =>
UserTokensDAO.delete(tokenId, userId: UUID)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import anorm.SqlStringInterpolation
import net.wiringbits.repositories.models.UserToken

import java.sql.Connection
import java.time.Clock
import java.util.UUID

object UserTokensDAO {
Expand Down Expand Up @@ -44,6 +45,15 @@ object UserTokensDAO {
""".as(tokenParser.*)
}

def getExpiredTokens()(implicit conn: Connection, clock: Clock): List[UserToken] = {
SQL"""
SELECT user_token_id, token, token_type, created_at, expires_at, user_id
FROM user_tokens
WHERE expires_at > ${clock.instant()}
ORDER BY created_at DESC, user_token_id
""".as(tokenParser.*)
}

def delete(tokenId: UUID, userId: UUID)(implicit conn: Connection): Unit = {
val _ = SQL"""
DELETE FROM user_tokens
Expand Down
49 changes: 49 additions & 0 deletions server/src/main/scala/net/wiringbits/tasks/ExpiredTokensTask.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package net.wiringbits.tasks

import akka.actor.ActorSystem
import net.wiringbits.actions.internal.{DeleteExpiredTokenAction, GetExpiredTokensAction}
import net.wiringbits.config.ExpiredTokensConfig
import org.slf4j.LoggerFactory

import javax.inject.Inject
import scala.concurrent.ExecutionContext
import scala.util.{Failure, Success}

class ExpiredTokensTask @Inject() (
expiredTokensConfig: ExpiredTokensConfig,
getExpiredTokens: GetExpiredTokensAction,
deleteExpiredTokenAction: DeleteExpiredTokenAction
)(implicit
ec: ExecutionContext,
actorSystem: ActorSystem
) {
val logger = LoggerFactory.getLogger(this.getClass)

logger.info("Starting the expired tokens task")
actorSystem.scheduler.scheduleOnce(
expiredTokensConfig.interval
) {
run()
}

def run(): Unit = {
getExpiredTokens()
.onComplete {
case Failure(exception) => logger.error("Failed to get expired tokens", exception)
case Success(expiredTokens) =>
val message = s"There's ${expiredTokens.size} expired tokens"
if (expiredTokens.isEmpty) logger.trace(message)
else logger.info(message)
expiredTokens.foreach { expiredToken =>
deleteExpiredTokenAction(expiredToken).onComplete {
case Failure(ex) =>
logger.info(s"There was an error trying to send notification with id = ${expiredToken.id}", ex)
case Success(_) => ()
}
}
}

actorSystem.scheduler.scheduleOnce(expiredTokensConfig.interval) { run() }
()
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ class UsersControllerSpec extends PlayPostgresSpec with LoginUtils {
.verifyEmail(VerifyEmail.Request(UserToken(user.id, verificationToken)))
.expectError

error must be("Token is expired")
error must be(s"Token for user ${user.id} wasn't found")
}

"login after successful email confirmation" in withApiClient { client =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ trait RepositorySpec extends AnyWordSpec with PostgresSpec {

def withRepositories[T](clock: Clock = Clock.systemUTC)(runTest: RepositoryComponents => T): T = withDatabase { db =>
val users = new UsersRepository(db, UserTokensConfig(1.hour, 1.hour, "secret"))(Executors.databaseEC, clock)
val userTokens = new UserTokensRepository(db)(Executors.databaseEC)
val userTokens = new UserTokensRepository(db)(Executors.databaseEC, clock)
val userNotifications = new UserNotificationsRepository(db)(Executors.databaseEC, clock)
val userLogs = new UserLogsRepository(db)(Executors.databaseEC)
val components =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@ package net.wiringbits.repositories
import net.wiringbits.common.models.{Email, Name}
import net.wiringbits.core.RepositorySpec
import net.wiringbits.repositories.models.{User, UserToken, UserTokenType}
import org.mockito.MockitoSugar.{mock, when}
import org.scalatest.OptionValues.convertOptionToValuable
import org.scalatest.concurrent.ScalaFutures._
import org.scalatest.matchers.must.Matchers._

import java.time.Instant
import java.time.{Clock, Instant}
import java.time.temporal.ChronoUnit
import java.util.UUID

class UserTokensRepositorySpec extends RepositorySpec {

private val clock = mock[Clock]
when(clock.instant()).thenAnswer(Instant.now())

"create" should {
"work" in withRepositories() { repositories =>
val request = User.CreateUser(
Expand Down Expand Up @@ -87,6 +92,33 @@ class UserTokensRepositorySpec extends RepositorySpec {
val response = repositories.userTokens.find(UUID.randomUUID()).futureValue
response.isEmpty must be(true)
}

"return no results if tokens are expired" in withRepositories(clock) { repositories =>
val request = User.CreateUser(
id = UUID.randomUUID(),
email = Email.trusted("[email protected]"),
name = Name.trusted("Sample"),
hashedPassword = "password",
verifyEmailToken = "token"
)
repositories.users.create(request).futureValue

val tokenRequest =
UserToken.Create(
id = UUID.randomUUID(),
token = "test",
tokenType = UserTokenType.ResetPassword,
createdAt = Instant.now(),
expiresAt = Instant.now.plus(1, ChronoUnit.HOURS),
userId = request.id
)
repositories.userTokens.create(tokenRequest).futureValue

when(clock.instant()).thenAnswer(Instant.now().plus(2, ChronoUnit.HOURS))

val response = repositories.userTokens.find(request.id).futureValue
response.isEmpty must be(true)
}
}

"find(userId, token)" should {
Expand Down Expand Up @@ -119,6 +151,69 @@ class UserTokensRepositorySpec extends RepositorySpec {
val response = repositories.userTokens.find(UUID.randomUUID(), "test").futureValue
response.isEmpty must be(true)
}

"return no results if tokens are expired" in withRepositories(clock) { repositories =>
val request = User.CreateUser(
id = UUID.randomUUID(),
email = Email.trusted("[email protected]"),
name = Name.trusted("Sample"),
hashedPassword = "password",
verifyEmailToken = "token"
)
repositories.users.create(request).futureValue

val tokenRequest =
UserToken.Create(
id = UUID.randomUUID(),
token = "test",
tokenType = UserTokenType.ResetPassword,
createdAt = Instant.now(),
expiresAt = Instant.now.plus(1, ChronoUnit.HOURS),
userId = request.id
)
repositories.userTokens.create(tokenRequest).futureValue

when(clock.instant()).thenAnswer(Instant.now().plus(2, ChronoUnit.HOURS))

val response = repositories.userTokens.find(request.id, tokenRequest.token).futureValue
response.isEmpty must be(true)
}
}

"getExpiredTokens" should {
"return expired tokens" in withRepositories() { repositories =>
val request = User.CreateUser(
id = UUID.randomUUID(),
email = Email.trusted("[email protected]"),
name = Name.trusted("Sample"),
hashedPassword = "password",
verifyEmailToken = "token"
)
repositories.users.create(request).futureValue

val tokenRequest =
UserToken.Create(
id = UUID.randomUUID(),
token = "test",
tokenType = UserTokenType.ResetPassword,
createdAt = Instant.now(),
expiresAt = Instant.now.plus(1, ChronoUnit.HOURS),
userId = request.id
)
repositories.userTokens.create(tokenRequest).futureValue

when(clock.instant()).thenAnswer(Instant.now().plus(2, ChronoUnit.HOURS))

val expiredUserTokens = repositories.userTokens.getExpiredTokens.futureValue

// two tokens: creating an account and token created using tokenRequest
expiredUserTokens.length must be(2)
}

"return no results" in withRepositories() { repositories =>
val response = repositories.userTokens.getExpiredTokens.futureValue
response.isEmpty must be(true)
}
}

"delete" should {
Expand Down