Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Jun 18, 2024
1 parent f9a10b2 commit 41c87b6
Show file tree
Hide file tree
Showing 29 changed files with 152 additions and 706 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: start

start:
cd backend && poetry run langgraph up -c ../langgraph.json -d ../compose.override.yml
cd backend && poetry run langgraph up -c ../langgraph.json -d ../compose.override.yml --postgres-uri 'postgres://postgres:postgres@langgraph-postgres:5432/postgres?sslmode=disable' --verbose
4 changes: 0 additions & 4 deletions backend/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,10 @@ TEST_FILE ?= tests/unit_tests/
start:
poetry run uvicorn app.server:app --reload --port 8100 --log-config log_config.json

migrate:
migrate -database postgres://$(POSTGRES_USER):$(POSTGRES_PASSWORD)@$(POSTGRES_HOST):$(POSTGRES_PORT)/$(POSTGRES_DB)?sslmode=disable -path ./migrations up

test:
# We need to update handling of env variables for tests
YDC_API_KEY=placeholder OPENAI_API_KEY=placeholder poetry run pytest $(TEST_FILE)


test_watch:
# We need to update handling of env variables for tests
YDC_API_KEY=placeholder OPENAI_API_KEY=placeholder poetry run ptw . -- $(TEST_FILE)
Expand Down
8 changes: 3 additions & 5 deletions backend/app/api/assistants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Annotated, List
from uuid import uuid4

from fastapi import APIRouter, HTTPException, Path
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -52,23 +51,22 @@ async def create_assistant(
payload: AssistantPayload,
) -> Assistant:
"""Create an assistant."""
return await storage.put_assistant(
return await storage.create_assistant(
user["user_id"],
str(uuid4()),
name=payload.name,
config=payload.config,
public=payload.public,
)


@router.put("/{aid}")
@router.patch("/{aid}")
async def upsert_assistant(
user: AuthedUser,
aid: AssistantID,
payload: AssistantPayload,
) -> Assistant:
"""Create or update an assistant."""
return await storage.put_assistant(
return await storage.patch_assistant(
user["user_id"],
aid,
name=payload.name,
Expand Down
25 changes: 15 additions & 10 deletions backend/app/api/runs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import pathlib
from typing import Any, Dict, Optional, Sequence, Union
from uuid import UUID
Expand All @@ -13,7 +12,7 @@
from sse_starlette import EventSourceResponse

from app.auth.handlers import AuthedUser
from app.lifespan import get_langserve
from app.lifespan import get_api_client
from app.storage import get_assistant, get_thread

router = APIRouter()
Expand All @@ -33,7 +32,7 @@ async def create_run(payload: CreateRunPayload, user: AuthedUser):
thread = await get_thread(user["user_id"], payload.thread_id)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
return await get_langserve().runs.create(
return await get_api_client().runs.create(
payload.thread_id,
thread["assistant_id"],
input=payload.input,
Expand All @@ -53,21 +52,27 @@ async def stream_run(
assistant = await get_assistant(user["user_id"], thread["assistant_id"])
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
interrupt_before = (
["action"]
if assistant["config"]["configurable"].get(
"type==agent/interrupt_before_action"
)
else None
)

return EventSourceResponse(
(
{"event": e.event, "data": json.dumps(e.data)}
async for e in get_langserve().runs.stream(
{
"event": "data" if e.event.startswith("messages") else e.event,
"data": orjson.dumps(e.data).decode(),
}
async for e in get_api_client().runs.stream(
payload.thread_id,
thread["assistant_id"],
input=payload.input,
config=payload.config,
stream_mode="messages",
interrupt_before=["action"]
if assistant["config"]["configurable"].get(
"type==agent/interrupt_before_action"
)
else None,
interrupt_before=interrupt_before,
)
)
)
Expand Down
8 changes: 3 additions & 5 deletions backend/app/api/threads.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Annotated, Any, Dict, List, Optional, Sequence, Union
from uuid import uuid4

from fastapi import APIRouter, HTTPException, Path
from langchain.schema.messages import AnyMessage
Expand Down Expand Up @@ -136,22 +135,21 @@ async def create_thread(
thread_put_request: ThreadPutRequest,
) -> Thread:
"""Create a thread."""
return await storage.put_thread(
return await storage.create_thread(
user["user_id"],
str(uuid4()),
assistant_id=thread_put_request.assistant_id,
name=thread_put_request.name,
)


@router.put("/{tid}")
@router.patch("/{tid}")
async def upsert_thread(
user: AuthedUser,
tid: ThreadID,
thread_put_request: ThreadPutRequest,
) -> Thread:
"""Update a thread."""
return await storage.put_thread(
return await storage.patch_thread(
user["user_id"],
tid,
assistant_id=thread_put_request.assistant_id,
Expand Down
7 changes: 2 additions & 5 deletions backend/app/auth/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from fastapi import Depends, HTTPException, Request
from fastapi.security.http import HTTPBearer

import app.storage as storage
from app.auth.settings import AuthType, settings
from app.schema import User

Expand All @@ -23,8 +22,7 @@ class NOOPAuth(AuthHandler):

async def __call__(self, request: Request) -> User:
sub = request.cookies.get("opengpts_user_id") or self._default_sub
user, _ = await storage.get_or_create_user(sub)
return user
return User(user_id=sub, sub=sub)


class JWTAuthBase(AuthHandler):
Expand All @@ -37,8 +35,7 @@ async def __call__(self, request: Request) -> User:
except jwt.PyJWTError as e:
raise HTTPException(status_code=401, detail=str(e))

user, _ = await storage.get_or_create_user(payload["sub"])
return user
return User(user_id=payload["sub"], sub=payload["sub"])

@abstractmethod
def decode_token(self, token: str, decode_key: str) -> dict:
Expand Down
37 changes: 3 additions & 34 deletions backend/app/lifespan.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,17 @@
import os
from contextlib import asynccontextmanager

import asyncpg
import orjson
import structlog
from fastapi import FastAPI
from langgraph_sdk.client import LangServeClient, get_client
from langgraph_sdk.client import LangGraphClient, get_client

_pg_pool = None
_langserve = None


def get_pg_pool() -> asyncpg.pool.Pool:
return _pg_pool


def get_langserve() -> LangServeClient:
def get_api_client() -> LangGraphClient:
return _langserve


async def _init_connection(conn) -> None:
await conn.set_type_codec(
"json",
encoder=lambda v: orjson.dumps(v).decode(),
decoder=orjson.loads,
schema="pg_catalog",
)
await conn.set_type_codec(
"jsonb",
encoder=lambda v: orjson.dumps(v).decode(),
decoder=orjson.loads,
schema="pg_catalog",
)
await conn.set_type_codec(
"uuid", encoder=lambda v: str(v), decoder=lambda v: v, schema="pg_catalog"
)


@asynccontextmanager
async def lifespan(app: FastAPI):
structlog.configure(
Expand All @@ -52,15 +27,9 @@ async def lifespan(app: FastAPI):
cache_logger_on_first_use=True,
)

global _pg_pool, _langserve
global _langserve

_pg_pool = await asyncpg.create_pool(
os.environ["POSTGRES_URI"],
init=_init_connection,
)
_langserve = get_client(url=os.environ["LANGGRAPH_URL"])
yield
await _pg_pool.close()
await _langserve.http.client.aclose()
_pg_pool = None
_langserve = None
2 changes: 0 additions & 2 deletions backend/app/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ class User(TypedDict):
"""The ID of the user."""
sub: str
"""The sub of the user (from a JWT token)."""
created_at: datetime
"""The time the user was created."""


class Assistant(TypedDict):
Expand Down
Loading

0 comments on commit 41c87b6

Please sign in to comment.