diff --git a/Makefile b/Makefile index 83ba0144..054598af 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ .PHONY: start start: - cd backend && poetry run langgraph up -c ../langgraph.json -d ../compose.override.yml + cd backend && poetry run langgraph up -c ../langgraph.json -d ../compose.override.yml --postgres-uri 'postgres://postgres:postgres@langgraph-postgres:5432/postgres?sslmode=disable' --verbose diff --git a/backend/Makefile b/backend/Makefile index 3f24ce37..58d4dd63 100644 --- a/backend/Makefile +++ b/backend/Makefile @@ -16,14 +16,10 @@ TEST_FILE ?= tests/unit_tests/ start: poetry run uvicorn app.server:app --reload --port 8100 --log-config log_config.json -migrate: - migrate -database postgres://$(POSTGRES_USER):$(POSTGRES_PASSWORD)@$(POSTGRES_HOST):$(POSTGRES_PORT)/$(POSTGRES_DB)?sslmode=disable -path ./migrations up - test: # We need to update handling of env variables for tests YDC_API_KEY=placeholder OPENAI_API_KEY=placeholder poetry run pytest $(TEST_FILE) - test_watch: # We need to update handling of env variables for tests YDC_API_KEY=placeholder OPENAI_API_KEY=placeholder poetry run ptw . -- $(TEST_FILE) diff --git a/backend/app/api/assistants.py b/backend/app/api/assistants.py index b6f5ff0c..effb6a1e 100644 --- a/backend/app/api/assistants.py +++ b/backend/app/api/assistants.py @@ -1,5 +1,4 @@ from typing import Annotated, List -from uuid import uuid4 from fastapi import APIRouter, HTTPException, Path from pydantic import BaseModel, Field @@ -52,23 +51,22 @@ async def create_assistant( payload: AssistantPayload, ) -> Assistant: """Create an assistant.""" - return await storage.put_assistant( + return await storage.create_assistant( user["user_id"], - str(uuid4()), name=payload.name, config=payload.config, public=payload.public, ) -@router.put("/{aid}") +@router.patch("/{aid}") async def upsert_assistant( user: AuthedUser, aid: AssistantID, payload: AssistantPayload, ) -> Assistant: """Create or update an assistant.""" - return await storage.put_assistant( + return await storage.patch_assistant( user["user_id"], aid, name=payload.name, diff --git a/backend/app/api/runs.py b/backend/app/api/runs.py index d0a35132..5b3c0dde 100644 --- a/backend/app/api/runs.py +++ b/backend/app/api/runs.py @@ -1,4 +1,3 @@ -import json import pathlib from typing import Any, Dict, Optional, Sequence, Union from uuid import UUID @@ -13,7 +12,7 @@ from sse_starlette import EventSourceResponse from app.auth.handlers import AuthedUser -from app.lifespan import get_langserve +from app.lifespan import get_api_client from app.storage import get_assistant, get_thread router = APIRouter() @@ -33,7 +32,7 @@ async def create_run(payload: CreateRunPayload, user: AuthedUser): thread = await get_thread(user["user_id"], payload.thread_id) if not thread: raise HTTPException(status_code=404, detail="Thread not found") - return await get_langserve().runs.create( + return await get_api_client().runs.create( payload.thread_id, thread["assistant_id"], input=payload.input, @@ -53,21 +52,27 @@ async def stream_run( assistant = await get_assistant(user["user_id"], thread["assistant_id"]) if not assistant: raise HTTPException(status_code=404, detail="Assistant not found") + interrupt_before = ( + ["action"] + if assistant["config"]["configurable"].get( + "type==agent/interrupt_before_action" + ) + else None + ) return EventSourceResponse( ( - {"event": e.event, "data": json.dumps(e.data)} - async for e in get_langserve().runs.stream( + { + "event": "data" if e.event.startswith("messages") else e.event, + "data": orjson.dumps(e.data).decode(), + } + async for e in get_api_client().runs.stream( payload.thread_id, thread["assistant_id"], input=payload.input, config=payload.config, stream_mode="messages", - interrupt_before=["action"] - if assistant["config"]["configurable"].get( - "type==agent/interrupt_before_action" - ) - else None, + interrupt_before=interrupt_before, ) ) ) diff --git a/backend/app/api/threads.py b/backend/app/api/threads.py index 3edba78b..165cf5f8 100644 --- a/backend/app/api/threads.py +++ b/backend/app/api/threads.py @@ -1,5 +1,4 @@ from typing import Annotated, Any, Dict, List, Optional, Sequence, Union -from uuid import uuid4 from fastapi import APIRouter, HTTPException, Path from langchain.schema.messages import AnyMessage @@ -136,22 +135,21 @@ async def create_thread( thread_put_request: ThreadPutRequest, ) -> Thread: """Create a thread.""" - return await storage.put_thread( + return await storage.create_thread( user["user_id"], - str(uuid4()), assistant_id=thread_put_request.assistant_id, name=thread_put_request.name, ) -@router.put("/{tid}") +@router.patch("/{tid}") async def upsert_thread( user: AuthedUser, tid: ThreadID, thread_put_request: ThreadPutRequest, ) -> Thread: """Update a thread.""" - return await storage.put_thread( + return await storage.patch_thread( user["user_id"], tid, assistant_id=thread_put_request.assistant_id, diff --git a/backend/app/auth/handlers.py b/backend/app/auth/handlers.py index 630d45ff..09356a4b 100644 --- a/backend/app/auth/handlers.py +++ b/backend/app/auth/handlers.py @@ -7,7 +7,6 @@ from fastapi import Depends, HTTPException, Request from fastapi.security.http import HTTPBearer -import app.storage as storage from app.auth.settings import AuthType, settings from app.schema import User @@ -23,8 +22,7 @@ class NOOPAuth(AuthHandler): async def __call__(self, request: Request) -> User: sub = request.cookies.get("opengpts_user_id") or self._default_sub - user, _ = await storage.get_or_create_user(sub) - return user + return User(user_id=sub, sub=sub) class JWTAuthBase(AuthHandler): @@ -37,8 +35,7 @@ async def __call__(self, request: Request) -> User: except jwt.PyJWTError as e: raise HTTPException(status_code=401, detail=str(e)) - user, _ = await storage.get_or_create_user(payload["sub"]) - return user + return User(user_id=payload["sub"], sub=payload["sub"]) @abstractmethod def decode_token(self, token: str, decode_key: str) -> dict: diff --git a/backend/app/lifespan.py b/backend/app/lifespan.py index ed4bbfc0..27a86734 100644 --- a/backend/app/lifespan.py +++ b/backend/app/lifespan.py @@ -1,42 +1,17 @@ import os from contextlib import asynccontextmanager -import asyncpg -import orjson import structlog from fastapi import FastAPI -from langgraph_sdk.client import LangServeClient, get_client +from langgraph_sdk.client import LangGraphClient, get_client -_pg_pool = None _langserve = None -def get_pg_pool() -> asyncpg.pool.Pool: - return _pg_pool - - -def get_langserve() -> LangServeClient: +def get_api_client() -> LangGraphClient: return _langserve -async def _init_connection(conn) -> None: - await conn.set_type_codec( - "json", - encoder=lambda v: orjson.dumps(v).decode(), - decoder=orjson.loads, - schema="pg_catalog", - ) - await conn.set_type_codec( - "jsonb", - encoder=lambda v: orjson.dumps(v).decode(), - decoder=orjson.loads, - schema="pg_catalog", - ) - await conn.set_type_codec( - "uuid", encoder=lambda v: str(v), decoder=lambda v: v, schema="pg_catalog" - ) - - @asynccontextmanager async def lifespan(app: FastAPI): structlog.configure( @@ -52,15 +27,9 @@ async def lifespan(app: FastAPI): cache_logger_on_first_use=True, ) - global _pg_pool, _langserve + global _langserve - _pg_pool = await asyncpg.create_pool( - os.environ["POSTGRES_URI"], - init=_init_connection, - ) _langserve = get_client(url=os.environ["LANGGRAPH_URL"]) yield - await _pg_pool.close() await _langserve.http.client.aclose() - _pg_pool = None _langserve = None diff --git a/backend/app/schema.py b/backend/app/schema.py index 3ae6e595..a5c82744 100644 --- a/backend/app/schema.py +++ b/backend/app/schema.py @@ -9,8 +9,6 @@ class User(TypedDict): """The ID of the user.""" sub: str """The sub of the user (from a JWT token).""" - created_at: datetime - """The time the user was created.""" class Assistant(TypedDict): diff --git a/backend/app/storage.py b/backend/app/storage.py index ab91e9db..f874aa5f 100644 --- a/backend/app/storage.py +++ b/backend/app/storage.py @@ -1,15 +1,16 @@ from typing import Any, Dict, List, Optional, Sequence, Union +from fastapi import HTTPException from langchain_core.messages import AnyMessage from langchain_core.runnables import RunnableConfig -from app.lifespan import get_langserve, get_pg_pool -from app.schema import Assistant, Thread, User +from app.lifespan import get_api_client +from app.schema import Assistant, Thread async def list_assistants(user_id: str) -> List[Assistant]: """List all assistants for the current user.""" - assistants = await get_langserve().assistants.search( + assistants = await get_api_client().assistants.search( metadata={"user_id": user_id}, limit=100 ) return [ @@ -25,7 +26,7 @@ async def list_assistants(user_id: str) -> List[Assistant]: async def get_assistant(user_id: str, assistant_id: str) -> Optional[Assistant]: """Get an assistant by ID.""" - assistant = await get_langserve().assistants.get(assistant_id) + assistant = await get_api_client().assistants.get(assistant_id) if assistant["metadata"].get("user_id") != user_id and not assistant[ "metadata" ].get("public"): @@ -41,7 +42,7 @@ async def get_assistant(user_id: str, assistant_id: str) -> Optional[Assistant]: async def list_public_assistants() -> List[Assistant]: """List all the public assistants.""" - assistants = await get_langserve().assistants.search(metadata={"public": True}) + assistants = await get_api_client().assistants.search(metadata={"public": True}) return [ Assistant( assistant_id=a["assistant_id"], @@ -53,10 +54,10 @@ async def list_public_assistants() -> List[Assistant]: ] -async def put_assistant( - user_id: str, assistant_id: str, *, name: str, config: dict, public: bool = False +async def create_assistant( + user_id: str, *, name: str, config: dict, public: bool = False ) -> Assistant: - """Modify an assistant. + """Create an assistant. Args: user_id: The user ID. @@ -68,8 +69,7 @@ async def put_assistant( Returns: return the assistant model if no exception is raised. """ - assistant = await get_langserve().assistants.upsert( - assistant_id, + assistant = await get_api_client().assistants.create( config["configurable"]["type"], config, metadata={"user_id": user_id, "public": public, "name": name}, @@ -84,19 +84,48 @@ async def put_assistant( ) +async def patch_assistant( + user_id: str, assistant_id: str, *, name: str, config: dict, public: bool = False +) -> Assistant: + """Patch an assistant. + + Args: + user_id: The user ID. + assistant_id: The assistant ID. + name: The assistant name. + config: The assistant config. + public: Whether the assistant is public. + + Returns: + return the assistant model if no exception is raised. + """ + assistant = await get_api_client().assistants.update( + assistant_id, + graph_id=config["configurable"]["type"], + config=config, + metadata={"user_id": user_id, "public": public, "name": name}, + ) + return Assistant( + assistant_id=assistant["assistant_id"], + updated_at=assistant["updated_at"], + config=assistant["config"], + name=name, + public=public, + user_id=user_id, + ) + + async def delete_assistant(user_id: str, assistant_id: str) -> None: """Delete an assistant by ID.""" - async with get_pg_pool().acquire() as conn: - await conn.execute( - "DELETE FROM assistant WHERE assistant_id = $1 AND user_id = $2", - assistant_id, - user_id, - ) + assistant = await get_api_client().assistants.get(assistant_id) + if assistant["metadata"].get("user_id") != user_id: + raise HTTPException(status_code=404, detail="Thread not found") + await get_api_client().assistants.delete(assistant_id) async def list_threads(user_id: str) -> List[Thread]: """List all threads for the current user.""" - threads = await get_langserve().threads.search( + threads = await get_api_client().threads.search( metadata={"user_id": user_id}, limit=100 ) @@ -115,7 +144,7 @@ async def list_threads(user_id: str) -> List[Thread]: async def get_thread(user_id: str, thread_id: str) -> Optional[Thread]: """Get a thread by ID.""" - thread = await get_langserve().threads.get(thread_id) + thread = await get_api_client().threads.get(thread_id) if thread["metadata"].get("user_id") != user_id: return None else: @@ -131,7 +160,7 @@ async def get_thread(user_id: str, thread_id: str) -> Optional[Thread]: async def get_thread_state(*, user_id: str, thread_id: str, assistant: Assistant): """Get state for a thread.""" - return await get_langserve().threads.get_state(thread_id) + return await get_api_client().threads.get_state(thread_id) async def update_thread_state( @@ -145,7 +174,7 @@ async def update_thread_state( # thread_id (str) must be passed to update_state() instead of config # (dict) so that default configs are applied in LangGraph API. thread_id = config["configurable"]["thread_id"] - return await get_langserve().threads.update_state(thread_id, values) + return await get_api_client().threads.update_state(thread_id, values) async def patch_thread_state( @@ -153,19 +182,34 @@ async def patch_thread_state( metadata: Dict[str, Any], ): """Patch state of a thread.""" - return await get_langserve().threads.patch_state(config, metadata) + return await get_api_client().threads.patch_state(config, metadata) async def get_thread_history(*, user_id: str, thread_id: str, assistant: Assistant): """Get the history of a thread.""" - return await get_langserve().threads.get_history(thread_id) + return await get_api_client().threads.get_history(thread_id) -async def put_thread( +async def create_thread(user_id: str, *, assistant_id: str, name: str) -> Thread: + """Modify a thread.""" + thread = await get_api_client().threads.create( + metadata={"user_id": user_id, "assistant_id": assistant_id, "name": name}, + ) + return Thread( + thread_id=thread["thread_id"], + user_id=thread["metadata"].pop("user_id"), + assistant_id=thread["metadata"].pop("assistant_id"), + name=thread["metadata"].pop("name"), + updated_at=thread["updated_at"], + metadata=thread["metadata"], + ) + + +async def patch_thread( user_id: str, thread_id: str, *, assistant_id: str, name: str ) -> Thread: """Modify a thread.""" - thread = await get_langserve().threads.upsert( + thread = await get_api_client().threads.update( thread_id, metadata={"user_id": user_id, "assistant_id": assistant_id, "name": name}, ) @@ -181,19 +225,7 @@ async def put_thread( async def delete_thread(user_id: str, thread_id: str): """Delete a thread by ID.""" - await get_langserve().threads.delete(thread_id) - - -async def get_or_create_user(sub: str) -> tuple[User, bool]: - """Returns a tuple of the user and a boolean indicating whether the user was created.""" - async with get_pg_pool().acquire() as conn: - if user := await conn.fetchrow('SELECT * FROM "user" WHERE sub = $1', sub): - return user, False - if user := await conn.fetchrow( - 'INSERT INTO "user" (sub) VALUES ($1) ON CONFLICT (sub) DO NOTHING RETURNING *', - sub, - ): - return user, True - if user := await conn.fetchrow('SELECT * FROM "user" WHERE sub = $1', sub): - return user, False - raise RuntimeError("User creation failed.") + thread = await get_api_client().threads.get(thread_id) + if thread["metadata"].get("user_id") != user_id: + raise HTTPException(status_code=404, detail="Thread not found") + await get_api_client().threads.delete(thread_id) diff --git a/backend/migrations/000001_create_extensions_and_first_tables.down.sql b/backend/migrations/000001_create_extensions_and_first_tables.down.sql deleted file mode 100644 index 08c8d5d5..00000000 --- a/backend/migrations/000001_create_extensions_and_first_tables.down.sql +++ /dev/null @@ -1,3 +0,0 @@ -DROP TABLE IF EXISTS thread; -DROP TABLE IF EXISTS assistant; -DROP TABLE IF EXISTS checkpoints; diff --git a/backend/migrations/000001_create_extensions_and_first_tables.up.sql b/backend/migrations/000001_create_extensions_and_first_tables.up.sql deleted file mode 100644 index cb395a74..00000000 --- a/backend/migrations/000001_create_extensions_and_first_tables.up.sql +++ /dev/null @@ -1,24 +0,0 @@ -CREATE EXTENSION IF NOT EXISTS vector; -CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; - -CREATE TABLE IF NOT EXISTS assistant ( - assistant_id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), - user_id VARCHAR(255) NOT NULL, - name VARCHAR(255) NOT NULL, - config JSON NOT NULL, - updated_at TIMESTAMP WITH TIME ZONE DEFAULT (CURRENT_TIMESTAMP AT TIME ZONE 'UTC'), - public BOOLEAN NOT NULL -); - -CREATE TABLE IF NOT EXISTS thread ( - thread_id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), - assistant_id UUID REFERENCES assistant(assistant_id) ON DELETE SET NULL, - user_id VARCHAR(255) NOT NULL, - name VARCHAR(255) NOT NULL, - updated_at TIMESTAMP WITH TIME ZONE DEFAULT (CURRENT_TIMESTAMP AT TIME ZONE 'UTC') -); - -CREATE TABLE IF NOT EXISTS checkpoints ( - thread_id TEXT PRIMARY KEY, - checkpoint BYTEA -); \ No newline at end of file diff --git a/backend/migrations/000002_checkpoints_update_schema.down.sql b/backend/migrations/000002_checkpoints_update_schema.down.sql deleted file mode 100644 index c8a249eb..00000000 --- a/backend/migrations/000002_checkpoints_update_schema.down.sql +++ /dev/null @@ -1,5 +0,0 @@ -ALTER TABLE checkpoints - DROP CONSTRAINT IF EXISTS checkpoints_pkey, - ADD PRIMARY KEY (thread_id), - DROP COLUMN IF EXISTS thread_ts, - DROP COLUMN IF EXISTS parent_ts; diff --git a/backend/migrations/000002_checkpoints_update_schema.up.sql b/backend/migrations/000002_checkpoints_update_schema.up.sql deleted file mode 100644 index 9ddd077f..00000000 --- a/backend/migrations/000002_checkpoints_update_schema.up.sql +++ /dev/null @@ -1,11 +0,0 @@ -ALTER TABLE checkpoints - ADD COLUMN IF NOT EXISTS thread_ts TIMESTAMPTZ, - ADD COLUMN IF NOT EXISTS parent_ts TIMESTAMPTZ; - -UPDATE checkpoints - SET thread_ts = CURRENT_TIMESTAMP AT TIME ZONE 'UTC' -WHERE thread_ts IS NULL; - -ALTER TABLE checkpoints - DROP CONSTRAINT IF EXISTS checkpoints_pkey, - ADD PRIMARY KEY (thread_id, thread_ts) diff --git a/backend/migrations/000003_create_user.down.sql b/backend/migrations/000003_create_user.down.sql deleted file mode 100644 index 66c5acad..00000000 --- a/backend/migrations/000003_create_user.down.sql +++ /dev/null @@ -1,9 +0,0 @@ -ALTER TABLE assistant - DROP CONSTRAINT fk_assistant_user_id, - ALTER COLUMN user_id TYPE VARCHAR USING (user_id::text); - -ALTER TABLE thread - DROP CONSTRAINT fk_thread_user_id, - ALTER COLUMN user_id TYPE VARCHAR USING (user_id::text); - -DROP TABLE IF EXISTS "user"; \ No newline at end of file diff --git a/backend/migrations/000003_create_user.up.sql b/backend/migrations/000003_create_user.up.sql deleted file mode 100644 index bf0ae7bc..00000000 --- a/backend/migrations/000003_create_user.up.sql +++ /dev/null @@ -1,25 +0,0 @@ -CREATE TABLE IF NOT EXISTS "user" ( - user_id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - sub VARCHAR(255) UNIQUE NOT NULL, - created_at TIMESTAMP WITH TIME ZONE DEFAULT (CURRENT_TIMESTAMP AT TIME ZONE 'UTC') -); - -INSERT INTO "user" (user_id, sub) -SELECT DISTINCT user_id::uuid, user_id -FROM assistant -WHERE user_id IS NOT NULL -ON CONFLICT (user_id) DO NOTHING; - -INSERT INTO "user" (user_id, sub) -SELECT DISTINCT user_id::uuid, user_id -FROM thread -WHERE user_id IS NOT NULL -ON CONFLICT (user_id) DO NOTHING; - -ALTER TABLE assistant - ALTER COLUMN user_id TYPE UUID USING (user_id::UUID), - ADD CONSTRAINT fk_assistant_user_id FOREIGN KEY (user_id) REFERENCES "user"(user_id); - -ALTER TABLE thread - ALTER COLUMN user_id TYPE UUID USING (user_id::UUID), - ADD CONSTRAINT fk_thread_user_id FOREIGN KEY (user_id) REFERENCES "user"(user_id); diff --git a/backend/migrations/000004_add_metadata_to_thread.down.sql b/backend/migrations/000004_add_metadata_to_thread.down.sql deleted file mode 100644 index 106fd0ba..00000000 --- a/backend/migrations/000004_add_metadata_to_thread.down.sql +++ /dev/null @@ -1,2 +0,0 @@ -ALTER TABLE thread -DROP COLUMN metadata; \ No newline at end of file diff --git a/backend/migrations/000004_add_metadata_to_thread.up.sql b/backend/migrations/000004_add_metadata_to_thread.up.sql deleted file mode 100644 index d0394582..00000000 --- a/backend/migrations/000004_add_metadata_to_thread.up.sql +++ /dev/null @@ -1,9 +0,0 @@ -ALTER TABLE thread -ADD COLUMN metadata JSONB; - -UPDATE thread -SET metadata = json_build_object( - 'assistant_type', (SELECT config->'configurable'->>'type' - FROM assistant - WHERE assistant.assistant_id = thread.assistant_id) -); \ No newline at end of file diff --git a/backend/poetry.lock b/backend/poetry.lock index 39a5ae88..c8f0246a 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -182,63 +182,6 @@ files = [ {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, ] -[[package]] -name = "asyncpg" -version = "0.29.0" -description = "An asyncio PostgreSQL driver" -optional = false -python-versions = ">=3.8.0" -files = [ - {file = "asyncpg-0.29.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:72fd0ef9f00aeed37179c62282a3d14262dbbafb74ec0ba16e1b1864d8a12169"}, - {file = "asyncpg-0.29.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:52e8f8f9ff6e21f9b39ca9f8e3e33a5fcdceaf5667a8c5c32bee158e313be385"}, - {file = "asyncpg-0.29.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9e6823a7012be8b68301342ba33b4740e5a166f6bbda0aee32bc01638491a22"}, - {file = "asyncpg-0.29.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:746e80d83ad5d5464cfbf94315eb6744222ab00aa4e522b704322fb182b83610"}, - {file = "asyncpg-0.29.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ff8e8109cd6a46ff852a5e6bab8b0a047d7ea42fcb7ca5ae6eaae97d8eacf397"}, - {file = "asyncpg-0.29.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:97eb024685b1d7e72b1972863de527c11ff87960837919dac6e34754768098eb"}, - {file = "asyncpg-0.29.0-cp310-cp310-win32.whl", hash = "sha256:5bbb7f2cafd8d1fa3e65431833de2642f4b2124be61a449fa064e1a08d27e449"}, - {file = "asyncpg-0.29.0-cp310-cp310-win_amd64.whl", hash = "sha256:76c3ac6530904838a4b650b2880f8e7af938ee049e769ec2fba7cd66469d7772"}, - {file = "asyncpg-0.29.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d4900ee08e85af01adb207519bb4e14b1cae8fd21e0ccf80fac6aa60b6da37b4"}, - {file = "asyncpg-0.29.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a65c1dcd820d5aea7c7d82a3fdcb70e096f8f70d1a8bf93eb458e49bfad036ac"}, - {file = "asyncpg-0.29.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b52e46f165585fd6af4863f268566668407c76b2c72d366bb8b522fa66f1870"}, - {file = "asyncpg-0.29.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc600ee8ef3dd38b8d67421359779f8ccec30b463e7aec7ed481c8346decf99f"}, - {file = "asyncpg-0.29.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:039a261af4f38f949095e1e780bae84a25ffe3e370175193174eb08d3cecab23"}, - {file = "asyncpg-0.29.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6feaf2d8f9138d190e5ec4390c1715c3e87b37715cd69b2c3dfca616134efd2b"}, - {file = "asyncpg-0.29.0-cp311-cp311-win32.whl", hash = "sha256:1e186427c88225ef730555f5fdda6c1812daa884064bfe6bc462fd3a71c4b675"}, - {file = "asyncpg-0.29.0-cp311-cp311-win_amd64.whl", hash = "sha256:cfe73ffae35f518cfd6e4e5f5abb2618ceb5ef02a2365ce64f132601000587d3"}, - {file = "asyncpg-0.29.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6011b0dc29886ab424dc042bf9eeb507670a3b40aece3439944006aafe023178"}, - {file = "asyncpg-0.29.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b544ffc66b039d5ec5a7454667f855f7fec08e0dfaf5a5490dfafbb7abbd2cfb"}, - {file = "asyncpg-0.29.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d84156d5fb530b06c493f9e7635aa18f518fa1d1395ef240d211cb563c4e2364"}, - {file = "asyncpg-0.29.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54858bc25b49d1114178d65a88e48ad50cb2b6f3e475caa0f0c092d5f527c106"}, - {file = "asyncpg-0.29.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:bde17a1861cf10d5afce80a36fca736a86769ab3579532c03e45f83ba8a09c59"}, - {file = "asyncpg-0.29.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:37a2ec1b9ff88d8773d3eb6d3784dc7e3fee7756a5317b67f923172a4748a175"}, - {file = "asyncpg-0.29.0-cp312-cp312-win32.whl", hash = "sha256:bb1292d9fad43112a85e98ecdc2e051602bce97c199920586be83254d9dafc02"}, - {file = "asyncpg-0.29.0-cp312-cp312-win_amd64.whl", hash = "sha256:2245be8ec5047a605e0b454c894e54bf2ec787ac04b1cb7e0d3c67aa1e32f0fe"}, - {file = "asyncpg-0.29.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0009a300cae37b8c525e5b449233d59cd9868fd35431abc470a3e364d2b85cb9"}, - {file = "asyncpg-0.29.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5cad1324dbb33f3ca0cd2074d5114354ed3be2b94d48ddfd88af75ebda7c43cc"}, - {file = "asyncpg-0.29.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:012d01df61e009015944ac7543d6ee30c2dc1eb2f6b10b62a3f598beb6531548"}, - {file = "asyncpg-0.29.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:000c996c53c04770798053e1730d34e30cb645ad95a63265aec82da9093d88e7"}, - {file = "asyncpg-0.29.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:e0bfe9c4d3429706cf70d3249089de14d6a01192d617e9093a8e941fea8ee775"}, - {file = "asyncpg-0.29.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:642a36eb41b6313ffa328e8a5c5c2b5bea6ee138546c9c3cf1bffaad8ee36dd9"}, - {file = "asyncpg-0.29.0-cp38-cp38-win32.whl", hash = "sha256:a921372bbd0aa3a5822dd0409da61b4cd50df89ae85150149f8c119f23e8c408"}, - {file = "asyncpg-0.29.0-cp38-cp38-win_amd64.whl", hash = "sha256:103aad2b92d1506700cbf51cd8bb5441e7e72e87a7b3a2ca4e32c840f051a6a3"}, - {file = "asyncpg-0.29.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5340dd515d7e52f4c11ada32171d87c05570479dc01dc66d03ee3e150fb695da"}, - {file = "asyncpg-0.29.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e17b52c6cf83e170d3d865571ba574577ab8e533e7361a2b8ce6157d02c665d3"}, - {file = "asyncpg-0.29.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f100d23f273555f4b19b74a96840aa27b85e99ba4b1f18d4ebff0734e78dc090"}, - {file = "asyncpg-0.29.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48e7c58b516057126b363cec8ca02b804644fd012ef8e6c7e23386b7d5e6ce83"}, - {file = "asyncpg-0.29.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f9ea3f24eb4c49a615573724d88a48bd1b7821c890c2effe04f05382ed9e8810"}, - {file = "asyncpg-0.29.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8d36c7f14a22ec9e928f15f92a48207546ffe68bc412f3be718eedccdf10dc5c"}, - {file = "asyncpg-0.29.0-cp39-cp39-win32.whl", hash = "sha256:797ab8123ebaed304a1fad4d7576d5376c3a006a4100380fb9d517f0b59c1ab2"}, - {file = "asyncpg-0.29.0-cp39-cp39-win_amd64.whl", hash = "sha256:cce08a178858b426ae1aa8409b5cc171def45d4293626e7aa6510696d46decd8"}, - {file = "asyncpg-0.29.0.tar.gz", hash = "sha256:d1c49e1f44fffafd9a55e1a9b101590859d881d639ea2922516f5d9c512d354e"}, -] - -[package.dependencies] -async-timeout = {version = ">=4.0.3", markers = "python_version < \"3.12.0\""} - -[package.extras] -docs = ["Sphinx (>=5.3.0,<5.4.0)", "sphinx-rtd-theme (>=1.2.2)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"] -test = ["flake8 (>=6.1,<7.0)", "uvloop (>=0.15.3)"] - [[package]] name = "attrs" version = "23.2.0" @@ -863,24 +806,24 @@ files = [ test = ["pytest (>=6)"] [[package]] -name = "fastapi" -version = "0.103.2" +name = "fastapi-slim" +version = "0.111.0" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "fastapi-0.103.2-py3-none-any.whl", hash = "sha256:3270de872f0fe9ec809d4bd3d4d890c6d5cc7b9611d721d6438f9dacc8c4ef2e"}, - {file = "fastapi-0.103.2.tar.gz", hash = "sha256:75a11f6bfb8fc4d2bec0bd710c2d5f2829659c0e8c0afd5560fdda6ce25ec653"}, + {file = "fastapi_slim-0.111.0-py3-none-any.whl", hash = "sha256:6e4b04a555496e5a2590031fcae3ef8e364ad4901b340033e2e1d8136471aca2"}, + {file = "fastapi_slim-0.111.0.tar.gz", hash = "sha256:100720e4362ec4de97dee83a579b970e79fb5bf48073b37c9ce9b0e63dda4bec"}, ] [package.dependencies] -anyio = ">=3.7.1,<4.0.0" pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0" -starlette = ">=0.27.0,<0.28.0" -typing-extensions = ">=4.5.0" +starlette = ">=0.37.2,<0.38.0" +typing-extensions = ">=4.8.0" [package.extras] -all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +all = ["email_validator (>=2.0.0)", "fastapi-cli (>=0.0.2)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +standard = ["email_validator (>=2.0.0)", "fastapi-cli (>=0.0.2)", "httpx (>=0.23.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "python-multipart (>=0.0.7)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] [[package]] name = "feedparser" @@ -1999,13 +1942,13 @@ uuid6 = ">=2024.1.12,<2025.0.0" [[package]] name = "langgraph-cli" -version = "0.1.21" -description = "" +version = "0.1.36" +description = "CLI for interacting with LangGraph API" optional = false python-versions = "<4.0.0,>=3.9.0" files = [ - {file = "langgraph_cli-0.1.21-py3-none-any.whl", hash = "sha256:3e8d6df554bf8054ff7edd2a0a5a79854a8debbe642da2f07d3ecd5bee651267"}, - {file = "langgraph_cli-0.1.21.tar.gz", hash = "sha256:83af68c923425fbf25be6db76b412581f5bb93b4322d0c1df18ff2178f427415"}, + {file = "langgraph_cli-0.1.36-py3-none-any.whl", hash = "sha256:aa6cc4fc4d76e8235d5d30b33fe8049fca4982581a0fcc1c95cd177e3da77b2d"}, + {file = "langgraph_cli-0.1.36.tar.gz", hash = "sha256:a5e914fe4d7e48419ff0c569b5b792da59663884dfd1c7656c87310a28a55d05"}, ] [package.dependencies] @@ -2013,13 +1956,13 @@ click = ">=8.1.7,<9.0.0" [[package]] name = "langgraph-sdk" -version = "0.1.10" +version = "0.1.21" description = "" optional = false python-versions = "<4.0.0,>=3.9.0" files = [ - {file = "langgraph_sdk-0.1.10-py3-none-any.whl", hash = "sha256:66eaf85583deced783c2ae9f43548f4c0467a443f8d316745e335380343d7634"}, - {file = "langgraph_sdk-0.1.10.tar.gz", hash = "sha256:f4131c57e55d1ad82c23f09f9d187f74d219f01e56fc017928aea34ed15d9f6e"}, + {file = "langgraph_sdk-0.1.21-py3-none-any.whl", hash = "sha256:6b8e121efe5d6500d60002ed0e61bff6ce1f340c486af1834ba0df6a36e0f242"}, + {file = "langgraph_sdk-0.1.21.tar.gz", hash = "sha256:69b614d3b1d73d712088ad9216d3b3f79e543c038b912b59277b11807b9bd3e1"}, ] [package.dependencies] @@ -3674,30 +3617,32 @@ sqlcipher = ["sqlcipher3_binary"] [[package]] name = "sse-starlette" -version = "1.8.2" +version = "2.1.2" description = "SSE plugin for Starlette" optional = false python-versions = ">=3.8" files = [ - {file = "sse_starlette-1.8.2-py3-none-any.whl", hash = "sha256:70cc7ef5aca4abe8a25dec1284cce4fe644dd7bf0c406d3e852e516092b7f849"}, - {file = "sse_starlette-1.8.2.tar.gz", hash = "sha256:e0f9b8dec41adc092a0a6e0694334bd3cfd3084c44c497a6ebc1fb4bdd919acd"}, + {file = "sse_starlette-2.1.2-py3-none-any.whl", hash = "sha256:af7fbd2b307befcf59130ab5b9a8ad67d06c4dd76cf98d0b46ec6500286bdc89"}, + {file = "sse_starlette-2.1.2.tar.gz", hash = "sha256:b93035678d5c2c4a94bd34d3d5803636a32ee898943e68120796da9bb3bbb073"}, ] [package.dependencies] anyio = "*" -fastapi = "*" starlette = "*" uvicorn = "*" +[package.extras] +examples = ["fastapi"] + [[package]] name = "starlette" -version = "0.27.0" +version = "0.37.2" description = "The little ASGI library that shines." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "starlette-0.27.0-py3-none-any.whl", hash = "sha256:918416370e846586541235ccd38a474c08b80443ed31c578a418e2209b3eef91"}, - {file = "starlette-0.27.0.tar.gz", hash = "sha256:6a6b0d042acb8d469a01eba54e9cda6cbd24ac602c4cd016723117d6a7e73b75"}, + {file = "starlette-0.37.2-py3-none-any.whl", hash = "sha256:6fe59f29268538e5d0d182f2791a479a0c64638e6935d1c6989e63fb2699c6ee"}, + {file = "starlette-0.37.2.tar.gz", hash = "sha256:9af890290133b79fc3db55474ade20f6220a364a0402e0b556e7cd5e1e093823"}, ] [package.dependencies] @@ -3705,7 +3650,7 @@ anyio = ">=3.4.0,<5" typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} [package.extras] -full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyaml"] +full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"] [[package]] name = "structlog" @@ -4467,4 +4412,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.9.0,<3.12" -content-hash = "682c7736440532200f35e5d400c86b7568894b1d045c05d7a8be42171759d56c" +content-hash = "5cff3ebc38cdfc15aadf17e33db955600529f979f49a468a963668dfc606f0fb" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index ac41fdc8..6e686b65 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -8,13 +8,10 @@ packages = [{include = "app"}] [tool.poetry.dependencies] python = "^3.9.0,<3.12" -sse-starlette = "^1.6.5" +sse-starlette = ">=2.1.0" tomli-w = "^1.0.0" -uvicorn = "^0.23.2" -fastapi = "^0.103.2" -# Uncomment if you need to work from a development branch -# This will only work for local development though! -# langchain = { git = "git@github.com:langchain-ai/langchain.git/", branch = "nc/subclass-runnable-binding" , subdirectory = "libs/langchain"} +uvicorn = ">=0.23.2" +fastapi-slim = ">=0.103.2" orjson = ">=3.9.10" python-multipart = "^0.0.6" langchain = ">=0.2.0" @@ -37,7 +34,6 @@ httpx = { version = ">=0.25.2", extras = ["socks"] } unstructured = {extras = ["doc", "docx"], version = "^0.12.5"} pgvector = "^0.2.5" psycopg2-binary = "^2.9.9" -asyncpg = "^0.29.0" pyjwt = {extras = ["crypto"], version = "^2.8.0"} langchain-anthropic = ">=0.1.8" structlog = "^24.1.0" diff --git a/backend/tests/unit_tests/app/test_app.py b/backend/tests/unit_tests/app/test_app.py index b84f887f..e5f6f4a4 100644 --- a/backend/tests/unit_tests/app/test_app.py +++ b/backend/tests/unit_tests/app/test_app.py @@ -3,8 +3,6 @@ from typing import Optional, Sequence from uuid import uuid4 -import asyncpg - from tests.unit_tests.app.helpers import get_client @@ -14,14 +12,11 @@ def _project(d: dict, *, exclude_keys: Optional[Sequence[str]]) -> dict: return {k: v for k, v in d.items() if k not in _exclude} -async def test_list_and_create_assistants(pool: asyncpg.pool.Pool) -> None: +async def test_list_and_create_assistants() -> None: """Test list and create assistants.""" headers = {"Cookie": "opengpts_user_id=1"} aid = str(uuid4()) - async with pool.acquire() as conn: - assert len(await conn.fetch("SELECT * FROM assistant;")) == 0 - async with get_client() as client: response = await client.get( "/assistants/", @@ -44,8 +39,6 @@ async def test_list_and_create_assistants(pool: asyncpg.pool.Pool) -> None: "name": "bobby", "public": False, } - async with pool.acquire() as conn: - assert len(await conn.fetch("SELECT * FROM assistant;")) == 1 response = await client.get("/assistants/", headers=headers) assert [ diff --git a/backend/tests/unit_tests/conftest.py b/backend/tests/unit_tests/conftest.py index 4d21da0d..4ee8ccfb 100644 --- a/backend/tests/unit_tests/conftest.py +++ b/backend/tests/unit_tests/conftest.py @@ -1,85 +1,16 @@ import asyncio import os -import subprocess -import asyncpg import pytest from app.auth.settings import AuthType from app.auth.settings import settings as auth_settings -from app.lifespan import get_pg_pool, lifespan -from app.server import app auth_settings.auth_type = AuthType.NOOP # Temporary handling of environment variables for testing os.environ["OPENAI_API_KEY"] = "test" -TEST_DB = "test" -assert os.environ["POSTGRES_DB"] != TEST_DB, "Test and main database conflict." -os.environ["POSTGRES_DB"] = TEST_DB - - -async def _get_conn() -> asyncpg.Connection: - return await asyncpg.connect( - user=os.environ["POSTGRES_USER"], - password=os.environ["POSTGRES_PASSWORD"], - host=os.environ["POSTGRES_HOST"], - port=os.environ["POSTGRES_PORT"], - database="postgres", - ) - - -async def _create_test_db() -> None: - """Check if the test database exists and create it if it doesn't.""" - conn = await _get_conn() - exists = await conn.fetchval("SELECT 1 FROM pg_database WHERE datname=$1", TEST_DB) - if not exists: - await conn.execute(f'CREATE DATABASE "{TEST_DB}"') - await conn.close() - - -async def _drop_test_db() -> None: - """Check if the test database exists and if so, drop it.""" - conn = await _get_conn() - exists = await conn.fetchval("SELECT 1 FROM pg_database WHERE datname=$1", TEST_DB) - if exists: - await conn.execute(f'DROP DATABASE "{TEST_DB}" WITH (FORCE)') - await conn.close() - - -def _migrate_test_db() -> None: - subprocess.run(["make", "migrate"], check=True) - - -@pytest.fixture(scope="session") -async def pool(): - await _drop_test_db() # In case previous test session was abruptly terminated - await _create_test_db() - _migrate_test_db() - async with lifespan(app): - yield get_pg_pool() - await _drop_test_db() - - -@pytest.fixture(scope="function", autouse=True) -async def clear_test_db(pool): - """Truncate all tables before each test.""" - async with pool.acquire() as conn: - query = """ - DO - $$ - DECLARE - r RECORD; - BEGIN - FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public') LOOP - EXECUTE 'TRUNCATE TABLE ' || quote_ident(r.tablename) || ' CASCADE;'; - END LOOP; - END - $$; - """ - await conn.execute(query) - @pytest.fixture(scope="session") def event_loop(request): diff --git a/compose.override.yml b/compose.override.yml index 60d02161..9e22b1cd 100644 --- a/compose.override.yml +++ b/compose.override.yml @@ -1,26 +1,26 @@ +volumes: + langgraph-data: + driver: local services: langgraph-api: environment: PGVECTOR_URI: "postgresql+psycopg2://postgres:postgres@langgraph-postgres:5432/postgres?sslmode=disable" langgraph-postgres: image: pgvector/pgvector:pg16 - postgres-setup: - image: migrate/migrate - depends_on: - langgraph-postgres: - condition: service_healthy + restart: on-failure + ports: + - "5433:5432" + environment: + POSTGRES_DB: postgres + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres volumes: - - ./backend/migrations:/migrations - env_file: - - .env - command: - [ - "-path", - "/migrations", - "-database", - "postgres://postgres:postgres@langgraph-postgres:5432/postgres?sslmode=disable", - "up", - ] + - langgraph-data:/var/lib/postgresql/data + healthcheck: + test: pg_isready -U postgres + start_period: 10s + timeout: 1s + retries: 5 backend: container_name: opengpts-backend pull_policy: build @@ -29,14 +29,13 @@ services: ports: - "8100:8000" # Backend is accessible on localhost:8100 depends_on: - postgres-setup: - condition: service_completed_successfully + langgraph-postgres: + condition: service_healthy env_file: - .env volumes: - ./backend:/backend environment: - POSTGRES_URI: "postgres://postgres:postgres@langgraph-postgres:5432/postgres?sslmode=disable" PGVECTOR_URI: "postgresql+psycopg2://postgres:postgres@langgraph-postgres:5432/postgres?sslmode=disable" LANGGRAPH_URL: "http://langgraph-api:8000" command: diff --git a/frontend/src/hooks/useChatList.ts b/frontend/src/hooks/useChatList.ts index 2c76b065..8853e2a1 100644 --- a/frontend/src/hooks/useChatList.ts +++ b/frontend/src/hooks/useChatList.ts @@ -61,7 +61,7 @@ export function useChatList(): ChatListProps { const updateChat = useCallback( async (thread_id: string, name: string, assistant_id: string | null) => { const response = await fetch(`/threads/${thread_id}`, { - method: "PUT", + method: "PATCH", body: JSON.stringify({ assistant_id, name }), headers: { "Content-Type": "application/json", diff --git a/frontend/src/hooks/useConfigList.ts b/frontend/src/hooks/useConfigList.ts index 548ce6a1..82ec627e 100644 --- a/frontend/src/hooks/useConfigList.ts +++ b/frontend/src/hooks/useConfigList.ts @@ -73,7 +73,7 @@ export function useConfigList(): ConfigListProps { const confResponse = await fetch( assistantId ? `/api/assistants/${assistantId}` : "/api/assistants", { - method: assistantId ? "PUT" : "POST", + method: assistantId ? "PATCH" : "POST", body: JSON.stringify({ name, config, public: isPublic }), headers: { "Content-Type": "application/json", diff --git a/frontend/src/hooks/useStreamState.tsx b/frontend/src/hooks/useStreamState.tsx index f79aa9c9..dafb24fd 100644 --- a/frontend/src/hooks/useStreamState.tsx +++ b/frontend/src/hooks/useStreamState.tsx @@ -40,6 +40,7 @@ export function useStreamState(): StreamStateProps { body: JSON.stringify({ input, thread_id, config }), openWhenHidden: true, onmessage(msg) { + console.log(msg); if (msg.event === "data") { const messages = JSON.parse(msg.data); setCurrent((current) => ({ diff --git a/tools/redis_to_postgres/Dockerfile b/tools/redis_to_postgres/Dockerfile deleted file mode 100644 index 7055882a..00000000 --- a/tools/redis_to_postgres/Dockerfile +++ /dev/null @@ -1,8 +0,0 @@ -FROM langchain/open-gpts:latest - -RUN poetry add redis==5.0.1 - -COPY migrate_data.py . - -# Run database schema migrations and then migrate data -ENTRYPOINT sh -c "make migrate && python migrate_data.py" \ No newline at end of file diff --git a/tools/redis_to_postgres/README.md b/tools/redis_to_postgres/README.md deleted file mode 100644 index a1d45677..00000000 --- a/tools/redis_to_postgres/README.md +++ /dev/null @@ -1,11 +0,0 @@ -OpenGPTs previously used Redis for data persistence, but has since switched to Postgres. If you have data in Redis that you would like to migrate to Postgres, follow the instructions below. - -Navigate to the `tools/redis_to_postgres` directory and ensure that the environment variables in the docker-compose file are set correctly for your Redis and Postgres instances. Then, run the following command to perform the migration: - -```shell -docker compose up --build --abort-on-container-exit -``` - -This will run database schema migrations for Postgres and then copy data from Redis to Postgres. Eventually all containers will be stopped. - -Note: if you were not using Redis locally and instead were using a remote Redis instance (for example on AWS), you can simply set the `REDIS_URL` environment variable to the remote instance's address, remove the `redis` service from the docker-compose file, and run the same command as above. \ No newline at end of file diff --git a/tools/redis_to_postgres/docker-compose.yml b/tools/redis_to_postgres/docker-compose.yml deleted file mode 100644 index b3b962d8..00000000 --- a/tools/redis_to_postgres/docker-compose.yml +++ /dev/null @@ -1,23 +0,0 @@ -version: "3" - -services: - redis: - image: redis/redis-stack-server:latest - ports: - - "6380:6379" - volumes: - - ./../../redis-volume:/data - data-migrator: - build: - context: . - depends_on: - - redis - network_mode: "host" - environment: - REDIS_URL: "redis://localhost:6380" - POSTGRES_HOST: "localhost" - POSTGRES_PORT: 5433 - POSTGRES_DB: postgres - POSTGRES_USER: postgres - POSTGRES_PASSWORD: postgres - OPENAI_API_KEY: ... diff --git a/tools/redis_to_postgres/migrate_data.py b/tools/redis_to_postgres/migrate_data.py deleted file mode 100644 index 84cfd9a0..00000000 --- a/tools/redis_to_postgres/migrate_data.py +++ /dev/null @@ -1,282 +0,0 @@ -import asyncio -import json -import logging -import os -import pickle -import struct -import uuid -from collections import defaultdict -from datetime import datetime -from typing import Any, Iterator, Optional - -import asyncpg -import orjson -from langchain.utilities.redis import get_client -from langchain_core.runnables import ConfigurableFieldSpec, RunnableConfig -from langgraph.checkpoint import BaseCheckpointSaver -from langgraph.checkpoint.base import ( - Checkpoint, - empty_checkpoint, -) -from redis.client import Redis as RedisType - -from app.checkpoint import PostgresCheckpoint -from app.lifespan import get_pg_pool, lifespan -from app.server import app - -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" -) -logger = logging.getLogger(__name__) - -redis_client: RedisType = get_client(os.environ["REDIS_URL"], socket_keepalive=True) - -thread_hash_keys = ["assistant_id", "name", "updated_at"] -assistant_hash_keys = ["name", "config", "updated_at", "public"] -embedding_hash_keys = ["namespace", "source", "content_vector", "title", "content"] -public_user_id = "eef39817-c173-4eb6-8be4-f77cf37054fb" - - -def keys(match: str) -> Iterator[str]: - cursor = 0 - while True: - cursor, keys = redis_client.scan(cursor=cursor, match=match, count=100) - for key in keys: - yield key.decode("utf-8") - if cursor == 0: - break - - -def load(keys: list[str], values: list[bytes]) -> dict: - return {k: orjson.loads(v) if v is not None else None for k, v in zip(keys, values)} - - -class RedisCheckpoint(BaseCheckpointSaver): - class Config: - arbitrary_types_allowed = True - - @property - def config_specs(self) -> list[ConfigurableFieldSpec]: - return [ - ConfigurableFieldSpec( - id="user_id", - annotation=Optional[str], - name="User ID", - description=None, - default=None, - is_shared=True, - ), - ConfigurableFieldSpec( - id="thread_id", - annotation=Optional[str], - name="Thread ID", - description=None, - default=None, - is_shared=True, - ), - ] - - def _dump(self, mapping: dict[str, Any]) -> dict: - return { - k: pickle.dumps(v) if v is not None else None for k, v in mapping.items() - } - - def _load(self, mapping: dict[bytes, bytes]) -> dict: - return { - k.decode(): pickle.loads(v) if v is not None else None - for k, v in mapping.items() - } - - def _hash_key(self, config: RunnableConfig) -> str: - user_id = config["configurable"]["user_id"] - thread_id = config["configurable"]["thread_id"] - return f"opengpts:{user_id}:thread:{thread_id}:checkpoint" - - def get(self, config: RunnableConfig) -> Checkpoint | None: - value = self._load(redis_client.hgetall(self._hash_key(config))) - if value.get("v") == 1: - # langgraph version 1 - return value - elif value.get("__pregel_version") == 1: - # permchain version 1 - value.pop("__pregel_version") - value.pop("__pregel_ts") - checkpoint = empty_checkpoint() - if value.get("messages"): - checkpoint["channel_values"] = {"__root__": value["messages"][1]} - else: - checkpoint["channel_values"] = {} - for key in checkpoint["channel_values"]: - checkpoint["channel_versions"][key] = 1 - return checkpoint - else: - # unknown version - return None - - def put(self, config: RunnableConfig, checkpoint: Checkpoint) -> None: - return redis_client.hmset(self._hash_key(config), self._dump(checkpoint)) - - -async def migrate_assistants(conn: asyncpg.Connection) -> None: - logger.info("Migrating assistants.") - - for key in keys("opengpts:*:assistant:*"): - parts = key.split(":") - user_id, assistant_id = parts[1], parts[3] - if user_id == public_user_id: - continue - - values = redis_client.hmget(key, *assistant_hash_keys) - assistant = load(assistant_hash_keys, values) if any(values) else None - if assistant is not None: - await conn.execute( - ( - "INSERT INTO assistant (assistant_id, user_id, name, config, updated_at, public) " - "VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (assistant_id) DO UPDATE SET " - "user_id = EXCLUDED.user_id, name = EXCLUDED.name, config = EXCLUDED.config, " - "updated_at = EXCLUDED.updated_at, public = EXCLUDED.public;" - ), - assistant_id, - user_id, - assistant["name"], - assistant["config"], - datetime.fromisoformat(assistant["updated_at"]), - assistant["public"], - ) - logger.info(f"Migrated assistant {assistant_id} for user {user_id}.") - - -async def migrate_threads(conn: asyncpg.Connection) -> None: - logger.info("Migrating threads.") - - for key in keys("opengpts:*:thread:*"): - if key.endswith(":checkpoint"): - continue - - parts = key.split(":") - user_id, thread_id = parts[1], parts[3] - - values = redis_client.hmget(key, *thread_hash_keys) - thread = load(thread_hash_keys, values) if any(values) else None - if thread is not None: - await conn.execute( - ( - "INSERT INTO thread (thread_id, assistant_id, user_id, name, updated_at) " - "VALUES ($1, $2, $3, $4, $5) ON CONFLICT (thread_id) DO UPDATE SET " - "assistant_id = EXCLUDED.assistant_id, user_id = EXCLUDED.user_id, " - "name = EXCLUDED.name, updated_at = EXCLUDED.updated_at;" - ), - thread_id, - thread["assistant_id"], - user_id, - thread["name"], - datetime.fromisoformat(thread["updated_at"]), - ) - logger.info(f"Migrated thread {thread_id} for user {user_id}.") - - -async def migrate_checkpoints() -> None: - logger.info("Migrating checkpoints.") - - redis_checkpoint = RedisCheckpoint() - postgres_checkpoint = PostgresCheckpoint() - - for key in keys("opengpts:*:thread:*:checkpoint"): - parts = key.split(":") - user_id, thread_id = parts[1], parts[3] - config = {"configurable": {"user_id": user_id, "thread_id": thread_id}} - checkpoint = redis_checkpoint.get(config) - if checkpoint: - if checkpoint.get("channel_values", {}).get("__root__"): - checkpoint["channel_values"]["__root__"] = [ - msg.__class__(**msg.__dict__) - for msg in checkpoint["channel_values"]["__root__"] - ] - await postgres_checkpoint.aput(config, checkpoint) - logger.info( - f"Migrated checkpoint for thread {thread_id} for user {user_id}." - ) - - -async def migrate_embeddings(conn: asyncpg.Connection) -> None: - logger.info("Migrating embeddings.") - - custom_ids = defaultdict(lambda: str(uuid.uuid4())) - - def _get_custom_id(doc: dict) -> str: - """custom_id is unique for each namespace.""" - return custom_ids[doc["namespace"]] - - def _redis_to_postgres_vector(binary_data: bytes) -> list[float]: - """Deserialize binary data to a list of floats.""" - assert len(binary_data) == 4 * 1536, "Invalid binary data length." - format_str = "<" + "1536f" - return list(struct.unpack(format_str, binary_data)) - - def _load_doc(values: list) -> Optional[str]: - doc = {} - for k, v in zip(embedding_hash_keys, values): - if k == "content_vector": - doc[k] = _redis_to_postgres_vector(v) - else: - doc[k] = v.decode() if v is not None else None - return doc - - def _get_cmetadata(doc: dict) -> str: - return json.dumps( - { - "source": doc["source"] if doc["source"] else None, - "namespace": doc["namespace"], - "title": doc["title"], - } - ) - - def _get_document(doc: dict) -> str: - """Sanitize the content by replacing null bytes.""" - return doc["content"].replace("\x00", "x") - - def _get_embedding(doc: dict) -> str: - return str(doc["content_vector"]) - - default_collection = await conn.fetchrow( - "SELECT uuid FROM langchain_pg_collection WHERE name = $1;", "langchain" - ) - assert ( - default_collection is not None - ), "Default collection not found in the database." - - for key in keys("doc:*"): - values = redis_client.hmget(key, *embedding_hash_keys) - doc = _load_doc(values) - await conn.execute( - ( - "INSERT INTO langchain_pg_embedding (document, collection_id, cmetadata, custom_id, embedding, uuid) " - "VALUES ($1, $2, $3, $4, $5, $6);" - ), - _get_document(doc), - default_collection["uuid"], - _get_cmetadata(doc), - _get_custom_id(doc), - _get_embedding(doc), - str(uuid.uuid4()), - ) - logger.info(f"Migrated embedding for namespace {doc['namespace']}.") - - -async def migrate_data(): - logger.info("Starting to migrate data from Redis to Postgres.") - async with get_pg_pool().acquire() as conn, conn.transaction(): - await migrate_assistants(conn) - await migrate_threads(conn) - await migrate_checkpoints() - await migrate_embeddings(conn) - logger.info("Data was migrated successfully.") - - -async def main(): - async with lifespan(app): - await migrate_data() - - -if __name__ == "__main__": - asyncio.run(main())