Skip to content

Commit

Permalink
refactor rewriting query (#69)
Browse files Browse the repository at this point in the history
Signed-off-by: ChengZi <chen.zhang@zilliz.com>
  • Loading branch information
zc277584121 authored Aug 17, 2023
1 parent 7ea19a9 commit 97d5660
Showing 1 changed file with 48 additions and 21 deletions.
69 changes: 48 additions & 21 deletions src_towhee/pipelines/search/rewrite_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,60 @@

from utils import get_llm_op # pylint: disable=C0413

ZH_PRONOUN_LIST = ['他', '她', '它', '者', '这', ]

EN_PRONOUN_LIST = ['he', 'his', 'him', 'she', 'her', 'it', 'they', 'them', 'their', 'both', 'former', 'latter',
'this', 'these', 'that']

PRONOUN_LIST = ZH_PRONOUN_LIST + EN_PRONOUN_LIST


class RewriteQuery:
'''
Replace third-person pronouns with words from historical dialogue
'''

def __init__(self, config):
self._llm_op = get_llm_op(config)

def __call__(self, question: str, history: list = []): # pylint: disable=W0102
if self._contain_pron(question):
prompt = self._build_prompt(question, history)
raw_ret = self._llm_op(prompt)
new_question = self._parse_raw_ret(raw_ret, question)
return new_question
else:
return question

def _contain_pron(self, question: str):
for pron in PRONOUN_LIST:
if pron in question.lower():
return True
return False

def _build_prompt(self, question: str, history: list = []): # pylint: disable=W0102
output_str = ''
for qa in history:
output_str += f'Q: {qa[0]}\n'
output_str += f'A: {qa[1]}\n'
history_str = output_str.strip()
prompt = REWRITE_TEMP.format(question=question, history_str=history_str)
return [({'question': prompt})]

def _parse_raw_ret(self, raw_ret: str, question: str):
try:
raw_ret = raw_ret.replace('&gt;', '>')
new_question = raw_ret.split('=> OUTPUT QUESTION: ')[1].split('-------------------')[0]
except: # pylint: disable=W0702
new_question = question
return new_question

def build_prompt(question: str, history: list = []): # pylint: disable=W0102
if not history:
history_str = ''
output_str = ''
for qa in history:
output_str += f'Q: {qa[0]}\n'
output_str += f'A: {qa[1]}\n'
history_str = output_str.strip()
prompt = REWRITE_TEMP.format(question=question, history_str=history_str)
return [({'question': prompt})]

def parse_raw_ret(raw_ret, question):
try:
new_question = raw_ret.split('=> OUTPUT QUESTION: ')[1]
except: # pylint: disable=W0702
new_question = question
return new_question

def custom_pipeline(config):
llm_op = get_llm_op(config)
chat = AutoPipes.pipeline('osschat-search', config=config)
p = (
pipe.input('question', 'history', 'project')
.map(('question', 'history'), 'prompt', build_prompt)
.map('prompt', 'new_question', llm_op)
.map(('new_question', 'question'), 'new_question', parse_raw_ret)
.map(('question', 'history'), 'new_question', RewriteQuery(config))
.map(('new_question', 'history', 'project'), 'answer', chat)
.output('new_question', 'answer')
)
Expand Down

0 comments on commit 97d5660

Please sign in to comment.