373 lines
13 KiB
Python
373 lines
13 KiB
Python
"""
|
||
PDF翻译服务模块
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import time
|
||
import hashlib
|
||
import threading
|
||
from datetime import datetime, timedelta
|
||
from pypdf import PdfReader
|
||
from openai import OpenAI
|
||
from flask import current_app
|
||
|
||
# ==================== LLM客户端 ====================
|
||
class TranslationService:
|
||
"""翻译服务"""
|
||
|
||
def __init__(self, config):
|
||
self.config = config
|
||
self.llm_config = config['LLM_CONFIG']
|
||
self.client = OpenAI(
|
||
api_key=self.llm_config['api_key'],
|
||
base_url=self.llm_config['api_base'],
|
||
)
|
||
|
||
def translate_text(self, text, instruction=None):
|
||
"""
|
||
翻译文本
|
||
|
||
Args:
|
||
text: 待翻译文本
|
||
instruction: 用户自定义翻译要求
|
||
|
||
Returns:
|
||
翻译后的文本
|
||
"""
|
||
system_prompt = """你是一个专业的英译中翻译专家。请遵循以下规则:
|
||
1. 保持原文的格式和段落结构
|
||
2. 专业术语保持准确性,必要时保留英文原文
|
||
3. 语言流畅自然,符合中文表达习惯
|
||
4. 不要添加任何解释或注释,只输出翻译结果"""
|
||
|
||
user_prompt = f"""请将以下英文翻译成中文。直接输出中文翻译,不要解释。
|
||
|
||
英文内容:
|
||
{text}"""
|
||
|
||
if instruction:
|
||
user_prompt = f"""请将以下英文翻译成中文。
|
||
用户翻译要求:{instruction}
|
||
|
||
英文内容:
|
||
{text}"""
|
||
|
||
try:
|
||
response = self.client.chat.completions.create(
|
||
model=self.llm_config['model'],
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt}
|
||
],
|
||
max_tokens=self.llm_config['max_tokens'],
|
||
temperature=0.3,
|
||
timeout=self.llm_config['timeout'],
|
||
)
|
||
|
||
content = response.choices[0].message.content
|
||
if content and content.strip():
|
||
return content.strip()
|
||
return text
|
||
|
||
except Exception as e:
|
||
print(f"翻译错误: {e}")
|
||
return text
|
||
|
||
def extract_pdf_text(self, pdf_path):
|
||
"""提取PDF文本"""
|
||
reader = PdfReader(pdf_path)
|
||
pages = []
|
||
|
||
for i, page in enumerate(reader.pages):
|
||
text = page.extract_text()
|
||
if text.strip():
|
||
# 清理文本
|
||
text = self._clean_text(text)
|
||
pages.append({
|
||
'page': i + 1,
|
||
'text': text
|
||
})
|
||
|
||
return pages
|
||
|
||
def _clean_text(self, text):
|
||
"""清理文本"""
|
||
import re
|
||
text = re.sub(r'\n{3,}', '\n\n', text)
|
||
text = re.sub(r' {2,}', ' ', text)
|
||
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f]', '', text)
|
||
return text.strip()
|
||
|
||
def chunk_text(self, text, max_size=2000):
|
||
"""分块"""
|
||
paragraphs = text.split('\n\n')
|
||
chunks = []
|
||
current = ""
|
||
|
||
for para in paragraphs:
|
||
if len(current) + len(para) < max_size:
|
||
current += para + '\n\n'
|
||
else:
|
||
if current:
|
||
chunks.append(current.strip())
|
||
current = para + '\n\n'
|
||
|
||
if current:
|
||
chunks.append(current.strip())
|
||
|
||
return chunks
|
||
|
||
def translate_pdf(self, pdf_path, output_path, instruction=None, progress_callback=None):
|
||
"""
|
||
翻译PDF
|
||
|
||
Args:
|
||
pdf_path: 输入PDF路径
|
||
output_path: 输出路径
|
||
instruction: 用户翻译要求
|
||
progress_callback: 进度回调函数
|
||
|
||
Returns:
|
||
翻译统计信息
|
||
"""
|
||
pages = self.extract_pdf_text(pdf_path)
|
||
total_pages = len(pages)
|
||
|
||
# 检查是否有可翻译内容
|
||
total_text = sum(len(p['text']) for p in pages)
|
||
if total_pages == 0 or total_text < 10:
|
||
error_msg = "PDF无法提取文本内容。可能原因:\n1. PDF是扫描版(图像形式),需要OCR处理\n2. PDF为空或加密\n请使用包含可提取文本的PDF文件。"
|
||
if progress_callback:
|
||
progress_callback(0, 0, error_msg)
|
||
raise ValueError(error_msg)
|
||
|
||
if progress_callback:
|
||
progress_callback(0, total_pages, "开始翻译...")
|
||
|
||
translated_pages = []
|
||
total_chunks = 0
|
||
|
||
for page_data in pages:
|
||
chunks = self.chunk_text(page_data['text'], self.llm_config['chunk_size'])
|
||
total_chunks += len(chunks)
|
||
|
||
translated_chunks = []
|
||
for i, chunk in enumerate(chunks):
|
||
translated = self.translate_text(chunk, instruction)
|
||
translated_chunks.append(translated)
|
||
|
||
if progress_callback:
|
||
progress = int((i + 1) / len(chunks) * 100 / total_pages)
|
||
progress_callback(progress, total_pages, f"翻译第{page_data['page']}页")
|
||
|
||
translated_pages.append({
|
||
'page': page_data['page'],
|
||
'original': page_data['text'],
|
||
'translated': '\n\n'.join(translated_chunks)
|
||
})
|
||
|
||
# 保存结果
|
||
self._save_output(translated_pages, output_path)
|
||
|
||
if progress_callback:
|
||
progress_callback(100, total_pages, "翻译完成")
|
||
|
||
return {
|
||
'total_pages': total_pages,
|
||
'total_chunks': total_chunks,
|
||
'output_path': output_path
|
||
}
|
||
|
||
def _save_output(self, pages, output_path):
|
||
"""保存翻译结果"""
|
||
content = "# 英文PDF中文翻译\n\n> 自动翻译生成\n\n---\n\n"
|
||
for page in pages:
|
||
content += f"## 第 {page['page']} 页\n\n"
|
||
content += page['translated'] + "\n\n---\n\n"
|
||
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
f.write(content)
|
||
|
||
def save_comparison(self, pages, output_path):
|
||
"""保存对比文件(原文+译文)"""
|
||
content = "# 英文PDF翻译对比\n\n---\n\n"
|
||
for page in pages:
|
||
content += f"## 第 {page['page']} 页\n\n"
|
||
content += "### 原文\n\n```\n" + page['original'] + "\n```\n\n"
|
||
content += "### 译文\n\n" + page['translated'] + "\n\n---\n\n"
|
||
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
f.write(content)
|
||
|
||
|
||
# ==================== 缓存服务 ====================
|
||
class CacheService:
|
||
"""翻译缓存服务"""
|
||
|
||
def __init__(self, cache_dir, expire_days=30):
|
||
self.cache_dir = cache_dir
|
||
self.expire_days = expire_days
|
||
|
||
if not os.path.exists(cache_dir):
|
||
os.makedirs(cache_dir)
|
||
|
||
def compute_hash(self, file_content):
|
||
"""计算文件哈希"""
|
||
return hashlib.md5(file_content).hexdigest()
|
||
|
||
def get_cache(self, file_hash, db_model=None):
|
||
"""
|
||
获取缓存
|
||
|
||
Returns:
|
||
缓存路径或None
|
||
"""
|
||
cache_file = os.path.join(self.cache_dir, f"{file_hash}.md")
|
||
|
||
if os.path.exists(cache_file):
|
||
# 检查过期
|
||
file_time = datetime.fromtimestamp(os.path.getmtime(cache_file))
|
||
if datetime.now() - file_time > timedelta(days=self.expire_days):
|
||
os.remove(cache_file)
|
||
return None
|
||
|
||
# 更新命中计数
|
||
if db_model:
|
||
cache_record = db_model.query.filter_by(file_hash=file_hash).first()
|
||
if cache_record:
|
||
cache_record.increment_hit()
|
||
|
||
return cache_file
|
||
|
||
return None
|
||
|
||
def save_cache(self, file_hash, content):
|
||
"""保存缓存"""
|
||
cache_file = os.path.join(self.cache_dir, f"{file_hash}.md")
|
||
with open(cache_file, 'w', encoding='utf-8') as f:
|
||
f.write(content)
|
||
return cache_file
|
||
|
||
def check_cache_exists(self, file_hash):
|
||
"""检查缓存是否存在"""
|
||
cache_file = os.path.join(self.cache_dir, f"{file_hash}.md")
|
||
return os.path.exists(cache_file)
|
||
|
||
|
||
# ==================== 异步翻译任务 ====================
|
||
class TranslationTask:
|
||
"""异步翻译任务"""
|
||
|
||
tasks = {} # 任务存储
|
||
lock = threading.Lock()
|
||
|
||
@classmethod
|
||
def create_task(cls, task_id, pdf_path, output_path, config, instruction=None, translation_id=None, app=None):
|
||
"""创建翻译任务"""
|
||
task = {
|
||
'id': task_id,
|
||
'status': 'pending',
|
||
'progress': 0,
|
||
'message': '等待开始',
|
||
'output_path': output_path,
|
||
'error': None,
|
||
'started_at': None,
|
||
'completed_at': None,
|
||
'translation_id': translation_id,
|
||
}
|
||
|
||
with cls.lock:
|
||
cls.tasks[task_id] = task
|
||
|
||
# 启动翻译线程
|
||
def run_translation():
|
||
# 动态获取LLM配置
|
||
if app:
|
||
with app.app_context():
|
||
from admin import get_llm_config
|
||
llm_config = get_llm_config()
|
||
config = {'LLM_CONFIG': llm_config}
|
||
|
||
service = TranslationService(config)
|
||
task['status'] = 'processing'
|
||
task['started_at'] = datetime.now().isoformat()
|
||
|
||
print(f"[翻译任务] 开始翻译,使用配置: {config.get('LLM_CONFIG', {}).get('api_base', '未知')}")
|
||
|
||
# 更新数据库状态为 processing
|
||
if app and translation_id:
|
||
with app.app_context():
|
||
from models import db, Translation
|
||
trans = Translation.query.get(translation_id)
|
||
if trans:
|
||
trans.status = 'processing'
|
||
db.session.commit()
|
||
|
||
def progress_callback(progress, total, message):
|
||
with cls.lock:
|
||
task['progress'] = progress
|
||
task['message'] = message
|
||
|
||
# 更新数据库进度
|
||
if app and translation_id:
|
||
with app.app_context():
|
||
from models import db, Translation
|
||
trans = Translation.query.get(translation_id)
|
||
if trans:
|
||
trans.progress = progress
|
||
db.session.commit()
|
||
|
||
try:
|
||
result = service.translate_pdf(
|
||
pdf_path, output_path, instruction, progress_callback
|
||
)
|
||
task['status'] = 'completed'
|
||
task['progress'] = 100
|
||
task['message'] = '翻译完成'
|
||
task['completed_at'] = datetime.now().isoformat()
|
||
task['result'] = result
|
||
|
||
# 更新数据库状态为 completed
|
||
if app and translation_id:
|
||
with app.app_context():
|
||
from models import db, Translation
|
||
trans = Translation.query.get(translation_id)
|
||
if trans:
|
||
trans.status = 'completed'
|
||
trans.progress = 100
|
||
trans.completed_at = datetime.now()
|
||
db.session.commit()
|
||
|
||
except Exception as e:
|
||
task['status'] = 'failed'
|
||
task['error'] = str(e)
|
||
task['message'] = f'翻译失败: {e}'
|
||
|
||
# 更新数据库状态为 failed
|
||
if app and translation_id:
|
||
with app.app_context():
|
||
from models import db, Translation
|
||
trans = Translation.query.get(translation_id)
|
||
if trans:
|
||
trans.status = 'failed'
|
||
trans.error_message = str(e)
|
||
db.session.commit()
|
||
|
||
thread = threading.Thread(target=run_translation)
|
||
thread.start()
|
||
|
||
return task_id
|
||
|
||
@classmethod
|
||
def get_task(cls, task_id):
|
||
"""获取任务状态"""
|
||
with cls.lock:
|
||
return cls.tasks.get(task_id)
|
||
|
||
@classmethod
|
||
def update_task(cls, task_id, **kwargs):
|
||
"""更新任务"""
|
||
with cls.lock:
|
||
if task_id in cls.tasks:
|
||
cls.tasks[task_id].update(kwargs) |