diff --git a/backend/app/server.py b/backend/app/server.py index a1f09ede..ca1e63de 100644 --- a/backend/app/server.py +++ b/backend/app/server.py @@ -2,6 +2,7 @@ import json from pathlib import Path from typing import Annotated, AsyncIterator, Optional, Sequence +from uuid import uuid4 from fastapi.exceptions import RequestValidationError import orjson @@ -10,7 +11,8 @@ from gizmo_agent import agent, ingest_runnable from langserve.callbacks import AsyncEventAggregatorCallback from langchain.pydantic_v1 import ValidationError -from langchain.schema.messages import AnyMessage +from langchain.schema.messages import AnyMessage, FunctionMessage +from langchain.schema.output import ChatGeneration from langchain.schema.runnable import RunnableConfig from langserve import add_routes from langserve.server import _get_base_run_id_as_str, _unpack_input @@ -116,6 +118,11 @@ async def consume_astream() -> None: try: async for chunk in agent.astream(input_, config): await streamer.send_stream.send(chunk) + # hack: function messages aren't generated by chat model + # so the callback handler doesn't know about them + message = chunk["messages"][-1] + if isinstance(message, FunctionMessage): + streamer.output[uuid4()] = ChatGeneration(message=message) except Exception as e: await streamer.send_stream.send(e) finally: diff --git a/backend/app/stream.py b/backend/app/stream.py index 0cda436d..96af0af6 100644 --- a/backend/app/stream.py +++ b/backend/app/stream.py @@ -69,6 +69,8 @@ def on_llm_new_token( def map_chunk_to_msg(chunk: BaseMessageChunk) -> BaseMessage: + if not isinstance(chunk, BaseMessageChunk): + return chunk args = {k: v for k, v in chunk.__dict__.items() if k != "type"} if isinstance(chunk, HumanMessageChunk): return HumanMessage(**args)