Skip to content

Commit

Permalink
Merge pull request #6 from matinone/users
Browse files Browse the repository at this point in the history
User authentication
  • Loading branch information
matinone authored Dec 28, 2023
2 parents d27d7cb + c4c49c0 commit 0e9cfe2
Show file tree
Hide file tree
Showing 24 changed files with 830 additions and 46 deletions.
1 change: 0 additions & 1 deletion api_design.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
* email (String, Unique): User's email address.
* password_hash (String): Hashed password for security.
* created_at (DateTime): Date and time of registration.
* last_login (DateTime): Date and time of last login.

## Quizzes Table
* Table Name: **quizzes**
Expand Down
3 changes: 2 additions & 1 deletion app/api/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from fastapi import APIRouter

from app.api.endpoints import question, quiz
from app.api.endpoints import login, question, quiz

api_router = APIRouter()
api_router.include_router(quiz.router)
api_router.include_router(question.router)
api_router.include_router(login.router)
26 changes: 26 additions & 0 deletions app/api/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Annotated

from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer

import app.models as models
from app.core.security import decode_token
from app.core.settings import Settings, get_settings
from app.models.database import AsyncSessionDep

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/tokens")


async def get_current_user(
db: AsyncSessionDep,
token: Annotated[str, Depends(oauth2_scheme)],
settings: Annotated[Settings, Depends(get_settings)],
) -> models.User:
token_data = decode_token(token=token, settings=settings)
user = await models.User.get(db=db, id=int(token_data.sub))
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)

return user
69 changes: 69 additions & 0 deletions app/api/endpoints/login.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import Annotated, Any

from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm

import app.models as models
import app.schemas as schemas
from app.core.security import create_access_token, verify_password
from app.models.database import AsyncSessionDep

router = APIRouter(prefix="", tags=["login"])


@router.post(
"/tokens",
response_model=schemas.Token,
status_code=status.HTTP_201_CREATED,
summary="Get a new access token",
)
async def get_access_token_from_username(
db: AsyncSessionDep,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
) -> Any:
"""
Get an OAuth2 access token from a user logging in with a username and password,
to use in future requests as an authenticated user.
"""
user = await models.User.get_by_username(db=db, username=form_data.username)
if not user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Incorrect username or password",
)
if not verify_password(
plain_password=form_data.password, hashed_password=user.password_hash
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Incorrect username or password",
)

response = {
"access_token": create_access_token(subject=user.id),
"token_type": "bearer",
}
return response


@router.post(
"/register",
response_model=schemas.UserReturn,
status_code=status.HTTP_201_CREATED,
summary="Register a new user",
)
async def register_user(db: AsyncSessionDep, user: schemas.UserCreate):
new_user = await models.User.get_by_username(db=db, username=user.username)
if new_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Username already exists"
)

new_user = await models.User.get_by_email(db=db, email=user.email)
if new_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered"
)

new_user = await models.User.create(db=db, user=user)
return new_user
21 changes: 19 additions & 2 deletions app/api/endpoints/question.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import app.models as models
import app.schemas as schemas
from app.api.dependencies import get_current_user
from app.models.database import AsyncSessionDep

router = APIRouter(prefix="/questions", tags=["question"])
Expand All @@ -22,6 +23,21 @@ async def get_question_from_id(
return question


async def get_question_check_user(
question: Annotated[models.Question, Depends(get_question_from_id)],
user: Annotated[models.User, Depends(get_current_user)],
db: AsyncSessionDep,
) -> models.Question:
quiz_author_id = await models.Quiz.get_quiz_created_by(db=db, id=question.quiz_id)
if quiz_author_id != user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Quiz does not belong to current user",
)

return question


@router.get(
"/{question_id}",
response_model=schemas.QuestionReturn,
Expand All @@ -44,7 +60,7 @@ async def get_question(
)
async def update_question(
update_data: schemas.QuestionUpdate,
question: Annotated[models.Question, Depends(get_question_from_id)],
question: Annotated[models.Question, Depends(get_question_check_user)],
db: AsyncSessionDep,
) -> Any:
updated_question = await models.Question.update(
Expand All @@ -60,13 +76,14 @@ async def update_question(
summary="Delete quiz by id",
)
async def delete_question(
question: Annotated[models.Question, Depends(get_question_from_id)],
question: Annotated[models.Question, Depends(get_question_check_user)],
db: AsyncSessionDep,
) -> None:
await models.Question.delete(db=db, db_obj=question)
# body will be empty when using status code 204


# TODO: only the quiz author should be able to add answer options to a question
@router.post(
"/{question_id}/options",
response_model=schemas.AnswerOptionReturn,
Expand Down
38 changes: 33 additions & 5 deletions app/api/endpoints/quiz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import app.models as models
import app.schemas as schemas
from app.api.dependencies import get_current_user
from app.models.database import AsyncSessionDep

router = APIRouter(prefix="/quizzes", tags=["quiz"])
Expand All @@ -20,14 +21,39 @@ async def get_quiz_from_id(quiz_id: int, db: AsyncSessionDep) -> models.Quiz:
return quiz


async def get_quiz_check_user(
quiz_id: int,
user: Annotated[models.User, Depends(get_current_user)],
db: AsyncSessionDep,
) -> models.Quiz:
quiz = await models.Quiz.get(db=db, id=quiz_id)
if not quiz:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Quiz not found"
)

if quiz.created_by != user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Quiz does not belong to current user",
)

return quiz


@router.post(
"",
response_model=schemas.QuizReturn,
status_code=status.HTTP_201_CREATED,
summary="Create a new quiz",
response_description="The new created quiz",
)
async def create_quiz(db: AsyncSessionDep, quiz: schemas.QuizCreate) -> Any:
async def create_quiz(
db: AsyncSessionDep,
quiz: schemas.QuizCreate,
current_user: Annotated[models.User, Depends(get_current_user)],
) -> Any:
quiz.created_by = current_user.id
new_quiz = await models.Quiz.create(db=db, quiz=quiz)
return new_quiz

Expand Down Expand Up @@ -68,7 +94,7 @@ async def get_quiz(quiz: Annotated[models.Quiz, Depends(get_quiz_from_id)]) -> A
)
async def update_quiz(
update_data: schemas.QuizUpdate,
quiz: Annotated[models.Quiz, Depends(get_quiz_from_id)],
quiz: Annotated[models.Quiz, Depends(get_quiz_check_user)],
db: AsyncSessionDep,
) -> Any:
updated_quiz = await models.Quiz.update(db=db, current=quiz, new=update_data)
Expand All @@ -82,7 +108,7 @@ async def update_quiz(
summary="Delete quiz by id",
)
async def delete_quiz(
quiz: Annotated[models.Quiz, Depends(get_quiz_from_id)],
quiz: Annotated[models.Quiz, Depends(get_quiz_check_user)],
db: AsyncSessionDep,
) -> None:
await models.Quiz.delete(db=db, db_obj=quiz)
Expand All @@ -97,9 +123,11 @@ async def delete_quiz(
response_description="The created question",
)
async def create_question_for_quiz(
quiz_id: int, question: schemas.QuestionCreate, db: AsyncSessionDep
quiz: Annotated[models.Quiz, Depends(get_quiz_check_user)],
question: schemas.QuestionCreate,
db: AsyncSessionDep,
) -> Any:
question.quiz_id = quiz_id
question.quiz_id = quiz.id
try:
new_question = await models.Question.create(db=db, question=question)
except IntegrityError as exc:
Expand Down
12 changes: 6 additions & 6 deletions app/core/custom_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ def configure_logger():

# logger.level("INFO", color="<green>")

logging.basicConfig(handlers=[InterceptHandler()], level=0)
logging.getLogger("uvicorn.access").handlers = [InterceptHandler()]
for _log in ["uvicorn", "uvicorn.error", "fastapi"]:
_logger = logging.getLogger(_log)
_logger.handlers = [InterceptHandler()]
# logging.basicConfig(handlers=[InterceptHandler()], level=0)
# logging.getLogger("uvicorn.access").handlers = [InterceptHandler()]
# for _log in ["uvicorn", "uvicorn.error", "fastapi"]:
# _logger = logging.getLogger(_log)
# _logger.handlers = [InterceptHandler()]

logger.bind(request_id=None, method=None)
# logger.bind(request_id=None, method=None)

return logger

Expand Down
60 changes: 60 additions & 0 deletions app/core/security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from datetime import datetime, timedelta

from fastapi import HTTPException, status
from jose import ExpiredSignatureError, JWTError, jwt
from passlib.context import CryptContext
from pydantic import ValidationError

from app.core.settings import Settings, get_settings
from app.schemas import TokenPayload

settings = get_settings()

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")


def decode_token(token: str, settings: Settings) -> TokenPayload:
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
# raises ValidationError if the payload is not valid
token_data = TokenPayload(**payload)
except ExpiredSignatureError as exc:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Token expired"
) from exc
except (JWTError, ValidationError) as exc:
# any HTTP status code 401 "UNAUTHORIZED" is supposed to also
# return a WWW-Authenticate header
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
) from exc

return token_data


def create_access_token(
subject: str | int, expires_delta: timedelta | None = None
) -> str:
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MIN)

to_encode = {"exp": expire, "sub": str(subject)}
encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)

return encoded_jwt


def get_password_hash(password: str) -> str:
return pwd_context.hash(password)


def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
5 changes: 5 additions & 0 deletions app/core/settings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import secrets
from functools import lru_cache

from pydantic_settings import BaseSettings, SettingsConfigDict
Expand All @@ -18,6 +19,10 @@ class Settings(BaseSettings):
POSTGRES_SERVER: str = "postgres"
POSTGRES_PORT: int = 5432

SECRET_KEY: str = secrets.token_urlsafe(32)
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MIN: int = 30

model_config = SettingsConfigDict(env_file=".env")

def get_db_url(self) -> str:
Expand Down
1 change: 1 addition & 0 deletions app/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .answer_options import AnswerOption
from .question import Question
from .quiz import Quiz
from .user import User
Loading

0 comments on commit 0e9cfe2

Please sign in to comment.