Skip to content

Commit

Permalink
Reorganize api
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Nov 20, 2023
1 parent 1a6b46c commit 9391065
Show file tree
Hide file tree
Showing 15 changed files with 404 additions and 332 deletions.
23 changes: 23 additions & 0 deletions backend/app/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from fastapi import APIRouter

from app.api.assistants import router as assistants_router
from app.api.runs import router as runs_router
from app.api.threads import router as threads_router

router = APIRouter()

router.include_router(
assistants_router,
prefix="/assistants",
tags=["assistants"],
)
router.include_router(
runs_router,
prefix="/runs",
tags=["runs"],
)
router.include_router(
threads_router,
prefix="/threads",
tags=["threads"],
)
59 changes: 59 additions & 0 deletions backend/app/api/assistants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Annotated, List, Optional

from fastapi import APIRouter, Cookie, Path, Query
from pydantic import BaseModel, Field

import app.storage as storage
from app.schema import Assistant, AssistantWithoutUserId, OpengptsUserId

router = APIRouter()

FEATURED_PUBLIC_ASSISTANTS = [
"ba721964-b7e4-474c-b817-fb089d94dc5f",
"dc3ec482-aafc-4d90-8a1a-afb9b2876cde",
]


@router.get("/")
def list_assistants(opengpts_user_id: OpengptsUserId) -> List[AssistantWithoutUserId]:
"""List all assistants for the current user."""
return storage.list_assistants(opengpts_user_id)


@router.get("/public/")
def list_public_assistants(
shared_id: Annotated[
Optional[str], Query(description="ID of a publicly shared assistant.")
] = None,
) -> List[AssistantWithoutUserId]:
"""List all public assistants."""
return storage.list_public_assistants(
FEATURED_PUBLIC_ASSISTANTS + ([shared_id] if shared_id else [])
)


class AssistantPayload(BaseModel):
"""Payload for creating an assistant."""

name: str = Field(..., description="The name of the assistant.")
config: dict = Field(..., description="The assistant config.")
public: bool = Field(default=False, description="Whether the assistant is public.")


AssistantID = Annotated[str, Path(description="The ID of the assistant.")]


@router.put("/{aid}")
def put_assistant(
opengpts_user_id: Annotated[str, Cookie()],
aid: AssistantID,
payload: AssistantPayload,
) -> Assistant:
"""Create or update an assistant."""
return storage.put_assistant(
opengpts_user_id,
aid,
name=payload.name,
config=payload.config,
public=payload.public,
)
200 changes: 200 additions & 0 deletions backend/app/api/runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import asyncio
import json
from typing import AsyncIterator, Sequence
from uuid import uuid4

import langsmith.client
import orjson
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from gizmo_agent import agent
from langchain.pydantic_v1 import ValidationError
from langchain.schema.messages import AnyMessage, FunctionMessage
from langchain.schema.output import ChatGeneration
from langchain.schema.runnable import RunnableConfig
from langserve.callbacks import AsyncEventAggregatorCallback
from langserve.schema import FeedbackCreateRequest
from langserve.serialization import WellKnownLCSerializer
from langserve.server import _get_base_run_id_as_str, _unpack_input
from langsmith.utils import tracing_is_enabled
from pydantic import BaseModel
from sse_starlette import EventSourceResponse

from app.schema import OpengptsUserId
from app.storage import get_assistant, get_thread_messages, public_user_id
from app.stream import StreamMessagesHandler

router = APIRouter()


_serializer = WellKnownLCSerializer()


class AgentInput(BaseModel):
"""An input into an agent."""

messages: Sequence[AnyMessage]


class CreateRunPayload(BaseModel):
"""Payload for creating a run."""

assistant_id: str
thread_id: str
stream: bool
# TODO make optional
input: AgentInput


@router.post("")
async def create_run(
request: Request,
payload: CreateRunPayload, # for openapi docs
opengpts_user_id: OpengptsUserId,
background_tasks: BackgroundTasks,
):
"""Create a run."""
try:
body = await request.json()
except json.JSONDecodeError:
raise RequestValidationError(errors=["Invalid JSON body"])
assistant, public_assistant, state = await asyncio.gather(
asyncio.get_running_loop().run_in_executor(
None, get_assistant, opengpts_user_id, body["assistant_id"]
),
asyncio.get_running_loop().run_in_executor(
None, get_assistant, public_user_id, body["assistant_id"]
),
asyncio.get_running_loop().run_in_executor(
None, get_thread_messages, opengpts_user_id, body["thread_id"]
),
)
assistant = assistant or public_assistant
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
config: RunnableConfig = {
**assistant["config"],
"configurable": {
**assistant["config"]["configurable"],
"user_id": opengpts_user_id,
"thread_id": body["thread_id"],
"assistant_id": body["assistant_id"],
},
}
try:
input_ = _unpack_input(agent.get_input_schema(config).validate(body["input"]))
except ValidationError as e:
raise RequestValidationError(e.errors(), body=body)
if body["stream"]:
streamer = StreamMessagesHandler(state["messages"] + input_["messages"])
event_aggregator = AsyncEventAggregatorCallback()
config["callbacks"] = [streamer, event_aggregator]

# Call the runnable in streaming mode,
# add each chunk to the output stream
async def consume_astream() -> None:
try:
async for chunk in agent.astream(input_, config):
await streamer.send_stream.send(chunk)
# hack: function messages aren't generated by chat model
# so the callback handler doesn't know about them
message = chunk["messages"][-1]
if isinstance(message, FunctionMessage):
streamer.output[uuid4()] = ChatGeneration(message=message)
except Exception as e:
await streamer.send_stream.send(e)
finally:
await streamer.send_stream.aclose()

# Start the runnable in the background
task = asyncio.create_task(consume_astream())

# Consume the stream into an EventSourceResponse
async def _stream() -> AsyncIterator[dict]:
has_sent_metadata = False

async for chunk in streamer.receive_stream:
if isinstance(chunk, BaseException):
yield {
"event": "error",
# Do not expose the error message to the client since
# the message may contain sensitive information.
# We'll add client side errors for validation as well.
"data": orjson.dumps(
{"status_code": 500, "message": "Internal Server Error"}
).decode(),
}
raise chunk
else:
if not has_sent_metadata and event_aggregator.callback_events:
yield {
"event": "metadata",
"data": orjson.dumps(
{"run_id": _get_base_run_id_as_str(event_aggregator)}
).decode(),
}
has_sent_metadata = True

yield {
# EventSourceResponse expects a string for data
# so after serializing into bytes, we decode into utf-8
# to get a string.
"data": _serializer.dumps(chunk).decode("utf-8"),
"event": "data",
}

# Send an end event to signal the end of the stream
yield {"event": "end"}
# Wait for the runnable to finish
await task

return EventSourceResponse(_stream())
else:
background_tasks.add_task(agent.ainvoke, input_, config)
return {"status": "ok"} # TODO add a run id


@router.get("/input_schema")
async def input_schema() -> dict:
"""Return the input schema of the runnable."""
return agent.get_input_schema().schema()


@router.get("/output_schema")
async def output_schema() -> dict:
"""Return the output schema of the runnable."""
return agent.get_output_schema().schema()


@router.get("/config_schema")
async def config_schema() -> dict:
"""Return the config schema of the runnable."""
return agent.config_schema().schema()


if tracing_is_enabled():
langsmith_client = langsmith.client.Client()

@router.post("/feedback")
def create_run_feedback(feedback_create_req: FeedbackCreateRequest) -> dict:
"""
Send feedback on an individual run to langsmith
Note that a successful response means that feedback was successfully
submitted. It does not guarantee that the feedback is recorded by
langsmith. Requests may be silently rejected if they are
unauthenticated or invalid by the server.
"""

langsmith_client.create_feedback(
feedback_create_req.run_id,
feedback_create_req.key,
score=feedback_create_req.score,
value=feedback_create_req.value,
comment=feedback_create_req.comment,
source_info={
"from_langserve": True,
},
)

return {"status": "ok"}
51 changes: 51 additions & 0 deletions backend/app/api/threads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Annotated, List

from fastapi import APIRouter, Path
from pydantic import BaseModel, Field

import app.storage as storage
from app.schema import OpengptsUserId, Thread, ThreadWithoutUserId

router = APIRouter()


ThreadID = Annotated[str, Path(description="The ID of the thread.")]


class ThreadPutRequest(BaseModel):
"""Payload for creating a thread."""

name: str = Field(..., description="The name of the thread.")
assistant_id: str = Field(..., description="The ID of the assistant to use.")


@router.get("/")
def list_threads_endpoint(
opengpts_user_id: OpengptsUserId
) -> List[ThreadWithoutUserId]:
"""List all threads for the current user."""
return storage.list_threads(opengpts_user_id)


@router.get("/{tid}/messages")
def get_thread_messages_endpoint(
opengpts_user_id: OpengptsUserId,
tid: ThreadID,
):
"""Get all messages for a thread."""
return storage.get_thread_messages(opengpts_user_id, tid)


@router.put("/{tid}")
def put_thread_endpoint(
opengpts_user_id: OpengptsUserId,
tid: ThreadID,
thread_put_request: ThreadPutRequest,
) -> Thread:
"""Update a thread."""
return storage.put_thread(
opengpts_user_id,
tid,
assistant_id=thread_put_request.assistant_id,
name=thread_put_request.name,
)
14 changes: 14 additions & 0 deletions backend/app/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from datetime import datetime
from typing import Annotated

from fastapi import Cookie
from typing_extensions import TypedDict


Expand Down Expand Up @@ -41,3 +43,15 @@ class Thread(ThreadWithoutUserId):

user_id: str
"""The ID of the user that owns the thread."""


OpengptsUserId = Annotated[
str,
Cookie(
description=(
"A cookie that identifies the user. This is not an authentication "
"mechanism that should be used in an actual production environment that "
"contains sensitive information."
)
),
]
Loading

0 comments on commit 9391065

Please sign in to comment.