diff --git a/manga_translator/translators/__init__.py b/manga_translator/translators/__init__.py index 2c9a4528..768b2c3b 100644 --- a/manga_translator/translators/__init__.py +++ b/manga_translator/translators/__init__.py @@ -15,7 +15,7 @@ from .selective import SelectiveOfflineTranslator, prepare as prepare_selective_translator from .none import NoneTranslator from .original import OriginalTranslator -from .sakura import Sakura13BTranslator +from .sakura import SakuraTranslator OFFLINE_TRANSLATORS = { 'offline': SelectiveOfflineTranslator, @@ -41,7 +41,7 @@ 'gpt4': GPT4Translator, 'none': NoneTranslator, 'original': OriginalTranslator, - 'sakura': Sakura13BTranslator, + 'sakura': SakuraTranslator, **OFFLINE_TRANSLATORS, } translator_cache = {} diff --git a/manga_translator/translators/keys.py b/manga_translator/translators/keys.py index 0cfa9dfb..ef0678fc 100644 --- a/manga_translator/translators/keys.py +++ b/manga_translator/translators/keys.py @@ -15,10 +15,11 @@ OPENAI_HTTP_PROXY = os.getenv('OPENAI_HTTP_PROXY') # TODO: Replace with --proxy OPENAI_API_BASE = os.getenv('OPENAI_API_BASE', 'https://api.openai.com/v1') #使用api-for-open-llm例子 http://127.0.0.1:8000/v1 -SAKURA_API_BASE = os.getenv('SAKURA_API_BASE', 'http://127.0.0.1:8080/v1') #SAKURA API地址 +# sakura +SAKURA_API_BASE = os.getenv('SAKURA_API_BASE', 'http://127.0.0.1:8080/v1') #SAKURA API地址 +SAKURA_VERSION = os.getenv('SAKURA_VERSION', '0.9') #SAKURA API版本,可选值:0.9、0.10,选择0.10则会加载术语表。 +SAKURA_DICT_PATH = os.getenv('SAKURA_DICT_PATH', './sakura_dict.txt') #SAKURA 术语表路径 -CAIYUN_TOKEN = os.getenv('CAIYUN_TOKEN', '') # 彩云小译API访问令牌 -SAKURA_API_KEY = os.getenv('SAKURA_API_KEY', '') -SAKURA_API_BASE = os.getenv('SAKURA_API_BASE', 'http://127.0.0.1:5000/v1') \ No newline at end of file +CAIYUN_TOKEN = os.getenv('CAIYUN_TOKEN', '') # 彩云小译API访问令牌 \ No newline at end of file diff --git a/manga_translator/translators/sakura.py b/manga_translator/translators/sakura.py index 0fee7989..1a9acc2a 100644 --- a/manga_translator/translators/sakura.py +++ b/manga_translator/translators/sakura.py @@ -1,4 +1,5 @@ import re +import os from venv import logger try: @@ -7,23 +8,213 @@ except ImportError: openai = None import asyncio -import time -from typing import List, Dict +from typing import List, Dict, Callable, Tuple from .common import CommonTranslator -from .keys import SAKURA_API_BASE, SAKURA_API_KEY +from .keys import SAKURA_API_BASE, SAKURA_VERSION, SAKURA_DICT_PATH +import logging -class Sakura13BTranslator(CommonTranslator): - _TIMEOUT = 999 # Seconds to wait for a response from the server before retrying - _RETRY_ATTEMPTS = 1 # Number of times to retry an errored request before giving up - _TIMEOUT_RETRY_ATTEMPTS = 3 # Number of times to retry a timed out request before giving up - _RATELIMIT_RETRY_ATTEMPTS = 3 # Number of times to retry a ratelimited request before giving up +class SakuraDict(): + def __init__(self, path: str, logger: logging.Logger, version: str = "0.9") -> None: + self.logger = logger + self.dict_str = "" + self.version = version + if not os.path.exists(path): + if self.version == '0.10': + self.logger.warning(f"字典文件不存在: {path}") + return + else: + self.path = path + if self.version == '0.10': + self.dict_str = self.get_dict_from_file(path) + if self.version == '0.9': + self.logger.info("您当前选择了Sakura 0.9版本,暂不支持术语表") + + def load_galtransl_dic(self, dic_path: str): + """ + 载入Galtransl词典。 + """ + + with open(dic_path, encoding="utf8") as f: + dic_lines = f.readlines() + if len(dic_lines) == 0: + return + dic_path = os.path.abspath(dic_path) + dic_name = os.path.basename(dic_path) + normalDic_count = 0 + + gpt_dict = [] + for line in dic_lines: + if line.startswith("\n"): + continue + elif line.startswith("\\\\") or line.startswith("//"): # 注释行跳过 + continue + + # 四个空格换成Tab + line = line.replace(" ", "\t") + + sp = line.rstrip("\r\n").split("\t") # 去多余换行符,Tab分割 + len_sp = len(sp) + + if len_sp < 2: # 至少是2个元素 + continue + + src = sp[0] + dst = sp[1] + info = sp[2] if len_sp > 2 else None + gpt_dict.append({"src": src, "dst": dst, "info": info}) + normalDic_count += 1 + + gpt_dict_text_list = [] + for gpt in gpt_dict: + src = gpt['src'] + dst = gpt['dst'] + info = gpt['info'] if "info" in gpt.keys() else None + if info: + single = f"{src}->{dst} #{info}" + else: + single = f"{src}->{dst}" + gpt_dict_text_list.append(single) + + gpt_dict_raw_text = "\n".join(gpt_dict_text_list) + self.dict_str = gpt_dict_raw_text + self.logger.info( + f"载入 Galtransl 字典: {dic_name} {normalDic_count}普通词条" + ) + + def load_sakura_dict(self, dic_path: str): + """ + 直接载入标准的Sakura字典。 + """ + + with open(dic_path, encoding="utf8") as f: + dic_lines = f.readlines() + + if len(dic_lines) == 0: + return + dic_path = os.path.abspath(dic_path) + dic_name = os.path.basename(dic_path) + normalDic_count = 0 + + gpt_dict_text_list = [] + for line in dic_lines: + if line.startswith("\n"): + continue + elif line.startswith("\\\\") or line.startswith("//"): # 注释行跳过 + continue + + sp = line.rstrip("\r\n").split("->") # 去多余换行符,->分割 + len_sp = len(sp) + + if len_sp < 2: # 至少是2个元素 + continue + + src = sp[0] + dst_info = sp[1].split("#") # 使用#分割目标和信息 + dst = dst_info[0].strip() + info = dst_info[1].strip() if len(dst_info) > 1 else None + if info: + single = f"{src}->{dst} #{info}" + else: + single = f"{src}->{dst}" + gpt_dict_text_list.append(single) + normalDic_count += 1 + + gpt_dict_raw_text = "\n".join(gpt_dict_text_list) + self.dict_str = gpt_dict_raw_text + self.logger.info( + f"载入标准Sakura字典: {dic_name} {normalDic_count}普通词条" + ) + + def detect_type(self, dic_path: str): + """ + 检测字典类型。 + """ + with open(dic_path, encoding="utf8") as f: + dic_lines = f.readlines() + self.logger.debug(f"检测字典类型: {dic_path}") + if len(dic_lines) == 0: + return "unknown" + + # 判断是否为Galtransl字典 + is_galtransl = True + for line in dic_lines: + if line.startswith("\n"): + continue + elif line.startswith("\\\\") or line.startswith("//"): + continue + + if "\t" not in line and " " not in line: + is_galtransl = False + break + + if is_galtransl: + return "galtransl" + + # 判断是否为Sakura字典 + is_sakura = True + for line in dic_lines: + if line.startswith("\n"): + continue + elif line.startswith("\\\\") or line.startswith("//"): + continue + + if "->" not in line: + is_sakura = False + break + + if is_sakura: + return "sakura" + + return "unknown" + + def get_dict_str(self): + """ + 获取字典内容。 + """ + if self.version == '0.9': + self.logger.info("您当前选择了Sakura 0.9版本,暂不支持术语表") + return "" + if self.dict_str == "": + try: + self.dict_str = self.get_dict_from_file(self.path) + return self.dict_str + except Exception as e: + if self.version == '0.10': + self.logger.warning(f"载入字典失败: {e}") + return "" + return self.dict_str + + def get_dict_from_file(self, dic_path: str): + """ + 从文件载入字典。 + """ + dic_type = self.detect_type(dic_path) + if dic_type == "galtransl": + self.load_galtransl_dic(dic_path) + elif dic_type == "sakura": + self.load_sakura_dict(dic_path) + else: + self.logger.warning(f"未知的字典类型: {dic_path}") + return self.get_dict_str() + + +class SakuraTranslator(CommonTranslator): + + _TIMEOUT = 999 # 等待服务器响应的超时时间(秒) + _RETRY_ATTEMPTS = 3 # 请求出错时的重试次数 + _TIMEOUT_RETRY_ATTEMPTS = 3 # 请求超时时的重试次数 + _RATELIMIT_RETRY_ATTEMPTS = 3 # 请求被限速时的重试次数 + _REPEAT_DETECT_THRESHOLD = 20 # 重复检测的阈值 - _CHAT_SYSTEM_TEMPLATE = ( + _CHAT_SYSTEM_TEMPLATE_009 = ( '你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。' ) + _CHAT_SYSTEM_TEMPLATE_010 = ( + '你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,注意不要擅自添加原文中没有的代词,也不要擅自增加或减少换行。' + ) _LANGUAGE_CODE_MAP = { 'CHS': 'Simplified Chinese', @@ -32,23 +223,32 @@ class Sakura13BTranslator(CommonTranslator): def __init__(self): super().__init__() - #检测/v1是否存在 + if "/v1" not in SAKURA_API_BASE: + openai.api_base = SAKURA_API_BASE + "/v1" + else: + openai.api_base = SAKURA_API_BASE + openai.api_key = "sk-114514" self.temperature = 0.3 self.top_p = 0.3 - self.frequency_penalty = 0.0 - self._current_style = "normal" + self.frequency_penalty = 0.1 + self._current_style = "precise" + self._emoji_pattern = re.compile(r'[\U00010000-\U0010ffff]') + self._heart_pattern = re.compile(r'❤') + self.sakura_dict = SakuraDict(self.get_dict_path(), self.logger, SAKURA_VERSION) - def detect_and_remove_extra_repeats(self, s: str, threshold: int = 10, remove_all=True): - """ - 检测字符串中是否有任何模式连续重复出现超过阈值,并在去除多余重复后返回新字符串。 - 保留一个模式的重复。 + def get_sakura_version(self): + return SAKURA_VERSION - :param s: str - 待检测的字符串。 - :param threshold: int - 连续重复模式出现的最小次数,默认为2。 - :return: tuple - (bool, str),第一个元素表示是否有重复,第二个元素是处理后的字符串。 - """ + def get_dict_path(self): + return SAKURA_DICT_PATH + def detect_and_caculate_repeats(self, s: str, threshold: int = _REPEAT_DETECT_THRESHOLD, remove_all=True) -> Tuple[bool, str, int, str]: + """ + 检测文本中是否存在重复模式,并计算重复次数。 + 返回值: (是否重复, 去除重复后的文本, 重复次数, 重复模式) + """ repeated = False + counts = [] for pattern_length in range(1, len(s) // 2 + 1): i = 0 while i < len(s) - pattern_length: @@ -61,56 +261,190 @@ def detect_and_remove_extra_repeats(self, s: str, threshold: int = 10, remove_al j += pattern_length else: break + counts.append(count) if count >= threshold: + self.logger.warning(f"检测到重复模式: {pattern},重复次数: {count}") repeated = True - # 保留一个模式的重复 if remove_all: s = s[:i + pattern_length] + s[j:] break i += 1 if repeated: break - return repeated, s + + # 计算重复次数的众数 + if counts: + mode_count = max(set(counts), key=counts.count) + else: + mode_count = 0 + + # 根据默认阈值和众数计算实际阈值 + actual_threshold = max(threshold, mode_count) + + return repeated, s, count, pattern, actual_threshold + + @staticmethod + def enlarge_small_kana(text, ignore=''): + """将小写平假名或片假名转换为普通大小 + + 参数 + ---------- + text : str + 全角平假名或片假名字符串。 + ignore : str, 可选 + 转换时要忽略的字符。 + + 返回 + ------ + str + 平假名或片假名字符串,小写假名已转换为大写 + + 示例 + -------- + >>> print(enlarge_small_kana('さくらきょうこ')) + さくらきようこ + >>> print(enlarge_small_kana('キュゥべえ')) + キユウべえ + """ + SMALL_KANA = list('ぁぃぅぇぉゃゅょっァィゥェォヵヶャュョッ') + SMALL_KANA_NORMALIZED = list('あいうえおやゆよつアイウエオカケヤユヨツ') + SMALL_KANA2BIG_KANA = dict(zip(map(ord, SMALL_KANA), SMALL_KANA_NORMALIZED)) + + def _exclude_ignorechar(ignore, conv_map): + for character in map(ord, ignore): + del conv_map[character] + return conv_map + + def _convert(text, conv_map): + return text.translate(conv_map) + + def _translate(text, ignore, conv_map): + if ignore: + _conv_map = _exclude_ignorechar(ignore, conv_map.copy()) + return _convert(text, _conv_map) + return _convert(text, conv_map) + + return _translate(text, ignore, SMALL_KANA2BIG_KANA) def _format_prompt_log(self, prompt: str) -> str: - return '\n'.join([ + """ + 格式化日志输出的提示文本。 + """ + gpt_dict_raw_text = self.sakura_dict.get_dict_str() + prompt_009 = '\n'.join([ 'System:', - self._CHAT_SYSTEM_TEMPLATE, + self._CHAT_SYSTEM_TEMPLATE_009, 'User:', '将下面的日文文本翻译成中文:', prompt, ]) + prompt_010 = '\n'.join([ + 'System:', + self._CHAT_SYSTEM_TEMPLATE_010, + 'User:', + "根据以下术语表:", + gpt_dict_raw_text, + "将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:", + prompt, + ]) + return prompt_009 if SAKURA_VERSION == '0.9' else prompt_010 - # str 通过/n转换为list - def _split_text(self, text: str) -> list: + def _split_text(self, text: str) -> List[str]: + """ + 将字符串按换行符分割为列表。 + """ if isinstance(text, list): return text return text.split('\n') - def check_align(self, queries: List[str], response: str) -> bool: + def _preprocess_queries(self, queries: List[str]) -> List[str]: """ - 检查原始文本(queries)与翻译后的文本(response)是否保持相同的行数。 + 预处理查询文本,去除emoji,替换特殊字符,并添加「」标记。 + """ + queries = [self.enlarge_small_kana(query) for query in queries] + queries = [self._emoji_pattern.sub('', query) for query in queries] + queries = [self._heart_pattern.sub('♥', query) for query in queries] + queries = [f'「{query}」' for query in queries] + self.logger.debug(f'预处理后的查询文本:{queries}') + return queries - :param queries: 原始文本的列表。 - :param response: 翻译后的文本,可能是一个字符串。 - :return: 两者行数是否相同的布尔值。 + async def _check_translation_quality(self, queries: List[str], response: str) -> List[str]: + """ + 检查翻译结果的质量,包括重复和行数对齐问题,如果存在问题则尝试重新翻译或返回原始文本。 """ - # 确保response是列表形式 - translated_texts = self._split_text(response) if isinstance(response, str) else response + async def _retry_translation(queries: List[str], check_func: Callable[[str], bool], error_message: str) -> str: + styles = ["precise", "normal", "aggressive", ] + for i in range(self._RETRY_ATTEMPTS): + self._set_gpt_style(styles[i]) + self.logger.warning(f'{error_message} 尝试次数: {i + 1}。当前参数风格:{self._current_style}。') + response = await self._handle_translation_request(queries) + if not check_func(response): + return response + return None - # 日志记录,而不是直接打印 - print(f"原始文本行数: {len(queries)}, 翻译文本行数: {len(translated_texts)}") - logger.warning(f"原始文本行数: {len(queries)}, 翻译文本行数: {len(translated_texts)}") + # 检查请求内容是否含有超过默认阈值的重复内容 + if self._detect_repeats(''.join(queries), self._REPEAT_DETECT_THRESHOLD): + self.logger.warning(f'请求内容本身含有超过默认阈值{self._REPEAT_DETECT_THRESHOLD}的重复内容。') - # 检查行数是否匹配 - is_aligned = len(queries) == len(translated_texts) - if not is_aligned: - logger.warning("原始文本与翻译文本的行数不匹配。") + # 根据译文众数和默认阈值计算实际阈值 + actual_threshold = max(max(self._get_repeat_count(query) for query in queries), self._REPEAT_DETECT_THRESHOLD) + + if self._detect_repeats(response, actual_threshold): + response = await _retry_translation(queries, lambda r: self._detect_repeats(r, actual_threshold), f'检测到大量重复内容(当前阈值:{actual_threshold}),疑似模型退化,重新翻译。') + if response is None: + self.logger.warning(f'疑似模型退化,尝试{self._RETRY_ATTEMPTS}次仍未解决,进行单行翻译。') + return await self._translate_single_lines(queries) + + if not self._check_align(queries, response): + response = await _retry_translation(queries, lambda r: not self._check_align(queries, r), '因为检测到原文与译文行数不匹配,重新翻译。') + if response is None: + self.logger.warning(f'原文与译文行数不匹配,尝试{self._RETRY_ATTEMPTS}次仍未解决,进行单行翻译。') + return await self._translate_single_lines(queries) + + return self._split_text(response) + + def _detect_repeats(self, text: str, threshold: int = _REPEAT_DETECT_THRESHOLD) -> bool: + """ + 检测文本中是否存在重复模式。 + """ + is_repeated, text, count, pattern, actual_threshold = self.detect_and_caculate_repeats(text, threshold, remove_all=False) + return is_repeated + + def _get_repeat_count(self, text: str, threshold: int = _REPEAT_DETECT_THRESHOLD) -> bool: + """ + 计算文本中重复模式的次数。 + """ + is_repeated, text, count, pattern, actual_threshold = self.detect_and_caculate_repeats(text, threshold, remove_all=False) + return count + def _check_align(self, queries: List[str], response: str) -> bool: + """ + 检查原始文本和翻译结果的行数是否对齐。 + """ + translations = self._split_text(response) + is_aligned = len(queries) == len(translations) + if not is_aligned: + self.logger.warning(f"行数不匹配 - 原文行数: {len(queries)},译文行数: {len(translations)}") return is_aligned + async def _translate_single_lines(self, queries: List[str]) -> List[str]: + """ + 逐行翻译查询文本。 + """ + translations = [] + for query in queries: + response = await self._handle_translation_request(query) + if self._detect_repeats(response): + self.logger.warning(f"单行翻译结果存在重复内容: {response},返回原文。") + translations.append(query) + else: + translations.append(response) + return translations + def _delete_quotation_mark(self, texts: List[str]) -> List[str]: - print(texts) + """ + 删除文本中的「」标记。 + """ new_texts = [] for text in texts: text = text.strip('「」') @@ -118,179 +452,129 @@ def _delete_quotation_mark(self, texts: List[str]) -> List[str]: return new_texts async def _translate(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]: - translations = [] self.logger.debug(f'Temperature: {self.temperature}, TopP: {self.top_p}') - self.logger.debug(f'Queries: {queries}') + self.logger.debug(f'原文: {queries}') text_prompt = '\n'.join(queries) self.logger.debug('-- Sakura Prompt --\n' + self._format_prompt_log(text_prompt) + '\n\n') - # 去除emoji - queries = [re.sub(r'[\U00010000-\U0010ffff]', '', query) for query in queries] - # 替换❤ - queries = [re.sub(r'❤', '♥', query) for query in queries] - # 给queries的每行加上「」 - queries = [f'「{query}」' for query in queries] + + # 预处理查询文本 + queries = self._preprocess_queries(queries) + + # 发送翻译请求 response = await self._handle_translation_request(queries) self.logger.debug('-- Sakura Response --\n' + response + '\n\n') - # 提取翻译结果并去除首尾空白 - response = response.strip() + # 检查翻译结果是否存在重复或行数不匹配的问题 + translations = await self._check_translation_quality(queries, response) - rep_flag = self.detect_and_remove_extra_repeats(response)[0] - if rep_flag: - for i in range(self._RETRY_ATTEMPTS): - if self.detect_and_remove_extra_repeats(queries)[0]: - self.logger.warning('Queries have repeats.') - break - self.logger.warning(f'Re-translated because of model degradation, {i} times.') - self._set_gpt_style("precise") - self.logger.debug(f'Temperature: {self.temperature}, TopP: {self.top_p}') - response = await self._handle_translation_request(queries) - rep_flag = self.detect_and_remove_extra_repeats(response)[0] - if not rep_flag: - break - if rep_flag: - self.logger.warning('Model degradation, try to translate single line.') - for query in queries: - response = await self._handle_translation_request(query) - translations.append(response) - rep_flag = self.detect_and_remove_extra_repeats(response)[0] - if rep_flag: - self.logger.warning('Model degradation, fill original text') - return self._delete_quotation_mark(queries) - return self._delete_quotation_mark(translations) - - align_flag = self.check_align(queries, response) - if not align_flag: - for i in range(self._RETRY_ATTEMPTS): - self.logger.warning(f'Re-translated because of a mismatch in the number of lines, {i} times.') - self._set_gpt_style("precise") - self.logger.debug(f'Temperature: {self.temperature}, TopP: {self.top_p}') - response = await self._handle_translation_request(queries) - align_flag = self.check_align(queries, response) - if align_flag: - break - if not align_flag: - self.logger.warning('Mismatch in the number of lines, try to translate single line.') - for query in queries: - print(query) - response = await self._handle_translation_request(query) - translations.append(response) - print(translations) - align_flag = self.check_align(queries, translations) - if not align_flag: - self.logger.warning('Mismatch in the number of lines, fill original text') - return self._delete_quotation_mark(queries) - return self._delete_quotation_mark(translations) - translations = self._split_text(response) - if isinstance(translations, list): - return self._delete_quotation_mark(translations) - translations = self._split_text(response) return self._delete_quotation_mark(translations) async def _handle_translation_request(self, prompt: str) -> str: - # 翻译请求和错误处理逻辑 + """ + 处理翻译请求,包括错误处理和重试逻辑。 + """ ratelimit_attempt = 0 server_error_attempt = 0 timeout_attempt = 0 while True: - request_task = asyncio.create_task(self._request_translation(prompt)) - started = time.time() - while not request_task.done(): - await asyncio.sleep(0.1) - if time.time() - started > self._TIMEOUT + (timeout_attempt * self._TIMEOUT / 2): - if timeout_attempt >= self._TIMEOUT_RETRY_ATTEMPTS: - raise Exception('Sakura timeout.') - timeout_attempt += 1 - self.logger.warn(f'Restarting request due to timeout. Attempt: {timeout_attempt}') - request_task.cancel() - request_task = asyncio.create_task(self._request_translation(prompt)) - started = time.time() try: - response = await request_task + request_task = asyncio.create_task(self._request_translation(prompt)) + response = await asyncio.wait_for(request_task, timeout=self._TIMEOUT) break + except asyncio.TimeoutError: + timeout_attempt += 1 + if timeout_attempt >= self._TIMEOUT_RETRY_ATTEMPTS: + raise Exception('Sakura超时。') + self.logger.warning(f'Sakura因超时而进行重试。尝试次数: {timeout_attempt}') except openai.error.RateLimitError: ratelimit_attempt += 1 if ratelimit_attempt >= self._RATELIMIT_RETRY_ATTEMPTS: raise - self.logger.warn(f'Restarting request due to ratelimiting by sakura servers. Attempt: {ratelimit_attempt}') + self.logger.warning(f'Sakura因被限速而进行重试。尝试次数: {ratelimit_attempt}') await asyncio.sleep(2) - except openai.error.APIError: + except (openai.error.APIError, openai.error.APIConnectionError) as e: server_error_attempt += 1 if server_error_attempt >= self._RETRY_ATTEMPTS: - self.logger.error('Sakura server error. Returning original text.') - return prompt # 返回原始文本而不是抛出异常 - self.logger.warn(f'Restarting request due to a server error. Attempt: {server_error_attempt}') - await asyncio.sleep(1) - except openai.error.APIConnectionError: - server_error_attempt += 1 - self.logger.warn(f'Restarting request due to a server connection error. Attempt: {server_error_attempt}') - await asyncio.sleep(1) - except FileNotFoundError: - self.logger.warn(f'Restarting request due to FileNotFoundError.') - await asyncio.sleep(30) + self.logger.error(f'Sakura API请求失败。错误信息: {e}') + return prompt + self.logger.warning(f'Sakura因服务器错误而进行重试。尝试次数: {server_error_attempt},错误信息: {e}') return response async def _request_translation(self, input_text_list) -> str: + """ + 向Sakura API发送翻译请求。 + """ if isinstance(input_text_list, list): raw_text = "\n".join(input_text_list) else: raw_text = input_text_list + raw_lenth = len(raw_text) + max_lenth = 512 + max_token_num = max(raw_lenth*2, max_lenth) extra_query = { 'do_sample': False, 'num_beams': 1, 'repetition_penalty': 1.0, } - old_api_base = openai.api_base or '' - old_api_key = openai.api_key or '' - - if "/v1" not in SAKURA_API_BASE: - openai.api_base = SAKURA_API_BASE + "/v1" - else: - openai.api_base = SAKURA_API_BASE - openai.api_key = SAKURA_API_KEY - response = await openai.ChatCompletion.acreate( - model="sukinishiro", - messages=[ + if SAKURA_VERSION == "0.9": + messages = [ { "role": "system", - "content": "你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。" + "content": f"{self._CHAT_SYSTEM_TEMPLATE_009}" }, { "role": "user", "content": f"将下面的日文文本翻译成中文:{raw_text}" } - ], + ] + else: + gpt_dict_raw_text = self.sakura_dict.get_dict_str() + self.logger.debug(f"Sakura Dict: {gpt_dict_raw_text}") + messages = [ + { + "role": "system", + "content": f"{self._CHAT_SYSTEM_TEMPLATE_010}" + }, + { + "role": "user", + "content": f"根据以下术语表:\n{gpt_dict_raw_text}\n将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:{raw_text}" + } + ] + response = await openai.ChatCompletion.acreate( + model="sukinishiro", + messages=messages, temperature=self.temperature, top_p=self.top_p, - max_tokens=1024, + max_tokens=max_token_num, frequency_penalty=self.frequency_penalty, seed=-1, extra_query=extra_query, ) - openai.api_base = old_api_base - openai.api_key = old_api_key - # 提取并返回响应文本 for choice in response.choices: if 'text' in choice: return choice.text - # 如果没有找到包含文本的响应,返回第一个响应的内容(可能为空) - return response.choices[0].message.content def _set_gpt_style(self, style_name: str): + """ + 设置GPT的生成风格。 + """ if self._current_style == style_name: return self._current_style = style_name if style_name == "precise": temperature, top_p = 0.1, 0.3 - frequency_penalty = 0.0 + frequency_penalty = 0.05 elif style_name == "normal": temperature, top_p = 0.3, 0.3 - frequency_penalty = 0.15 + frequency_penalty = 0.2 + elif style_name == "aggressive": + temperature, top_p = 0.3, 0.3 + frequency_penalty = 0.3 self.temperature = temperature self.top_p = top_p - self.frequency_penalty = frequency_penalty \ No newline at end of file + self.frequency_penalty = frequency_penalty diff --git a/sakura_dict.txt b/sakura_dict.txt new file mode 100644 index 00000000..26456abe --- /dev/null +++ b/sakura_dict.txt @@ -0,0 +1,13 @@ +// 示例字典,可自行添加或修改 + +安芸倫也->安艺伦也 #名字,男性,学生 +倫也->伦也 #名字,男性,学生 +安芸->安艺 #姓氏 + +加藤恵->加藤惠 #名字,女性,学生,安芸倫也的同班同学 +恵->惠 #名字,女性,学生,安芸倫也的同班同学 +加藤->加藤 #姓氏 + +澤村・スペンサー・英梨々->泽村・斯宾塞・英梨梨 #名字,女性,学生,同人志作者 +英梨々->英梨梨 #名字,女性,学生,同人志作者 +澤村->泽村 #姓氏 \ No newline at end of file