diff --git a/app/tests/api/test_question.py b/app/tests/api/test_question.py index c6925ca..9ece9ff 100644 --- a/app/tests/api/test_question.py +++ b/app/tests/api/test_question.py @@ -1,6 +1,7 @@ from datetime import datetime, timedelta import pytest +from dirty_equals import IsDatetime from fastapi import status from httpx import AsyncClient from sqlalchemy.ext.asyncio import AsyncSession @@ -8,10 +9,94 @@ import app.models as models from app.schemas import QuestionType from app.tests.factories.question_factory import QuestionFactory +from app.tests.factories.quiz_factory import QuizFactory + + +@pytest.mark.parametrize("cases", ["default", "custom"]) +async def test_create_question( + client: AsyncClient, db_session: AsyncSession, cases: str +): + # questions must be associated to an existing quiz + quiz = await QuizFactory.create() + question_data = {"quiz_id": quiz.id, "content": "What is the question?"} + + if cases == "custom": + question_data["type"] = QuestionType.multiple_choice + question_data["points"] = 4 + + response = await client.post(f"/api/quiz/{quiz.id}/questions", json=question_data) + + assert response.status_code == status.HTTP_201_CREATED + + if cases == "default": + question_data["type"] = QuestionType.open + question_data["points"] = 1 + + created_question = response.json() + for key in question_data: + assert created_question[key] == question_data[key] + + # created_at/updated_at should be close to the current time + for key in ["created_at", "updated_at"]: + assert created_question[key] == IsDatetime( + approx=datetime.utcnow(), delta=2, iso_string=True + ) + + # check quiz exists in database + db_question = await models.Question.get(db=db_session, id=created_question["id"]) + assert db_question + for key in question_data: + assert question_data[key] == getattr(db_question, key) + + +async def test_create_question_no_quiz(client: AsyncClient, db_session: AsyncSession): + question_data = {"content": "What is the question?"} + response = await client.post("/api/quiz/123/questions", json=question_data) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + +async def test_create_question_invalid_type( + client: AsyncClient, db_session: AsyncSession +): + quiz = await QuizFactory.create() + question_data = { + "quiz_id": quiz.id, + "content": "What is the question?", + "type": "invalid", + } + response = await client.post(f"/api/quiz/{quiz.id}/questions", json=question_data) + + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + +async def test_questions(client: AsyncClient, db_session: AsyncSession): + quiz = await QuizFactory.create() + questions = await QuestionFactory.create_batch(5, quiz=quiz) + + response = await client.get(f"/api/quiz/{quiz.id}/questions") + + assert response.status_code == status.HTTP_200_OK + + returned_questions = response.json() + assert len(returned_questions) == len(questions) + + # sort questions by id so they can be iterated simultaneously + questions = sorted(questions, key=lambda d: d.id) + returned_questions = sorted(returned_questions, key=lambda d: d["id"]) + + for created, returned in zip(questions, returned_questions): + for key in returned: + if key in ["created_at", "updated_at"]: + assert returned[key] == IsDatetime( + approx=getattr(created, key), delta=0, iso_string=True + ) + else: + assert returned[key] == getattr(created, key) @pytest.mark.parametrize("cases", ["found", "not_found"]) -async def test_get_quiz(client: AsyncClient, db_session: AsyncSession, cases: str): +async def test_get_question(client: AsyncClient, db_session: AsyncSession, cases: str): question_id = 4 if cases == "found": created_question = await QuestionFactory.create(id=question_id) @@ -31,7 +116,9 @@ async def test_get_quiz(client: AsyncClient, db_session: AsyncSession, cases: st @pytest.mark.parametrize("cases", ["found", "not_found", "partial_update", "invalid"]) -async def test_update_quiz(client: AsyncClient, db_session: AsyncSession, cases: str): +async def test_update_question( + client: AsyncClient, db_session: AsyncSession, cases: str +): question_id = 4 if cases != "not_found": await QuestionFactory.create( diff --git a/app/tests/api/test_quiz.py b/app/tests/api/test_quiz.py index fa12963..afd602c 100644 --- a/app/tests/api/test_quiz.py +++ b/app/tests/api/test_quiz.py @@ -7,8 +7,6 @@ from sqlalchemy.ext.asyncio import AsyncSession import app.models as models -from app.schemas import QuestionType -from app.tests.factories.question_factory import QuestionFactory from app.tests.factories.quiz_factory import QuizFactory @@ -162,86 +160,3 @@ async def test_delete_quiz(client: AsyncClient, db_session: AsyncSession, cases: # check quiz was deleted from database db_quiz = await models.Quiz.get(db=db_session, id=quiz_id) assert not db_quiz - - -@pytest.mark.parametrize("cases", ["default", "custom"]) -async def test_create_question( - client: AsyncClient, db_session: AsyncSession, cases: str -): - # questions must be associated to an existing quiz - quiz = await QuizFactory.create() - question_data = {"quiz_id": quiz.id, "content": "What is the question?"} - - if cases == "custom": - question_data["type"] = QuestionType.multiple_choice - question_data["points"] = 4 - - response = await client.post(f"/api/quiz/{quiz.id}/questions", json=question_data) - - assert response.status_code == status.HTTP_201_CREATED - - if cases == "default": - question_data["type"] = QuestionType.open - question_data["points"] = 1 - - created_question = response.json() - for key in question_data: - assert created_question[key] == question_data[key] - - # created_at/updated_at should be close to the current time - for key in ["created_at", "updated_at"]: - assert created_question[key] == IsDatetime( - approx=datetime.utcnow(), delta=2, iso_string=True - ) - - # check quiz exists in database - db_question = await models.Question.get(db=db_session, id=created_question["id"]) - assert db_question - for key in question_data: - assert question_data[key] == getattr(db_question, key) - - -async def test_create_question_no_quiz(client: AsyncClient, db_session: AsyncSession): - question_data = {"content": "What is the question?"} - response = await client.post("/api/quiz/123/questions", json=question_data) - - assert response.status_code == status.HTTP_404_NOT_FOUND - - -async def test_create_question_invalid_type( - client: AsyncClient, db_session: AsyncSession -): - quiz = await QuizFactory.create() - question_data = { - "quiz_id": quiz.id, - "content": "What is the question?", - "type": "invalid", - } - response = await client.post(f"/api/quiz/{quiz.id}/questions", json=question_data) - - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - - -async def test_questions(client: AsyncClient, db_session: AsyncSession): - quiz = await QuizFactory.create() - questions = await QuestionFactory.create_batch(5, quiz=quiz) - - response = await client.get(f"/api/quiz/{quiz.id}/questions") - - assert response.status_code == status.HTTP_200_OK - - returned_questions = response.json() - assert len(returned_questions) == len(questions) - - # sort questions by id so they can be iterated simultaneously - questions = sorted(questions, key=lambda d: d.id) - returned_questions = sorted(returned_questions, key=lambda d: d["id"]) - - for created, returned in zip(questions, returned_questions): - for key in returned: - if key in ["created_at", "updated_at"]: - assert returned[key] == IsDatetime( - approx=getattr(created, key), delta=0, iso_string=True - ) - else: - assert returned[key] == getattr(created, key)