diff --git a/app/api/dao/user.py b/app/api/dao/user.py index c38343917..a1af42068 100644 --- a/app/api/dao/user.py +++ b/app/api/dao/user.py @@ -4,6 +4,7 @@ from typing import Dict from flask_restx import marshal from sqlalchemy import func +from sqlalchemy.exc import IntegrityError from app import messages from app.api.email_utils import confirm_token @@ -44,30 +45,27 @@ def create_user(data: Dict[str, str]): password = data["password"] email = data["email"] terms_and_conditions_checked = data["terms_and_conditions_checked"] - - existing_user = UserModel.find_by_username(data["username"]) - if existing_user: - return ( - messages.USER_USES_A_USERNAME_THAT_ALREADY_EXISTS, - HTTPStatus.CONFLICT, - ) - else: - existing_user = UserModel.find_by_email(data["email"]) - if existing_user: - return ( - messages.USER_USES_AN_EMAIL_ID_THAT_ALREADY_EXISTS, - HTTPStatus.CONFLICT, - ) - user = UserModel(name, username, password, email, terms_and_conditions_checked) + if "need_mentoring" in data: user.need_mentoring = data["need_mentoring"] if "available_to_mentor" in data: user.available_to_mentor = data["available_to_mentor"] - user.save_to_db() - + try: + user.save_to_db() + except IntegrityError as e: + if e.args.__str__().__contains__("username"): + return ( + messages.USER_USES_A_USERNAME_THAT_ALREADY_EXISTS, + HTTPStatus.CONFLICT, + ) + if e.args.__str__().__contains__("email"): + return ( + messages.USER_USES_AN_EMAIL_ID_THAT_ALREADY_EXISTS, + HTTPStatus.CONFLICT, + ) return messages.USER_WAS_CREATED_SUCCESSFULLY, HTTPStatus.CREATED @staticmethod diff --git a/app/database/models/user.py b/app/database/models/user.py index 075b7eb07..32b6f0d76 100644 --- a/app/database/models/user.py +++ b/app/database/models/user.py @@ -1,6 +1,7 @@ from werkzeug.security import generate_password_hash, check_password_hash import time from app.database.sqlalchemy_extension import db +from sqlalchemy.exc import IntegrityError class UserModel(db.Model): @@ -148,8 +149,12 @@ def check_password(self, password_plain_text: str) -> bool: def save_to_db(self) -> None: """Adds a user to the database.""" - db.session.add(self) - db.session.commit() + try: + db.session.add(self) + db.session.commit() + except IntegrityError as e: + db.session.rollback() + raise e def delete_from_db(self) -> None: """Deletes a user from the database."""