From 97d56600c2ba927580f87a37ddf55d943acac153 Mon Sep 17 00:00:00 2001 From: ChengZi Date: Thu, 17 Aug 2023 19:14:12 +0800 Subject: [PATCH] refactor rewriting query (#69) Signed-off-by: ChengZi --- src_towhee/pipelines/search/rewrite_query.py | 69 ++++++++++++++------ 1 file changed, 48 insertions(+), 21 deletions(-) diff --git a/src_towhee/pipelines/search/rewrite_query.py b/src_towhee/pipelines/search/rewrite_query.py index 916be96..ba2da26 100644 --- a/src_towhee/pipelines/search/rewrite_query.py +++ b/src_towhee/pipelines/search/rewrite_query.py @@ -8,33 +8,60 @@ from utils import get_llm_op # pylint: disable=C0413 +ZH_PRONOUN_LIST = ['他', '她', '它', '者', '这', ] + +EN_PRONOUN_LIST = ['he', 'his', 'him', 'she', 'her', 'it', 'they', 'them', 'their', 'both', 'former', 'latter', + 'this', 'these', 'that'] + +PRONOUN_LIST = ZH_PRONOUN_LIST + EN_PRONOUN_LIST + + +class RewriteQuery: + ''' + Replace third-person pronouns with words from historical dialogue + ''' + + def __init__(self, config): + self._llm_op = get_llm_op(config) + + def __call__(self, question: str, history: list = []): # pylint: disable=W0102 + if self._contain_pron(question): + prompt = self._build_prompt(question, history) + raw_ret = self._llm_op(prompt) + new_question = self._parse_raw_ret(raw_ret, question) + return new_question + else: + return question + + def _contain_pron(self, question: str): + for pron in PRONOUN_LIST: + if pron in question.lower(): + return True + return False + + def _build_prompt(self, question: str, history: list = []): # pylint: disable=W0102 + output_str = '' + for qa in history: + output_str += f'Q: {qa[0]}\n' + output_str += f'A: {qa[1]}\n' + history_str = output_str.strip() + prompt = REWRITE_TEMP.format(question=question, history_str=history_str) + return [({'question': prompt})] + + def _parse_raw_ret(self, raw_ret: str, question: str): + try: + raw_ret = raw_ret.replace('>', '>') + new_question = raw_ret.split('=> OUTPUT QUESTION: ')[1].split('-------------------')[0] + except: # pylint: disable=W0702 + new_question = question + return new_question -def build_prompt(question: str, history: list = []): # pylint: disable=W0102 - if not history: - history_str = '' - output_str = '' - for qa in history: - output_str += f'Q: {qa[0]}\n' - output_str += f'A: {qa[1]}\n' - history_str = output_str.strip() - prompt = REWRITE_TEMP.format(question=question, history_str=history_str) - return [({'question': prompt})] - -def parse_raw_ret(raw_ret, question): - try: - new_question = raw_ret.split('=> OUTPUT QUESTION: ')[1] - except: # pylint: disable=W0702 - new_question = question - return new_question def custom_pipeline(config): - llm_op = get_llm_op(config) chat = AutoPipes.pipeline('osschat-search', config=config) p = ( pipe.input('question', 'history', 'project') - .map(('question', 'history'), 'prompt', build_prompt) - .map('prompt', 'new_question', llm_op) - .map(('new_question', 'question'), 'new_question', parse_raw_ret) + .map(('question', 'history'), 'new_question', RewriteQuery(config)) .map(('new_question', 'history', 'project'), 'answer', chat) .output('new_question', 'answer') )