Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand FastAPI + SQLAlchemy example #8

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions fastapi-sqlalchemy/alembic/versions/2b3c8830abf6_add_auth_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Add auth model

Revision ID: 2b3c8830abf6
Revises: bea5e58f3328
Create Date: 2021-09-10 16:22:57.620951

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "2b3c8830abf6"
down_revision = "bea5e58f3328"
branch_labels = None
depends_on = None


def upgrade():
op.create_table(
"users",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("email", sa.String(), nullable=False),
sa.Column("password_hash", sa.String(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
op.create_index(op.f("ix_users_id"), "users", ["id"], unique=False)


def downgrade():
op.drop_index(op.f("ix_users_id"), table_name="users")
op.drop_index(op.f("ix_users_email"), table_name="users")
op.drop_table("users")
16 changes: 16 additions & 0 deletions fastapi-sqlalchemy/api/definitions/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import strawberry

from main.models import User as UserModel


@strawberry.type
class User:
id: int
email: str

@classmethod
def from_instance(cls, instance: UserModel):
return cls(
id=instance.id,
email=instance.email,
)
13 changes: 13 additions & 0 deletions fastapi-sqlalchemy/api/mutation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from strawberry.tools import create_type

from .mutations.register_user import register_user
from .mutations.login_user import login_user


Mutation = create_type(
"Mutation",
[
register_user,
login_user,
],
)
Empty file.
44 changes: 44 additions & 0 deletions fastapi-sqlalchemy/api/mutations/login_user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import re
import strawberry

from main.auth import login
from main.models import get_user_by_email

from ..definitions.user import User


@strawberry.type
class LoginUserSuccess:
user: User


@strawberry.type
class LoginUserError:
error_message: str


LoginUserResponse = strawberry.union(
"LoginUserResponse", types=(LoginUserSuccess, LoginUserError)
)


@strawberry.mutation
def login_user(info, email: str, password: str) -> LoginUserResponse:
if not re.fullmatch(r"[^@]+@[^@]+\.[^@]+", email):
return LoginUserError(error_message="Invalid email")

db = info.context["db"]

user = get_user_by_email(db, email)
if not user:
return LoginUserError(error_message="User not found")

if not user.check_password(password):
return LoginUserError(error_message="Invalid password")

# Login user
login(info.context["request"], user)

return LoginUserSuccess(
user=User.from_instance(user),
)
60 changes: 60 additions & 0 deletions fastapi-sqlalchemy/api/mutations/register_user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import re
import strawberry

from main.auth import login
from main.models import User as UserModel, get_user_by_email

from ..definitions.user import User


@strawberry.input
class RegisterUserInput:
email: str
password: str


@strawberry.type
class RegisterUserSuccess:
user: User


@strawberry.type
class RegisterUserError:
error_message: str


RegisterUserResponse = strawberry.union(
"RegisterUserResponse", types=(RegisterUserSuccess, RegisterUserError)
)


@strawberry.mutation
def register_user(info, data: RegisterUserInput) -> RegisterUserResponse:
email = data.email
password = data.password

if not re.fullmatch(r"[^@]+@[^@]+\.[^@]+", email):
return RegisterUserError(error_message="Invalid email")

if len(password) < 4:
return RegisterUserError(error_message="Password too short")

db = info.context["db"]

existing_user = get_user_by_email(db, email)
if existing_user:
return RegisterUserError(error_message="User already exists")

user = UserModel(
email=email,
)
user.set_password(password)
db.add(user)
db.commit()

# Login user
login(info.context["request"], user)

return RegisterUserSuccess(
user=User.from_instance(user),
)
15 changes: 13 additions & 2 deletions fastapi-sqlalchemy/api/schema.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import List
from typing import List, Optional

import strawberry
from strawberry.extensions import Extension

from main.models import get_movies
from main.database import SessionLocal

from .mutation import Mutation
from .definitions.movie import Movie
from .definitions.user import User


class SQLAlchemySession(Extension):
Expand All @@ -25,5 +27,14 @@ def top_rated_movies(self, info, limit: int = 250) -> List[Movie]:
movies = get_movies(db, limit=limit)
return [Movie.from_instance(movie) for movie in movies]

@strawberry.field
def current_user(self, info) -> Optional[User]:
request = info.context["request"]

if request.user.is_authenticated:
return User.from_instance(request.user)

return None


schema = strawberry.Schema(Query, extensions=[SQLAlchemySession])
schema = strawberry.Schema(Query, mutation=Mutation, extensions=[SQLAlchemySession])
13 changes: 12 additions & 1 deletion fastapi-sqlalchemy/main/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
from fastapi import FastAPI
from strawberry.asgi import GraphQL
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.sessions import SessionMiddleware

from api.schema import schema

from .middleware import SessionBackend


middleware = [
Middleware(SessionMiddleware, secret_key="supersecretkey"),
Middleware(AuthenticationMiddleware, backend=SessionBackend()),
]

graphql_app = GraphQL(schema)

app = FastAPI()
app = FastAPI(middleware=middleware)
app.mount("/graphql", graphql_app)
33 changes: 33 additions & 0 deletions fastapi-sqlalchemy/main/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Optional
from sqlalchemy import select
from sqlalchemy.exc import NoResultFound

from main.models import User


SESSION_KEY = "_auth_user_id"


def login(request, user: User):
session = request.session
session[SESSION_KEY] = user.id


def logout(request):
session = request.session
del session[SESSION_KEY]


def get_user(db, request) -> Optional[User]:
session = request.session
if SESSION_KEY not in session:
return None

try:
user = db.execute(
select(User).filter_by(id=request.session[SESSION_KEY])
).scalar_one()
except NoResultFound:
return None

return user
29 changes: 29 additions & 0 deletions fastapi-sqlalchemy/main/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from starlette.authentication import AuthenticationBackend, BaseUser, AuthCredentials

from main.auth import get_user
from main.database import SessionLocal
from main.models import User


class ProxyDBUser(BaseUser):
def __init__(self, instance: User):
self.instance = instance

# Proxy attributes from instance
def __getattr__(self, name):
return getattr(self.instance, name)

@property
def is_authenticated(self) -> bool:
return True


class SessionBackend(AuthenticationBackend):
async def authenticate(self, request):
db = SessionLocal()
user = get_user(db, request)

if not user:
return

return AuthCredentials(["authenticated"]), ProxyDBUser(user)
25 changes: 25 additions & 0 deletions fastapi-sqlalchemy/main/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Optional
from sqlalchemy import Column, Integer, String, Float, ForeignKey, select
from sqlalchemy.orm import relationship, joinedload
from sqlalchemy.orm import Session
from sqlalchemy.exc import NoResultFound
from passlib.hash import pbkdf2_sha256

from .database import Base

Expand Down Expand Up @@ -37,3 +40,25 @@ def get_movies(db: Session, limit: int = 250):

result = db.execute(query).unique()
return result.scalars()


class User(Base):
__tablename__ = "users"

id: int = Column(Integer, primary_key=True, index=True, nullable=False)
email: str = Column(String, unique=True, index=True, nullable=False)
password_hash: str = Column(String, nullable=False)

def set_password(self, password: str):
self.password_hash = pbkdf2_sha256.hash(password)

def check_password(self, password: str):
return pbkdf2_sha256.verify(password, self.password_hash)


def get_user_by_email(db, email: str) -> Optional[User]:
try:
user = db.execute(select(User).filter_by(email=email)).scalar_one()
return user
except NoResultFound:
return None
Loading