From 9cd55706c45b1cee13582d67e3f7e9d979b2bf0a Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Sun, 12 Nov 2023 16:53:20 -0500 Subject: [PATCH] Add some validation to endpoints (#35) Add a bit more validation to endpoints --- backend/app/server.py | 45 +++++++++++++++++------- backend/tests/unit_tests/app/test_app.py | 19 ++++++++++ 2 files changed, 51 insertions(+), 13 deletions(-) diff --git a/backend/app/server.py b/backend/app/server.py index 5cdfde69..b171b0f9 100644 --- a/backend/app/server.py +++ b/backend/app/server.py @@ -1,12 +1,13 @@ from pathlib import Path -from typing import Optional +from typing import Annotated, Optional import orjson -from fastapi import FastAPI, Form, Request, UploadFile +from fastapi import Cookie, FastAPI, Form, Request, UploadFile from fastapi.staticfiles import StaticFiles from gizmo_agent import agent, ingest_runnable from langchain.schema.runnable import RunnableConfig from langserve import add_routes +from typing_extensions import TypedDict from app.storage import ( get_thread_messages, @@ -26,7 +27,8 @@ def attach_user_id_to_config( - config: RunnableConfig, request: Request + config: RunnableConfig, + request: Request, ) -> RunnableConfig: config["configurable"]["user_id"] = request.cookies["opengpts_user_id"] return config @@ -48,9 +50,9 @@ def ingest_endpoint(files: list[UploadFile], config: str = Form(...)): @app.get("/assistants/") -def list_assistants_endpoint(req: Request): +def list_assistants_endpoint(opengpts_user_id: Annotated[str, Cookie()]): """List all assistants for the current user.""" - return list_assistants(req.cookies["opengpts_user_id"]) + return list_assistants(opengpts_user_id) @app.get("/assistants/public/") @@ -60,10 +62,20 @@ def list_public_assistants_endpoint(shared_id: Optional[str] = None): ) +class AssistantPayload(TypedDict): + name: str + config: dict + public: bool + + @app.put("/assistants/{aid}") -def put_assistant_endpoint(req: Request, aid: str, payload: dict): +def put_assistant_endpoint( + aid: str, + payload: AssistantPayload, + opengpts_user_id: Annotated[str, Cookie()], +): return put_assistant( - req.cookies["opengpts_user_id"], + opengpts_user_id, aid, name=payload["name"], config=payload["config"], @@ -72,19 +84,26 @@ def put_assistant_endpoint(req: Request, aid: str, payload: dict): @app.get("/threads/") -def list_threads_endpoint(req: Request): - return list_threads(req.cookies["opengpts_user_id"]) +def list_threads_endpoint(opengpts_user_id: Annotated[str, Cookie()]): + return list_threads(opengpts_user_id) @app.get("/threads/{tid}/messages") -def get_thread_messages_endpoint(req: Request, tid: str): - return get_thread_messages(req.cookies["opengpts_user_id"], tid) +def get_thread_messages_endpoint(opengpts_user_id: Annotated[str, Cookie()], tid: str): + return get_thread_messages(opengpts_user_id, tid) + + +class ThreadPayload(TypedDict): + name: str + assistant_id: str @app.put("/threads/{tid}") -def put_thread_endpoint(req: Request, tid: str, payload: dict): +def put_thread_endpoint( + opengpts_user_id: Annotated[str, Cookie()], tid: str, payload: ThreadPayload +): return put_thread( - req.cookies["opengpts_user_id"], + opengpts_user_id, tid, assistant_id=payload["assistant_id"], name=payload["name"], diff --git a/backend/tests/unit_tests/app/test_app.py b/backend/tests/unit_tests/app/test_app.py index 3619dfd2..c38d156f 100644 --- a/backend/tests/unit_tests/app/test_app.py +++ b/backend/tests/unit_tests/app/test_app.py @@ -60,6 +60,7 @@ async def test_list_and_create_assistants(redis_client: RedisType) -> None: headers=headers, ) assert response.status_code == 200 + assert response.json() == [] # Create an assistant @@ -140,3 +141,21 @@ async def test_threads(redis_client: RedisType) -> None: "thread_id": "1", } ] + + # Test a bad requests + response = await client.put( + "/threads/1", + json={"name": "bobby", "assistant_id": "bobby"}, + ) + assert response.status_code == 422 + + response = await client.put( + "/threads/1", + headers={"Cookie": "opengpts_user_id=2"}, + ) + assert response.status_code == 422 + + response = await client.get( + "/threads/", + ) + assert response.status_code == 422