Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce auth #287

Merged
merged 33 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
1fd391a
Revert changes to .py files from commit #264
mkorpela Apr 4, 2024
29726de
Merge branch 'langchain-ai:main' into main
bakar-io Apr 5, 2024
22042fd
Merge branch 'langchain-ai:main' into main
bakar-io Apr 7, 2024
3453ed4
Merge branch 'langchain-ai:main' into main
bakar-io Apr 9, 2024
f997382
Merge branch 'langchain-ai:main' into main
bakar-io Apr 9, 2024
ffcbd99
Convert UUIDs to str.
bakar-io Apr 7, 2024
6202984
Create user table.
bakar-io Apr 7, 2024
aa0d193
Add pyjwt dependency.
bakar-io Apr 7, 2024
b4d23f4
Add auth settings.
bakar-io Apr 8, 2024
8766b1d
Add auth handlers.
bakar-io Apr 8, 2024
b2eee3d
Require auth on main endpoints.
bakar-io Apr 8, 2024
68471ff
Require auth on the ingest endpoint.
bakar-io Apr 8, 2024
f8bccb0
Refactor tests to account for new auth mechanism.
bakar-io Apr 8, 2024
918624f
Retain data between up and down migrations.
bakar-io Apr 8, 2024
909f9a6
Minor fix.
bakar-io Apr 8, 2024
fc4a52b
NOOPAuth: more persistent and backwards compatible
mkorpela Apr 8, 2024
49c2296
Replace hardcoded algorithms when decoding OIDC JWTs.
bakar-io Apr 8, 2024
f35de5b
Enable indicating a list of audiences for JWT auth.
bakar-io Apr 8, 2024
81fad80
Add auth guide and add references to it in other guides.
bakar-io Apr 8, 2024
dac82d4
Simplify a storage method (get_or_create_user).
bakar-io Apr 9, 2024
efa8e94
Minor naming change.
bakar-io Apr 9, 2024
d83dad9
poetry lock --no-update
bakar-io Apr 9, 2024
f6e0f73
Update documentation on auth.
bakar-io Apr 9, 2024
238267d
Make alg configurable for JWTAuthLocal.
bakar-io Apr 9, 2024
6f38105
Minor change in readme.
bakar-io Apr 10, 2024
6712f38
Minor spacing fix in migration files.
bakar-io Apr 10, 2024
2ebb27c
langchain.document_loaders => langchain_community.document_loaders
mkorpela Apr 10, 2024
6e30450
Explain how to make authenticated requests when using JWT auth.
bakar-io Apr 10, 2024
d716ba4
Add auth tests.
bakar-io Apr 10, 2024
8f0c5fe
Add created_at column to the user table.
bakar-io Apr 11, 2024
74add18
Check thread ownership more efficiently.
bakar-io Apr 12, 2024
30ad78c
Merge remote-tracking branch 'upstream/main' into introduce-auth
bakar-io Apr 12, 2024
8e1a151
Minor fix - return docstring.
bakar-io Apr 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion API.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ For full API documentation, see [localhost:8100/docs](localhost:8100/docs) after

If you want to see the API docs before deployment, check out the [hosted docs here](https://opengpts-example-vz4y4ooboq-uc.a.run.app/docs).

In the examples below, cookies are used as a mock auth method. For production, we recommend using JWT auth. Refer to the [auth guide for production](auth.md) for more information.
When using JWT auth, you will need to include the JWT in the `Authorization` header as a Bearer token.

## Create an Assistant

First, let's use the API to create an assistant.
Expand All @@ -20,7 +23,7 @@ requests.post('http://127.0.0.1:8100/assistants', json={
"public": True
}, cookies= {"opengpts_user_id": "foo"}).content
```
This is creating an assistant with name `"bar"`, with default configuration, that is public, and is associated with user `"foo"` (we are using cookies as a mock auth method).
This is creating an assistant with name `"bar"`, with default configuration, that is public, and is associated with user `"foo"`.

This should return something like:

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Because this is open source, if you do not like those architectures or want to m

- [GPTs: a simple hosted version](https://opengpts-example-vz4y4ooboq-uc.a.run.app/)
- [Assistants API: a getting started guide](API.md)
- [Auth: a guide for production](auth.md)

## Quickstart with Docker

Expand Down
51 changes: 51 additions & 0 deletions auth.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Auth

By default, we're using cookies as a mock auth method. It's for trying out OpenGPTs.
For production, we recommend using JWT auth, outlined below.

## JWT Auth: Options

There are two ways to use JWT: Local and OIDC. The main difference is in how the key
used to decode the JWT is obtained. For the Local method, you'll provide the decode
key as a Base64-encoded string in an environment variable. For the OIDC method, the
key is obtained from the OIDC provider automatically.

### JWT OIDC

If you're looking to integrate with an identity provider, OIDC is the way to go.
It will figure out the decode key for you, so you don't have to worry about it.
Just set `AUTH_TYPE=jwt_oidc` along with the issuer and audience. Audience can
be one or many - just separate them with commas.

```bash
export AUTH_TYPE=jwt_oidc
export JWT_ISS=<issuer>
bakar-io marked this conversation as resolved.
Show resolved Hide resolved
export JWT_AUD=<audience> # or <audience1>,<audience2>,...
```

### JWT Local

To use JWT Local, set `AUTH_TYPE=jwt_local`. Then, set the issuer, audience,
algorithm used to sign the JWT, and the decode key in Base64 format.

```bash
export AUTH_TYPE=jwt_local
export JWT_ISS=<issuer>
export JWT_AUD=<audience>
export JWT_ALG=<algorithm> # e.g. ES256
export JWT_DECODE_KEY_B64=<base64_decode_key>
```

Base64 is used for the decode key because handling multiline strings in environment
variables is error-prone. Base64 makes it a one-liner, easy to paste in and use.


## Making Requests

To make authenticated requests, include the JWT in the `Authorization` header as a Bearer token:

```
Authorization: Bearer <JWT>
```


19 changes: 10 additions & 9 deletions backend/app/api/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from pydantic import BaseModel, Field

import app.storage as storage
from app.schema import Assistant, OpengptsUserId
from app.auth.handlers import AuthedUser
from app.schema import Assistant

router = APIRouter()

Expand All @@ -24,9 +25,9 @@ class AssistantPayload(BaseModel):


@router.get("/")
async def list_assistants(opengpts_user_id: OpengptsUserId) -> List[Assistant]:
async def list_assistants(user: AuthedUser) -> List[Assistant]:
"""List all assistants for the current user."""
return await storage.list_assistants(opengpts_user_id)
return await storage.list_assistants(user["user_id"])


@router.get("/public/")
Expand All @@ -43,24 +44,24 @@ async def list_public_assistants(

@router.get("/{aid}")
async def get_assistant(
opengpts_user_id: OpengptsUserId,
user: AuthedUser,
aid: AssistantID,
) -> Assistant:
"""Get an assistant by ID."""
assistant = await storage.get_assistant(opengpts_user_id, aid)
assistant = await storage.get_assistant(user["user_id"], aid)
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
return assistant


@router.post("")
async def create_assistant(
opengpts_user_id: OpengptsUserId,
user: AuthedUser,
payload: AssistantPayload,
) -> Assistant:
"""Create an assistant."""
return await storage.put_assistant(
opengpts_user_id,
user["user_id"],
str(uuid4()),
name=payload.name,
config=payload.config,
Expand All @@ -70,13 +71,13 @@ async def create_assistant(

@router.put("/{aid}")
async def upsert_assistant(
opengpts_user_id: OpengptsUserId,
user: AuthedUser,
aid: AssistantID,
payload: AssistantPayload,
) -> Assistant:
"""Create or update an assistant."""
return await storage.put_assistant(
opengpts_user_id,
user["user_id"],
aid,
name=payload.name,
config=payload.config,
Expand Down
20 changes: 9 additions & 11 deletions backend/app/api/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from sse_starlette import EventSourceResponse

from app.agent import agent
from app.schema import OpengptsUserId
from app.auth.handlers import AuthedUser
from app.storage import get_assistant, get_thread
from app.stream import astream_messages, to_sse

Expand All @@ -30,14 +30,12 @@ class CreateRunPayload(BaseModel):
config: Optional[RunnableConfig] = None


async def _run_input_and_config(
payload: CreateRunPayload, opengpts_user_id: OpengptsUserId
):
thread = await get_thread(opengpts_user_id, payload.thread_id)
async def _run_input_and_config(payload: CreateRunPayload, user_id: str):
thread = await get_thread(user_id, payload.thread_id)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")

assistant = await get_assistant(opengpts_user_id, str(thread["assistant_id"]))
assistant = await get_assistant(user_id, str(thread["assistant_id"]))
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")

Expand All @@ -46,7 +44,7 @@ async def _run_input_and_config(
"configurable": {
**assistant["config"]["configurable"],
**((payload.config or {}).get("configurable") or {}),
"user_id": opengpts_user_id,
"user_id": user_id,
"thread_id": str(thread["thread_id"]),
"assistant_id": str(assistant["assistant_id"]),
},
Expand All @@ -67,22 +65,22 @@ async def _run_input_and_config(
@router.post("")
async def create_run(
payload: CreateRunPayload,
opengpts_user_id: OpengptsUserId,
user: AuthedUser,
background_tasks: BackgroundTasks,
):
"""Create a run."""
input_, config = await _run_input_and_config(payload, opengpts_user_id)
input_, config = await _run_input_and_config(payload, user["user_id"])
background_tasks.add_task(agent.ainvoke, input_, config)
return {"status": "ok"} # TODO add a run id


@router.post("/stream")
async def stream_run(
payload: CreateRunPayload,
opengpts_user_id: OpengptsUserId,
user: AuthedUser,
):
"""Create a run."""
input_, config = await _run_input_and_config(payload, opengpts_user_id)
input_, config = await _run_input_and_config(payload, user["user_id"])

return EventSourceResponse(to_sse(astream_messages(agent, input_, config)))

Expand Down
47 changes: 32 additions & 15 deletions backend/app/api/threads.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Annotated, Any, Dict, List, Sequence, Union
from uuid import uuid4

Expand All @@ -6,7 +7,8 @@
from pydantic import BaseModel, Field

import app.storage as storage
from app.schema import OpengptsUserId, Thread
from app.auth.handlers import AuthedUser
from app.schema import Thread

router = APIRouter()

Expand All @@ -28,59 +30,74 @@ class ThreadPostRequest(BaseModel):


@router.get("/")
async def list_threads(opengpts_user_id: OpengptsUserId) -> List[Thread]:
async def list_threads(user: AuthedUser) -> List[Thread]:
"""List all threads for the current user."""
return await storage.list_threads(opengpts_user_id)
return await storage.list_threads(user["user_id"])


@router.get("/{tid}/state")
async def get_thread_state(
opengpts_user_id: OpengptsUserId,
user: AuthedUser,
tid: ThreadID,
):
"""Get state for a thread."""
return await storage.get_thread_state(opengpts_user_id, tid)
thread, state = await asyncio.gather(
storage.get_thread(user["user_id"], tid),
storage.get_thread_state(user["user_id"], tid),
)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
return state


@router.post("/{tid}/state")
async def add_thread_state(
opengpts_user_id: OpengptsUserId,
user: AuthedUser,
tid: ThreadID,
payload: ThreadPostRequest,
):
"""Add state to a thread."""
return await storage.update_thread_state(opengpts_user_id, tid, payload.values)
thread = await storage.get_thread(user["user_id"], tid)
bakar-io marked this conversation as resolved.
Show resolved Hide resolved
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
return await storage.update_thread_state(user["user_id"], tid, payload.values)


@router.get("/{tid}/history")
async def get_thread_history(
opengpts_user_id: OpengptsUserId,
user: AuthedUser,
tid: ThreadID,
):
"""Get all past states for a thread."""
return await storage.get_thread_history(opengpts_user_id, tid)
thread, history = await asyncio.gather(
storage.get_thread(user["user_id"], tid),
storage.get_thread_history(user["user_id"], tid),
)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
return history


@router.get("/{tid}")
async def get_thread(
opengpts_user_id: OpengptsUserId,
user: AuthedUser,
tid: ThreadID,
) -> Thread:
"""Get a thread by ID."""
thread = await storage.get_thread(opengpts_user_id, tid)
thread = await storage.get_thread(user["user_id"], tid)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
return thread


@router.post("")
async def create_thread(
opengpts_user_id: OpengptsUserId,
user: AuthedUser,
thread_put_request: ThreadPutRequest,
) -> Thread:
"""Create a thread."""
return await storage.put_thread(
opengpts_user_id,
user["user_id"],
str(uuid4()),
assistant_id=thread_put_request.assistant_id,
name=thread_put_request.name,
Expand All @@ -89,13 +106,13 @@ async def create_thread(

@router.put("/{tid}")
async def upsert_thread(
opengpts_user_id: OpengptsUserId,
user: AuthedUser,
tid: ThreadID,
thread_put_request: ThreadPutRequest,
) -> Thread:
"""Update a thread."""
return await storage.put_thread(
opengpts_user_id,
user["user_id"],
tid,
assistant_id=thread_put_request.assistant_id,
name=thread_put_request.name,
Expand Down
Empty file added backend/app/auth/__init__.py
Empty file.
Loading
Loading