Skip to content

Commit

Permalink
Move question tests to test_question.py
Browse files Browse the repository at this point in the history
  • Loading branch information
matinone committed Nov 26, 2023
1 parent 5d8ab30 commit 7feaa5f
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 87 deletions.
91 changes: 89 additions & 2 deletions app/tests/api/test_question.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,102 @@
from datetime import datetime, timedelta

import pytest
from dirty_equals import IsDatetime
from fastapi import status
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession

import app.models as models
from app.schemas import QuestionType
from app.tests.factories.question_factory import QuestionFactory
from app.tests.factories.quiz_factory import QuizFactory


@pytest.mark.parametrize("cases", ["default", "custom"])
async def test_create_question(
client: AsyncClient, db_session: AsyncSession, cases: str
):
# questions must be associated to an existing quiz
quiz = await QuizFactory.create()
question_data = {"quiz_id": quiz.id, "content": "What is the question?"}

if cases == "custom":
question_data["type"] = QuestionType.multiple_choice
question_data["points"] = 4

response = await client.post(f"/api/quiz/{quiz.id}/questions", json=question_data)

assert response.status_code == status.HTTP_201_CREATED

if cases == "default":
question_data["type"] = QuestionType.open
question_data["points"] = 1

created_question = response.json()
for key in question_data:
assert created_question[key] == question_data[key]

# created_at/updated_at should be close to the current time
for key in ["created_at", "updated_at"]:
assert created_question[key] == IsDatetime(
approx=datetime.utcnow(), delta=2, iso_string=True
)

# check quiz exists in database
db_question = await models.Question.get(db=db_session, id=created_question["id"])
assert db_question
for key in question_data:
assert question_data[key] == getattr(db_question, key)


async def test_create_question_no_quiz(client: AsyncClient, db_session: AsyncSession):
question_data = {"content": "What is the question?"}
response = await client.post("/api/quiz/123/questions", json=question_data)

assert response.status_code == status.HTTP_404_NOT_FOUND


async def test_create_question_invalid_type(
client: AsyncClient, db_session: AsyncSession
):
quiz = await QuizFactory.create()
question_data = {
"quiz_id": quiz.id,
"content": "What is the question?",
"type": "invalid",
}
response = await client.post(f"/api/quiz/{quiz.id}/questions", json=question_data)

assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY


async def test_questions(client: AsyncClient, db_session: AsyncSession):
quiz = await QuizFactory.create()
questions = await QuestionFactory.create_batch(5, quiz=quiz)

response = await client.get(f"/api/quiz/{quiz.id}/questions")

assert response.status_code == status.HTTP_200_OK

returned_questions = response.json()
assert len(returned_questions) == len(questions)

# sort questions by id so they can be iterated simultaneously
questions = sorted(questions, key=lambda d: d.id)
returned_questions = sorted(returned_questions, key=lambda d: d["id"])

for created, returned in zip(questions, returned_questions):
for key in returned:
if key in ["created_at", "updated_at"]:
assert returned[key] == IsDatetime(
approx=getattr(created, key), delta=0, iso_string=True
)
else:
assert returned[key] == getattr(created, key)


@pytest.mark.parametrize("cases", ["found", "not_found"])
async def test_get_quiz(client: AsyncClient, db_session: AsyncSession, cases: str):
async def test_get_question(client: AsyncClient, db_session: AsyncSession, cases: str):
question_id = 4
if cases == "found":
created_question = await QuestionFactory.create(id=question_id)
Expand All @@ -31,7 +116,9 @@ async def test_get_quiz(client: AsyncClient, db_session: AsyncSession, cases: st


@pytest.mark.parametrize("cases", ["found", "not_found", "partial_update", "invalid"])
async def test_update_quiz(client: AsyncClient, db_session: AsyncSession, cases: str):
async def test_update_question(
client: AsyncClient, db_session: AsyncSession, cases: str
):
question_id = 4
if cases != "not_found":
await QuestionFactory.create(
Expand Down
85 changes: 0 additions & 85 deletions app/tests/api/test_quiz.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from sqlalchemy.ext.asyncio import AsyncSession

import app.models as models
from app.schemas import QuestionType
from app.tests.factories.question_factory import QuestionFactory
from app.tests.factories.quiz_factory import QuizFactory


Expand Down Expand Up @@ -162,86 +160,3 @@ async def test_delete_quiz(client: AsyncClient, db_session: AsyncSession, cases:
# check quiz was deleted from database
db_quiz = await models.Quiz.get(db=db_session, id=quiz_id)
assert not db_quiz


@pytest.mark.parametrize("cases", ["default", "custom"])
async def test_create_question(
client: AsyncClient, db_session: AsyncSession, cases: str
):
# questions must be associated to an existing quiz
quiz = await QuizFactory.create()
question_data = {"quiz_id": quiz.id, "content": "What is the question?"}

if cases == "custom":
question_data["type"] = QuestionType.multiple_choice
question_data["points"] = 4

response = await client.post(f"/api/quiz/{quiz.id}/questions", json=question_data)

assert response.status_code == status.HTTP_201_CREATED

if cases == "default":
question_data["type"] = QuestionType.open
question_data["points"] = 1

created_question = response.json()
for key in question_data:
assert created_question[key] == question_data[key]

# created_at/updated_at should be close to the current time
for key in ["created_at", "updated_at"]:
assert created_question[key] == IsDatetime(
approx=datetime.utcnow(), delta=2, iso_string=True
)

# check quiz exists in database
db_question = await models.Question.get(db=db_session, id=created_question["id"])
assert db_question
for key in question_data:
assert question_data[key] == getattr(db_question, key)


async def test_create_question_no_quiz(client: AsyncClient, db_session: AsyncSession):
question_data = {"content": "What is the question?"}
response = await client.post("/api/quiz/123/questions", json=question_data)

assert response.status_code == status.HTTP_404_NOT_FOUND


async def test_create_question_invalid_type(
client: AsyncClient, db_session: AsyncSession
):
quiz = await QuizFactory.create()
question_data = {
"quiz_id": quiz.id,
"content": "What is the question?",
"type": "invalid",
}
response = await client.post(f"/api/quiz/{quiz.id}/questions", json=question_data)

assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY


async def test_questions(client: AsyncClient, db_session: AsyncSession):
quiz = await QuizFactory.create()
questions = await QuestionFactory.create_batch(5, quiz=quiz)

response = await client.get(f"/api/quiz/{quiz.id}/questions")

assert response.status_code == status.HTTP_200_OK

returned_questions = response.json()
assert len(returned_questions) == len(questions)

# sort questions by id so they can be iterated simultaneously
questions = sorted(questions, key=lambda d: d.id)
returned_questions = sorted(returned_questions, key=lambda d: d["id"])

for created, returned in zip(questions, returned_questions):
for key in returned:
if key in ["created_at", "updated_at"]:
assert returned[key] == IsDatetime(
approx=getattr(created, key), delta=0, iso_string=True
)
else:
assert returned[key] == getattr(created, key)

0 comments on commit 7feaa5f

Please sign in to comment.