-
Notifications
You must be signed in to change notification settings - Fork 859
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
404 additions
and
332 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from fastapi import APIRouter | ||
|
||
from app.api.assistants import router as assistants_router | ||
from app.api.runs import router as runs_router | ||
from app.api.threads import router as threads_router | ||
|
||
router = APIRouter() | ||
|
||
router.include_router( | ||
assistants_router, | ||
prefix="/assistants", | ||
tags=["assistants"], | ||
) | ||
router.include_router( | ||
runs_router, | ||
prefix="/runs", | ||
tags=["runs"], | ||
) | ||
router.include_router( | ||
threads_router, | ||
prefix="/threads", | ||
tags=["threads"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from typing import Annotated, List, Optional | ||
|
||
from fastapi import APIRouter, Cookie, Path, Query | ||
from pydantic import BaseModel, Field | ||
|
||
import app.storage as storage | ||
from app.schema import Assistant, AssistantWithoutUserId, OpengptsUserId | ||
|
||
router = APIRouter() | ||
|
||
FEATURED_PUBLIC_ASSISTANTS = [ | ||
"ba721964-b7e4-474c-b817-fb089d94dc5f", | ||
"dc3ec482-aafc-4d90-8a1a-afb9b2876cde", | ||
] | ||
|
||
|
||
@router.get("/") | ||
def list_assistants(opengpts_user_id: OpengptsUserId) -> List[AssistantWithoutUserId]: | ||
"""List all assistants for the current user.""" | ||
return storage.list_assistants(opengpts_user_id) | ||
|
||
|
||
@router.get("/public/") | ||
def list_public_assistants( | ||
shared_id: Annotated[ | ||
Optional[str], Query(description="ID of a publicly shared assistant.") | ||
] = None, | ||
) -> List[AssistantWithoutUserId]: | ||
"""List all public assistants.""" | ||
return storage.list_public_assistants( | ||
FEATURED_PUBLIC_ASSISTANTS + ([shared_id] if shared_id else []) | ||
) | ||
|
||
|
||
class AssistantPayload(BaseModel): | ||
"""Payload for creating an assistant.""" | ||
|
||
name: str = Field(..., description="The name of the assistant.") | ||
config: dict = Field(..., description="The assistant config.") | ||
public: bool = Field(default=False, description="Whether the assistant is public.") | ||
|
||
|
||
AssistantID = Annotated[str, Path(description="The ID of the assistant.")] | ||
|
||
|
||
@router.put("/{aid}") | ||
def put_assistant( | ||
opengpts_user_id: Annotated[str, Cookie()], | ||
aid: AssistantID, | ||
payload: AssistantPayload, | ||
) -> Assistant: | ||
"""Create or update an assistant.""" | ||
return storage.put_assistant( | ||
opengpts_user_id, | ||
aid, | ||
name=payload.name, | ||
config=payload.config, | ||
public=payload.public, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
import asyncio | ||
import json | ||
from typing import AsyncIterator, Sequence | ||
from uuid import uuid4 | ||
|
||
import langsmith.client | ||
import orjson | ||
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request | ||
from fastapi.exceptions import RequestValidationError | ||
from gizmo_agent import agent | ||
from langchain.pydantic_v1 import ValidationError | ||
from langchain.schema.messages import AnyMessage, FunctionMessage | ||
from langchain.schema.output import ChatGeneration | ||
from langchain.schema.runnable import RunnableConfig | ||
from langserve.callbacks import AsyncEventAggregatorCallback | ||
from langserve.schema import FeedbackCreateRequest | ||
from langserve.serialization import WellKnownLCSerializer | ||
from langserve.server import _get_base_run_id_as_str, _unpack_input | ||
from langsmith.utils import tracing_is_enabled | ||
from pydantic import BaseModel | ||
from sse_starlette import EventSourceResponse | ||
|
||
from app.schema import OpengptsUserId | ||
from app.storage import get_assistant, get_thread_messages, public_user_id | ||
from app.stream import StreamMessagesHandler | ||
|
||
router = APIRouter() | ||
|
||
|
||
_serializer = WellKnownLCSerializer() | ||
|
||
|
||
class AgentInput(BaseModel): | ||
"""An input into an agent.""" | ||
|
||
messages: Sequence[AnyMessage] | ||
|
||
|
||
class CreateRunPayload(BaseModel): | ||
"""Payload for creating a run.""" | ||
|
||
assistant_id: str | ||
thread_id: str | ||
stream: bool | ||
# TODO make optional | ||
input: AgentInput | ||
|
||
|
||
@router.post("") | ||
async def create_run( | ||
request: Request, | ||
payload: CreateRunPayload, # for openapi docs | ||
opengpts_user_id: OpengptsUserId, | ||
background_tasks: BackgroundTasks, | ||
): | ||
"""Create a run.""" | ||
try: | ||
body = await request.json() | ||
except json.JSONDecodeError: | ||
raise RequestValidationError(errors=["Invalid JSON body"]) | ||
assistant, public_assistant, state = await asyncio.gather( | ||
asyncio.get_running_loop().run_in_executor( | ||
None, get_assistant, opengpts_user_id, body["assistant_id"] | ||
), | ||
asyncio.get_running_loop().run_in_executor( | ||
None, get_assistant, public_user_id, body["assistant_id"] | ||
), | ||
asyncio.get_running_loop().run_in_executor( | ||
None, get_thread_messages, opengpts_user_id, body["thread_id"] | ||
), | ||
) | ||
assistant = assistant or public_assistant | ||
if not assistant: | ||
raise HTTPException(status_code=404, detail="Assistant not found") | ||
config: RunnableConfig = { | ||
**assistant["config"], | ||
"configurable": { | ||
**assistant["config"]["configurable"], | ||
"user_id": opengpts_user_id, | ||
"thread_id": body["thread_id"], | ||
"assistant_id": body["assistant_id"], | ||
}, | ||
} | ||
try: | ||
input_ = _unpack_input(agent.get_input_schema(config).validate(body["input"])) | ||
except ValidationError as e: | ||
raise RequestValidationError(e.errors(), body=body) | ||
if body["stream"]: | ||
streamer = StreamMessagesHandler(state["messages"] + input_["messages"]) | ||
event_aggregator = AsyncEventAggregatorCallback() | ||
config["callbacks"] = [streamer, event_aggregator] | ||
|
||
# Call the runnable in streaming mode, | ||
# add each chunk to the output stream | ||
async def consume_astream() -> None: | ||
try: | ||
async for chunk in agent.astream(input_, config): | ||
await streamer.send_stream.send(chunk) | ||
# hack: function messages aren't generated by chat model | ||
# so the callback handler doesn't know about them | ||
message = chunk["messages"][-1] | ||
if isinstance(message, FunctionMessage): | ||
streamer.output[uuid4()] = ChatGeneration(message=message) | ||
except Exception as e: | ||
await streamer.send_stream.send(e) | ||
finally: | ||
await streamer.send_stream.aclose() | ||
|
||
# Start the runnable in the background | ||
task = asyncio.create_task(consume_astream()) | ||
|
||
# Consume the stream into an EventSourceResponse | ||
async def _stream() -> AsyncIterator[dict]: | ||
has_sent_metadata = False | ||
|
||
async for chunk in streamer.receive_stream: | ||
if isinstance(chunk, BaseException): | ||
yield { | ||
"event": "error", | ||
# Do not expose the error message to the client since | ||
# the message may contain sensitive information. | ||
# We'll add client side errors for validation as well. | ||
"data": orjson.dumps( | ||
{"status_code": 500, "message": "Internal Server Error"} | ||
).decode(), | ||
} | ||
raise chunk | ||
else: | ||
if not has_sent_metadata and event_aggregator.callback_events: | ||
yield { | ||
"event": "metadata", | ||
"data": orjson.dumps( | ||
{"run_id": _get_base_run_id_as_str(event_aggregator)} | ||
).decode(), | ||
} | ||
has_sent_metadata = True | ||
|
||
yield { | ||
# EventSourceResponse expects a string for data | ||
# so after serializing into bytes, we decode into utf-8 | ||
# to get a string. | ||
"data": _serializer.dumps(chunk).decode("utf-8"), | ||
"event": "data", | ||
} | ||
|
||
# Send an end event to signal the end of the stream | ||
yield {"event": "end"} | ||
# Wait for the runnable to finish | ||
await task | ||
|
||
return EventSourceResponse(_stream()) | ||
else: | ||
background_tasks.add_task(agent.ainvoke, input_, config) | ||
return {"status": "ok"} # TODO add a run id | ||
|
||
|
||
@router.get("/input_schema") | ||
async def input_schema() -> dict: | ||
"""Return the input schema of the runnable.""" | ||
return agent.get_input_schema().schema() | ||
|
||
|
||
@router.get("/output_schema") | ||
async def output_schema() -> dict: | ||
"""Return the output schema of the runnable.""" | ||
return agent.get_output_schema().schema() | ||
|
||
|
||
@router.get("/config_schema") | ||
async def config_schema() -> dict: | ||
"""Return the config schema of the runnable.""" | ||
return agent.config_schema().schema() | ||
|
||
|
||
if tracing_is_enabled(): | ||
langsmith_client = langsmith.client.Client() | ||
|
||
@router.post("/feedback") | ||
def create_run_feedback(feedback_create_req: FeedbackCreateRequest) -> dict: | ||
""" | ||
Send feedback on an individual run to langsmith | ||
Note that a successful response means that feedback was successfully | ||
submitted. It does not guarantee that the feedback is recorded by | ||
langsmith. Requests may be silently rejected if they are | ||
unauthenticated or invalid by the server. | ||
""" | ||
|
||
langsmith_client.create_feedback( | ||
feedback_create_req.run_id, | ||
feedback_create_req.key, | ||
score=feedback_create_req.score, | ||
value=feedback_create_req.value, | ||
comment=feedback_create_req.comment, | ||
source_info={ | ||
"from_langserve": True, | ||
}, | ||
) | ||
|
||
return {"status": "ok"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from typing import Annotated, List | ||
|
||
from fastapi import APIRouter, Path | ||
from pydantic import BaseModel, Field | ||
|
||
import app.storage as storage | ||
from app.schema import OpengptsUserId, Thread, ThreadWithoutUserId | ||
|
||
router = APIRouter() | ||
|
||
|
||
ThreadID = Annotated[str, Path(description="The ID of the thread.")] | ||
|
||
|
||
class ThreadPutRequest(BaseModel): | ||
"""Payload for creating a thread.""" | ||
|
||
name: str = Field(..., description="The name of the thread.") | ||
assistant_id: str = Field(..., description="The ID of the assistant to use.") | ||
|
||
|
||
@router.get("/") | ||
def list_threads_endpoint( | ||
opengpts_user_id: OpengptsUserId | ||
) -> List[ThreadWithoutUserId]: | ||
"""List all threads for the current user.""" | ||
return storage.list_threads(opengpts_user_id) | ||
|
||
|
||
@router.get("/{tid}/messages") | ||
def get_thread_messages_endpoint( | ||
opengpts_user_id: OpengptsUserId, | ||
tid: ThreadID, | ||
): | ||
"""Get all messages for a thread.""" | ||
return storage.get_thread_messages(opengpts_user_id, tid) | ||
|
||
|
||
@router.put("/{tid}") | ||
def put_thread_endpoint( | ||
opengpts_user_id: OpengptsUserId, | ||
tid: ThreadID, | ||
thread_put_request: ThreadPutRequest, | ||
) -> Thread: | ||
"""Update a thread.""" | ||
return storage.put_thread( | ||
opengpts_user_id, | ||
tid, | ||
assistant_id=thread_put_request.assistant_id, | ||
name=thread_put_request.name, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.