Skip to content

Commit

Permalink
Test.
Browse files Browse the repository at this point in the history
  • Loading branch information
bakar-io committed Apr 22, 2024
1 parent 2410f9a commit 17a10e9
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 35 deletions.
1 change: 0 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ jobs:
sudo dpkg -i golang-migrate.deb && rm golang-migrate.deb
- name: Run tests
env:
STORAGE_TYPE: postgres
POSTGRES_HOST: localhost
POSTGRES_PORT: 5432
POSTGRES_DB: postgres
Expand Down
3 changes: 1 addition & 2 deletions backend/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ migrate:

test:
# We need to update handling of env variables for tests
STORAGE_TYPE=postgres YDC_API_KEY=placeholder OPENAI_API_KEY=placeholder poetry run pytest -s $(TEST_FILE)

STORAGE_TYPE=postgres YDC_API_KEY=placeholder OPENAI_API_KEY=placeholder poetry run pytest $(TEST_FILE)

test_watch:
# We need to update handling of env variables for tests
Expand Down
53 changes: 21 additions & 32 deletions backend/tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import logging
import os
import subprocess

Expand All @@ -10,36 +9,30 @@
from app.auth.settings import settings as auth_settings
from app.lifespan import get_pg_pool, lifespan
from app.server import app
from app.storage.settings import settings as storage_settings

auth_settings.auth_type = AuthType.NOOP

# Temporary handling of environment variables for testing
os.environ["OPENAI_API_KEY"] = "test"

TEST_DB = "test"
assert os.environ["POSTGRES_DB"] != TEST_DB, "Test and main database conflict."
os.environ["POSTGRES_DB"] = TEST_DB


logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
assert storage_settings.postgres.db != TEST_DB, "Test and main database conflict."
storage_settings.postgres.db = TEST_DB


async def _get_conn() -> asyncpg.Connection:
return await asyncpg.connect(
user=os.environ["POSTGRES_USER"],
password=os.environ["POSTGRES_PASSWORD"],
host=os.environ["POSTGRES_HOST"],
port=os.environ["POSTGRES_PORT"],
user=storage_settings.postgres.user,
password=storage_settings.postgres.password,
host=storage_settings.postgres.host,
port=storage_settings.postgres.port,
database="postgres",
)


async def _create_test_db() -> None:
"""Check if the test database exists and create it if it doesn't."""
logger.info("Creating test database")
conn = await _get_conn()
exists = await conn.fetchval("SELECT 1 FROM pg_database WHERE datname=$1", TEST_DB)
if not exists:
Expand All @@ -49,7 +42,6 @@ async def _create_test_db() -> None:

async def _drop_test_db() -> None:
"""Check if the test database exists and if so, drop it."""
logger.info("Dropping test database")
conn = await _get_conn()
exists = await conn.fetchval("SELECT 1 FROM pg_database WHERE datname=$1", TEST_DB)
if exists:
Expand All @@ -58,25 +50,22 @@ async def _drop_test_db() -> None:


def _migrate_test_db() -> None:
logger.info("Migrating test database")
# Run subprocess and capture output and errors
result = subprocess.run(
["make", "migrate"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
subprocess.run(
[
"migrate",
"-database",
(
f"postgres://{storage_settings.postgres.user}:{storage_settings.postgres.password}"
f"@{storage_settings.postgres.host}:{storage_settings.postgres.port}"
f"/{storage_settings.postgres.db}?sslmode=disable"
),
"-path",
"./migrations/postgres",
"up",
],
check=True,
)

# Log standard output and errors
if result.stdout:
logger.info("Subprocess output: %s", result.stdout)
if result.stderr:
logger.error("Subprocess error: %s", result.stderr)

# Check if the subprocess exited with a non-zero exit code
if result.returncode != 0:
logger.error("Subprocess failed with return code %s", result.returncode)
raise subprocess.CalledProcessError(
result.returncode, result.args, output=result.stdout, stderr=result.stderr
)


@pytest.fixture(scope="session")
async def pool():
Expand Down

0 comments on commit 17a10e9

Please sign in to comment.