Skip to content

Commit

Permalink
Add some validation to endpoints (#35)
Browse files Browse the repository at this point in the history
Add a bit more validation to endpoints
  • Loading branch information
eyurtsev authored Nov 12, 2023
1 parent 84b3fe5 commit 9cd5570
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 13 deletions.
45 changes: 32 additions & 13 deletions backend/app/server.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from pathlib import Path
from typing import Optional
from typing import Annotated, Optional

import orjson
from fastapi import FastAPI, Form, Request, UploadFile
from fastapi import Cookie, FastAPI, Form, Request, UploadFile
from fastapi.staticfiles import StaticFiles
from gizmo_agent import agent, ingest_runnable
from langchain.schema.runnable import RunnableConfig
from langserve import add_routes
from typing_extensions import TypedDict

from app.storage import (
get_thread_messages,
Expand All @@ -26,7 +27,8 @@


def attach_user_id_to_config(
config: RunnableConfig, request: Request
config: RunnableConfig,
request: Request,
) -> RunnableConfig:
config["configurable"]["user_id"] = request.cookies["opengpts_user_id"]
return config
Expand All @@ -48,9 +50,9 @@ def ingest_endpoint(files: list[UploadFile], config: str = Form(...)):


@app.get("/assistants/")
def list_assistants_endpoint(req: Request):
def list_assistants_endpoint(opengpts_user_id: Annotated[str, Cookie()]):
"""List all assistants for the current user."""
return list_assistants(req.cookies["opengpts_user_id"])
return list_assistants(opengpts_user_id)


@app.get("/assistants/public/")
Expand All @@ -60,10 +62,20 @@ def list_public_assistants_endpoint(shared_id: Optional[str] = None):
)


class AssistantPayload(TypedDict):
name: str
config: dict
public: bool


@app.put("/assistants/{aid}")
def put_assistant_endpoint(req: Request, aid: str, payload: dict):
def put_assistant_endpoint(
aid: str,
payload: AssistantPayload,
opengpts_user_id: Annotated[str, Cookie()],
):
return put_assistant(
req.cookies["opengpts_user_id"],
opengpts_user_id,
aid,
name=payload["name"],
config=payload["config"],
Expand All @@ -72,19 +84,26 @@ def put_assistant_endpoint(req: Request, aid: str, payload: dict):


@app.get("/threads/")
def list_threads_endpoint(req: Request):
return list_threads(req.cookies["opengpts_user_id"])
def list_threads_endpoint(opengpts_user_id: Annotated[str, Cookie()]):
return list_threads(opengpts_user_id)


@app.get("/threads/{tid}/messages")
def get_thread_messages_endpoint(req: Request, tid: str):
return get_thread_messages(req.cookies["opengpts_user_id"], tid)
def get_thread_messages_endpoint(opengpts_user_id: Annotated[str, Cookie()], tid: str):
return get_thread_messages(opengpts_user_id, tid)


class ThreadPayload(TypedDict):
name: str
assistant_id: str


@app.put("/threads/{tid}")
def put_thread_endpoint(req: Request, tid: str, payload: dict):
def put_thread_endpoint(
opengpts_user_id: Annotated[str, Cookie()], tid: str, payload: ThreadPayload
):
return put_thread(
req.cookies["opengpts_user_id"],
opengpts_user_id,
tid,
assistant_id=payload["assistant_id"],
name=payload["name"],
Expand Down
19 changes: 19 additions & 0 deletions backend/tests/unit_tests/app/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ async def test_list_and_create_assistants(redis_client: RedisType) -> None:
headers=headers,
)
assert response.status_code == 200

assert response.json() == []

# Create an assistant
Expand Down Expand Up @@ -140,3 +141,21 @@ async def test_threads(redis_client: RedisType) -> None:
"thread_id": "1",
}
]

# Test a bad requests
response = await client.put(
"/threads/1",
json={"name": "bobby", "assistant_id": "bobby"},
)
assert response.status_code == 422

response = await client.put(
"/threads/1",
headers={"Cookie": "opengpts_user_id=2"},
)
assert response.status_code == 422

response = await client.get(
"/threads/",
)
assert response.status_code == 422

0 comments on commit 9cd5570

Please sign in to comment.