diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index 41fefb4..c5cc80e 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -10,7 +10,7 @@ services: - ..:/workspace:cached,z command: sleep infinity environment: - DATABASE_URL: mysql+pymysql://root:rootpassword@ispyb/ispyb_build + DATABASE_URL: mysql+aiomysql://root:rootpassword@ispyb/ispyb_build OTEL_COLLECTOR_URL: http://jaeger:4317 ispyb: diff --git a/pyproject.toml b/pyproject.toml index 20df7c2..6300f83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "starlette", "strawberry-graphql[asgi]", "sqlalchemy", - "pymysql", + "aiomysql", "opentelemetry-api", "opentelemetry-sdk", "opentelemetry-semantic-conventions", diff --git a/src/graph_energy_scan/__main__.py b/src/graph_energy_scan/__main__.py index 191eb12..eef86bc 100644 --- a/src/graph_energy_scan/__main__.py +++ b/src/graph_energy_scan/__main__.py @@ -7,14 +7,14 @@ from strawberry.printer import print_schema from graph_energy_scan.database import create_session -from graph_energy_scan.graphql import EnergyScan, Session +from graph_energy_scan.graphql import EnergyScan, Query, Session from graph_energy_scan.telemetry import setup_telemetry from . import __version__ __all__ = ["main"] -SCHEMA = Schema(types=[Session, EnergyScan], enable_federation_2=True) +SCHEMA = Schema(Query, types=[Session, EnergyScan], enable_federation_2=True) @click.group(invoke_without_command=True) diff --git a/src/graph_energy_scan/database.py b/src/graph_energy_scan/database.py index b7fa3b0..abb2077 100644 --- a/src/graph_energy_scan/database.py +++ b/src/graph_energy_scan/database.py @@ -2,21 +2,21 @@ from contextlib import asynccontextmanager from typing import AsyncGenerator, Optional, cast -from sqlalchemy import create_engine -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -SESSION: Optional[Session] = None -SESSION_SET = Event() +SESSION_MAKER: Optional[async_sessionmaker] = None +SESSION_CREATED = Event() def create_session(url: str): - global SESSION - engine = create_engine(url) - SESSION = Session(engine) - SESSION_SET.set() + global SESSION_MAKER + engine = create_async_engine(url) + SESSION_MAKER = async_sessionmaker(engine) + SESSION_CREATED.set() @asynccontextmanager -async def current_session() -> AsyncGenerator[Session, None]: - await SESSION_SET.wait() - yield cast(Session, SESSION) +async def current_session() -> AsyncGenerator[AsyncSession, None]: + await SESSION_CREATED.wait() + async with cast(async_sessionmaker, SESSION_MAKER)() as session: + yield session diff --git a/src/graph_energy_scan/graphql.py b/src/graph_energy_scan/graphql.py index cf92d20..76c2b11 100644 --- a/src/graph_energy_scan/graphql.py +++ b/src/graph_energy_scan/graphql.py @@ -81,4 +81,6 @@ async def energy_scans(self) -> list[EnergyScan]: stmt = select(models.EnergyScan).where( models.EnergyScan.sessionId == self.id ) - return [EnergyScan.from_model(model) for model in session.scalars(stmt)] + return [ + EnergyScan.from_model(model) for model in await session.scalars(stmt) + ]