-
Notifications
You must be signed in to change notification settings - Fork 859
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #45 from langchain-ai/nc/permchain
WIP Use permchain agent executor
- Loading branch information
Showing
38 changed files
with
1,565 additions
and
1,389 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from fastapi import APIRouter | ||
|
||
from app.api.assistants import router as assistants_router | ||
from app.api.runs import router as runs_router | ||
from app.api.threads import router as threads_router | ||
|
||
router = APIRouter() | ||
|
||
router.include_router( | ||
assistants_router, | ||
prefix="/assistants", | ||
tags=["assistants"], | ||
) | ||
router.include_router( | ||
runs_router, | ||
prefix="/runs", | ||
tags=["runs"], | ||
) | ||
router.include_router( | ||
threads_router, | ||
prefix="/threads", | ||
tags=["threads"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from typing import Annotated, List, Optional | ||
from uuid import uuid4 | ||
|
||
from fastapi import APIRouter, HTTPException, Path, Query | ||
from pydantic import BaseModel, Field | ||
|
||
import app.storage as storage | ||
from app.schema import Assistant, AssistantWithoutUserId, OpengptsUserId | ||
|
||
router = APIRouter() | ||
|
||
FEATURED_PUBLIC_ASSISTANTS = [ | ||
"ba721964-b7e4-474c-b817-fb089d94dc5f", | ||
"dc3ec482-aafc-4d90-8a1a-afb9b2876cde", | ||
] | ||
|
||
|
||
class AssistantPayload(BaseModel): | ||
"""Payload for creating an assistant.""" | ||
|
||
name: str = Field(..., description="The name of the assistant.") | ||
config: dict = Field(..., description="The assistant config.") | ||
public: bool = Field(default=False, description="Whether the assistant is public.") | ||
|
||
|
||
AssistantID = Annotated[str, Path(description="The ID of the assistant.")] | ||
|
||
|
||
@router.get("/") | ||
def list_assistants(opengpts_user_id: OpengptsUserId) -> List[AssistantWithoutUserId]: | ||
"""List all assistants for the current user.""" | ||
return storage.list_assistants(opengpts_user_id) | ||
|
||
|
||
@router.get("/public/") | ||
def list_public_assistants( | ||
shared_id: Annotated[ | ||
Optional[str], Query(description="ID of a publicly shared assistant.") | ||
] = None, | ||
) -> List[AssistantWithoutUserId]: | ||
"""List all public assistants.""" | ||
return storage.list_public_assistants( | ||
FEATURED_PUBLIC_ASSISTANTS + ([shared_id] if shared_id else []) | ||
) | ||
|
||
|
||
@router.get("/{aid}") | ||
def get_asistant( | ||
opengpts_user_id: OpengptsUserId, | ||
aid: AssistantID, | ||
) -> Assistant: | ||
"""Get an assistant by ID.""" | ||
assistant = storage.get_assistant(opengpts_user_id, aid) | ||
if not assistant: | ||
raise HTTPException(status_code=404, detail="Assistant not found") | ||
return assistant | ||
|
||
|
||
@router.post("") | ||
def create_assistant( | ||
opengpts_user_id: OpengptsUserId, | ||
payload: AssistantPayload, | ||
) -> Assistant: | ||
"""Create an assistant.""" | ||
return storage.put_assistant( | ||
opengpts_user_id, | ||
str(uuid4()), | ||
name=payload.name, | ||
config=payload.config, | ||
public=payload.public, | ||
) | ||
|
||
|
||
@router.put("/{aid}") | ||
def upsert_assistant( | ||
opengpts_user_id: OpengptsUserId, | ||
aid: AssistantID, | ||
payload: AssistantPayload, | ||
) -> Assistant: | ||
"""Create or update an assistant.""" | ||
return storage.put_assistant( | ||
opengpts_user_id, | ||
aid, | ||
name=payload.name, | ||
config=payload.config, | ||
public=payload.public, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
import asyncio | ||
import json | ||
from typing import AsyncIterator, Sequence | ||
from uuid import uuid4 | ||
|
||
import langsmith.client | ||
import orjson | ||
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request | ||
from fastapi.exceptions import RequestValidationError | ||
from gizmo_agent import agent | ||
from langchain.pydantic_v1 import ValidationError | ||
from langchain.schema.messages import AnyMessage, FunctionMessage | ||
from langchain.schema.output import ChatGeneration | ||
from langchain.schema.runnable import RunnableConfig | ||
from langserve.callbacks import AsyncEventAggregatorCallback | ||
from langserve.schema import FeedbackCreateRequest | ||
from langserve.serialization import WellKnownLCSerializer | ||
from langserve.server import _get_base_run_id_as_str, _unpack_input | ||
from langsmith.utils import tracing_is_enabled | ||
from pydantic import BaseModel, Field | ||
from sse_starlette import EventSourceResponse | ||
|
||
from app.schema import OpengptsUserId | ||
from app.storage import get_assistant, get_thread_messages, public_user_id | ||
from app.stream import StreamMessagesHandler | ||
|
||
router = APIRouter() | ||
|
||
|
||
_serializer = WellKnownLCSerializer() | ||
|
||
|
||
class AgentInput(BaseModel): | ||
"""An input into an agent.""" | ||
|
||
messages: Sequence[AnyMessage] = Field(default_factory=list) | ||
|
||
|
||
class CreateRunPayload(BaseModel): | ||
"""Payload for creating a run.""" | ||
|
||
assistant_id: str | ||
thread_id: str | ||
input: AgentInput = Field(default_factory=AgentInput) | ||
|
||
|
||
async def _run_input_and_config(request: Request, opengpts_user_id: OpengptsUserId): | ||
try: | ||
body = await request.json() | ||
except json.JSONDecodeError: | ||
raise RequestValidationError(errors=["Invalid JSON body"]) | ||
assistant, public_assistant, state = await asyncio.gather( | ||
asyncio.get_running_loop().run_in_executor( | ||
None, get_assistant, opengpts_user_id, body["assistant_id"] | ||
), | ||
asyncio.get_running_loop().run_in_executor( | ||
None, get_assistant, public_user_id, body["assistant_id"] | ||
), | ||
asyncio.get_running_loop().run_in_executor( | ||
None, get_thread_messages, opengpts_user_id, body["thread_id"] | ||
), | ||
) | ||
assistant = assistant or public_assistant | ||
if not assistant: | ||
raise HTTPException(status_code=404, detail="Assistant not found") | ||
config: RunnableConfig = { | ||
**assistant["config"], | ||
"configurable": { | ||
**assistant["config"]["configurable"], | ||
"user_id": opengpts_user_id, | ||
"thread_id": body["thread_id"], | ||
"assistant_id": body["assistant_id"], | ||
}, | ||
} | ||
try: | ||
input_ = _unpack_input(agent.get_input_schema(config).validate(body["input"])) | ||
except ValidationError as e: | ||
raise RequestValidationError(e.errors(), body=body) | ||
|
||
return input_, config, state["messages"] | ||
|
||
|
||
@router.post("") | ||
async def create_run( | ||
request: Request, | ||
payload: CreateRunPayload, # for openapi docs | ||
opengpts_user_id: OpengptsUserId, | ||
background_tasks: BackgroundTasks, | ||
): | ||
"""Create a run.""" | ||
input_, config, messages = await _run_input_and_config(request, opengpts_user_id) | ||
background_tasks.add_task(agent.ainvoke, input_, config) | ||
return {"status": "ok"} # TODO add a run id | ||
|
||
|
||
@router.post("/stream") | ||
async def stream_run( | ||
request: Request, | ||
payload: CreateRunPayload, # for openapi docs | ||
opengpts_user_id: OpengptsUserId, | ||
): | ||
"""Create a run.""" | ||
input_, config, messages = await _run_input_and_config(request, opengpts_user_id) | ||
streamer = StreamMessagesHandler(messages + input_["messages"]) | ||
event_aggregator = AsyncEventAggregatorCallback() | ||
config["callbacks"] = [streamer, event_aggregator] | ||
|
||
# Call the runnable in streaming mode, | ||
# add each chunk to the output stream | ||
async def consume_astream() -> None: | ||
try: | ||
async for chunk in agent.astream(input_, config): | ||
await streamer.send_stream.send(chunk) | ||
# hack: function messages aren't generated by chat model | ||
# so the callback handler doesn't know about them | ||
if chunk["messages"]: | ||
message = chunk["messages"][-1] | ||
if isinstance(message, FunctionMessage): | ||
streamer.output[uuid4()] = ChatGeneration(message=message) | ||
except Exception as e: | ||
await streamer.send_stream.send(e) | ||
finally: | ||
await streamer.send_stream.aclose() | ||
|
||
# Start the runnable in the background | ||
task = asyncio.create_task(consume_astream()) | ||
|
||
# Consume the stream into an EventSourceResponse | ||
async def _stream() -> AsyncIterator[dict]: | ||
has_sent_metadata = False | ||
|
||
async for chunk in streamer.receive_stream: | ||
if isinstance(chunk, BaseException): | ||
yield { | ||
"event": "error", | ||
# Do not expose the error message to the client since | ||
# the message may contain sensitive information. | ||
# We'll add client side errors for validation as well. | ||
"data": orjson.dumps( | ||
{"status_code": 500, "message": "Internal Server Error"} | ||
).decode(), | ||
} | ||
raise chunk | ||
else: | ||
if not has_sent_metadata and event_aggregator.callback_events: | ||
yield { | ||
"event": "metadata", | ||
"data": orjson.dumps( | ||
{"run_id": _get_base_run_id_as_str(event_aggregator)} | ||
).decode(), | ||
} | ||
has_sent_metadata = True | ||
|
||
yield { | ||
# EventSourceResponse expects a string for data | ||
# so after serializing into bytes, we decode into utf-8 | ||
# to get a string. | ||
"data": _serializer.dumps(chunk).decode("utf-8"), | ||
"event": "data", | ||
} | ||
|
||
# Send an end event to signal the end of the stream | ||
yield {"event": "end"} | ||
# Wait for the runnable to finish | ||
await task | ||
|
||
return EventSourceResponse(_stream()) | ||
|
||
|
||
@router.get("/input_schema") | ||
async def input_schema() -> dict: | ||
"""Return the input schema of the runnable.""" | ||
return agent.get_input_schema().schema() | ||
|
||
|
||
@router.get("/output_schema") | ||
async def output_schema() -> dict: | ||
"""Return the output schema of the runnable.""" | ||
return agent.get_output_schema().schema() | ||
|
||
|
||
@router.get("/config_schema") | ||
async def config_schema() -> dict: | ||
"""Return the config schema of the runnable.""" | ||
return agent.config_schema().schema() | ||
|
||
|
||
if tracing_is_enabled(): | ||
langsmith_client = langsmith.client.Client() | ||
|
||
@router.post("/feedback") | ||
def create_run_feedback(feedback_create_req: FeedbackCreateRequest) -> dict: | ||
""" | ||
Send feedback on an individual run to langsmith | ||
Note that a successful response means that feedback was successfully | ||
submitted. It does not guarantee that the feedback is recorded by | ||
langsmith. Requests may be silently rejected if they are | ||
unauthenticated or invalid by the server. | ||
""" | ||
|
||
langsmith_client.create_feedback( | ||
feedback_create_req.run_id, | ||
feedback_create_req.key, | ||
score=feedback_create_req.score, | ||
value=feedback_create_req.value, | ||
comment=feedback_create_req.comment, | ||
source_info={ | ||
"from_langserve": True, | ||
}, | ||
) | ||
|
||
return {"status": "ok"} |
Oops, something went wrong.