From b1fab0f497dce048a8e7a9e8542c8a94ea4b6a89 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 17 Apr 2024 17:05:18 -0700 Subject: [PATCH] Fix retrieval bot --- backend/app/retrieval.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/backend/app/retrieval.py b/backend/app/retrieval.py index 4a713129..52c8266e 100644 --- a/backend/app/retrieval.py +++ b/backend/app/retrieval.py @@ -1,4 +1,4 @@ -import json +from uuid import uuid4 from langchain_core.language_models.base import LanguageModelLike from langchain_core.messages import AIMessage, HumanMessage, SystemMessage @@ -69,35 +69,38 @@ async def get_search_query(messages): conversation = "\n".join(convo) prompt = await search_prompt.ainvoke({"conversation": conversation}) response = await llm.ainvoke(prompt) - return response.content + return response async def invoke_retrieval(messages): if len(messages) == 1: human_input = messages[-1].content return AIMessage( content="", - additional_kwargs={ - "function_call": { + tool_calls=[ + { + "id": uuid4().hex, "name": "retrieval", - "arguments": json.dumps({"query": human_input}), + "args": {"query": human_input}, } - }, + ], ) else: search_query = await get_search_query.ainvoke(messages) return AIMessage( + id=search_query.id, content="", - additional_kwargs={ - "function_call": { + tool_calls=[ + { + "id": uuid4().hex, "name": "retrieval", - "arguments": json.dumps({"query": search_query}), + "args": {"query": search_query.content}, } - }, + ], ) async def retrieve(messages): - params = messages[-1].additional_kwargs["function_call"] - query = json.loads(params["arguments"])["query"] + params = messages[-1].tool_calls[0] + query = params["args"]["query"] response = await retriever.ainvoke(query) msg = LiberalFunctionMessage(name="retrieval", content=response) return msg