From 47488754219f85498c786ec2a8802fc55f560c30 Mon Sep 17 00:00:00 2001 From: PiDanShouRouZhouXD <38401147+PiDanShouRouZhouXD@users.noreply.github.com> Date: Sat, 13 Apr 2024 04:02:23 +0800 Subject: [PATCH 1/8] Fix Sakura Translator; add support for Sakura 0.10's terminology table feature, compatible with galtransl-style terminology tables and standard Sakura-style terminology tables. --- .../__init__ - \345\211\257\346\234\254.py" | 124 +++++ manga_translator/translators/__init__.py | 4 +- .../keys - \345\211\257\346\234\254.py" | 25 + manga_translator/translators/keys.py | 9 +- .../sakura - \345\211\257\346\234\254.py" | 477 ++++++++++++++++++ manga_translator/translators/sakura.py | 467 +++++++++++------ sakura_dict.txt | 13 + 7 files changed, 970 insertions(+), 149 deletions(-) create mode 100644 "manga_translator/translators/__init__ - \345\211\257\346\234\254.py" create mode 100644 "manga_translator/translators/keys - \345\211\257\346\234\254.py" create mode 100644 "manga_translator/translators/sakura - \345\211\257\346\234\254.py" create mode 100644 sakura_dict.txt diff --git "a/manga_translator/translators/__init__ - \345\211\257\346\234\254.py" "b/manga_translator/translators/__init__ - \345\211\257\346\234\254.py" new file mode 100644 index 00000000..ad29cdf0 --- /dev/null +++ "b/manga_translator/translators/__init__ - \345\211\257\346\234\254.py" @@ -0,0 +1,124 @@ +import py3langid as langid + +from .common import * +from .baidu import BaiduTranslator +from .google import GoogleTranslator +from .youdao import YoudaoTranslator +from .deepl import DeeplTranslator +from .papago import PapagoTranslator +from .caiyun import CaiyunTranslator +from .chatgpt import GPT3Translator, GPT35TurboTranslator, GPT4Translator +from .nllb import NLLBTranslator, NLLBBigTranslator +from .sugoi import JparacrawlTranslator, JparacrawlBigTranslator, SugoiTranslator +from .m2m100 import M2M100Translator, M2M100BigTranslator +from .selective import SelectiveOfflineTranslator, prepare as prepare_selective_translator +from .none import NoneTranslator +from .original import OriginalTranslator +from .sakura import SakuraTranslator + +OFFLINE_TRANSLATORS = { + 'offline': SelectiveOfflineTranslator, + 'nllb': NLLBTranslator, + 'nllb_big': NLLBBigTranslator, + 'sugoi': SugoiTranslator, + 'jparacrawl': JparacrawlTranslator, + 'jparacrawl_big': JparacrawlBigTranslator, + 'm2m100': M2M100Translator, + 'm2m100_big': M2M100BigTranslator, +} + +TRANSLATORS = { + 'google': GoogleTranslator, + 'youdao': YoudaoTranslator, + 'baidu': BaiduTranslator, + 'deepl': DeeplTranslator, + 'papago': PapagoTranslator, + 'caiyun': CaiyunTranslator, + 'gpt3': GPT3Translator, + 'gpt3.5': GPT35TurboTranslator, + 'gpt4': GPT4Translator, + 'none': NoneTranslator, + 'original': OriginalTranslator, + 'sakura': SakuraTranslator, + **OFFLINE_TRANSLATORS, +} +translator_cache = {} + +def get_translator(key: str, *args, **kwargs) -> CommonTranslator: + if key not in TRANSLATORS: + raise ValueError(f'Could not find translator for: "{key}". Choose from the following: %s' % ','.join(TRANSLATORS)) + if not translator_cache.get(key): + translator = TRANSLATORS[key] + translator_cache[key] = translator(*args, **kwargs) + return translator_cache[key] + +prepare_selective_translator(get_translator) + +# TODO: Refactor +class TranslatorChain(): + def __init__(self, string: str): + """ + Parses string in form 'trans1:lang1;trans2:lang2' into chains, + which will be executed one after another when passed to the dispatch function. + """ + if not string: + raise Exception('Invalid translator chain') + self.chain = [] + self.target_lang = None + for g in string.split(';'): + trans, lang = g.split(':') + if trans not in TRANSLATORS: + raise ValueError(f'Invalid choice: %s (choose from %s)' % (trans, ', '.join(map(repr, TRANSLATORS)))) + if lang not in VALID_LANGUAGES: + raise ValueError(f'Invalid choice: %s (choose from %s)' % (lang, ', '.join(map(repr, VALID_LANGUAGES)))) + self.chain.append((trans, lang)) + self.translators, self.langs = list(zip(*self.chain)) + + def has_offline(self) -> bool: + """ + Returns True if the chain contains offline translators. + """ + return any(translator in OFFLINE_TRANSLATORS for translator in self.translators) + + def __eq__(self, __o: object) -> bool: + if type(__o) is str: + return __o == self.translators[0] + return super.__eq__(self, __o) + +async def prepare(chain: TranslatorChain): + for key, tgt_lang in chain.chain: + translator = get_translator(key) + translator.supports_languages('auto', tgt_lang, fatal=True) + if isinstance(translator, OfflineTranslator): + await translator.download() + +# TODO: Optionally take in strings instead of TranslatorChain for simplicity +async def dispatch(chain: TranslatorChain, queries: List[str], use_mtpe: bool = False, args = None, device: str = 'cpu') -> List[str]: + if not queries: + return queries + + if chain.target_lang is not None: + text_lang = ISO_639_1_TO_VALID_LANGUAGES.get(langid.classify('\n'.join(queries))[0]) + translator = None + for key, lang in chain.chain: + if text_lang == lang: + translator = get_translator(key) + break + if translator is None: + translator = get_translator(chain.langs[0]) + if isinstance(translator, OfflineTranslator): + await translator.load('auto', chain.target_lang, device) + translator.parse_args(args) + queries = await translator.translate('auto', chain.target_lang, queries, use_mtpe) + return queries + if args is not None: + args['translations'] = {} + for key, tgt_lang in chain.chain: + translator = get_translator(key) + if isinstance(translator, OfflineTranslator): + await translator.load('auto', tgt_lang, device) + translator.parse_args(args) + queries = await translator.translate('auto', tgt_lang, queries, use_mtpe) + if args is not None: + args['translations'][tgt_lang] = queries + return queries diff --git a/manga_translator/translators/__init__.py b/manga_translator/translators/__init__.py index 2c9a4528..768b2c3b 100644 --- a/manga_translator/translators/__init__.py +++ b/manga_translator/translators/__init__.py @@ -15,7 +15,7 @@ from .selective import SelectiveOfflineTranslator, prepare as prepare_selective_translator from .none import NoneTranslator from .original import OriginalTranslator -from .sakura import Sakura13BTranslator +from .sakura import SakuraTranslator OFFLINE_TRANSLATORS = { 'offline': SelectiveOfflineTranslator, @@ -41,7 +41,7 @@ 'gpt4': GPT4Translator, 'none': NoneTranslator, 'original': OriginalTranslator, - 'sakura': Sakura13BTranslator, + 'sakura': SakuraTranslator, **OFFLINE_TRANSLATORS, } translator_cache = {} diff --git "a/manga_translator/translators/keys - \345\211\257\346\234\254.py" "b/manga_translator/translators/keys - \345\211\257\346\234\254.py" new file mode 100644 index 00000000..c5efd17f --- /dev/null +++ "b/manga_translator/translators/keys - \345\211\257\346\234\254.py" @@ -0,0 +1,25 @@ +import os +from dotenv import load_dotenv +load_dotenv() + +# baidu +BAIDU_APP_ID = os.getenv('BAIDU_APP_ID', '') #你的appid +BAIDU_SECRET_KEY = os.getenv('BAIDU_SECRET_KEY', '') #你的密钥 +# youdao +YOUDAO_APP_KEY = os.getenv('YOUDAO_APP_KEY', '') # 应用ID +YOUDAO_SECRET_KEY = os.getenv('YOUDAO_SECRET_KEY', '') # 应用秘钥 +# deepl +DEEPL_AUTH_KEY = os.getenv('DEEPL_AUTH_KEY', '') #YOUR_AUTH_KEY +# openai +OPENAI_API_KEY = os.getenv('OPENAI_API_KEY', '') +OPENAI_HTTP_PROXY = os.getenv('OPENAI_HTTP_PROXY') # TODO: Replace with --proxy + +OPENAI_API_BASE = os.getenv('OPENAI_API_BASE', 'https://api.openai.com/v1') #使用api-for-open-llm例子 http://127.0.0.1:8000/v1 + +# sakura +SAKURA_API_BASE = os.getenv('SAKURA_API_BASE', 'http://127.0.0.1:8080/v1') #SAKURA API地址 +SAKURA_VERSION = os.getenv('SAKURA_VERSION', '0.10') #SAKURA API版本,可选值:0.9、0.10,选择1.0则会加载术语表。 +SAKURA_DICT_PATH = os.getenv('SAKURA_DICT_PATH', './sakura_dict.txt') #SAKURA 术语表路径 + + +CAIYUN_TOKEN = os.getenv('CAIYUN_TOKEN', '') # 彩云小译API访问令牌 \ No newline at end of file diff --git a/manga_translator/translators/keys.py b/manga_translator/translators/keys.py index 0cfa9dfb..9ab19542 100644 --- a/manga_translator/translators/keys.py +++ b/manga_translator/translators/keys.py @@ -15,10 +15,11 @@ OPENAI_HTTP_PROXY = os.getenv('OPENAI_HTTP_PROXY') # TODO: Replace with --proxy OPENAI_API_BASE = os.getenv('OPENAI_API_BASE', 'https://api.openai.com/v1') #使用api-for-open-llm例子 http://127.0.0.1:8000/v1 -SAKURA_API_BASE = os.getenv('SAKURA_API_BASE', 'http://127.0.0.1:8080/v1') #SAKURA API地址 +# sakura +SAKURA_API_BASE = os.getenv('SAKURA_API_BASE', 'http://127.0.0.1:8080/v1') #SAKURA API地址 +SAKURA_VERSION = os.getenv('SAKURA_VERSION', '0.10') #SAKURA API版本,可选值:0.9、0.10,选择1.0则会加载术语表。 +SAKURA_DICT_PATH = os.getenv('SAKURA_DICT_PATH', './sakura_dict.txt') #SAKURA 术语表路径 -CAIYUN_TOKEN = os.getenv('CAIYUN_TOKEN', '') # 彩云小译API访问令牌 -SAKURA_API_KEY = os.getenv('SAKURA_API_KEY', '') -SAKURA_API_BASE = os.getenv('SAKURA_API_BASE', 'http://127.0.0.1:5000/v1') \ No newline at end of file +CAIYUN_TOKEN = os.getenv('CAIYUN_TOKEN', '') # 彩云小译API访问令牌 \ No newline at end of file diff --git "a/manga_translator/translators/sakura - \345\211\257\346\234\254.py" "b/manga_translator/translators/sakura - \345\211\257\346\234\254.py" new file mode 100644 index 00000000..983694f5 --- /dev/null +++ "b/manga_translator/translators/sakura - \345\211\257\346\234\254.py" @@ -0,0 +1,477 @@ +from distutils.cygwinccompiler import get_versions +import re +import os +from venv import logger + +from httpx import get + +try: + import openai + import openai.error +except ImportError: + openai = None +import asyncio +import time +from typing import List, Dict + +from .common import CommonTranslator +from .keys import SAKURA_API_BASE, SAKURA_VERSION, SAKURA_DICT_PATH + +import logging + +class SakuraDict(): + def __init__(self, path: str): + self.path = path + self.dict = {} + self.logger = logging.getLogger(__name__) + self.logger.setLevel(logging.INFO) + + def load_galtransl_dic(self, dic_path: str): + """ + 载入Galtransl词典。 + """ + + with open(dic_path, encoding="utf8") as f: + dic_lines = f.readlines() + if len(dic_lines) == 0: + return + dic_path = os.path.abspath(dic_path) + dic_name = os.path.basename(dic_path) + normalDic_count = 0 + + gpt_dict = [] + for line in dic_lines: + if line.startswith("\n"): + continue + elif line.startswith("\\\\") or line.startswith("//"): # 注释行跳过 + continue + + # 四个空格换成Tab + line = line.replace(" ", "\t") + + sp = line.rstrip("\r\n").split("\t") # 去多余换行符,Tab分割 + len_sp = len(sp) + + if len_sp < 2: # 至少是2个元素 + continue + + src = sp[0] + dst = sp[1] + info = sp[2] if len_sp > 2 else None + gpt_dict.append({"src": src, "dst": dst, "info": info}) + normalDic_count += 1 + + gpt_dict_text_list = [] + for gpt in gpt_dict: + src = gpt['src'] + dst = gpt['dst'] + info = gpt['info'] if "info" in gpt.keys() else None + if info: + single = f"{src}->{dst} #{info}" + else: + single = f"{src}->{dst}" + gpt_dict_text_list.append(single) + + gpt_dict_raw_text = "\n".join(gpt_dict_text_list) + self.dict = gpt_dict_raw_text + self.logger.info( + f"载入 Galtransl 字典: {dic_name} {normalDic_count}普通词条" + ) + + def load_sakura_dict(self, dic_path: str): + """ + 直接载入标准的Sakura字典。 + """ + + with open(dic_path, encoding="utf8") as f: + dic_lines = f.readlines() + if len(dic_lines) == 0: + return + dic_path = os.path.abspath(dic_path) + dic_name = os.path.basename(dic_path) + normalDic_count = 0 + + gpt_dict_text_list = [] + for line in dic_lines: + if line.startswith("\n"): + continue + elif line.startswith("\\\\") or line.startswith("//"): # 注释行跳过 + continue + + sp = line.rstrip("\r\n").split("\t") # 去多余换行符,Tab分割 + len_sp = len(sp) + + if len_sp < 2: # 至少是2个元素 + continue + + src = sp[0] + dst = sp[1] + info = sp[2] if len_sp > 2 else None + if info: + single = f"{src}->{dst} #{info}" + else: + single = f"{src}->{dst}" + gpt_dict_text_list.append(single) + normalDic_count += 1 + + gpt_dict_raw_text = "\n".join(gpt_dict_text_list) + self.dict = gpt_dict_raw_text + self.logger.info( + f"载入标准Sakura字典: {dic_name} {normalDic_count}普通词条" + ) + + def detect_type(self, dic_path: str): + """ + 检测字典类型。 + """ + with open(dic_path, encoding="utf8") as f: + dic_lines = f.readlines() + if len(dic_lines) == 0: + return "unknown" + + # 判断是否为Galtransl字典 + is_galtransl = True + for line in dic_lines: + if line.startswith("\n"): + continue + elif line.startswith("\\\\") or line.startswith("//"): + continue + + if "\t" not in line: + is_galtransl = False + break + + if is_galtransl: + return "galtransl" + + # 判断是否为Sakura字典 + is_sakura = True + for line in dic_lines: + if line.startswith("\n"): + continue + elif line.startswith("\\\\") or line.startswith("//"): + continue + + if "->" not in line: + is_sakura = False + break + + if is_sakura: + return "sakura" + + return "unknown" + + def get_dict(self): + """ + 获取字典内容。 + """ + if self.dict == {}: + self.logger.warning("字典为空") + return {} + return self.dict + + def get_dict_from_file(self, dic_path: str): + """ + 从文件载入字典。 + """ + dic_type = self.detect_type(dic_path) + if dic_type == "galtransl": + self.load_galtransl_dic(dic_path) + elif dic_type == "sakura": + self.load_sakura_dict(dic_path) + else: + self.logger.warning(f"未知的字典类型: {dic_path}") + return self.get_dict() + + +class SakuraTranslator(CommonTranslator): + + _TIMEOUT = 999 # 等待服务器响应的超时时间(秒) + _RETRY_ATTEMPTS = 1 # 请求出错时的重试次数 + _TIMEOUT_RETRY_ATTEMPTS = 3 # 请求超时时的重试次数 + _RATELIMIT_RETRY_ATTEMPTS = 3 # 请求被限速时的重试次数 + + _CHAT_SYSTEM_TEMPLATE_009 = ( + '你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。' + ) + _CHAT_SYSTEM_TEMPLATE_010 = ( + '你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,注意不要擅自添加原文中没有的代词,也不要擅自增加或减少换行。' + ) + + _LANGUAGE_CODE_MAP = { + 'CHS': 'Simplified Chinese', + 'JPN': 'Japanese' + } + + def __init__(self): + super().__init__() + if "/v1" not in SAKURA_API_BASE: + openai.api_base = SAKURA_API_BASE + "/v1" + else: + openai.api_base = SAKURA_API_BASE + openai.api_key = "sk-114514" + self.temperature = 0.3 + self.top_p = 0.3 + self.frequency_penalty = 0.1 + self._current_style = "normal" + self._emoji_pattern = re.compile(r'[\U00010000-\U0010ffff]') + self._heart_pattern = re.compile(r'❤') + self.sakura_dict = SakuraDict(self.get_dict_path()) + + def get_sakura_version(self): + return SAKURA_VERSION + + def get_dict_path(self): + return SAKURA_DICT_PATH + + def detect_and_remove_extra_repeats(self, s: str, threshold: int = 20, remove_all=True): + """ + 检测字符串中是否有任何模式连续重复出现超过阈值,并在去除多余重复后返回新字符串。 + 保留一个模式的重复。 + """ + repeated = False + for pattern_length in range(1, len(s) // 2 + 1): + i = 0 + while i < len(s) - pattern_length: + pattern = s[i:i + pattern_length] + count = 1 + j = i + pattern_length + while j <= len(s) - pattern_length: + if s[j:j + pattern_length] == pattern: + count += 1 + j += pattern_length + else: + break + if count >= threshold: + repeated = True + if remove_all: + s = s[:i + pattern_length] + s[j:] + break + i += 1 + if repeated: + break + return repeated, s + + def _format_prompt_log(self, prompt: str) -> str: + return '\n'.join([ + 'System:', + self._CHAT_SYSTEM_TEMPLATE, + 'User:', + '将下面的日文文本翻译成中文:', + prompt, + ]) + + def _split_text(self, text: str) -> List[str]: + """ + 将字符串按换行符分割为列表。 + """ + if isinstance(text, list): + return text + return text.split('\n') + + def _preprocess_queries(self, queries: List[str]) -> List[str]: + """ + 预处理查询文本,去除emoji,替换特殊字符,并添加「」标记。 + """ + queries = [self._emoji_pattern.sub('', query) for query in queries] + queries = [self._heart_pattern.sub('♥', query) for query in queries] + queries = [f'「{query}」' for query in queries] + return queries + + async def _check_translation_quality(self, queries: List[str], response: str) -> List[str]: + """ + 检查翻译结果的质量,包括重复和行数对齐问题,如果存在问题则尝试重新翻译或返回原始文本。 + """ + rep_flag = self._detect_repeats(response) + if rep_flag: + for i in range(self._RETRY_ATTEMPTS): + if self._detect_repeats(''.join(queries)): + self.logger.warning('Queries have repeats.') + break + self.logger.warning(f'Re-translating due to model degradation, attempt: {i + 1}') + self._set_gpt_style("precise") + response = await self._handle_translation_request(queries) + rep_flag = self._detect_repeats(response) + if not rep_flag: + break + if rep_flag: + self.logger.warning('Model degradation, translating single lines.') + return await self._translate_single_lines(queries) + + align_flag = self._check_align(queries, response) + if not align_flag: + for i in range(self._RETRY_ATTEMPTS): + self.logger.warning(f'Re-translating due to mismatched lines, attempt: {i + 1}') + self._set_gpt_style("precise") + response = await self._handle_translation_request(queries) + align_flag = self._check_align(queries, response) + if align_flag: + break + if not align_flag: + self.logger.warning('Mismatched lines, translating single lines.') + return await self._translate_single_lines(queries) + + return self._split_text(response) + + def _detect_repeats(self, text: str, threshold: int = 20) -> bool: + """ + 检测文本中是否存在重复模式。 + """ + _, text = self.detect_and_remove_extra_repeats(text, threshold) + return text != text + + def _check_align(self, queries: List[str], response: str) -> bool: + """ + 检查原始文本和翻译结果的行数是否对齐。 + """ + translations = self._split_text(response) + is_aligned = len(queries) == len(translations) + if not is_aligned: + self.logger.warning(f"Mismatched lines - Queries: {len(queries)}, Translations: {len(translations)}") + return is_aligned + + async def _translate_single_lines(self, queries: List[str]) -> List[str]: + """ + 逐行翻译查询文本。 + """ + translations = [] + for query in queries: + response = await self._handle_translation_request(query) + if self._detect_repeats(response): + self.logger.warning('Model degradation, using original text.') + translations.append(query) + else: + translations.append(response) + return translations + + def _delete_quotation_mark(self, texts: List[str]) -> List[str]: + """ + 删除文本中的「」标记。 + """ + new_texts = [] + for text in texts: + text = text.strip('「」') + new_texts.append(text) + return new_texts + + async def _translate(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]: + self.logger.debug(f'Temperature: {self.temperature}, TopP: {self.top_p}') + self.logger.debug(f'Queries: {queries}') + text_prompt = '\n'.join(queries) + self.logger.debug('-- Sakura Prompt --\n' + self._format_prompt_log(text_prompt) + '\n\n') + + # 预处理查询文本 + queries = self._preprocess_queries(queries) + + # 发送翻译请求 + response = await self._handle_translation_request(queries) + self.logger.debug('-- Sakura Response --\n' + response + '\n\n') + + # 检查翻译结果是否存在重复或行数不匹配的问题 + translations = await self._check_translation_quality(queries, response) + + return self._delete_quotation_mark(translations) + + async def _handle_translation_request(self, prompt: str) -> str: + """ + 处理翻译请求,包括错误处理和重试逻辑。 + """ + ratelimit_attempt = 0 + server_error_attempt = 0 + timeout_attempt = 0 + while True: + try: + request_task = asyncio.create_task(self._request_translation(prompt)) + response = await asyncio.wait_for(request_task, timeout=self._TIMEOUT) + break + except asyncio.TimeoutError: + timeout_attempt += 1 + if timeout_attempt >= self._TIMEOUT_RETRY_ATTEMPTS: + raise Exception('Sakura timeout.') + self.logger.warning(f'Restarting request due to timeout. Attempt: {timeout_attempt}') + except openai.error.RateLimitError: + ratelimit_attempt += 1 + if ratelimit_attempt >= self._RATELIMIT_RETRY_ATTEMPTS: + raise + self.logger.warning(f'Restarting request due to ratelimiting by sakura servers. Attempt: {ratelimit_attempt}') + await asyncio.sleep(2) + except (openai.error.APIError, openai.error.APIConnectionError) as e: + server_error_attempt += 1 + if server_error_attempt >= self._RETRY_ATTEMPTS: + self.logger.error(f'Sakura server error: {str(e)}. Returning original text.') + return prompt + self.logger.warning(f'Restarting request due to server error. Attempt: {server_error_attempt}') + + return response + + async def _request_translation(self, input_text_list) -> str: + """ + 向Sakura API发送翻译请求。 + """ + if isinstance(input_text_list, list): + raw_text = "\n".join(input_text_list) + else: + raw_text = input_text_list + extra_query = { + 'do_sample': False, + 'num_beams': 1, + 'repetition_penalty': 1.0, + } + if SAKURA_VERSION == "0.9": + messages=[ + { + "role": "system", + "content": f"{self._CHAT_SYSTEM_TEMPLATE_009}" + }, + { + "role": "user", + "content": f"将下面的日文文本翻译成中文:{raw_text}" + } + ] + else: + gpt_dict_raw_text = self.sakura_dict.get_dict() + self.logger.debug(f"Sakura Dict: {gpt_dict_raw_text}") + messages=[ + { + "role": "system", + "content": f"{self._CHAT_SYSTEM_TEMPLATE_010}" + }, + { + "role": "user", + "content": f"根据以下术语表:\n" + {gpt_dict_raw_text} + "\n" + f"将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:{raw_text}" + } + ] + response = await openai.ChatCompletion.acreate( + model="sukinishiro", + messages=messages, + temperature=self.temperature, + top_p=self.top_p, + max_tokens=1024, + frequency_penalty=self.frequency_penalty, + seed=-1, + extra_query=extra_query, + ) + # 提取并返回响应文本 + for choice in response.choices: + if 'text' in choice: + return choice.text + + return response.choices[0].message.content + + def _set_gpt_style(self, style_name: str): + """ + 设置GPT的生成风格。 + """ + if self._current_style == style_name: + return + self._current_style = style_name + if style_name == "precise": + temperature, top_p = 0.1, 0.3 + frequency_penalty = 0.1 + elif style_name == "normal": + temperature, top_p = 0.3, 0.3 + frequency_penalty = 0.15 + + self.temperature = temperature + self.top_p = top_p + self.frequency_penalty = frequency_penalty diff --git a/manga_translator/translators/sakura.py b/manga_translator/translators/sakura.py index 0fee7989..983694f5 100644 --- a/manga_translator/translators/sakura.py +++ b/manga_translator/translators/sakura.py @@ -1,6 +1,10 @@ +from distutils.cygwinccompiler import get_versions import re +import os from venv import logger +from httpx import get + try: import openai import openai.error @@ -11,19 +15,188 @@ from typing import List, Dict from .common import CommonTranslator -from .keys import SAKURA_API_BASE, SAKURA_API_KEY +from .keys import SAKURA_API_BASE, SAKURA_VERSION, SAKURA_DICT_PATH + +import logging +class SakuraDict(): + def __init__(self, path: str): + self.path = path + self.dict = {} + self.logger = logging.getLogger(__name__) + self.logger.setLevel(logging.INFO) + + def load_galtransl_dic(self, dic_path: str): + """ + 载入Galtransl词典。 + """ -class Sakura13BTranslator(CommonTranslator): + with open(dic_path, encoding="utf8") as f: + dic_lines = f.readlines() + if len(dic_lines) == 0: + return + dic_path = os.path.abspath(dic_path) + dic_name = os.path.basename(dic_path) + normalDic_count = 0 + + gpt_dict = [] + for line in dic_lines: + if line.startswith("\n"): + continue + elif line.startswith("\\\\") or line.startswith("//"): # 注释行跳过 + continue + + # 四个空格换成Tab + line = line.replace(" ", "\t") + + sp = line.rstrip("\r\n").split("\t") # 去多余换行符,Tab分割 + len_sp = len(sp) + + if len_sp < 2: # 至少是2个元素 + continue + + src = sp[0] + dst = sp[1] + info = sp[2] if len_sp > 2 else None + gpt_dict.append({"src": src, "dst": dst, "info": info}) + normalDic_count += 1 + + gpt_dict_text_list = [] + for gpt in gpt_dict: + src = gpt['src'] + dst = gpt['dst'] + info = gpt['info'] if "info" in gpt.keys() else None + if info: + single = f"{src}->{dst} #{info}" + else: + single = f"{src}->{dst}" + gpt_dict_text_list.append(single) + + gpt_dict_raw_text = "\n".join(gpt_dict_text_list) + self.dict = gpt_dict_raw_text + self.logger.info( + f"载入 Galtransl 字典: {dic_name} {normalDic_count}普通词条" + ) - _TIMEOUT = 999 # Seconds to wait for a response from the server before retrying - _RETRY_ATTEMPTS = 1 # Number of times to retry an errored request before giving up - _TIMEOUT_RETRY_ATTEMPTS = 3 # Number of times to retry a timed out request before giving up - _RATELIMIT_RETRY_ATTEMPTS = 3 # Number of times to retry a ratelimited request before giving up + def load_sakura_dict(self, dic_path: str): + """ + 直接载入标准的Sakura字典。 + """ + + with open(dic_path, encoding="utf8") as f: + dic_lines = f.readlines() + if len(dic_lines) == 0: + return + dic_path = os.path.abspath(dic_path) + dic_name = os.path.basename(dic_path) + normalDic_count = 0 + + gpt_dict_text_list = [] + for line in dic_lines: + if line.startswith("\n"): + continue + elif line.startswith("\\\\") or line.startswith("//"): # 注释行跳过 + continue + + sp = line.rstrip("\r\n").split("\t") # 去多余换行符,Tab分割 + len_sp = len(sp) + + if len_sp < 2: # 至少是2个元素 + continue + + src = sp[0] + dst = sp[1] + info = sp[2] if len_sp > 2 else None + if info: + single = f"{src}->{dst} #{info}" + else: + single = f"{src}->{dst}" + gpt_dict_text_list.append(single) + normalDic_count += 1 + + gpt_dict_raw_text = "\n".join(gpt_dict_text_list) + self.dict = gpt_dict_raw_text + self.logger.info( + f"载入标准Sakura字典: {dic_name} {normalDic_count}普通词条" + ) + + def detect_type(self, dic_path: str): + """ + 检测字典类型。 + """ + with open(dic_path, encoding="utf8") as f: + dic_lines = f.readlines() + if len(dic_lines) == 0: + return "unknown" + + # 判断是否为Galtransl字典 + is_galtransl = True + for line in dic_lines: + if line.startswith("\n"): + continue + elif line.startswith("\\\\") or line.startswith("//"): + continue + + if "\t" not in line: + is_galtransl = False + break + + if is_galtransl: + return "galtransl" + + # 判断是否为Sakura字典 + is_sakura = True + for line in dic_lines: + if line.startswith("\n"): + continue + elif line.startswith("\\\\") or line.startswith("//"): + continue + + if "->" not in line: + is_sakura = False + break + + if is_sakura: + return "sakura" + + return "unknown" + + def get_dict(self): + """ + 获取字典内容。 + """ + if self.dict == {}: + self.logger.warning("字典为空") + return {} + return self.dict + + def get_dict_from_file(self, dic_path: str): + """ + 从文件载入字典。 + """ + dic_type = self.detect_type(dic_path) + if dic_type == "galtransl": + self.load_galtransl_dic(dic_path) + elif dic_type == "sakura": + self.load_sakura_dict(dic_path) + else: + self.logger.warning(f"未知的字典类型: {dic_path}") + return self.get_dict() - _CHAT_SYSTEM_TEMPLATE = ( + +class SakuraTranslator(CommonTranslator): + + _TIMEOUT = 999 # 等待服务器响应的超时时间(秒) + _RETRY_ATTEMPTS = 1 # 请求出错时的重试次数 + _TIMEOUT_RETRY_ATTEMPTS = 3 # 请求超时时的重试次数 + _RATELIMIT_RETRY_ATTEMPTS = 3 # 请求被限速时的重试次数 + + _CHAT_SYSTEM_TEMPLATE_009 = ( '你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。' ) + _CHAT_SYSTEM_TEMPLATE_010 = ( + '你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,注意不要擅自添加原文中没有的代词,也不要擅自增加或减少换行。' + ) _LANGUAGE_CODE_MAP = { 'CHS': 'Simplified Chinese', @@ -32,22 +205,30 @@ class Sakura13BTranslator(CommonTranslator): def __init__(self): super().__init__() - #检测/v1是否存在 + if "/v1" not in SAKURA_API_BASE: + openai.api_base = SAKURA_API_BASE + "/v1" + else: + openai.api_base = SAKURA_API_BASE + openai.api_key = "sk-114514" self.temperature = 0.3 self.top_p = 0.3 - self.frequency_penalty = 0.0 + self.frequency_penalty = 0.1 self._current_style = "normal" + self._emoji_pattern = re.compile(r'[\U00010000-\U0010ffff]') + self._heart_pattern = re.compile(r'❤') + self.sakura_dict = SakuraDict(self.get_dict_path()) + + def get_sakura_version(self): + return SAKURA_VERSION - def detect_and_remove_extra_repeats(self, s: str, threshold: int = 10, remove_all=True): + def get_dict_path(self): + return SAKURA_DICT_PATH + + def detect_and_remove_extra_repeats(self, s: str, threshold: int = 20, remove_all=True): """ 检测字符串中是否有任何模式连续重复出现超过阈值,并在去除多余重复后返回新字符串。 保留一个模式的重复。 - - :param s: str - 待检测的字符串。 - :param threshold: int - 连续重复模式出现的最小次数,默认为2。 - :return: tuple - (bool, str),第一个元素表示是否有重复,第二个元素是处理后的字符串。 """ - repeated = False for pattern_length in range(1, len(s) // 2 + 1): i = 0 @@ -63,7 +244,6 @@ def detect_and_remove_extra_repeats(self, s: str, threshold: int = 10, remove_al break if count >= threshold: repeated = True - # 保留一个模式的重复 if remove_all: s = s[:i + pattern_length] + s[j:] break @@ -81,36 +261,93 @@ def _format_prompt_log(self, prompt: str) -> str: prompt, ]) - # str 通过/n转换为list - def _split_text(self, text: str) -> list: + def _split_text(self, text: str) -> List[str]: + """ + 将字符串按换行符分割为列表。 + """ if isinstance(text, list): return text return text.split('\n') - def check_align(self, queries: List[str], response: str) -> bool: + def _preprocess_queries(self, queries: List[str]) -> List[str]: """ - 检查原始文本(queries)与翻译后的文本(response)是否保持相同的行数。 + 预处理查询文本,去除emoji,替换特殊字符,并添加「」标记。 + """ + queries = [self._emoji_pattern.sub('', query) for query in queries] + queries = [self._heart_pattern.sub('♥', query) for query in queries] + queries = [f'「{query}」' for query in queries] + return queries - :param queries: 原始文本的列表。 - :param response: 翻译后的文本,可能是一个字符串。 - :return: 两者行数是否相同的布尔值。 + async def _check_translation_quality(self, queries: List[str], response: str) -> List[str]: """ - # 确保response是列表形式 - translated_texts = self._split_text(response) if isinstance(response, str) else response + 检查翻译结果的质量,包括重复和行数对齐问题,如果存在问题则尝试重新翻译或返回原始文本。 + """ + rep_flag = self._detect_repeats(response) + if rep_flag: + for i in range(self._RETRY_ATTEMPTS): + if self._detect_repeats(''.join(queries)): + self.logger.warning('Queries have repeats.') + break + self.logger.warning(f'Re-translating due to model degradation, attempt: {i + 1}') + self._set_gpt_style("precise") + response = await self._handle_translation_request(queries) + rep_flag = self._detect_repeats(response) + if not rep_flag: + break + if rep_flag: + self.logger.warning('Model degradation, translating single lines.') + return await self._translate_single_lines(queries) + + align_flag = self._check_align(queries, response) + if not align_flag: + for i in range(self._RETRY_ATTEMPTS): + self.logger.warning(f'Re-translating due to mismatched lines, attempt: {i + 1}') + self._set_gpt_style("precise") + response = await self._handle_translation_request(queries) + align_flag = self._check_align(queries, response) + if align_flag: + break + if not align_flag: + self.logger.warning('Mismatched lines, translating single lines.') + return await self._translate_single_lines(queries) - # 日志记录,而不是直接打印 - print(f"原始文本行数: {len(queries)}, 翻译文本行数: {len(translated_texts)}") - logger.warning(f"原始文本行数: {len(queries)}, 翻译文本行数: {len(translated_texts)}") + return self._split_text(response) - # 检查行数是否匹配 - is_aligned = len(queries) == len(translated_texts) - if not is_aligned: - logger.warning("原始文本与翻译文本的行数不匹配。") + def _detect_repeats(self, text: str, threshold: int = 20) -> bool: + """ + 检测文本中是否存在重复模式。 + """ + _, text = self.detect_and_remove_extra_repeats(text, threshold) + return text != text + def _check_align(self, queries: List[str], response: str) -> bool: + """ + 检查原始文本和翻译结果的行数是否对齐。 + """ + translations = self._split_text(response) + is_aligned = len(queries) == len(translations) + if not is_aligned: + self.logger.warning(f"Mismatched lines - Queries: {len(queries)}, Translations: {len(translations)}") return is_aligned + async def _translate_single_lines(self, queries: List[str]) -> List[str]: + """ + 逐行翻译查询文本。 + """ + translations = [] + for query in queries: + response = await self._handle_translation_request(query) + if self._detect_repeats(response): + self.logger.warning('Model degradation, using original text.') + translations.append(query) + else: + translations.append(response) + return translations + def _delete_quotation_mark(self, texts: List[str]) -> List[str]: - print(texts) + """ + 删除文本中的「」标记。 + """ new_texts = [] for text in texts: text = text.strip('「」') @@ -118,120 +355,59 @@ def _delete_quotation_mark(self, texts: List[str]) -> List[str]: return new_texts async def _translate(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]: - translations = [] self.logger.debug(f'Temperature: {self.temperature}, TopP: {self.top_p}') self.logger.debug(f'Queries: {queries}') text_prompt = '\n'.join(queries) self.logger.debug('-- Sakura Prompt --\n' + self._format_prompt_log(text_prompt) + '\n\n') - # 去除emoji - queries = [re.sub(r'[\U00010000-\U0010ffff]', '', query) for query in queries] - # 替换❤ - queries = [re.sub(r'❤', '♥', query) for query in queries] - # 给queries的每行加上「」 - queries = [f'「{query}」' for query in queries] + + # 预处理查询文本 + queries = self._preprocess_queries(queries) + + # 发送翻译请求 response = await self._handle_translation_request(queries) self.logger.debug('-- Sakura Response --\n' + response + '\n\n') - # 提取翻译结果并去除首尾空白 - response = response.strip() + # 检查翻译结果是否存在重复或行数不匹配的问题 + translations = await self._check_translation_quality(queries, response) - rep_flag = self.detect_and_remove_extra_repeats(response)[0] - if rep_flag: - for i in range(self._RETRY_ATTEMPTS): - if self.detect_and_remove_extra_repeats(queries)[0]: - self.logger.warning('Queries have repeats.') - break - self.logger.warning(f'Re-translated because of model degradation, {i} times.') - self._set_gpt_style("precise") - self.logger.debug(f'Temperature: {self.temperature}, TopP: {self.top_p}') - response = await self._handle_translation_request(queries) - rep_flag = self.detect_and_remove_extra_repeats(response)[0] - if not rep_flag: - break - if rep_flag: - self.logger.warning('Model degradation, try to translate single line.') - for query in queries: - response = await self._handle_translation_request(query) - translations.append(response) - rep_flag = self.detect_and_remove_extra_repeats(response)[0] - if rep_flag: - self.logger.warning('Model degradation, fill original text') - return self._delete_quotation_mark(queries) - return self._delete_quotation_mark(translations) - - align_flag = self.check_align(queries, response) - if not align_flag: - for i in range(self._RETRY_ATTEMPTS): - self.logger.warning(f'Re-translated because of a mismatch in the number of lines, {i} times.') - self._set_gpt_style("precise") - self.logger.debug(f'Temperature: {self.temperature}, TopP: {self.top_p}') - response = await self._handle_translation_request(queries) - align_flag = self.check_align(queries, response) - if align_flag: - break - if not align_flag: - self.logger.warning('Mismatch in the number of lines, try to translate single line.') - for query in queries: - print(query) - response = await self._handle_translation_request(query) - translations.append(response) - print(translations) - align_flag = self.check_align(queries, translations) - if not align_flag: - self.logger.warning('Mismatch in the number of lines, fill original text') - return self._delete_quotation_mark(queries) - return self._delete_quotation_mark(translations) - translations = self._split_text(response) - if isinstance(translations, list): - return self._delete_quotation_mark(translations) - translations = self._split_text(response) return self._delete_quotation_mark(translations) async def _handle_translation_request(self, prompt: str) -> str: - # 翻译请求和错误处理逻辑 + """ + 处理翻译请求,包括错误处理和重试逻辑。 + """ ratelimit_attempt = 0 server_error_attempt = 0 timeout_attempt = 0 while True: - request_task = asyncio.create_task(self._request_translation(prompt)) - started = time.time() - while not request_task.done(): - await asyncio.sleep(0.1) - if time.time() - started > self._TIMEOUT + (timeout_attempt * self._TIMEOUT / 2): - if timeout_attempt >= self._TIMEOUT_RETRY_ATTEMPTS: - raise Exception('Sakura timeout.') - timeout_attempt += 1 - self.logger.warn(f'Restarting request due to timeout. Attempt: {timeout_attempt}') - request_task.cancel() - request_task = asyncio.create_task(self._request_translation(prompt)) - started = time.time() try: - response = await request_task + request_task = asyncio.create_task(self._request_translation(prompt)) + response = await asyncio.wait_for(request_task, timeout=self._TIMEOUT) break + except asyncio.TimeoutError: + timeout_attempt += 1 + if timeout_attempt >= self._TIMEOUT_RETRY_ATTEMPTS: + raise Exception('Sakura timeout.') + self.logger.warning(f'Restarting request due to timeout. Attempt: {timeout_attempt}') except openai.error.RateLimitError: ratelimit_attempt += 1 if ratelimit_attempt >= self._RATELIMIT_RETRY_ATTEMPTS: raise - self.logger.warn(f'Restarting request due to ratelimiting by sakura servers. Attempt: {ratelimit_attempt}') + self.logger.warning(f'Restarting request due to ratelimiting by sakura servers. Attempt: {ratelimit_attempt}') await asyncio.sleep(2) - except openai.error.APIError: + except (openai.error.APIError, openai.error.APIConnectionError) as e: server_error_attempt += 1 if server_error_attempt >= self._RETRY_ATTEMPTS: - self.logger.error('Sakura server error. Returning original text.') - return prompt # 返回原始文本而不是抛出异常 - self.logger.warn(f'Restarting request due to a server error. Attempt: {server_error_attempt}') - await asyncio.sleep(1) - except openai.error.APIConnectionError: - server_error_attempt += 1 - self.logger.warn(f'Restarting request due to a server connection error. Attempt: {server_error_attempt}') - await asyncio.sleep(1) - except FileNotFoundError: - self.logger.warn(f'Restarting request due to FileNotFoundError.') - await asyncio.sleep(30) + self.logger.error(f'Sakura server error: {str(e)}. Returning original text.') + return prompt + self.logger.warning(f'Restarting request due to server error. Attempt: {server_error_attempt}') return response async def _request_translation(self, input_text_list) -> str: + """ + 向Sakura API发送翻译请求。 + """ if isinstance(input_text_list, list): raw_text = "\n".join(input_text_list) else: @@ -241,26 +417,33 @@ async def _request_translation(self, input_text_list) -> str: 'num_beams': 1, 'repetition_penalty': 1.0, } - old_api_base = openai.api_base or '' - old_api_key = openai.api_key or '' - - if "/v1" not in SAKURA_API_BASE: - openai.api_base = SAKURA_API_BASE + "/v1" + if SAKURA_VERSION == "0.9": + messages=[ + { + "role": "system", + "content": f"{self._CHAT_SYSTEM_TEMPLATE_009}" + }, + { + "role": "user", + "content": f"将下面的日文文本翻译成中文:{raw_text}" + } + ] else: - openai.api_base = SAKURA_API_BASE - openai.api_key = SAKURA_API_KEY + gpt_dict_raw_text = self.sakura_dict.get_dict() + self.logger.debug(f"Sakura Dict: {gpt_dict_raw_text}") + messages=[ + { + "role": "system", + "content": f"{self._CHAT_SYSTEM_TEMPLATE_010}" + }, + { + "role": "user", + "content": f"根据以下术语表:\n" + {gpt_dict_raw_text} + "\n" + f"将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:{raw_text}" + } + ] response = await openai.ChatCompletion.acreate( model="sukinishiro", - messages=[ - { - "role": "system", - "content": "你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。" - }, - { - "role": "user", - "content": f"将下面的日文文本翻译成中文:{raw_text}" - } - ], + messages=messages, temperature=self.temperature, top_p=self.top_p, max_tokens=1024, @@ -268,29 +451,27 @@ async def _request_translation(self, input_text_list) -> str: seed=-1, extra_query=extra_query, ) - openai.api_base = old_api_base - openai.api_key = old_api_key - # 提取并返回响应文本 for choice in response.choices: if 'text' in choice: return choice.text - # 如果没有找到包含文本的响应,返回第一个响应的内容(可能为空) - return response.choices[0].message.content def _set_gpt_style(self, style_name: str): + """ + 设置GPT的生成风格。 + """ if self._current_style == style_name: return self._current_style = style_name if style_name == "precise": temperature, top_p = 0.1, 0.3 - frequency_penalty = 0.0 + frequency_penalty = 0.1 elif style_name == "normal": temperature, top_p = 0.3, 0.3 frequency_penalty = 0.15 self.temperature = temperature self.top_p = top_p - self.frequency_penalty = frequency_penalty \ No newline at end of file + self.frequency_penalty = frequency_penalty diff --git a/sakura_dict.txt b/sakura_dict.txt new file mode 100644 index 00000000..26456abe --- /dev/null +++ b/sakura_dict.txt @@ -0,0 +1,13 @@ +// 示例字典,可自行添加或修改 + +安芸倫也->安艺伦也 #名字,男性,学生 +倫也->伦也 #名字,男性,学生 +安芸->安艺 #姓氏 + +加藤恵->加藤惠 #名字,女性,学生,安芸倫也的同班同学 +恵->惠 #名字,女性,学生,安芸倫也的同班同学 +加藤->加藤 #姓氏 + +澤村・スペンサー・英梨々->泽村・斯宾塞・英梨梨 #名字,女性,学生,同人志作者 +英梨々->英梨梨 #名字,女性,学生,同人志作者 +澤村->泽村 #姓氏 \ No newline at end of file From 46be2ca690785d54536749e0f9d2d03bad750391 Mon Sep 17 00:00:00 2001 From: PiDanShouRouZhouXD <38401147+PiDanShouRouZhouXD@users.noreply.github.com> Date: Sat, 13 Apr 2024 04:06:39 +0800 Subject: [PATCH 2/8] fix --- .../__init__ - \345\211\257\346\234\254.py" | 124 ----- .../keys - \345\211\257\346\234\254.py" | 25 - manga_translator/translators/keys.py | 2 +- .../sakura - \345\211\257\346\234\254.py" | 477 ------------------ manga_translator/translators/sakura.py | 64 ++- 5 files changed, 42 insertions(+), 650 deletions(-) delete mode 100644 "manga_translator/translators/__init__ - \345\211\257\346\234\254.py" delete mode 100644 "manga_translator/translators/keys - \345\211\257\346\234\254.py" delete mode 100644 "manga_translator/translators/sakura - \345\211\257\346\234\254.py" diff --git "a/manga_translator/translators/__init__ - \345\211\257\346\234\254.py" "b/manga_translator/translators/__init__ - \345\211\257\346\234\254.py" deleted file mode 100644 index ad29cdf0..00000000 --- "a/manga_translator/translators/__init__ - \345\211\257\346\234\254.py" +++ /dev/null @@ -1,124 +0,0 @@ -import py3langid as langid - -from .common import * -from .baidu import BaiduTranslator -from .google import GoogleTranslator -from .youdao import YoudaoTranslator -from .deepl import DeeplTranslator -from .papago import PapagoTranslator -from .caiyun import CaiyunTranslator -from .chatgpt import GPT3Translator, GPT35TurboTranslator, GPT4Translator -from .nllb import NLLBTranslator, NLLBBigTranslator -from .sugoi import JparacrawlTranslator, JparacrawlBigTranslator, SugoiTranslator -from .m2m100 import M2M100Translator, M2M100BigTranslator -from .selective import SelectiveOfflineTranslator, prepare as prepare_selective_translator -from .none import NoneTranslator -from .original import OriginalTranslator -from .sakura import SakuraTranslator - -OFFLINE_TRANSLATORS = { - 'offline': SelectiveOfflineTranslator, - 'nllb': NLLBTranslator, - 'nllb_big': NLLBBigTranslator, - 'sugoi': SugoiTranslator, - 'jparacrawl': JparacrawlTranslator, - 'jparacrawl_big': JparacrawlBigTranslator, - 'm2m100': M2M100Translator, - 'm2m100_big': M2M100BigTranslator, -} - -TRANSLATORS = { - 'google': GoogleTranslator, - 'youdao': YoudaoTranslator, - 'baidu': BaiduTranslator, - 'deepl': DeeplTranslator, - 'papago': PapagoTranslator, - 'caiyun': CaiyunTranslator, - 'gpt3': GPT3Translator, - 'gpt3.5': GPT35TurboTranslator, - 'gpt4': GPT4Translator, - 'none': NoneTranslator, - 'original': OriginalTranslator, - 'sakura': SakuraTranslator, - **OFFLINE_TRANSLATORS, -} -translator_cache = {} - -def get_translator(key: str, *args, **kwargs) -> CommonTranslator: - if key not in TRANSLATORS: - raise ValueError(f'Could not find translator for: "{key}". Choose from the following: %s' % ','.join(TRANSLATORS)) - if not translator_cache.get(key): - translator = TRANSLATORS[key] - translator_cache[key] = translator(*args, **kwargs) - return translator_cache[key] - -prepare_selective_translator(get_translator) - -# TODO: Refactor -class TranslatorChain(): - def __init__(self, string: str): - """ - Parses string in form 'trans1:lang1;trans2:lang2' into chains, - which will be executed one after another when passed to the dispatch function. - """ - if not string: - raise Exception('Invalid translator chain') - self.chain = [] - self.target_lang = None - for g in string.split(';'): - trans, lang = g.split(':') - if trans not in TRANSLATORS: - raise ValueError(f'Invalid choice: %s (choose from %s)' % (trans, ', '.join(map(repr, TRANSLATORS)))) - if lang not in VALID_LANGUAGES: - raise ValueError(f'Invalid choice: %s (choose from %s)' % (lang, ', '.join(map(repr, VALID_LANGUAGES)))) - self.chain.append((trans, lang)) - self.translators, self.langs = list(zip(*self.chain)) - - def has_offline(self) -> bool: - """ - Returns True if the chain contains offline translators. - """ - return any(translator in OFFLINE_TRANSLATORS for translator in self.translators) - - def __eq__(self, __o: object) -> bool: - if type(__o) is str: - return __o == self.translators[0] - return super.__eq__(self, __o) - -async def prepare(chain: TranslatorChain): - for key, tgt_lang in chain.chain: - translator = get_translator(key) - translator.supports_languages('auto', tgt_lang, fatal=True) - if isinstance(translator, OfflineTranslator): - await translator.download() - -# TODO: Optionally take in strings instead of TranslatorChain for simplicity -async def dispatch(chain: TranslatorChain, queries: List[str], use_mtpe: bool = False, args = None, device: str = 'cpu') -> List[str]: - if not queries: - return queries - - if chain.target_lang is not None: - text_lang = ISO_639_1_TO_VALID_LANGUAGES.get(langid.classify('\n'.join(queries))[0]) - translator = None - for key, lang in chain.chain: - if text_lang == lang: - translator = get_translator(key) - break - if translator is None: - translator = get_translator(chain.langs[0]) - if isinstance(translator, OfflineTranslator): - await translator.load('auto', chain.target_lang, device) - translator.parse_args(args) - queries = await translator.translate('auto', chain.target_lang, queries, use_mtpe) - return queries - if args is not None: - args['translations'] = {} - for key, tgt_lang in chain.chain: - translator = get_translator(key) - if isinstance(translator, OfflineTranslator): - await translator.load('auto', tgt_lang, device) - translator.parse_args(args) - queries = await translator.translate('auto', tgt_lang, queries, use_mtpe) - if args is not None: - args['translations'][tgt_lang] = queries - return queries diff --git "a/manga_translator/translators/keys - \345\211\257\346\234\254.py" "b/manga_translator/translators/keys - \345\211\257\346\234\254.py" deleted file mode 100644 index c5efd17f..00000000 --- "a/manga_translator/translators/keys - \345\211\257\346\234\254.py" +++ /dev/null @@ -1,25 +0,0 @@ -import os -from dotenv import load_dotenv -load_dotenv() - -# baidu -BAIDU_APP_ID = os.getenv('BAIDU_APP_ID', '') #你的appid -BAIDU_SECRET_KEY = os.getenv('BAIDU_SECRET_KEY', '') #你的密钥 -# youdao -YOUDAO_APP_KEY = os.getenv('YOUDAO_APP_KEY', '') # 应用ID -YOUDAO_SECRET_KEY = os.getenv('YOUDAO_SECRET_KEY', '') # 应用秘钥 -# deepl -DEEPL_AUTH_KEY = os.getenv('DEEPL_AUTH_KEY', '') #YOUR_AUTH_KEY -# openai -OPENAI_API_KEY = os.getenv('OPENAI_API_KEY', '') -OPENAI_HTTP_PROXY = os.getenv('OPENAI_HTTP_PROXY') # TODO: Replace with --proxy - -OPENAI_API_BASE = os.getenv('OPENAI_API_BASE', 'https://api.openai.com/v1') #使用api-for-open-llm例子 http://127.0.0.1:8000/v1 - -# sakura -SAKURA_API_BASE = os.getenv('SAKURA_API_BASE', 'http://127.0.0.1:8080/v1') #SAKURA API地址 -SAKURA_VERSION = os.getenv('SAKURA_VERSION', '0.10') #SAKURA API版本,可选值:0.9、0.10,选择1.0则会加载术语表。 -SAKURA_DICT_PATH = os.getenv('SAKURA_DICT_PATH', './sakura_dict.txt') #SAKURA 术语表路径 - - -CAIYUN_TOKEN = os.getenv('CAIYUN_TOKEN', '') # 彩云小译API访问令牌 \ No newline at end of file diff --git a/manga_translator/translators/keys.py b/manga_translator/translators/keys.py index 9ab19542..ef0678fc 100644 --- a/manga_translator/translators/keys.py +++ b/manga_translator/translators/keys.py @@ -18,7 +18,7 @@ # sakura SAKURA_API_BASE = os.getenv('SAKURA_API_BASE', 'http://127.0.0.1:8080/v1') #SAKURA API地址 -SAKURA_VERSION = os.getenv('SAKURA_VERSION', '0.10') #SAKURA API版本,可选值:0.9、0.10,选择1.0则会加载术语表。 +SAKURA_VERSION = os.getenv('SAKURA_VERSION', '0.9') #SAKURA API版本,可选值:0.9、0.10,选择0.10则会加载术语表。 SAKURA_DICT_PATH = os.getenv('SAKURA_DICT_PATH', './sakura_dict.txt') #SAKURA 术语表路径 diff --git "a/manga_translator/translators/sakura - \345\211\257\346\234\254.py" "b/manga_translator/translators/sakura - \345\211\257\346\234\254.py" deleted file mode 100644 index 983694f5..00000000 --- "a/manga_translator/translators/sakura - \345\211\257\346\234\254.py" +++ /dev/null @@ -1,477 +0,0 @@ -from distutils.cygwinccompiler import get_versions -import re -import os -from venv import logger - -from httpx import get - -try: - import openai - import openai.error -except ImportError: - openai = None -import asyncio -import time -from typing import List, Dict - -from .common import CommonTranslator -from .keys import SAKURA_API_BASE, SAKURA_VERSION, SAKURA_DICT_PATH - -import logging - -class SakuraDict(): - def __init__(self, path: str): - self.path = path - self.dict = {} - self.logger = logging.getLogger(__name__) - self.logger.setLevel(logging.INFO) - - def load_galtransl_dic(self, dic_path: str): - """ - 载入Galtransl词典。 - """ - - with open(dic_path, encoding="utf8") as f: - dic_lines = f.readlines() - if len(dic_lines) == 0: - return - dic_path = os.path.abspath(dic_path) - dic_name = os.path.basename(dic_path) - normalDic_count = 0 - - gpt_dict = [] - for line in dic_lines: - if line.startswith("\n"): - continue - elif line.startswith("\\\\") or line.startswith("//"): # 注释行跳过 - continue - - # 四个空格换成Tab - line = line.replace(" ", "\t") - - sp = line.rstrip("\r\n").split("\t") # 去多余换行符,Tab分割 - len_sp = len(sp) - - if len_sp < 2: # 至少是2个元素 - continue - - src = sp[0] - dst = sp[1] - info = sp[2] if len_sp > 2 else None - gpt_dict.append({"src": src, "dst": dst, "info": info}) - normalDic_count += 1 - - gpt_dict_text_list = [] - for gpt in gpt_dict: - src = gpt['src'] - dst = gpt['dst'] - info = gpt['info'] if "info" in gpt.keys() else None - if info: - single = f"{src}->{dst} #{info}" - else: - single = f"{src}->{dst}" - gpt_dict_text_list.append(single) - - gpt_dict_raw_text = "\n".join(gpt_dict_text_list) - self.dict = gpt_dict_raw_text - self.logger.info( - f"载入 Galtransl 字典: {dic_name} {normalDic_count}普通词条" - ) - - def load_sakura_dict(self, dic_path: str): - """ - 直接载入标准的Sakura字典。 - """ - - with open(dic_path, encoding="utf8") as f: - dic_lines = f.readlines() - if len(dic_lines) == 0: - return - dic_path = os.path.abspath(dic_path) - dic_name = os.path.basename(dic_path) - normalDic_count = 0 - - gpt_dict_text_list = [] - for line in dic_lines: - if line.startswith("\n"): - continue - elif line.startswith("\\\\") or line.startswith("//"): # 注释行跳过 - continue - - sp = line.rstrip("\r\n").split("\t") # 去多余换行符,Tab分割 - len_sp = len(sp) - - if len_sp < 2: # 至少是2个元素 - continue - - src = sp[0] - dst = sp[1] - info = sp[2] if len_sp > 2 else None - if info: - single = f"{src}->{dst} #{info}" - else: - single = f"{src}->{dst}" - gpt_dict_text_list.append(single) - normalDic_count += 1 - - gpt_dict_raw_text = "\n".join(gpt_dict_text_list) - self.dict = gpt_dict_raw_text - self.logger.info( - f"载入标准Sakura字典: {dic_name} {normalDic_count}普通词条" - ) - - def detect_type(self, dic_path: str): - """ - 检测字典类型。 - """ - with open(dic_path, encoding="utf8") as f: - dic_lines = f.readlines() - if len(dic_lines) == 0: - return "unknown" - - # 判断是否为Galtransl字典 - is_galtransl = True - for line in dic_lines: - if line.startswith("\n"): - continue - elif line.startswith("\\\\") or line.startswith("//"): - continue - - if "\t" not in line: - is_galtransl = False - break - - if is_galtransl: - return "galtransl" - - # 判断是否为Sakura字典 - is_sakura = True - for line in dic_lines: - if line.startswith("\n"): - continue - elif line.startswith("\\\\") or line.startswith("//"): - continue - - if "->" not in line: - is_sakura = False - break - - if is_sakura: - return "sakura" - - return "unknown" - - def get_dict(self): - """ - 获取字典内容。 - """ - if self.dict == {}: - self.logger.warning("字典为空") - return {} - return self.dict - - def get_dict_from_file(self, dic_path: str): - """ - 从文件载入字典。 - """ - dic_type = self.detect_type(dic_path) - if dic_type == "galtransl": - self.load_galtransl_dic(dic_path) - elif dic_type == "sakura": - self.load_sakura_dict(dic_path) - else: - self.logger.warning(f"未知的字典类型: {dic_path}") - return self.get_dict() - - -class SakuraTranslator(CommonTranslator): - - _TIMEOUT = 999 # 等待服务器响应的超时时间(秒) - _RETRY_ATTEMPTS = 1 # 请求出错时的重试次数 - _TIMEOUT_RETRY_ATTEMPTS = 3 # 请求超时时的重试次数 - _RATELIMIT_RETRY_ATTEMPTS = 3 # 请求被限速时的重试次数 - - _CHAT_SYSTEM_TEMPLATE_009 = ( - '你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。' - ) - _CHAT_SYSTEM_TEMPLATE_010 = ( - '你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,注意不要擅自添加原文中没有的代词,也不要擅自增加或减少换行。' - ) - - _LANGUAGE_CODE_MAP = { - 'CHS': 'Simplified Chinese', - 'JPN': 'Japanese' - } - - def __init__(self): - super().__init__() - if "/v1" not in SAKURA_API_BASE: - openai.api_base = SAKURA_API_BASE + "/v1" - else: - openai.api_base = SAKURA_API_BASE - openai.api_key = "sk-114514" - self.temperature = 0.3 - self.top_p = 0.3 - self.frequency_penalty = 0.1 - self._current_style = "normal" - self._emoji_pattern = re.compile(r'[\U00010000-\U0010ffff]') - self._heart_pattern = re.compile(r'❤') - self.sakura_dict = SakuraDict(self.get_dict_path()) - - def get_sakura_version(self): - return SAKURA_VERSION - - def get_dict_path(self): - return SAKURA_DICT_PATH - - def detect_and_remove_extra_repeats(self, s: str, threshold: int = 20, remove_all=True): - """ - 检测字符串中是否有任何模式连续重复出现超过阈值,并在去除多余重复后返回新字符串。 - 保留一个模式的重复。 - """ - repeated = False - for pattern_length in range(1, len(s) // 2 + 1): - i = 0 - while i < len(s) - pattern_length: - pattern = s[i:i + pattern_length] - count = 1 - j = i + pattern_length - while j <= len(s) - pattern_length: - if s[j:j + pattern_length] == pattern: - count += 1 - j += pattern_length - else: - break - if count >= threshold: - repeated = True - if remove_all: - s = s[:i + pattern_length] + s[j:] - break - i += 1 - if repeated: - break - return repeated, s - - def _format_prompt_log(self, prompt: str) -> str: - return '\n'.join([ - 'System:', - self._CHAT_SYSTEM_TEMPLATE, - 'User:', - '将下面的日文文本翻译成中文:', - prompt, - ]) - - def _split_text(self, text: str) -> List[str]: - """ - 将字符串按换行符分割为列表。 - """ - if isinstance(text, list): - return text - return text.split('\n') - - def _preprocess_queries(self, queries: List[str]) -> List[str]: - """ - 预处理查询文本,去除emoji,替换特殊字符,并添加「」标记。 - """ - queries = [self._emoji_pattern.sub('', query) for query in queries] - queries = [self._heart_pattern.sub('♥', query) for query in queries] - queries = [f'「{query}」' for query in queries] - return queries - - async def _check_translation_quality(self, queries: List[str], response: str) -> List[str]: - """ - 检查翻译结果的质量,包括重复和行数对齐问题,如果存在问题则尝试重新翻译或返回原始文本。 - """ - rep_flag = self._detect_repeats(response) - if rep_flag: - for i in range(self._RETRY_ATTEMPTS): - if self._detect_repeats(''.join(queries)): - self.logger.warning('Queries have repeats.') - break - self.logger.warning(f'Re-translating due to model degradation, attempt: {i + 1}') - self._set_gpt_style("precise") - response = await self._handle_translation_request(queries) - rep_flag = self._detect_repeats(response) - if not rep_flag: - break - if rep_flag: - self.logger.warning('Model degradation, translating single lines.') - return await self._translate_single_lines(queries) - - align_flag = self._check_align(queries, response) - if not align_flag: - for i in range(self._RETRY_ATTEMPTS): - self.logger.warning(f'Re-translating due to mismatched lines, attempt: {i + 1}') - self._set_gpt_style("precise") - response = await self._handle_translation_request(queries) - align_flag = self._check_align(queries, response) - if align_flag: - break - if not align_flag: - self.logger.warning('Mismatched lines, translating single lines.') - return await self._translate_single_lines(queries) - - return self._split_text(response) - - def _detect_repeats(self, text: str, threshold: int = 20) -> bool: - """ - 检测文本中是否存在重复模式。 - """ - _, text = self.detect_and_remove_extra_repeats(text, threshold) - return text != text - - def _check_align(self, queries: List[str], response: str) -> bool: - """ - 检查原始文本和翻译结果的行数是否对齐。 - """ - translations = self._split_text(response) - is_aligned = len(queries) == len(translations) - if not is_aligned: - self.logger.warning(f"Mismatched lines - Queries: {len(queries)}, Translations: {len(translations)}") - return is_aligned - - async def _translate_single_lines(self, queries: List[str]) -> List[str]: - """ - 逐行翻译查询文本。 - """ - translations = [] - for query in queries: - response = await self._handle_translation_request(query) - if self._detect_repeats(response): - self.logger.warning('Model degradation, using original text.') - translations.append(query) - else: - translations.append(response) - return translations - - def _delete_quotation_mark(self, texts: List[str]) -> List[str]: - """ - 删除文本中的「」标记。 - """ - new_texts = [] - for text in texts: - text = text.strip('「」') - new_texts.append(text) - return new_texts - - async def _translate(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]: - self.logger.debug(f'Temperature: {self.temperature}, TopP: {self.top_p}') - self.logger.debug(f'Queries: {queries}') - text_prompt = '\n'.join(queries) - self.logger.debug('-- Sakura Prompt --\n' + self._format_prompt_log(text_prompt) + '\n\n') - - # 预处理查询文本 - queries = self._preprocess_queries(queries) - - # 发送翻译请求 - response = await self._handle_translation_request(queries) - self.logger.debug('-- Sakura Response --\n' + response + '\n\n') - - # 检查翻译结果是否存在重复或行数不匹配的问题 - translations = await self._check_translation_quality(queries, response) - - return self._delete_quotation_mark(translations) - - async def _handle_translation_request(self, prompt: str) -> str: - """ - 处理翻译请求,包括错误处理和重试逻辑。 - """ - ratelimit_attempt = 0 - server_error_attempt = 0 - timeout_attempt = 0 - while True: - try: - request_task = asyncio.create_task(self._request_translation(prompt)) - response = await asyncio.wait_for(request_task, timeout=self._TIMEOUT) - break - except asyncio.TimeoutError: - timeout_attempt += 1 - if timeout_attempt >= self._TIMEOUT_RETRY_ATTEMPTS: - raise Exception('Sakura timeout.') - self.logger.warning(f'Restarting request due to timeout. Attempt: {timeout_attempt}') - except openai.error.RateLimitError: - ratelimit_attempt += 1 - if ratelimit_attempt >= self._RATELIMIT_RETRY_ATTEMPTS: - raise - self.logger.warning(f'Restarting request due to ratelimiting by sakura servers. Attempt: {ratelimit_attempt}') - await asyncio.sleep(2) - except (openai.error.APIError, openai.error.APIConnectionError) as e: - server_error_attempt += 1 - if server_error_attempt >= self._RETRY_ATTEMPTS: - self.logger.error(f'Sakura server error: {str(e)}. Returning original text.') - return prompt - self.logger.warning(f'Restarting request due to server error. Attempt: {server_error_attempt}') - - return response - - async def _request_translation(self, input_text_list) -> str: - """ - 向Sakura API发送翻译请求。 - """ - if isinstance(input_text_list, list): - raw_text = "\n".join(input_text_list) - else: - raw_text = input_text_list - extra_query = { - 'do_sample': False, - 'num_beams': 1, - 'repetition_penalty': 1.0, - } - if SAKURA_VERSION == "0.9": - messages=[ - { - "role": "system", - "content": f"{self._CHAT_SYSTEM_TEMPLATE_009}" - }, - { - "role": "user", - "content": f"将下面的日文文本翻译成中文:{raw_text}" - } - ] - else: - gpt_dict_raw_text = self.sakura_dict.get_dict() - self.logger.debug(f"Sakura Dict: {gpt_dict_raw_text}") - messages=[ - { - "role": "system", - "content": f"{self._CHAT_SYSTEM_TEMPLATE_010}" - }, - { - "role": "user", - "content": f"根据以下术语表:\n" + {gpt_dict_raw_text} + "\n" + f"将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:{raw_text}" - } - ] - response = await openai.ChatCompletion.acreate( - model="sukinishiro", - messages=messages, - temperature=self.temperature, - top_p=self.top_p, - max_tokens=1024, - frequency_penalty=self.frequency_penalty, - seed=-1, - extra_query=extra_query, - ) - # 提取并返回响应文本 - for choice in response.choices: - if 'text' in choice: - return choice.text - - return response.choices[0].message.content - - def _set_gpt_style(self, style_name: str): - """ - 设置GPT的生成风格。 - """ - if self._current_style == style_name: - return - self._current_style = style_name - if style_name == "precise": - temperature, top_p = 0.1, 0.3 - frequency_penalty = 0.1 - elif style_name == "normal": - temperature, top_p = 0.3, 0.3 - frequency_penalty = 0.15 - - self.temperature = temperature - self.top_p = top_p - self.frequency_penalty = frequency_penalty diff --git a/manga_translator/translators/sakura.py b/manga_translator/translators/sakura.py index 983694f5..0ed0a8f6 100644 --- a/manga_translator/translators/sakura.py +++ b/manga_translator/translators/sakura.py @@ -11,7 +11,6 @@ except ImportError: openai = None import asyncio -import time from typing import List, Dict from .common import CommonTranslator @@ -20,11 +19,11 @@ import logging class SakuraDict(): - def __init__(self, path: str): + def __init__(self, path: str, logger: logging.Logger): + self.logger = logger self.path = path - self.dict = {} - self.logger = logging.getLogger(__name__) - self.logger.setLevel(logging.INFO) + self.dict_str = "" + self.dict_str = self.get_dict_from_file(path) def load_galtransl_dic(self, dic_path: str): """ @@ -73,7 +72,7 @@ def load_galtransl_dic(self, dic_path: str): gpt_dict_text_list.append(single) gpt_dict_raw_text = "\n".join(gpt_dict_text_list) - self.dict = gpt_dict_raw_text + self.dict_str = gpt_dict_raw_text self.logger.info( f"载入 Galtransl 字典: {dic_name} {normalDic_count}普通词条" ) @@ -85,6 +84,9 @@ def load_sakura_dict(self, dic_path: str): with open(dic_path, encoding="utf8") as f: dic_lines = f.readlines() + + self.logger.debug(f"载入Sakura字典: {dic_path}") + self.logger.debug(f"载入Sakura字典: {dic_lines}") if len(dic_lines) == 0: return dic_path = os.path.abspath(dic_path) @@ -98,15 +100,16 @@ def load_sakura_dict(self, dic_path: str): elif line.startswith("\\\\") or line.startswith("//"): # 注释行跳过 continue - sp = line.rstrip("\r\n").split("\t") # 去多余换行符,Tab分割 + sp = line.rstrip("\r\n").split("->") # 去多余换行符,->分割 len_sp = len(sp) if len_sp < 2: # 至少是2个元素 continue src = sp[0] - dst = sp[1] - info = sp[2] if len_sp > 2 else None + dst_info = sp[1].split("#") # 使用#分割目标和信息 + dst = dst_info[0].strip() + info = dst_info[1].strip() if len(dst_info) > 1 else None if info: single = f"{src}->{dst} #{info}" else: @@ -115,8 +118,8 @@ def load_sakura_dict(self, dic_path: str): normalDic_count += 1 gpt_dict_raw_text = "\n".join(gpt_dict_text_list) - self.dict = gpt_dict_raw_text - self.logger.info( + self.dict_str = gpt_dict_raw_text + self.logger.debug( f"载入标准Sakura字典: {dic_name} {normalDic_count}普通词条" ) @@ -126,6 +129,7 @@ def detect_type(self, dic_path: str): """ with open(dic_path, encoding="utf8") as f: dic_lines = f.readlines() + self.logger.debug(f"检测字典类型: {dic_path}") if len(dic_lines) == 0: return "unknown" @@ -137,7 +141,7 @@ def detect_type(self, dic_path: str): elif line.startswith("\\\\") or line.startswith("//"): continue - if "\t" not in line: + if "\t" not in line and " " not in line: is_galtransl = False break @@ -161,14 +165,14 @@ def detect_type(self, dic_path: str): return "unknown" - def get_dict(self): + def get_dict_str(self): """ 获取字典内容。 """ - if self.dict == {}: + if self.dict_str == "": self.logger.warning("字典为空") - return {} - return self.dict + return "" + return self.dict_str def get_dict_from_file(self, dic_path: str): """ @@ -181,7 +185,7 @@ def get_dict_from_file(self, dic_path: str): self.load_sakura_dict(dic_path) else: self.logger.warning(f"未知的字典类型: {dic_path}") - return self.get_dict() + return self.get_dict_str() class SakuraTranslator(CommonTranslator): @@ -216,7 +220,7 @@ def __init__(self): self._current_style = "normal" self._emoji_pattern = re.compile(r'[\U00010000-\U0010ffff]') self._heart_pattern = re.compile(r'❤') - self.sakura_dict = SakuraDict(self.get_dict_path()) + self.sakura_dict = SakuraDict(self.get_dict_path(), self.logger) def get_sakura_version(self): return SAKURA_VERSION @@ -253,14 +257,28 @@ def detect_and_remove_extra_repeats(self, s: str, threshold: int = 20, remove_al return repeated, s def _format_prompt_log(self, prompt: str) -> str: - return '\n'.join([ + """ + 格式化日志输出的提示文本。 + """ + gpt_dict_raw_text = self.sakura_dict.get_dict_str() + prompt_009 = '\n'.join([ 'System:', - self._CHAT_SYSTEM_TEMPLATE, + self._CHAT_SYSTEM_TEMPLATE_009, 'User:', '将下面的日文文本翻译成中文:', prompt, ]) - + prompt_010 = '\n'.join([ + 'System:', + self._CHAT_SYSTEM_TEMPLATE_010, + 'User:', + "根据以下术语表:", + gpt_dict_raw_text, + "将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:", + prompt, + ]) + return prompt_009 if SAKURA_VERSION == '0.0.9' else prompt_010 + def _split_text(self, text: str) -> List[str]: """ 将字符串按换行符分割为列表。 @@ -429,7 +447,7 @@ async def _request_translation(self, input_text_list) -> str: } ] else: - gpt_dict_raw_text = self.sakura_dict.get_dict() + gpt_dict_raw_text = self.sakura_dict.get_dict_str() self.logger.debug(f"Sakura Dict: {gpt_dict_raw_text}") messages=[ { @@ -438,7 +456,7 @@ async def _request_translation(self, input_text_list) -> str: }, { "role": "user", - "content": f"根据以下术语表:\n" + {gpt_dict_raw_text} + "\n" + f"将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:{raw_text}" + "content": f"根据以下术语表:\n{gpt_dict_raw_text}\n将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:{raw_text}" } ] response = await openai.ChatCompletion.acreate( From 2853533c2f68dc6c9de63a9a8426c6e9877ced76 Mon Sep 17 00:00:00 2001 From: PiDanShouRouZhouXD <38401147+PiDanShouRouZhouXD@users.noreply.github.com> Date: Mon, 15 Apr 2024 02:08:48 +0800 Subject: [PATCH 3/8] Modified the translation retry logic to better handle the degradation of the Sakura model; changed all logging outputs to Chinese. --- manga_translator/translators/sakura.py | 140 +++++++++++++------------ 1 file changed, 72 insertions(+), 68 deletions(-) diff --git a/manga_translator/translators/sakura.py b/manga_translator/translators/sakura.py index 0ed0a8f6..0992be9b 100644 --- a/manga_translator/translators/sakura.py +++ b/manga_translator/translators/sakura.py @@ -1,3 +1,4 @@ +from calendar import c from distutils.cygwinccompiler import get_versions import re import os @@ -11,19 +12,23 @@ except ImportError: openai = None import asyncio -from typing import List, Dict +from typing import List, Dict, Callable, Tuple from .common import CommonTranslator from .keys import SAKURA_API_BASE, SAKURA_VERSION, SAKURA_DICT_PATH import logging + class SakuraDict(): def __init__(self, path: str, logger: logging.Logger): self.logger = logger self.path = path self.dict_str = "" - self.dict_str = self.get_dict_from_file(path) + if SAKURA_VERSION == '0.10': + self.dict_str = self.get_dict_from_file(path) + if SAKURA_VERSION == '0.9': + self.logger.info("您当前选择了Sakura 0.9版本,暂不支持术语表") def load_galtransl_dic(self, dic_path: str): """ @@ -85,8 +90,6 @@ def load_sakura_dict(self, dic_path: str): with open(dic_path, encoding="utf8") as f: dic_lines = f.readlines() - self.logger.debug(f"载入Sakura字典: {dic_path}") - self.logger.debug(f"载入Sakura字典: {dic_lines}") if len(dic_lines) == 0: return dic_path = os.path.abspath(dic_path) @@ -119,7 +122,7 @@ def load_sakura_dict(self, dic_path: str): gpt_dict_raw_text = "\n".join(gpt_dict_text_list) self.dict_str = gpt_dict_raw_text - self.logger.debug( + self.logger.info( f"载入标准Sakura字典: {dic_name} {normalDic_count}普通词条" ) @@ -191,9 +194,10 @@ def get_dict_from_file(self, dic_path: str): class SakuraTranslator(CommonTranslator): _TIMEOUT = 999 # 等待服务器响应的超时时间(秒) - _RETRY_ATTEMPTS = 1 # 请求出错时的重试次数 + _RETRY_ATTEMPTS = 3 # 请求出错时的重试次数 _TIMEOUT_RETRY_ATTEMPTS = 3 # 请求超时时的重试次数 _RATELIMIT_RETRY_ATTEMPTS = 3 # 请求被限速时的重试次数 + _REPEAT_DETECT_THRESHOLD = 11 # 重复检测的阈值 _CHAT_SYSTEM_TEMPLATE_009 = ( '你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。' @@ -228,10 +232,10 @@ def get_sakura_version(self): def get_dict_path(self): return SAKURA_DICT_PATH - def detect_and_remove_extra_repeats(self, s: str, threshold: int = 20, remove_all=True): + def detect_and_caculate_repeats(self, s: str, threshold: int = _REPEAT_DETECT_THRESHOLD, remove_all=True) -> Tuple[bool, str, int, str]: """ - 检测字符串中是否有任何模式连续重复出现超过阈值,并在去除多余重复后返回新字符串。 - 保留一个模式的重复。 + 检测文本中是否存在重复模式,并计算重复次数。 + 返回值: (是否重复, 去除重复后的文本, 重复次数, 重复模式) """ repeated = False for pattern_length in range(1, len(s) // 2 + 1): @@ -247,6 +251,7 @@ def detect_and_remove_extra_repeats(self, s: str, threshold: int = 20, remove_al else: break if count >= threshold: + self.logger.warning(f"检测到重复模式: {pattern},重复次数: {count}") repeated = True if remove_all: s = s[:i + pattern_length] + s[j:] @@ -254,7 +259,7 @@ def detect_and_remove_extra_repeats(self, s: str, threshold: int = 20, remove_al i += 1 if repeated: break - return repeated, s + return repeated, s, count, pattern def _format_prompt_log(self, prompt: str) -> str: """ @@ -277,8 +282,8 @@ def _format_prompt_log(self, prompt: str) -> str: "将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:", prompt, ]) - return prompt_009 if SAKURA_VERSION == '0.0.9' else prompt_010 - + return prompt_009 if SAKURA_VERSION == '0.9' else prompt_010 + def _split_text(self, text: str) -> List[str]: """ 将字符串按换行符分割为列表。 @@ -300,43 +305,39 @@ async def _check_translation_quality(self, queries: List[str], response: str) -> """ 检查翻译结果的质量,包括重复和行数对齐问题,如果存在问题则尝试重新翻译或返回原始文本。 """ - rep_flag = self._detect_repeats(response) - if rep_flag: + async def _retry_translation(queries: List[str], check_func: Callable[[str], bool], error_message: str) -> str: + styles = ["precise", "normal", "aggressive", ] for i in range(self._RETRY_ATTEMPTS): - if self._detect_repeats(''.join(queries)): - self.logger.warning('Queries have repeats.') - break - self.logger.warning(f'Re-translating due to model degradation, attempt: {i + 1}') - self._set_gpt_style("precise") + self._set_gpt_style(styles[i]) + self.logger.warning(f'{error_message} 尝试次数: {i + 1}。当前参数风格:{self._current_style}。') response = await self._handle_translation_request(queries) - rep_flag = self._detect_repeats(response) - if not rep_flag: - break - if rep_flag: - self.logger.warning('Model degradation, translating single lines.') - return await self._translate_single_lines(queries) + if not check_func(response): + return response + return None - align_flag = self._check_align(queries, response) - if not align_flag: - for i in range(self._RETRY_ATTEMPTS): - self.logger.warning(f'Re-translating due to mismatched lines, attempt: {i + 1}') - self._set_gpt_style("precise") - response = await self._handle_translation_request(queries) - align_flag = self._check_align(queries, response) - if align_flag: - break - if not align_flag: - self.logger.warning('Mismatched lines, translating single lines.') + if self._detect_repeats(response): + if self._detect_repeats(''.join(queries)): + self.logger.warning('请求内容本身含有超过阈值的重复内容。') + else: + response = await _retry_translation(queries, self._detect_repeats, f'因为检测到大量重复内容(当前阈值:{self._REPEAT_DETECT_THRESHOLD}),疑似模型退化,重新翻译。') + if response is None: + self.logger.warning(f'疑似模型退化,尝试{self._RETRY_ATTEMPTS}次仍未解决,进行单行翻译。') + return await self._translate_single_lines(queries) + + if not self._check_align(queries, response): + response = await _retry_translation(queries, lambda r: not self._check_align(queries, r), '因为检测到原文与译文行数不匹配,重新翻译。') + if response is None: + self.logger.warning(f'原文与译文行数不匹配,尝试{self._RETRY_ATTEMPTS}次仍未解决,进行单行翻译。') return await self._translate_single_lines(queries) return self._split_text(response) - def _detect_repeats(self, text: str, threshold: int = 20) -> bool: + def _detect_repeats(self, text: str, threshold: int = _REPEAT_DETECT_THRESHOLD) -> bool: """ 检测文本中是否存在重复模式。 """ - _, text = self.detect_and_remove_extra_repeats(text, threshold) - return text != text + is_repeated, text, count, pattern = self.detect_and_caculate_repeats(text, threshold, remove_all=False) + return is_repeated def _check_align(self, queries: List[str], response: str) -> bool: """ @@ -345,7 +346,7 @@ def _check_align(self, queries: List[str], response: str) -> bool: translations = self._split_text(response) is_aligned = len(queries) == len(translations) if not is_aligned: - self.logger.warning(f"Mismatched lines - Queries: {len(queries)}, Translations: {len(translations)}") + self.logger.warning(f"行数不匹配 - 原文行数: {len(queries)},译文行数: {len(translations)}") return is_aligned async def _translate_single_lines(self, queries: List[str]) -> List[str]: @@ -356,7 +357,7 @@ async def _translate_single_lines(self, queries: List[str]) -> List[str]: for query in queries: response = await self._handle_translation_request(query) if self._detect_repeats(response): - self.logger.warning('Model degradation, using original text.') + self.logger.warning(f"单行翻译结果存在重复内容: {response},返回原文。") translations.append(query) else: translations.append(response) @@ -374,7 +375,7 @@ def _delete_quotation_mark(self, texts: List[str]) -> List[str]: async def _translate(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]: self.logger.debug(f'Temperature: {self.temperature}, TopP: {self.top_p}') - self.logger.debug(f'Queries: {queries}') + self.logger.debug(f'原文: {queries}') text_prompt = '\n'.join(queries) self.logger.debug('-- Sakura Prompt --\n' + self._format_prompt_log(text_prompt) + '\n\n') @@ -405,20 +406,20 @@ async def _handle_translation_request(self, prompt: str) -> str: except asyncio.TimeoutError: timeout_attempt += 1 if timeout_attempt >= self._TIMEOUT_RETRY_ATTEMPTS: - raise Exception('Sakura timeout.') - self.logger.warning(f'Restarting request due to timeout. Attempt: {timeout_attempt}') + raise Exception('Sakura超时。') + self.logger.warning(f'Sakura因超时而进行重试。尝试次数: {timeout_attempt}') except openai.error.RateLimitError: ratelimit_attempt += 1 if ratelimit_attempt >= self._RATELIMIT_RETRY_ATTEMPTS: raise - self.logger.warning(f'Restarting request due to ratelimiting by sakura servers. Attempt: {ratelimit_attempt}') + self.logger.warning(f'Sakura因被限速而进行重试。尝试次数: {ratelimit_attempt}') await asyncio.sleep(2) except (openai.error.APIError, openai.error.APIConnectionError) as e: server_error_attempt += 1 if server_error_attempt >= self._RETRY_ATTEMPTS: - self.logger.error(f'Sakura server error: {str(e)}. Returning original text.') + self.logger.error(f'Sakura API请求失败。错误信息: {e}') return prompt - self.logger.warning(f'Restarting request due to server error. Attempt: {server_error_attempt}') + self.logger.warning(f'Sakura因服务器错误而进行重试。尝试次数: {server_error_attempt},错误信息: {e}') return response @@ -436,29 +437,29 @@ async def _request_translation(self, input_text_list) -> str: 'repetition_penalty': 1.0, } if SAKURA_VERSION == "0.9": - messages=[ - { - "role": "system", - "content": f"{self._CHAT_SYSTEM_TEMPLATE_009}" - }, - { - "role": "user", - "content": f"将下面的日文文本翻译成中文:{raw_text}" - } - ] + messages = [ + { + "role": "system", + "content": f"{self._CHAT_SYSTEM_TEMPLATE_009}" + }, + { + "role": "user", + "content": f"将下面的日文文本翻译成中文:{raw_text}" + } + ] else: gpt_dict_raw_text = self.sakura_dict.get_dict_str() self.logger.debug(f"Sakura Dict: {gpt_dict_raw_text}") - messages=[ - { - "role": "system", - "content": f"{self._CHAT_SYSTEM_TEMPLATE_010}" - }, - { - "role": "user", - "content": f"根据以下术语表:\n{gpt_dict_raw_text}\n将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:{raw_text}" - } - ] + messages = [ + { + "role": "system", + "content": f"{self._CHAT_SYSTEM_TEMPLATE_010}" + }, + { + "role": "user", + "content": f"根据以下术语表:\n{gpt_dict_raw_text}\n将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:{raw_text}" + } + ] response = await openai.ChatCompletion.acreate( model="sukinishiro", messages=messages, @@ -488,7 +489,10 @@ def _set_gpt_style(self, style_name: str): frequency_penalty = 0.1 elif style_name == "normal": temperature, top_p = 0.3, 0.3 - frequency_penalty = 0.15 + frequency_penalty = 0.2 + elif style_name == "aggressive": + temperature, top_p = 0.1, 0.3 + frequency_penalty = 0.3 self.temperature = temperature self.top_p = top_p From 488983344d9c0e5ac0ad6a18665999af4b62e783 Mon Sep 17 00:00:00 2001 From: PiDanShouRouZhouXD <38401147+PiDanShouRouZhouXD@users.noreply.github.com> Date: Mon, 15 Apr 2024 02:32:58 +0800 Subject: [PATCH 4/8] Added a degradation threshold determination method that takes the maximum value between the mode of the original text repetition and a predefined value (20). --- manga_translator/translators/sakura.py | 41 +++++++++++++++++++------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/manga_translator/translators/sakura.py b/manga_translator/translators/sakura.py index 0992be9b..b172f14c 100644 --- a/manga_translator/translators/sakura.py +++ b/manga_translator/translators/sakura.py @@ -172,6 +172,9 @@ def get_dict_str(self): """ 获取字典内容。 """ + if SAKURA_VERSION == '0.9': + self.logger.info("您当前选择了Sakura 0.9版本,暂不支持术语表") + return "" if self.dict_str == "": self.logger.warning("字典为空") return "" @@ -197,7 +200,7 @@ class SakuraTranslator(CommonTranslator): _RETRY_ATTEMPTS = 3 # 请求出错时的重试次数 _TIMEOUT_RETRY_ATTEMPTS = 3 # 请求超时时的重试次数 _RATELIMIT_RETRY_ATTEMPTS = 3 # 请求被限速时的重试次数 - _REPEAT_DETECT_THRESHOLD = 11 # 重复检测的阈值 + _REPEAT_DETECT_THRESHOLD = 20 # 重复检测的阈值 _CHAT_SYSTEM_TEMPLATE_009 = ( '你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。' @@ -238,6 +241,7 @@ def detect_and_caculate_repeats(self, s: str, threshold: int = _REPEAT_DETECT_TH 返回值: (是否重复, 去除重复后的文本, 重复次数, 重复模式) """ repeated = False + counts = [] for pattern_length in range(1, len(s) // 2 + 1): i = 0 while i < len(s) - pattern_length: @@ -250,6 +254,7 @@ def detect_and_caculate_repeats(self, s: str, threshold: int = _REPEAT_DETECT_TH j += pattern_length else: break + counts.append(count) if count >= threshold: self.logger.warning(f"检测到重复模式: {pattern},重复次数: {count}") repeated = True @@ -259,7 +264,17 @@ def detect_and_caculate_repeats(self, s: str, threshold: int = _REPEAT_DETECT_TH i += 1 if repeated: break - return repeated, s, count, pattern + + # 计算重复次数的众数 + if counts: + mode_count = max(set(counts), key=counts.count) + else: + mode_count = 0 + + # 根据默认阈值和众数计算实际阈值 + actual_threshold = max(threshold, mode_count) + + return repeated, s, count, pattern, actual_threshold def _format_prompt_log(self, prompt: str) -> str: """ @@ -315,14 +330,18 @@ async def _retry_translation(queries: List[str], check_func: Callable[[str], boo return response return None - if self._detect_repeats(response): - if self._detect_repeats(''.join(queries)): - self.logger.warning('请求内容本身含有超过阈值的重复内容。') - else: - response = await _retry_translation(queries, self._detect_repeats, f'因为检测到大量重复内容(当前阈值:{self._REPEAT_DETECT_THRESHOLD}),疑似模型退化,重新翻译。') - if response is None: - self.logger.warning(f'疑似模型退化,尝试{self._RETRY_ATTEMPTS}次仍未解决,进行单行翻译。') - return await self._translate_single_lines(queries) + # 检查请求内容是否含有超过默认阈值的重复内容 + if self._detect_repeats(''.join(queries), self._REPEAT_DETECT_THRESHOLD): + self.logger.warning(f'请求内容本身含有超过默认阈值{self._REPEAT_DETECT_THRESHOLD}的重复内容。') + + # 根据译文众数和默认阈值计算实际阈值 + _, _, _, _, actual_threshold = self.detect_and_caculate_repeats(response) + + if self._detect_repeats(response, actual_threshold): + response = await _retry_translation(queries, lambda r: self._detect_repeats(r, actual_threshold), f'检测到大量重复内容(当前阈值:{actual_threshold}),疑似模型退化,重新翻译。') + if response is None: + self.logger.warning(f'疑似模型退化,尝试{self._RETRY_ATTEMPTS}次仍未解决,进行单行翻译。') + return await self._translate_single_lines(queries) if not self._check_align(queries, response): response = await _retry_translation(queries, lambda r: not self._check_align(queries, r), '因为检测到原文与译文行数不匹配,重新翻译。') @@ -336,7 +355,7 @@ def _detect_repeats(self, text: str, threshold: int = _REPEAT_DETECT_THRESHOLD) """ 检测文本中是否存在重复模式。 """ - is_repeated, text, count, pattern = self.detect_and_caculate_repeats(text, threshold, remove_all=False) + is_repeated, text, count, pattern, actual_threshold = self.detect_and_caculate_repeats(text, threshold, remove_all=False) return is_repeated def _check_align(self, queries: List[str], response: str) -> bool: From 24a47789d652c247da8c895f08832e56fb3de660 Mon Sep 17 00:00:00 2001 From: PiDanShouRouZhouXD <38401147+PiDanShouRouZhouXD@users.noreply.github.com> Date: Mon, 15 Apr 2024 15:12:01 +0800 Subject: [PATCH 5/8] fix repeat caculate --- manga_translator/translators/sakura.py | 60 ++++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 4 deletions(-) diff --git a/manga_translator/translators/sakura.py b/manga_translator/translators/sakura.py index b172f14c..ab77f9f1 100644 --- a/manga_translator/translators/sakura.py +++ b/manga_translator/translators/sakura.py @@ -224,7 +224,7 @@ def __init__(self): self.temperature = 0.3 self.top_p = 0.3 self.frequency_penalty = 0.1 - self._current_style = "normal" + self._current_style = "precise" self._emoji_pattern = re.compile(r'[\U00010000-\U0010ffff]') self._heart_pattern = re.compile(r'❤') self.sakura_dict = SakuraDict(self.get_dict_path(), self.logger) @@ -276,6 +276,49 @@ def detect_and_caculate_repeats(self, s: str, threshold: int = _REPEAT_DETECT_TH return repeated, s, count, pattern, actual_threshold + @staticmethod + def enlarge_small_kana(text, ignore=''): + """将小写平假名或片假名转换为普通大小 + + 参数 + ---------- + text : str + 全角平假名或片假名字符串。 + ignore : str, 可选 + 转换时要忽略的字符。 + + 返回 + ------ + str + 平假名或片假名字符串,小写假名已转换为大写 + + 示例 + -------- + >>> print(enlarge_small_kana('さくらきょうこ')) + さくらきようこ + >>> print(enlarge_small_kana('キュゥべえ')) + キユウべえ + """ + SMALL_KANA = list('ぁぃぅぇぉゃゅょっァィゥェォヵヶャュョッ') + SMALL_KANA_NORMALIZED = list('あいうえおやゆよつアイウエオカケヤユヨツ') + SMALL_KANA2BIG_KANA = dict(zip(map(ord, SMALL_KANA), SMALL_KANA_NORMALIZED)) + + def _exclude_ignorechar(ignore, conv_map): + for character in map(ord, ignore): + del conv_map[character] + return conv_map + + def _convert(text, conv_map): + return text.translate(conv_map) + + def _translate(text, ignore, conv_map): + if ignore: + _conv_map = _exclude_ignorechar(ignore, conv_map.copy()) + return _convert(text, _conv_map) + return _convert(text, conv_map) + + return _translate(text, ignore, SMALL_KANA2BIG_KANA) + def _format_prompt_log(self, prompt: str) -> str: """ 格式化日志输出的提示文本。 @@ -311,9 +354,11 @@ def _preprocess_queries(self, queries: List[str]) -> List[str]: """ 预处理查询文本,去除emoji,替换特殊字符,并添加「」标记。 """ + queries = [self.enlarge_small_kana(query) for query in queries] queries = [self._emoji_pattern.sub('', query) for query in queries] queries = [self._heart_pattern.sub('♥', query) for query in queries] queries = [f'「{query}」' for query in queries] + self.logger.debug(f'预处理后的查询文本:{queries}') return queries async def _check_translation_quality(self, queries: List[str], response: str) -> List[str]: @@ -335,7 +380,7 @@ async def _retry_translation(queries: List[str], check_func: Callable[[str], boo self.logger.warning(f'请求内容本身含有超过默认阈值{self._REPEAT_DETECT_THRESHOLD}的重复内容。') # 根据译文众数和默认阈值计算实际阈值 - _, _, _, _, actual_threshold = self.detect_and_caculate_repeats(response) + actual_threshold = max(max(self._get_repeat_count(query) for query in queries), self._REPEAT_DETECT_THRESHOLD) if self._detect_repeats(response, actual_threshold): response = await _retry_translation(queries, lambda r: self._detect_repeats(r, actual_threshold), f'检测到大量重复内容(当前阈值:{actual_threshold}),疑似模型退化,重新翻译。') @@ -358,6 +403,13 @@ def _detect_repeats(self, text: str, threshold: int = _REPEAT_DETECT_THRESHOLD) is_repeated, text, count, pattern, actual_threshold = self.detect_and_caculate_repeats(text, threshold, remove_all=False) return is_repeated + def _get_repeat_count(self, text: str, threshold: int = _REPEAT_DETECT_THRESHOLD) -> bool: + """ + 计算文本中重复模式的次数。 + """ + is_repeated, text, count, pattern, actual_threshold = self.detect_and_caculate_repeats(text, threshold, remove_all=False) + return count + def _check_align(self, queries: List[str], response: str) -> bool: """ 检查原始文本和翻译结果的行数是否对齐。 @@ -505,12 +557,12 @@ def _set_gpt_style(self, style_name: str): self._current_style = style_name if style_name == "precise": temperature, top_p = 0.1, 0.3 - frequency_penalty = 0.1 + frequency_penalty = 0.05 elif style_name == "normal": temperature, top_p = 0.3, 0.3 frequency_penalty = 0.2 elif style_name == "aggressive": - temperature, top_p = 0.1, 0.3 + temperature, top_p = 0.3, 0.3 frequency_penalty = 0.3 self.temperature = temperature From 6d6b8d47180cf15dce7f35314d9d262490dcae2a Mon Sep 17 00:00:00 2001 From: PiDanShouRouZhouXD <38401147+PiDanShouRouZhouXD@users.noreply.github.com> Date: Tue, 16 Apr 2024 00:04:38 +0800 Subject: [PATCH 6/8] fix --- manga_translator/translators/sakura.py | 27 ++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/manga_translator/translators/sakura.py b/manga_translator/translators/sakura.py index ab77f9f1..d4bae717 100644 --- a/manga_translator/translators/sakura.py +++ b/manga_translator/translators/sakura.py @@ -21,13 +21,19 @@ class SakuraDict(): - def __init__(self, path: str, logger: logging.Logger): + def __init__(self, path: str, logger: logging.Logger, version: str = "0.9") -> None: self.logger = logger - self.path = path self.dict_str = "" - if SAKURA_VERSION == '0.10': + self.version = version + if not os.path.exists(path): + if self.version == '0.10': + self.logger.warning(f"字典文件不存在: {path}") + return + else: + self.path = path + if self.version == '0.10': self.dict_str = self.get_dict_from_file(path) - if SAKURA_VERSION == '0.9': + if self.version == '0.9': self.logger.info("您当前选择了Sakura 0.9版本,暂不支持术语表") def load_galtransl_dic(self, dic_path: str): @@ -172,12 +178,17 @@ def get_dict_str(self): """ 获取字典内容。 """ - if SAKURA_VERSION == '0.9': + if self.version == '0.9': self.logger.info("您当前选择了Sakura 0.9版本,暂不支持术语表") return "" if self.dict_str == "": - self.logger.warning("字典为空") - return "" + try: + self.dict_str = self.get_dict_from_file(self.path) + return self.dict_str + except Exception as e: + if self.version == '0.10': + self.logger.warning(f"载入字典失败: {e}") + return "" return self.dict_str def get_dict_from_file(self, dic_path: str): @@ -227,7 +238,7 @@ def __init__(self): self._current_style = "precise" self._emoji_pattern = re.compile(r'[\U00010000-\U0010ffff]') self._heart_pattern = re.compile(r'❤') - self.sakura_dict = SakuraDict(self.get_dict_path(), self.logger) + self.sakura_dict = SakuraDict(self.get_dict_path(), self.logger, SAKURA_VERSION) def get_sakura_version(self): return SAKURA_VERSION From 3a95a707c3282a21f1ca35b5f926e7407f47039a Mon Sep 17 00:00:00 2001 From: PiDanShouRouZhouXD <38401147+PiDanShouRouZhouXD@users.noreply.github.com> Date: Tue, 16 Apr 2024 00:26:19 +0800 Subject: [PATCH 7/8] minor fix --- manga_translator/translators/sakura.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/manga_translator/translators/sakura.py b/manga_translator/translators/sakura.py index d4bae717..0dfeeacd 100644 --- a/manga_translator/translators/sakura.py +++ b/manga_translator/translators/sakura.py @@ -1,11 +1,7 @@ -from calendar import c -from distutils.cygwinccompiler import get_versions import re import os from venv import logger -from httpx import get - try: import openai import openai.error @@ -275,16 +271,16 @@ def detect_and_caculate_repeats(self, s: str, threshold: int = _REPEAT_DETECT_TH i += 1 if repeated: break - + # 计算重复次数的众数 if counts: mode_count = max(set(counts), key=counts.count) else: mode_count = 0 - + # 根据默认阈值和众数计算实际阈值 actual_threshold = max(threshold, mode_count) - + return repeated, s, count, pattern, actual_threshold @staticmethod @@ -327,7 +323,7 @@ def _translate(text, ignore, conv_map): _conv_map = _exclude_ignorechar(ignore, conv_map.copy()) return _convert(text, _conv_map) return _convert(text, conv_map) - + return _translate(text, ignore, SMALL_KANA2BIG_KANA) def _format_prompt_log(self, prompt: str) -> str: @@ -389,10 +385,10 @@ async def _retry_translation(queries: List[str], check_func: Callable[[str], boo # 检查请求内容是否含有超过默认阈值的重复内容 if self._detect_repeats(''.join(queries), self._REPEAT_DETECT_THRESHOLD): self.logger.warning(f'请求内容本身含有超过默认阈值{self._REPEAT_DETECT_THRESHOLD}的重复内容。') - - # 根据译文众数和默认阈值计算实际阈值 + + # 根据译文众数和默认阈值计算实际阈值 actual_threshold = max(max(self._get_repeat_count(query) for query in queries), self._REPEAT_DETECT_THRESHOLD) - + if self._detect_repeats(response, actual_threshold): response = await _retry_translation(queries, lambda r: self._detect_repeats(r, actual_threshold), f'检测到大量重复内容(当前阈值:{actual_threshold}),疑似模型退化,重新翻译。') if response is None: From 6b3df6bbcff2b2d4b135f4ada5509637863630e0 Mon Sep 17 00:00:00 2001 From: PiDanShouRouZhouXD <38401147+PiDanShouRouZhouXD@users.noreply.github.com> Date: Wed, 17 Apr 2024 22:24:20 +0800 Subject: [PATCH 8/8] Add dynamic max_token --- manga_translator/translators/sakura.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/manga_translator/translators/sakura.py b/manga_translator/translators/sakura.py index 0dfeeacd..1a9acc2a 100644 --- a/manga_translator/translators/sakura.py +++ b/manga_translator/translators/sakura.py @@ -509,6 +509,9 @@ async def _request_translation(self, input_text_list) -> str: raw_text = "\n".join(input_text_list) else: raw_text = input_text_list + raw_lenth = len(raw_text) + max_lenth = 512 + max_token_num = max(raw_lenth*2, max_lenth) extra_query = { 'do_sample': False, 'num_beams': 1, @@ -543,7 +546,7 @@ async def _request_translation(self, input_text_list) -> str: messages=messages, temperature=self.temperature, top_p=self.top_p, - max_tokens=1024, + max_tokens=max_token_num, frequency_penalty=self.frequency_penalty, seed=-1, extra_query=extra_query,