Skip to content

Commit

Permalink
Fix retrieval bot
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Apr 18, 2024
1 parent bb498e3 commit b1fab0f
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions backend/app/retrieval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import json
from uuid import uuid4

from langchain_core.language_models.base import LanguageModelLike
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
Expand Down Expand Up @@ -69,35 +69,38 @@ async def get_search_query(messages):
conversation = "\n".join(convo)
prompt = await search_prompt.ainvoke({"conversation": conversation})
response = await llm.ainvoke(prompt)
return response.content
return response

async def invoke_retrieval(messages):
if len(messages) == 1:
human_input = messages[-1].content
return AIMessage(
content="",
additional_kwargs={
"function_call": {
tool_calls=[
{
"id": uuid4().hex,
"name": "retrieval",
"arguments": json.dumps({"query": human_input}),
"args": {"query": human_input},
}
},
],
)
else:
search_query = await get_search_query.ainvoke(messages)
return AIMessage(
id=search_query.id,
content="",
additional_kwargs={
"function_call": {
tool_calls=[
{
"id": uuid4().hex,
"name": "retrieval",
"arguments": json.dumps({"query": search_query}),
"args": {"query": search_query.content},
}
},
],
)

async def retrieve(messages):
params = messages[-1].additional_kwargs["function_call"]
query = json.loads(params["arguments"])["query"]
params = messages[-1].tool_calls[0]
query = params["args"]["query"]
response = await retriever.ainvoke(query)
msg = LiberalFunctionMessage(name="retrieval", content=response)
return msg
Expand Down

0 comments on commit b1fab0f

Please sign in to comment.