diff --git a/API.md b/API.md index 10e8014a..65cc9fcc 100644 --- a/API.md +++ b/API.md @@ -7,6 +7,9 @@ For full API documentation, see [localhost:8100/docs](localhost:8100/docs) after If you want to see the API docs before deployment, check out the [hosted docs here](https://opengpts-example-vz4y4ooboq-uc.a.run.app/docs). +In the examples below, cookies are used as a mock auth method. For production, we recommend using JWT auth. Refer to the [auth guide for production](auth.md) for more information. +When using JWT auth, you will need to include the JWT in the `Authorization` header as a Bearer token. + ## Create an Assistant First, let's use the API to create an assistant. @@ -20,7 +23,7 @@ requests.post('http://127.0.0.1:8100/assistants', json={ "public": True }, cookies= {"opengpts_user_id": "foo"}).content ``` -This is creating an assistant with name `"bar"`, with default configuration, that is public, and is associated with user `"foo"` (we are using cookies as a mock auth method). +This is creating an assistant with name `"bar"`, with default configuration, that is public, and is associated with user `"foo"`. This should return something like: diff --git a/README.md b/README.md index 03ab8719..6a9935d1 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ Because this is open source, if you do not like those architectures or want to m - [GPTs: a simple hosted version](https://opengpts-example-vz4y4ooboq-uc.a.run.app/) - [Assistants API: a getting started guide](API.md) +- [Auth: a guide for production](auth.md) ## Quickstart with Docker diff --git a/auth.md b/auth.md new file mode 100644 index 00000000..8f58efa6 --- /dev/null +++ b/auth.md @@ -0,0 +1,51 @@ +# Auth + +By default, we're using cookies as a mock auth method. It's for trying out OpenGPTs. +For production, we recommend using JWT auth, outlined below. + +## JWT Auth: Options + +There are two ways to use JWT: Local and OIDC. The main difference is in how the key +used to decode the JWT is obtained. For the Local method, you'll provide the decode +key as a Base64-encoded string in an environment variable. For the OIDC method, the +key is obtained from the OIDC provider automatically. + +### JWT OIDC + +If you're looking to integrate with an identity provider, OIDC is the way to go. +It will figure out the decode key for you, so you don't have to worry about it. +Just set `AUTH_TYPE=jwt_oidc` along with the issuer and audience. Audience can +be one or many - just separate them with commas. + +```bash +export AUTH_TYPE=jwt_oidc +export JWT_ISS= +export JWT_AUD= # or ,,... +``` + +### JWT Local + +To use JWT Local, set `AUTH_TYPE=jwt_local`. Then, set the issuer, audience, +algorithm used to sign the JWT, and the decode key in Base64 format. + +```bash +export AUTH_TYPE=jwt_local +export JWT_ISS= +export JWT_AUD= +export JWT_ALG= # e.g. ES256 +export JWT_DECODE_KEY_B64= +``` + +Base64 is used for the decode key because handling multiline strings in environment +variables is error-prone. Base64 makes it a one-liner, easy to paste in and use. + + +## Making Requests + +To make authenticated requests, include the JWT in the `Authorization` header as a Bearer token: + +``` +Authorization: Bearer +``` + + diff --git a/backend/app/api/assistants.py b/backend/app/api/assistants.py index 1667c5f4..dda15581 100644 --- a/backend/app/api/assistants.py +++ b/backend/app/api/assistants.py @@ -5,7 +5,8 @@ from pydantic import BaseModel, Field import app.storage as storage -from app.schema import Assistant, OpengptsUserId +from app.auth.handlers import AuthedUser +from app.schema import Assistant router = APIRouter() @@ -24,9 +25,9 @@ class AssistantPayload(BaseModel): @router.get("/") -async def list_assistants(opengpts_user_id: OpengptsUserId) -> List[Assistant]: +async def list_assistants(user: AuthedUser) -> List[Assistant]: """List all assistants for the current user.""" - return await storage.list_assistants(opengpts_user_id) + return await storage.list_assistants(user["user_id"]) @router.get("/public/") @@ -43,11 +44,11 @@ async def list_public_assistants( @router.get("/{aid}") async def get_assistant( - opengpts_user_id: OpengptsUserId, + user: AuthedUser, aid: AssistantID, ) -> Assistant: """Get an assistant by ID.""" - assistant = await storage.get_assistant(opengpts_user_id, aid) + assistant = await storage.get_assistant(user["user_id"], aid) if not assistant: raise HTTPException(status_code=404, detail="Assistant not found") return assistant @@ -55,12 +56,12 @@ async def get_assistant( @router.post("") async def create_assistant( - opengpts_user_id: OpengptsUserId, + user: AuthedUser, payload: AssistantPayload, ) -> Assistant: """Create an assistant.""" return await storage.put_assistant( - opengpts_user_id, + user["user_id"], str(uuid4()), name=payload.name, config=payload.config, @@ -70,13 +71,13 @@ async def create_assistant( @router.put("/{aid}") async def upsert_assistant( - opengpts_user_id: OpengptsUserId, + user: AuthedUser, aid: AssistantID, payload: AssistantPayload, ) -> Assistant: """Create or update an assistant.""" return await storage.put_assistant( - opengpts_user_id, + user["user_id"], aid, name=payload.name, config=payload.config, diff --git a/backend/app/api/runs.py b/backend/app/api/runs.py index e6b1525b..11ab5758 100644 --- a/backend/app/api/runs.py +++ b/backend/app/api/runs.py @@ -13,7 +13,7 @@ from sse_starlette import EventSourceResponse from app.agent import agent -from app.schema import OpengptsUserId +from app.auth.handlers import AuthedUser from app.storage import get_assistant, get_thread from app.stream import astream_messages, to_sse @@ -30,14 +30,12 @@ class CreateRunPayload(BaseModel): config: Optional[RunnableConfig] = None -async def _run_input_and_config( - payload: CreateRunPayload, opengpts_user_id: OpengptsUserId -): - thread = await get_thread(opengpts_user_id, payload.thread_id) +async def _run_input_and_config(payload: CreateRunPayload, user_id: str): + thread = await get_thread(user_id, payload.thread_id) if not thread: raise HTTPException(status_code=404, detail="Thread not found") - assistant = await get_assistant(opengpts_user_id, str(thread["assistant_id"])) + assistant = await get_assistant(user_id, str(thread["assistant_id"])) if not assistant: raise HTTPException(status_code=404, detail="Assistant not found") @@ -46,7 +44,7 @@ async def _run_input_and_config( "configurable": { **assistant["config"]["configurable"], **((payload.config or {}).get("configurable") or {}), - "user_id": opengpts_user_id, + "user_id": user_id, "thread_id": str(thread["thread_id"]), "assistant_id": str(assistant["assistant_id"]), }, @@ -67,11 +65,11 @@ async def _run_input_and_config( @router.post("") async def create_run( payload: CreateRunPayload, - opengpts_user_id: OpengptsUserId, + user: AuthedUser, background_tasks: BackgroundTasks, ): """Create a run.""" - input_, config = await _run_input_and_config(payload, opengpts_user_id) + input_, config = await _run_input_and_config(payload, user["user_id"]) background_tasks.add_task(agent.ainvoke, input_, config) return {"status": "ok"} # TODO add a run id @@ -79,10 +77,10 @@ async def create_run( @router.post("/stream") async def stream_run( payload: CreateRunPayload, - opengpts_user_id: OpengptsUserId, + user: AuthedUser, ): """Create a run.""" - input_, config = await _run_input_and_config(payload, opengpts_user_id) + input_, config = await _run_input_and_config(payload, user["user_id"]) return EventSourceResponse(to_sse(astream_messages(agent, input_, config))) diff --git a/backend/app/api/threads.py b/backend/app/api/threads.py index 639cd993..646a061c 100644 --- a/backend/app/api/threads.py +++ b/backend/app/api/threads.py @@ -1,3 +1,4 @@ +import asyncio from typing import Annotated, Any, Dict, List, Sequence, Union from uuid import uuid4 @@ -6,7 +7,8 @@ from pydantic import BaseModel, Field import app.storage as storage -from app.schema import OpengptsUserId, Thread +from app.auth.handlers import AuthedUser +from app.schema import Thread router = APIRouter() @@ -28,46 +30,61 @@ class ThreadPostRequest(BaseModel): @router.get("/") -async def list_threads(opengpts_user_id: OpengptsUserId) -> List[Thread]: +async def list_threads(user: AuthedUser) -> List[Thread]: """List all threads for the current user.""" - return await storage.list_threads(opengpts_user_id) + return await storage.list_threads(user["user_id"]) @router.get("/{tid}/state") async def get_thread_state( - opengpts_user_id: OpengptsUserId, + user: AuthedUser, tid: ThreadID, ): """Get state for a thread.""" - return await storage.get_thread_state(opengpts_user_id, tid) + thread, state = await asyncio.gather( + storage.get_thread(user["user_id"], tid), + storage.get_thread_state(user["user_id"], tid), + ) + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + return state @router.post("/{tid}/state") async def add_thread_state( - opengpts_user_id: OpengptsUserId, + user: AuthedUser, tid: ThreadID, payload: ThreadPostRequest, ): """Add state to a thread.""" - return await storage.update_thread_state(opengpts_user_id, tid, payload.values) + thread = await storage.get_thread(user["user_id"], tid) + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + return await storage.update_thread_state(user["user_id"], tid, payload.values) @router.get("/{tid}/history") async def get_thread_history( - opengpts_user_id: OpengptsUserId, + user: AuthedUser, tid: ThreadID, ): """Get all past states for a thread.""" - return await storage.get_thread_history(opengpts_user_id, tid) + thread, history = await asyncio.gather( + storage.get_thread(user["user_id"], tid), + storage.get_thread_history(user["user_id"], tid), + ) + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + return history @router.get("/{tid}") async def get_thread( - opengpts_user_id: OpengptsUserId, + user: AuthedUser, tid: ThreadID, ) -> Thread: """Get a thread by ID.""" - thread = await storage.get_thread(opengpts_user_id, tid) + thread = await storage.get_thread(user["user_id"], tid) if not thread: raise HTTPException(status_code=404, detail="Thread not found") return thread @@ -75,12 +92,12 @@ async def get_thread( @router.post("") async def create_thread( - opengpts_user_id: OpengptsUserId, + user: AuthedUser, thread_put_request: ThreadPutRequest, ) -> Thread: """Create a thread.""" return await storage.put_thread( - opengpts_user_id, + user["user_id"], str(uuid4()), assistant_id=thread_put_request.assistant_id, name=thread_put_request.name, @@ -89,13 +106,13 @@ async def create_thread( @router.put("/{tid}") async def upsert_thread( - opengpts_user_id: OpengptsUserId, + user: AuthedUser, tid: ThreadID, thread_put_request: ThreadPutRequest, ) -> Thread: """Update a thread.""" return await storage.put_thread( - opengpts_user_id, + user["user_id"], tid, assistant_id=thread_put_request.assistant_id, name=thread_put_request.name, diff --git a/backend/app/auth/__init__.py b/backend/app/auth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/app/auth/handlers.py b/backend/app/auth/handlers.py new file mode 100644 index 00000000..630d45ff --- /dev/null +++ b/backend/app/auth/handlers.py @@ -0,0 +1,120 @@ +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import Annotated + +import jwt +import requests +from fastapi import Depends, HTTPException, Request +from fastapi.security.http import HTTPBearer + +import app.storage as storage +from app.auth.settings import AuthType, settings +from app.schema import User + + +class AuthHandler(ABC): + @abstractmethod + async def __call__(self, request: Request) -> User: + """Auth handler that returns a user object or raises an HTTPException.""" + + +class NOOPAuth(AuthHandler): + _default_sub = "static-default-user-id" + + async def __call__(self, request: Request) -> User: + sub = request.cookies.get("opengpts_user_id") or self._default_sub + user, _ = await storage.get_or_create_user(sub) + return user + + +class JWTAuthBase(AuthHandler): + async def __call__(self, request: Request) -> User: + http_bearer = await HTTPBearer()(request) + token = http_bearer.credentials + + try: + payload = self.decode_token(token, self.get_decode_key(token)) + except jwt.PyJWTError as e: + raise HTTPException(status_code=401, detail=str(e)) + + user, _ = await storage.get_or_create_user(payload["sub"]) + return user + + @abstractmethod + def decode_token(self, token: str, decode_key: str) -> dict: + ... + + @abstractmethod + def get_decode_key(self, token: str) -> str: + ... + + +class JWTAuthLocal(JWTAuthBase): + """Auth handler that uses a hardcoded decode key from env.""" + + def decode_token(self, token: str, decode_key: str) -> dict: + return jwt.decode( + token, + decode_key, + issuer=settings.jwt_local.iss, + audience=settings.jwt_local.aud, + algorithms=[settings.jwt_local.alg.upper()], + options={"require": ["exp", "iss", "aud", "sub"]}, + ) + + def get_decode_key(self, token: str) -> str: + return settings.jwt_local.decode_key + + +class JWTAuthOIDC(JWTAuthBase): + """Auth handler that uses OIDC discovery to get the decode key.""" + + def decode_token(self, token: str, decode_key: str) -> dict: + alg = self._decode_complete_unverified(token)["header"]["alg"] + return jwt.decode( + token, + decode_key, + issuer=settings.jwt_oidc.iss, + audience=settings.jwt_oidc.aud, + algorithms=[alg.upper()], + options={"require": ["exp", "iss", "aud", "sub"]}, + ) + + def get_decode_key(self, token: str) -> str: + unverified = self._decode_complete_unverified(token) + issuer = unverified["payload"].get("iss") + kid = unverified["header"].get("kid") + return self._get_jwk_client(issuer).get_signing_key(kid).key + + @lru_cache + def _decode_complete_unverified(self, token: str) -> dict: + return jwt.api_jwt.decode_complete(token, options={"verify_signature": False}) + + @lru_cache + def _get_jwk_client(self, issuer: str) -> jwt.PyJWKClient: + """ + lru_cache ensures a single instance of PyJWKClient per issuer. This is + so that we can take advantage of jwks caching (and invalidation) handled + by PyJWKClient. + """ + url = issuer.rstrip("/") + "/.well-known/openid-configuration" + config = requests.get(url).json() + return jwt.PyJWKClient(config["jwks_uri"], cache_jwk_set=True) + + +@lru_cache(maxsize=1) +def get_auth_handler() -> AuthHandler: + if settings.auth_type == AuthType.JWT_LOCAL: + return JWTAuthLocal() + elif settings.auth_type == AuthType.JWT_OIDC: + return JWTAuthOIDC() + return NOOPAuth() + + +async def auth_user( + request: Request, auth_handler: AuthHandler = Depends(get_auth_handler) +): + return await auth_handler(request) + + +AuthedUser = Annotated[User, Depends(auth_user)] diff --git a/backend/app/auth/settings.py b/backend/app/auth/settings.py new file mode 100644 index 00000000..f41a260d --- /dev/null +++ b/backend/app/auth/settings.py @@ -0,0 +1,71 @@ +import os +from base64 import b64decode +from enum import Enum +from typing import Optional, Union + +from pydantic import BaseSettings, root_validator, validator + + +class AuthType(Enum): + NOOP = "noop" + JWT_LOCAL = "jwt_local" + JWT_OIDC = "jwt_oidc" + + +class JWTSettingsBase(BaseSettings): + iss: str + aud: Union[str, list[str]] + + @validator("aud", pre=True, always=True) + def set_aud(cls, v, values) -> Union[str, list[str]]: + return v.split(",") if "," in v else v + + class Config: + env_prefix = "jwt_" + + +class JWTSettingsLocal(JWTSettingsBase): + decode_key_b64: str + decode_key: str = None + alg: str + + @validator("decode_key", pre=True, always=True) + def set_decode_key(cls, v, values): + """ + Key may be a multiline string (e.g. in the case of a public key), so to + be able to set it from env, we set it as a base64 encoded string and + decode it here. + """ + return b64decode(values["decode_key_b64"]).decode("utf-8") + + +class JWTSettingsOIDC(JWTSettingsBase): + ... + + +class Settings(BaseSettings): + auth_type: AuthType + jwt_local: Optional[JWTSettingsLocal] = None + jwt_oidc: Optional[JWTSettingsOIDC] = None + + @root_validator(pre=True) + def check_jwt_settings(cls, values): + auth_type = values.get("auth_type") + if auth_type == AuthType.JWT_LOCAL and values.get("jwt_local") is None: + raise ValueError( + "jwt local settings must be set when auth type is jwt_local." + ) + if auth_type == AuthType.JWT_OIDC and values.get("jwt_oidc") is None: + raise ValueError( + "jwt oidc settings must be set when auth type is jwt_oidc." + ) + return values + + +auth_type = AuthType(os.getenv("AUTH_TYPE", AuthType.NOOP.value).lower()) +kwargs = {"auth_type": auth_type} +if auth_type == AuthType.JWT_LOCAL: + kwargs["jwt_local"] = JWTSettingsLocal() +elif auth_type == AuthType.JWT_OIDC: + kwargs["jwt_oidc"] = JWTSettingsOIDC() +settings = Settings(**kwargs) diff --git a/backend/app/lifespan.py b/backend/app/lifespan.py index 1b685671..8e15f139 100644 --- a/backend/app/lifespan.py +++ b/backend/app/lifespan.py @@ -19,6 +19,9 @@ async def _init_connection(conn) -> None: decoder=orjson.loads, schema="pg_catalog", ) + await conn.set_type_codec( + "uuid", encoder=lambda v: str(v), decoder=lambda v: v, schema="pg_catalog" + ) @asynccontextmanager diff --git a/backend/app/schema.py b/backend/app/schema.py index 0c0e5923..4b5153a3 100644 --- a/backend/app/schema.py +++ b/backend/app/schema.py @@ -1,15 +1,22 @@ from datetime import datetime -from typing import Annotated, Optional -from uuid import UUID +from typing import Optional -from fastapi import Cookie from typing_extensions import TypedDict +class User(TypedDict): + user_id: str + """The ID of the user.""" + sub: str + """The sub of the user (from a JWT token).""" + created_at: datetime + """The time the user was created.""" + + class Assistant(TypedDict): """Assistant model.""" - assistant_id: UUID + assistant_id: str """The ID of the assistant.""" user_id: str """The ID of the user that owns the assistant.""" @@ -24,25 +31,13 @@ class Assistant(TypedDict): class Thread(TypedDict): - thread_id: UUID + thread_id: str """The ID of the thread.""" user_id: str """The ID of the user that owns the thread.""" - assistant_id: Optional[UUID] + assistant_id: Optional[str] """The assistant that was used in conjunction with this thread.""" name: str """The name of the thread.""" updated_at: datetime """The last time the thread was updated.""" - - -OpengptsUserId = Annotated[ - str, - Cookie( - description=( - "A cookie that identifies the user. This is not an authentication " - "mechanism that should be used in an actual production environment that " - "contains sensitive information." - ) - ), -] diff --git a/backend/app/server.py b/backend/app/server.py index 554af070..a8978da4 100644 --- a/backend/app/server.py +++ b/backend/app/server.py @@ -4,9 +4,12 @@ import orjson from fastapi import FastAPI, Form, UploadFile +from fastapi.exceptions import HTTPException from fastapi.staticfiles import StaticFiles +import app.storage as storage from app.api import router as api_router +from app.auth.handlers import AuthedUser from app.lifespan import lifespan from app.upload import ingest_runnable @@ -23,9 +26,24 @@ @app.post("/ingest", description="Upload files to the given assistant.") -def ingest_files(files: list[UploadFile], config: str = Form(...)) -> None: +async def ingest_files( + files: list[UploadFile], user: AuthedUser, config: str = Form(...) +) -> None: """Ingest a list of files.""" config = orjson.loads(config) + + assistant_id = config["configurable"].get("assistant_id") + if assistant_id is not None: + assistant = await storage.get_assistant(user["user_id"], assistant_id) + if assistant is None: + raise HTTPException(status_code=404, detail="Assistant not found.") + + thread_id = config["configurable"].get("thread_id") + if thread_id is not None: + thread = await storage.get_thread(user["user_id"], thread_id) + if thread is None: + raise HTTPException(status_code=404, detail="Thread not found.") + return ingest_runnable.batch([file.file for file in files], config) @@ -39,7 +57,7 @@ async def health() -> dict: if os.path.exists(ui_dir): app.mount("", StaticFiles(directory=ui_dir, html=True), name="ui") else: - logger.warning("No UI directory found, serving API only.") + logger.warn("No UI directory found, serving API only.") if __name__ == "__main__": import uvicorn diff --git a/backend/app/storage.py b/backend/app/storage.py index 3c694aed..5fa86a86 100644 --- a/backend/app/storage.py +++ b/backend/app/storage.py @@ -5,7 +5,7 @@ from app.agent import AgentType, get_agent_executor from app.lifespan import get_pg_pool -from app.schema import Assistant, Thread +from app.schema import Assistant, Thread, User async def list_assistants(user_id: str) -> List[Assistant]: @@ -160,3 +160,14 @@ async def put_thread( "name": name, "updated_at": updated_at, } + + +async def get_or_create_user(sub: str) -> tuple[User, bool]: + """Returns a tuple of the user and a boolean indicating whether the user was created.""" + async with get_pg_pool().acquire() as conn: + if user := await conn.fetchrow('SELECT * FROM "user" WHERE sub = $1', sub): + return user, False + user = await conn.fetchrow( + 'INSERT INTO "user" (sub) VALUES ($1) RETURNING *', sub + ) + return user, True diff --git a/backend/migrations/000003_create_user.down.sql b/backend/migrations/000003_create_user.down.sql new file mode 100644 index 00000000..66c5acad --- /dev/null +++ b/backend/migrations/000003_create_user.down.sql @@ -0,0 +1,9 @@ +ALTER TABLE assistant + DROP CONSTRAINT fk_assistant_user_id, + ALTER COLUMN user_id TYPE VARCHAR USING (user_id::text); + +ALTER TABLE thread + DROP CONSTRAINT fk_thread_user_id, + ALTER COLUMN user_id TYPE VARCHAR USING (user_id::text); + +DROP TABLE IF EXISTS "user"; \ No newline at end of file diff --git a/backend/migrations/000003_create_user.up.sql b/backend/migrations/000003_create_user.up.sql new file mode 100644 index 00000000..45612f9c --- /dev/null +++ b/backend/migrations/000003_create_user.up.sql @@ -0,0 +1,25 @@ +CREATE TABLE IF NOT EXISTS "user" ( + user_id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + sub VARCHAR(255) UNIQUE NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT (CURRENT_TIMESTAMP AT TIME ZONE 'UTC') +); + +INSERT INTO "user" (user_id, sub) +SELECT DISTINCT user_id::uuid, user_id +FROM assistant +WHERE user_id IS NOT NULL +ON CONFLICT (user_id) DO NOTHING; + +INSERT INTO "user" (user_id, sub) +SELECT DISTINCT user_id::uuid, user_id +FROM thread +WHERE user_id IS NOT NULL +ON CONFLICT (user_id) DO NOTHING; + +ALTER TABLE assistant + ALTER COLUMN user_id TYPE UUID USING (user_id::UUID), + ADD CONSTRAINT fk_assistant_user_id FOREIGN KEY (user_id) REFERENCES "user"(user_id); + +ALTER TABLE thread + ALTER COLUMN user_id TYPE UUID USING (user_id::UUID), + ADD CONSTRAINT fk_thread_user_id FOREIGN KEY (user_id) REFERENCES "user"(user_id); \ No newline at end of file diff --git a/backend/poetry.lock b/backend/poetry.lock index 58ebd11e..3fc44196 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -4168,4 +4168,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.9.0,<3.12" -content-hash = "4147c6ecd1944b56165dfa51c86bd4f05b1074a4cf88cbb79d9b4db9a1fdc3f7" +content-hash = "027126e2f06254070ba7fbcc6244eb5eb08eaf7caa4130f773abadb24ec584db" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index d99ebbb0..3f763819 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -43,6 +43,7 @@ pgvector = "^0.2.5" psycopg2-binary = "^2.9.9" asyncpg = "^0.29.0" langchain-core = "^0.1.39" +pyjwt = {extras = ["crypto"], version = "^2.8.0"} [tool.poetry.group.dev.dependencies] uvicorn = "^0.23.2" diff --git a/backend/tests/unit_tests/app/helpers.py b/backend/tests/unit_tests/app/helpers.py new file mode 100644 index 00000000..1eb46ff5 --- /dev/null +++ b/backend/tests/unit_tests/app/helpers.py @@ -0,0 +1,13 @@ +from contextlib import asynccontextmanager + +from httpx import AsyncClient +from typing_extensions import AsyncGenerator + + +@asynccontextmanager +async def get_client() -> AsyncGenerator[AsyncClient, None]: + """Get the app.""" + from app.server import app + + async with AsyncClient(app=app, base_url="http://test") as ac: + yield ac diff --git a/backend/tests/unit_tests/app/test_app.py b/backend/tests/unit_tests/app/test_app.py index 3979c1e5..ba54f280 100644 --- a/backend/tests/unit_tests/app/test_app.py +++ b/backend/tests/unit_tests/app/test_app.py @@ -1,21 +1,11 @@ """Test the server and client together.""" -from contextlib import asynccontextmanager from typing import Optional, Sequence from uuid import uuid4 import asyncpg -from httpx import AsyncClient -from typing_extensions import AsyncGenerator - -@asynccontextmanager -async def get_client() -> AsyncGenerator[AsyncClient, None]: - """Get the app.""" - from app.server import app - - async with AsyncClient(app=app, base_url="http://test") as ac: - yield ac +from tests.unit_tests.app.helpers import get_client def _project(d: dict, *, exclude_keys: Optional[Sequence[str]]) -> dict: @@ -48,24 +38,24 @@ async def test_list_and_create_assistants(pool: asyncpg.pool.Pool) -> None: headers=headers, ) assert response.status_code == 200 - assert _project(response.json(), exclude_keys=["updated_at"]) == { + assert _project(response.json(), exclude_keys=["updated_at", "user_id"]) == { "assistant_id": aid, "config": {}, "name": "bobby", "public": False, - "user_id": "1", } async with pool.acquire() as conn: assert len(await conn.fetch("SELECT * FROM assistant;")) == 1 response = await client.get("/assistants/", headers=headers) - assert [_project(d, exclude_keys=["updated_at"]) for d in response.json()] == [ + assert [ + _project(d, exclude_keys=["updated_at", "user_id"]) for d in response.json() + ] == [ { "assistant_id": aid, "config": {}, "name": "bobby", "public": False, - "user_id": "1", } ] @@ -75,12 +65,11 @@ async def test_list_and_create_assistants(pool: asyncpg.pool.Pool) -> None: headers=headers, ) - assert _project(response.json(), exclude_keys=["updated_at"]) == { + assert _project(response.json(), exclude_keys=["updated_at", "user_id"]) == { "assistant_id": aid, "config": {}, "name": "bobby", "public": False, - "user_id": "1", } # Check not visible to other users @@ -117,29 +106,12 @@ async def test_threads() -> None: response = await client.get("/threads/", headers=headers) assert response.status_code == 200 - assert [_project(d, exclude_keys=["updated_at"]) for d in response.json()] == [ - { - "assistant_id": aid, - "name": "bobby", - "thread_id": tid, - "user_id": "1", - } - ] - - # Test a bad requests - response = await client.put( - f"/threads/{tid}", - json={"name": "bobby", "assistant_id": aid}, - ) - assert response.status_code == 422 + assert [ + _project(d, exclude_keys=["updated_at", "user_id"]) for d in response.json() + ] == [{"assistant_id": aid, "name": "bobby", "thread_id": tid}] response = await client.put( f"/threads/{tid}", headers={"Cookie": "opengpts_user_id=2"}, ) assert response.status_code == 422 - - response = await client.get( - "/threads/", - ) - assert response.status_code == 422 diff --git a/backend/tests/unit_tests/app/test_auth.py b/backend/tests/unit_tests/app/test_auth.py new file mode 100644 index 00000000..a3d0e449 --- /dev/null +++ b/backend/tests/unit_tests/app/test_auth.py @@ -0,0 +1,108 @@ +from base64 import b64encode +from datetime import datetime, timedelta, timezone +from typing import Optional +from unittest.mock import MagicMock, patch + +import jwt + +from app.auth.handlers import AuthedUser, get_auth_handler +from app.auth.settings import ( + AuthType, + JWTSettingsLocal, + JWTSettingsOIDC, +) +from app.auth.settings import ( + settings as auth_settings, +) +from app.server import app +from tests.unit_tests.app.helpers import get_client + + +@app.get("/me") +async def me(user: AuthedUser) -> dict: + return user + + +def _create_jwt( + key: str, alg: str, payload: dict, headers: Optional[dict] = None +) -> str: + return jwt.encode(payload, key, algorithm=alg, headers=headers) + + +async def test_noop(): + get_auth_handler.cache_clear() + auth_settings.auth_type = AuthType.NOOP + sub = "user_noop" + + async with get_client() as client: + response = await client.get("/me", cookies={"opengpts_user_id": sub}) + assert response.status_code == 200 + assert response.json()["sub"] == sub + + +async def test_jwt_local(): + get_auth_handler.cache_clear() + auth_settings.auth_type = AuthType.JWT_LOCAL + key = "key" + auth_settings.jwt_local = JWTSettingsLocal( + alg="HS256", + iss="issuer", + aud="audience", + decode_key_b64=b64encode(key.encode("utf-8")), + ) + sub = "user_jwt_local" + + token = _create_jwt( + key=key, + alg=auth_settings.jwt_local.alg, + payload={ + "sub": sub, + "iss": auth_settings.jwt_local.iss, + "aud": auth_settings.jwt_local.aud, + "exp": datetime.now(timezone.utc) + timedelta(days=1), + }, + ) + + async with get_client() as client: + response = await client.get("/me", headers={"Authorization": f"Bearer {token}"}) + assert response.status_code == 200 + assert response.json()["sub"] == sub + + # Test invalid token + async with get_client() as client: + response = await client.get("/me", headers={"Authorization": "Bearer xyz"}) + assert response.status_code == 401 + + +async def test_jwt_oidc(): + get_auth_handler.cache_clear() + auth_settings.auth_type = AuthType.JWT_OIDC + auth_settings.jwt_oidc = JWTSettingsOIDC(iss="issuer", aud="audience") + sub = "user_jwt_oidc" + key = "key" + alg = "HS256" + + token = _create_jwt( + key=key, + alg=alg, + payload={ + "sub": sub, + "iss": auth_settings.jwt_oidc.iss, + "aud": auth_settings.jwt_oidc.aud, + "exp": datetime.now(timezone.utc) + timedelta(days=1), + }, + headers={"kid": "kid", "alg": alg}, + ) + + mock_jwk_client = MagicMock() + mock_jwk_client.get_signing_key.return_value = MagicMock(key=key) + + with patch( + "app.auth.handlers.JWTAuthOIDC._get_jwk_client", return_value=mock_jwk_client + ): + async with get_client() as client: + response = await client.get( + "/me", headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code == 200 + assert response.json()["sub"] == sub diff --git a/backend/tests/unit_tests/conftest.py b/backend/tests/unit_tests/conftest.py index 1520a429..4d21da0d 100644 --- a/backend/tests/unit_tests/conftest.py +++ b/backend/tests/unit_tests/conftest.py @@ -5,9 +5,13 @@ import asyncpg import pytest +from app.auth.settings import AuthType +from app.auth.settings import settings as auth_settings from app.lifespan import get_pg_pool, lifespan from app.server import app +auth_settings.auth_type = AuthType.NOOP + # Temporary handling of environment variables for testing os.environ["OPENAI_API_KEY"] = "test" diff --git a/frontend/src/main.tsx b/frontend/src/main.tsx index 4a41abeb..f635343d 100644 --- a/frontend/src/main.tsx +++ b/frontend/src/main.tsx @@ -7,10 +7,27 @@ import { StrictMode } from "react"; import { QueryClient, QueryClientProvider } from "react-query"; import { NotFound } from "./components/NotFound.tsx"; -if (document.cookie.indexOf("user_id") === -1) { - document.cookie = `opengpts_user_id=${uuidv4()}; path=/; SameSite=Lax`; +function getCookie(name: string) { + const cookie = document.cookie + .split("; ") + .find((row) => row.startsWith(`${name}=`)); + return cookie ? cookie.split("=")[1] : null; } +document.addEventListener("DOMContentLoaded", () => { + const userId = + localStorage.getItem("opengpts_user_id") || + getCookie("opengpts_user_id") || + uuidv4(); + + // Push the user id to localStorage in any case to make it stable + localStorage.setItem("opengpts_user_id", userId); + // Ensure the cookie is always set (for both new and returning users) + const weekInMilliseconds = 7 * 24 * 60 * 60 * 1000; + const expires = new Date(Date.now() + weekInMilliseconds).toUTCString(); + document.cookie = `opengpts_user_id=${userId}; path=/; expires=${expires}; SameSite=Lax;`; +}); + const queryClient = new QueryClient(); ReactDOM.createRoot(document.getElementById("root")!).render(