diff --git a/mathtranslate/chatgpt.py b/mathtranslate/chatgpt.py new file mode 100644 index 0000000..ee83cdd --- /dev/null +++ b/mathtranslate/chatgpt.py @@ -0,0 +1,88 @@ +from .config import config +import json +import requests +from urllib import parse +import openai +import sys +import time +import re + +class GPTTranslator: + def __init__(self): + self.baseURL=parse.urlparse(config.openai_api_endpoint)._replace(path='/v1') #make sure using ${host}/v1/chat api + self.model = config.openai_model_name + self.key = config.openai_api_key + self.client=openai.OpenAI(api_key=self.key,base_url=self.baseURL.geturl()) + + + def format_prompt(self, text, language_to, language_from): + PROMPT_PROTOTYPE = 'As an academic expert with specialized knowledge in various fields, please provide a proficient and precise translation translation from {} to {} of the academic text between 🔤 and 🔠. It is crucial to maintaining the original phrase or sentence and ensure accuracy while utilizing the appropriate language. Please provide only the translated result without any additional explanation or punctuation. Please remove \"🔤\" and \"🔠\". Do not modify or delete any word contains \"#XMATHX_\" such as #XMATHX_0, #XMATHX_1, #XMATHX_3_4. The text is as follows: 🔤{}🔠' + #prompt prototype changed from https://github.com/windingwind/zotero-pdf-translate + SYSTEM_PROMPT_PROTOTYPE = 'You are an academic translator with specialized knowledge in various fields, please provide a proficient and precise translation translation from {} to {} of the academic text enclosed in 🔤 and 🔠.Please provide only the translated result without any additional explanation or punctuation. Do not modify or delete any word contains "#XMATHX_" such as #XMATHX_0, #XMATHX_1, #XMATHX_3_4. ' + return {'system':SYSTEM_PROMPT_PROTOTYPE.format(language_from,language_to),'user':PROMPT_PROTOTYPE.format(language_from,language_to,text)} + + def get_server_errormsg(self,error): + try : + return error.response.json()['error']['message'] + except Exception : + return error.message + + def find_all_mathmask(self,text): + mask_pattern=re.compile(r'/XMATHX(_[0-9])+') + masks = set([i.group() for i in re.finditer(pattern=mask_pattern,string=text)]) + return masks + + def is_gpt_output_valid(self,masks,text_translated): + masks_translated = self.find_all_mathmask(text_translated) + return (masks_translated==masks) and ('🔤' not in text_translated) and ('🔠' not in text_translated) + + def is_text_all_mask(self,masks,text): + for mask in masks: + text = text.replace(mask,'') + return text.isspace() + + + def call_openai_api(self,prompt): + messages= [{ + "role":"system", + "content": prompt['system'] + }, + { + "role": "user", + "content": prompt['user'] + }] + try: + return self.client.chat.completions.create(model=self.model,temperature=1,messages=messages) + except openai.RateLimitError as e: + print('API rate limit exceeded, retry after 15s') + time.sleep(15) + self.call_openai_api(prompt) + except openai.InternalServerError as e: + print('Api server failed({}). retry after 30s.'.format(self.get_server_errormsg(e))) + time.sleep(30) + self.call_openai_api(prompt) + except (openai.PermissionDeniedError,openai.AuthenticationError) as e: + print('OpenAI api Authentication failed ({}). please check your api setting by:\n translate_tex --setgpt'.format(self.get_server_errormsg(e))) + sys.exit(-1) + except openai.APIError as e: + print('Api requests failed with error:{}. please check your service status'.format(e.message)) + raise e + + + + + def translate(self, text, language_to, language_from): + masks = self.find_all_mathmask(text) + if self.is_text_all_mask(masks,text): + return text + text = text.lstrip('\n') + while True: + result = self.call_openai_api(self.format_prompt(text, language_to, language_from)) + content_translated = result.choices[0].message.content + if self.is_gpt_output_valid(masks,content_translated): + if content_translated.startswith('"') and content_translated.endswith('"'): + content_translated=content_translated.lstrip('"').rstrip('"') + #remove unexpect " " added by gpt + return content_translated + + diff --git a/mathtranslate/config.py b/mathtranslate/config.py index 99f5226..65812a5 100644 --- a/mathtranslate/config.py +++ b/mathtranslate/config.py @@ -13,6 +13,9 @@ class Config: default_threads_path = 'DEFAULT_THREADS' tencent_secret_id_path = 'TENCENT_ID' tencent_secret_key_path = 'TENCENT_KEY' + openai_model_name_path = 'OPENAI_MODEL' + openai_api_endpoint_path = 'OPENAI_URL' + openai_api_key_path = 'OPENAI_KEY' default_engine_default = 'google' default_language_from_default = 'en' @@ -22,8 +25,11 @@ class Config: default_threads_default = 0 tencent_secret_id_default = None tencent_secret_key_default = None + openai_model_name_default = 'gpt-3.5-turbo' + openai_api_endpoint_default = 'https://api.openai.com' + openai_api_key_default = None - math_code = 'XMATHX' + math_code = '/XMATHX' #better for gpt to understand log_file = f'{app_dir}/translate_log' raw_mularg_command_list = [('textcolor', 2, (1, ))] mularg_command_list = [('textcolor', 2, (1, ))] @@ -62,6 +68,9 @@ def load(self): self.default_loading_dir = self.read_variable(self.default_loading_dir_path, self.default_loading_dir_default) self.default_saving_dir = self.read_variable(self.default_saving_dir_path, self.default_saving_dir_default) self.default_threads = int(self.read_variable(self.default_threads_path, self.default_threads_default)) + self.openai_model_name = self.read_variable(self.openai_model_name_path,self.openai_model_name_default) + self.openai_api_endpoint = self.read_variable(self.openai_api_endpoint_path,self.openai_api_endpoint_default) + self.openai_api_key = self.read_variable(self.openai_api_key_path,self.openai_api_key_default) if not os.path.exists(self.default_loading_dir): self.default_loading_dir = self.default_loading_dir_default if not os.path.exists(self.default_saving_dir): diff --git a/mathtranslate/translate.py b/mathtranslate/translate.py index 7f860a2..bbbe19f 100644 --- a/mathtranslate/translate.py +++ b/mathtranslate/translate.py @@ -30,8 +30,11 @@ def __init__(self, engine, language_to, language_from): elif engine == 'tencent': from mathtranslate.tencent import Translator translator = Translator() + elif engine== 'gpt': + from mathtranslate.chatgpt import GPTTranslator as Translator + translator = Translator() else: - assert False, "engine must be google or tencent" + assert False, "engine must be [google,tencent,gpt]" self.translator = translator self.language_to = language_to self.language_from = language_from diff --git a/mathtranslate/utils.py b/mathtranslate/utils.py index 5504847..8d9586d 100644 --- a/mathtranslate/utils.py +++ b/mathtranslate/utils.py @@ -85,7 +85,7 @@ def check_update(require_updated=True): def add_arguments(parser): - parser.add_argument("-engine", default=config.default_engine, help=f'translation engine, avaiable options include google and tencent. default is {config.default_engine}') + parser.add_argument("-engine", default=config.default_engine, help=f'translation engine, avaiable options include [google, tencent, gpt]. default is {config.default_engine}') parser.add_argument("-from", default=config.default_language_from, dest='l_from', help=f'language from, default is {config.default_language_from}') parser.add_argument("-to", default=config.default_language_to, dest='l_to', help=f'language to, default is {config.default_language_to}') parser.add_argument("-threads", default=config.default_threads, type=int, help='threads for tencent translation, default is auto') @@ -96,6 +96,7 @@ def add_arguments(parser): parser.add_argument("--setdefault", action='store_true', help='set default translation engine and languages') parser.add_argument("--debug", action='store_true', help='Debug options for developers') parser.add_argument("--nocache", action='store_true', help='Debug options for developers') + parser.add_argument("--setgpt",action='store_true',help='set baseUrl,apiKey and model name of your GPT service') def process_options(options): @@ -110,8 +111,22 @@ def process_options(options): print('secretKey:', config.tencent_secret_key) sys.exit() + if options.setgpt: + print('OpenAI api base URL: (leave empty for default {})'.format(config.openai_api_endpoint_default)) + config.set_variable(config.openai_api_endpoint_path, config.openai_api_endpoint_default) + print('OpenAI api key (something like sk-xxx...):') + config.set_variable(config.openai_api_key_path, config.openai_api_key_default) + print('ChatGPT model name: (leave empty for default {})'.format(config.openai_model_name_default)) + config.set_variable(config.openai_model_name_path,config.openai_model_name_default) + print('saved!') + config.load() + print('Base URL:', config.openai_api_endpoint) + print('Api key:', config.openai_api_key) + print('Model Name:', config.openai_model_name) + sys.exit() + if options.setdefault: - print('Translation engine (google or tencent, default google)') + print('Translation engine [google,tencent,gpt], default google)') config.set_variable(config.default_engine_path, config.default_engine_default) print('Translation language from (default en)') config.set_variable(config.default_language_from_path, config.default_language_from_default) @@ -148,6 +163,12 @@ def process_options(options): elif options.threads > 1: options.threads = 1 print('tencent engine does not support multi-threading, set to 1') + elif options.engine == 'gpt': + hasgptkey = (config.openai_api_key is not None) + if not hasgptkey: + print('Please setup api info for openAI api first by') + print('translate_tex --setgpt') + sys.exit() if options.threads < 0: print('threads must be a non-zero integer number (>=0 where 0 means auto), set to auto') diff --git a/setup.py b/setup.py index a77d274..adcb741 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,8 @@ "requests", "regex", "tqdm", - "appdata" + "appdata", + "openai" ], classifiers=[ "Programming Language :: Python :: 3",