From 4053d55c8b8979720ea72a942f8b602d54334dd3 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Sat, 11 Nov 2023 21:40:53 -0500 Subject: [PATCH] x --- backend/app/server.py | 7 +- backend/tests/unit_tests/app/__init__.py | 0 backend/tests/unit_tests/app/test_app.py | 99 ++++++++++++++++++++++++ backend/tests/unit_tests/conftest.py | 6 ++ 4 files changed, 111 insertions(+), 1 deletion(-) create mode 100644 backend/tests/unit_tests/app/__init__.py create mode 100644 backend/tests/unit_tests/app/test_app.py create mode 100644 backend/tests/unit_tests/conftest.py diff --git a/backend/app/server.py b/backend/app/server.py index dd7411f1..60b55595 100644 --- a/backend/app/server.py +++ b/backend/app/server.py @@ -6,6 +6,7 @@ from gizmo_agent import agent, ingest_runnable from langchain.schema.runnable import RunnableConfig from langserve import add_routes +from pathlib import Path from app.storage import ( get_thread_messages, @@ -20,6 +21,9 @@ FEATURED_PUBLIC_ASSISTANTS = [] +# Get root of app, used to point to directory containing static files +ROOT = Path(__file__).parent.parent + def attach_user_id_to_config( config: RunnableConfig, request: Request @@ -44,6 +48,7 @@ def ingest_endpoint(files: list[UploadFile], config: str = Form(...)): @app.get("/assistants/") def list_assistants_endpoint(req: Request): + """List all assistants for the current user.""" return list_assistants(req.cookies["opengpts_user_id"]) @@ -85,7 +90,7 @@ def put_thread_endpoint(req: Request, tid: str, payload: dict): ) -app.mount("", StaticFiles(directory="ui", html=True), name="ui") +app.mount("", StaticFiles(directory=str(ROOT / "ui"), html=True), name="ui") if __name__ == "__main__": import uvicorn diff --git a/backend/tests/unit_tests/app/__init__.py b/backend/tests/unit_tests/app/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/tests/unit_tests/app/test_app.py b/backend/tests/unit_tests/app/test_app.py new file mode 100644 index 00000000..a787b0a2 --- /dev/null +++ b/backend/tests/unit_tests/app/test_app.py @@ -0,0 +1,99 @@ +"""Test the server and client together.""" + +import os +from contextlib import asynccontextmanager + +import pytest +from httpx import AsyncClient +from langchain.utilities.redis import get_client as _get_redis_client +from redis.client import Redis as RedisType +from typing_extensions import AsyncGenerator + + +@asynccontextmanager +async def get_client() -> AsyncGenerator[AsyncClient, None]: + """Get the app.""" + from app.server import app + + async with AsyncClient(app=app, base_url="http://test") as ac: + yield ac + + +@pytest.fixture(scope="function") +def redis_client() -> RedisType: + """Get a redis client -- and clear it before the test!""" + redis_url = os.environ.get("REDIS_URL") + if "localhost" not in redis_url: + raise ValueError( + "This test is only intended to be run against a local redis instance" + ) + + if not redis_url.endswith("/3"): + raise ValueError( + "This test is only intended to be run against a local redis instance. " + "For testing purposes this is expected to be database #3 (arbitrary)." + ) + + client = _get_redis_client(redis_url) + client.flushdb() + try: + yield client + finally: + client.close() + + +@pytest.mark.asyncio +async def test_list_and_create_assistants(redis_client: RedisType) -> None: + """Test list and create assistants.""" + headers = {"Cookie": "opengpts_user_id=1"} + assert sorted(redis_client.keys()) == [] + async with get_client() as client: + response = await client.get( + "/assistants/", + headers=headers, + ) + assert response.status_code == 200 + assert response.json() == [] + + response = await client.put( + "/assistants/bobby", + json={"name": "bobby", "config": {}, "public": False}, + headers=headers, + ) + assert response.status_code == 200 + json_response = response.json() + assert "updated_at" in json_response + del json_response["updated_at"] + + assert json_response == { + "assistant_id": "bobby", + "config": {}, + "name": "bobby", + "public": False, + "user_id": "1", + } + assert sorted(redis_client.keys()) == [ + b"opengpts:1:assistant:bobby", + b"opengpts:1:assistants", + ] + + assistant_info = redis_client.hgetall("opengpts:1:assistant:bobby") + del assistant_info[b"updated_at"] + assert assistant_info == { + b"assistant_id": b'"bobby"', + b"config": b"{}", + b"name": b'"bobby"', + b"public": b"false", + b"user_id": b'"1"', + } + + +@pytest.mark.asyncio +async def test_list_threads() -> None: + """Test listing threads.""" + async with get_client() as client: + response = await client.get( + "/threads/", headers={"Cookie": "opengpts_user_id=1"} + ) + assert response.status_code == 200 + assert response.json() == [] diff --git a/backend/tests/unit_tests/conftest.py b/backend/tests/unit_tests/conftest.py new file mode 100644 index 00000000..169452f9 --- /dev/null +++ b/backend/tests/unit_tests/conftest.py @@ -0,0 +1,6 @@ +import os + +# Temporary handling of environment variables for testing +os.environ["REDIS_URL"] = "redis://localhost:6379/3" +os.environ["OPENAI_API_KEY"] = "test" +os.environ["YDC_API_KEY"] = "test"