Skip to content

Commit

Permalink
greed is good
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Jan 26, 2024
1 parent 87c7631 commit 82b85d9
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 132 deletions.
101 changes: 14 additions & 87 deletions backend/app/api/runs.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,27 @@
import asyncio
import json
from typing import AsyncIterator, Sequence
from uuid import uuid4
from typing import Sequence

import langsmith.client
import orjson
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from langchain.pydantic_v1 import ValidationError
from langchain.schema.messages import AnyMessage, FunctionMessage
from langchain.schema.output import ChatGeneration
from langchain.schema.runnable import RunnableConfig
from langserve.callbacks import AsyncEventAggregatorCallback
from langchain_core.messages import AnyMessage
from langchain_core.runnables import RunnableConfig
from langserve.schema import FeedbackCreateRequest
from langserve.serialization import WellKnownLCSerializer
from langserve.server import _get_base_run_id_as_str, _unpack_input
from langserve.server import _unpack_input
from langsmith.utils import tracing_is_enabled
from pydantic import BaseModel, Field
from sse_starlette import EventSourceResponse

from app.agent import agent
from app.schema import OpengptsUserId
from app.storage import get_assistant, get_thread_messages, public_user_id
from app.stream import StreamMessagesHandler
from app.storage import get_assistant, public_user_id
from app.stream import astream_messages, to_sse

router = APIRouter()


_serializer = WellKnownLCSerializer()


class CreateRunPayload(BaseModel):
"""Payload for creating a run."""

Expand All @@ -43,16 +35,13 @@ async def _run_input_and_config(request: Request, opengpts_user_id: OpengptsUser
body = await request.json()
except json.JSONDecodeError:
raise RequestValidationError(errors=["Invalid JSON body"])
assistant, public_assistant, state = await asyncio.gather(
assistant, public_assistant = await asyncio.gather(
asyncio.get_running_loop().run_in_executor(
None, get_assistant, opengpts_user_id, body["assistant_id"]
),
asyncio.get_running_loop().run_in_executor(
None, get_assistant, public_user_id, body["assistant_id"]
),
asyncio.get_running_loop().run_in_executor(
None, get_thread_messages, opengpts_user_id, body["thread_id"]
),
)
assistant = assistant or public_assistant
if not assistant:
Expand All @@ -71,94 +60,32 @@ async def _run_input_and_config(request: Request, opengpts_user_id: OpengptsUser
except ValidationError as e:
raise RequestValidationError(e.errors(), body=body)

return input_, config, state["messages"]
return input_, config


@router.post("")
async def create_run(
request: Request,
payload: CreateRunPayload, # for openapi docs
request: Request,
opengpts_user_id: OpengptsUserId,
background_tasks: BackgroundTasks,
):
"""Create a run."""
input_, config, messages = await _run_input_and_config(request, opengpts_user_id)
input_, config = await _run_input_and_config(request, opengpts_user_id)
background_tasks.add_task(agent.ainvoke, input_, config)
return {"status": "ok"} # TODO add a run id


@router.post("/stream")
async def stream_run(
request: Request,
payload: CreateRunPayload, # for openapi docs
request: Request,
opengpts_user_id: OpengptsUserId,
):
"""Create a run."""
input_, config, messages = await _run_input_and_config(request, opengpts_user_id)
streamer = StreamMessagesHandler(messages + input_)
event_aggregator = AsyncEventAggregatorCallback()
config["callbacks"] = [streamer, event_aggregator]

# Call the runnable in streaming mode,
# add each chunk to the output stream
async def consume_astream() -> None:
try:
async for chunk in agent.astream(input_, config):
# await streamer.send_stream.send(chunk)
# hack: function messages aren't generated by chat model
# so the callback handler doesn't know about them
if chunk.get("action"):
message = chunk["action"]
if isinstance(message, FunctionMessage):
streamer.output[uuid4()] = ChatGeneration(message=message)
except Exception as e:
await streamer.send_stream.send(e)
finally:
await streamer.send_stream.aclose()

# Start the runnable in the background
task = asyncio.create_task(consume_astream())

# Consume the stream into an EventSourceResponse
async def _stream() -> AsyncIterator[dict]:
has_sent_metadata = False

async for chunk in streamer.receive_stream:
if isinstance(chunk, BaseException):
yield {
"event": "error",
# Do not expose the error message to the client since
# the message may contain sensitive information.
# We'll add client side errors for validation as well.
"data": orjson.dumps(
{"status_code": 500, "message": "Internal Server Error"}
).decode(),
}
raise chunk
else:
if not has_sent_metadata and event_aggregator.callback_events:
yield {
"event": "metadata",
"data": orjson.dumps(
{"run_id": _get_base_run_id_as_str(event_aggregator)}
).decode(),
}
has_sent_metadata = True

yield {
# EventSourceResponse expects a string for data
# so after serializing into bytes, we decode into utf-8
# to get a string.
"data": _serializer.dumps(chunk).decode("utf-8"),
"event": "data",
}

# Send an end event to signal the end of the stream
yield {"event": "end"}
# Wait for the runnable to finish
await task

return EventSourceResponse(_stream())
input_, config = await _run_input_and_config(request, opengpts_user_id)

return EventSourceResponse(to_sse(astream_messages(agent, input_, config)))


@router.get("/input_schema")
Expand Down
7 changes: 6 additions & 1 deletion backend/app/storage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from datetime import datetime
from typing import List, Sequence
from app.stream import map_chunk_to_msg

import orjson
from langchain.schema.messages import AnyMessage
Expand Down Expand Up @@ -155,7 +156,11 @@ def get_thread_messages(user_id: str, thread_id: str):
app = get_agent_executor([], AgentType.GPT_35_TURBO, "")
checkpoint = app.checkpointer.get(config) or empty_checkpoint()
with ChannelsManager(app.channels, checkpoint) as channels:
return {"messages": channels[MESSAGES_CHANNEL_NAME].get()}
return {
"messages": [
map_chunk_to_msg(msg) for msg in channels[MESSAGES_CHANNEL_NAME].get()
]
}


def post_thread_messages(user_id: str, thread_id: str, messages: Sequence[AnyMessage]):
Expand Down
121 changes: 77 additions & 44 deletions backend/app/stream.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import math
from typing import Any, Dict, Optional, Sequence, Union
from uuid import UUID
from typing import AsyncIterator, Optional, Sequence, Union

from anyio import create_memory_object_stream
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema.messages import (
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
Expand All @@ -15,50 +11,50 @@
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
AnyMessage,
)
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
from langchain_core.runnables import Runnable, RunnableConfig
from langserve.serialization import WellKnownLCSerializer
import orjson

MessagesStream = AsyncIterator[Union[list[AnyMessage], str]]

class StreamMessagesHandler(BaseCallbackHandler):
def __init__(self, messages: Sequence[BaseMessage]) -> None:
self.messages = messages
self.output: Dict[UUID, ChatGenerationChunk] = {}
send_stream, receive_stream = create_memory_object_stream(math.inf)
self.send_stream = send_stream
self.receive_stream = receive_stream

def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
**kwargs: Any,
) -> Any:
# If this is being called for a non-Chat Model run, convert to AIMessage
if chunk is None:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
# If we get something we don't know how to handle, ignore it
if not (
isinstance(chunk, ChatGenerationChunk)
or isinstance(chunk, BaseMessageChunk)
async def astream_messages(
app: Runnable, input: Sequence[AnyMessage], config: RunnableConfig
) -> MessagesStream:
"""Stream messages from the runnable."""
root_run_id: Optional[str] = None
last_messages_list: Optional[list[AnyMessage]] = None
last_stream_run_id: Optional[str] = None

async for event in app.astream_events(
input, config, version="v1", output_keys=["__root__"]
):
if event["event"] == "on_chain_start" and not root_run_id:
root_run_id = event["run_id"]

yield root_run_id
elif event["event"] == "on_chain_stream" and event["run_id"] == root_run_id:
last_messages_list = event["data"]["chunk"]["__root__"]

yield last_messages_list
elif (
event["event"] == "on_chat_model_stream" and last_messages_list is not None
):
return
# Convert messages to ChatGenerationChunks (workaround for old langchahin)
if isinstance(chunk, BaseMessageChunk):
chunk = ChatGenerationChunk(message=chunk)
# Accumulate the output (ChatGenerationChunk implements __add__)
if not self.output.get(run_id):
self.output[run_id] = chunk
else:
self.output[run_id] += chunk
# Send the messages to the stream
self.send_stream.send_nowait(
(
self.messages
+ [map_chunk_to_msg(chunk.message) for chunk in self.output.values()]
is_new_stream_run = (
last_stream_run_id is None or last_stream_run_id != event["run_id"]
)
)
is_diff_msg_type = last_messages_list and type( # noqa: E721
last_messages_list[-1]
) != type(event["data"]["chunk"])
if is_new_stream_run or is_diff_msg_type:
last_stream_run_id = event["run_id"]
last_messages_list.append(event["data"]["chunk"])
else:
last_messages_list[-1] = last_messages_list[-1] + event["data"]["chunk"]

yield last_messages_list


def map_chunk_to_msg(chunk: BaseMessageChunk) -> BaseMessage:
Expand All @@ -75,3 +71,40 @@ def map_chunk_to_msg(chunk: BaseMessageChunk) -> BaseMessage:
return ChatMessage(**args)
else:
raise ValueError(f"Unknown chunk type: {chunk}")


_serializer = WellKnownLCSerializer()


async def to_sse(messages_stream: MessagesStream) -> AsyncIterator[dict]:
"""Consume the stream into an EventSourceResponse"""
try:
async for chunk in messages_stream:
# EventSourceResponse expects a string for data
# so after serializing into bytes, we decode into utf-8
# to get a string.
if isinstance(chunk, str):
yield {
"event": "metadata",
"data": orjson.dumps({"run_id": chunk}).decode(),
}
else:
yield {
"event": "data",
"data": _serializer.dumps(
[map_chunk_to_msg(msg) for msg in chunk]
).decode(),
}
except Exception:
yield {
"event": "error",
# Do not expose the error message to the client since
# the message may contain sensitive information.
# We'll add client side errors for validation as well.
"data": orjson.dumps(
{"status_code": 500, "message": "Internal Server Error"}
).decode(),
}

# Send an end event to signal the end of the stream
yield {"event": "end"}

0 comments on commit 82b85d9

Please sign in to comment.