- 添加 pdf_to_images 将PDF页面转为图像 - 添加 extract_text_from_image 使用视觉模型OCR识别图像文字 - 检测扫描版PDF自动切换OCR模式 - glm-4.6v 等视觉模型可识别图像中的文字 - 进度提示显示OCR识别过程
513 lines
18 KiB
Python
513 lines
18 KiB
Python
"""
|
||
PDF翻译服务模块
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import time
|
||
import hashlib
|
||
import threading
|
||
import base64
|
||
import io
|
||
from datetime import datetime, timedelta
|
||
from pypdf import PdfReader
|
||
from openai import OpenAI
|
||
from flask import current_app
|
||
from PIL import Image
|
||
|
||
# pdf2image 用于将PDF转为图像
|
||
try:
|
||
from pdf2image import convert_from_path
|
||
PDF_TO_IMAGE_AVAILABLE = True
|
||
except ImportError:
|
||
PDF_TO_IMAGE_AVAILABLE = False
|
||
|
||
# ==================== LLM客户端 ====================
|
||
class TranslationService:
|
||
"""翻译服务"""
|
||
|
||
def __init__(self, config):
|
||
self.config = config
|
||
self.llm_config = config['LLM_CONFIG']
|
||
self.client = OpenAI(
|
||
api_key=self.llm_config['api_key'],
|
||
base_url=self.llm_config['api_base'],
|
||
)
|
||
|
||
def translate_text(self, text, instruction=None):
|
||
"""
|
||
翻译文本
|
||
|
||
Args:
|
||
text: 待翻译文本
|
||
instruction: 用户自定义翻译要求
|
||
|
||
Returns:
|
||
翻译后的文本
|
||
"""
|
||
system_prompt = """你是一个专业的英译中翻译专家。请遵循以下规则:
|
||
1. 保持原文的格式和段落结构
|
||
2. 专业术语保持准确性,必要时保留英文原文
|
||
3. 语言流畅自然,符合中文表达习惯
|
||
4. 不要添加任何解释或注释,只输出翻译结果"""
|
||
|
||
user_prompt = f"""请将以下英文翻译成中文。直接输出中文翻译,不要解释。
|
||
|
||
英文内容:
|
||
{text}"""
|
||
|
||
if instruction:
|
||
user_prompt = f"""请将以下英文翻译成中文。
|
||
用户翻译要求:{instruction}
|
||
|
||
英文内容:
|
||
{text}"""
|
||
|
||
try:
|
||
response = self.client.chat.completions.create(
|
||
model=self.llm_config['model'],
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt}
|
||
],
|
||
max_tokens=self.llm_config['max_tokens'],
|
||
temperature=0.3,
|
||
timeout=self.llm_config['timeout'],
|
||
)
|
||
|
||
content = response.choices[0].message.content
|
||
if content and content.strip():
|
||
return content.strip()
|
||
return text
|
||
|
||
except Exception as e:
|
||
print(f"翻译错误: {e}")
|
||
return text
|
||
|
||
def extract_pdf_text(self, pdf_path):
|
||
"""提取PDF文本"""
|
||
reader = PdfReader(pdf_path)
|
||
pages = []
|
||
|
||
for i, page in enumerate(reader.pages):
|
||
text = page.extract_text()
|
||
if text.strip():
|
||
# 清理文本
|
||
text = self._clean_text(text)
|
||
pages.append({
|
||
'page': i + 1,
|
||
'text': text
|
||
})
|
||
|
||
return pages
|
||
|
||
def _clean_text(self, text):
|
||
"""清理文本"""
|
||
import re
|
||
text = re.sub(r'\n{3,}', '\n\n', text)
|
||
text = re.sub(r' {2,}', ' ', text)
|
||
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f]', '', text)
|
||
return text.strip()
|
||
|
||
def is_vision_model(self):
|
||
"""检查是否是视觉模型"""
|
||
model = self.llm_config.get('model', '')
|
||
# 常见视觉模型名称
|
||
vision_models = ['vision', 'vlm', 'glm-4v', 'glm-4.6v', 'gpt-4-vision', 'gpt-4o', 'qwen-vl', 'claude-3']
|
||
return any(v in model.lower() for v in vision_models)
|
||
|
||
def pdf_to_images(self, pdf_path, max_pages=None):
|
||
"""将PDF页面转换为图像"""
|
||
if not PDF_TO_IMAGE_AVAILABLE:
|
||
return None, "pdf2image未安装,无法处理扫描版PDF。请安装: pip install pdf2image"
|
||
|
||
try:
|
||
# 获取PDF页数
|
||
reader = PdfReader(pdf_path)
|
||
total_pages = len(reader.pages)
|
||
|
||
if max_pages:
|
||
pages_to_convert = min(max_pages, total_pages)
|
||
else:
|
||
pages_to_convert = total_pages
|
||
|
||
# 转换PDF为图像
|
||
images = convert_from_path(
|
||
pdf_path,
|
||
first_page=1,
|
||
last_page=pages_to_convert,
|
||
dpi=200, # 适当的DPI
|
||
fmt='jpeg'
|
||
)
|
||
|
||
return images, None
|
||
|
||
except Exception as e:
|
||
return None, f"PDF转图像失败: {str(e)}"
|
||
|
||
def extract_text_from_image(self, image):
|
||
"""使用视觉模型从图像中提取文字"""
|
||
if not self.is_vision_model():
|
||
return None, "当前模型不是视觉模型,无法识别图像文字"
|
||
|
||
try:
|
||
# 将图像转为base64
|
||
buffered = io.BytesIO()
|
||
image.save(buffered, format="JPEG")
|
||
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||
|
||
# 构建多模态请求
|
||
response = self.client.chat.completions.create(
|
||
model=self.llm_config['model'],
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": "请识别并提取这张图片中的所有文字内容。只输出提取的文字,不要添加任何解释或说明。保持原有的段落和格式。"
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/jpeg;base64,{img_base64}"
|
||
}
|
||
}
|
||
]
|
||
}
|
||
],
|
||
max_tokens=self.llm_config['max_tokens'],
|
||
temperature=0.1,
|
||
timeout=self.llm_config['timeout'],
|
||
)
|
||
|
||
content = response.choices[0].message.content
|
||
return content.strip() if content else '', None
|
||
|
||
except Exception as e:
|
||
return '', f"视觉模型识别失败: {str(e)}"
|
||
|
||
def extract_text_from_scanned_pdf(self, pdf_path, progress_callback=None):
|
||
"""从扫描版PDF提取文字(使用视觉模型OCR)"""
|
||
images, error = self.pdf_to_images(pdf_path)
|
||
|
||
if error:
|
||
return [], error
|
||
|
||
pages_text = []
|
||
total = len(images)
|
||
|
||
for i, image in enumerate(images):
|
||
if progress_callback:
|
||
progress_callback(int((i+1)/total*50), total, f"OCR识别第{i+1}页...")
|
||
|
||
text, err = self.extract_text_from_image(image)
|
||
|
||
if err:
|
||
pages_text.append({
|
||
'page': i + 1,
|
||
'text': '',
|
||
'error': err
|
||
})
|
||
else:
|
||
pages_text.append({
|
||
'page': i + 1,
|
||
'text': text or '',
|
||
'error': None
|
||
})
|
||
|
||
return pages_text, None
|
||
|
||
def chunk_text(self, text, max_size=2000):
|
||
"""分块"""
|
||
paragraphs = text.split('\n\n')
|
||
chunks = []
|
||
current = ""
|
||
|
||
for para in paragraphs:
|
||
if len(current) + len(para) < max_size:
|
||
current += para + '\n\n'
|
||
else:
|
||
if current:
|
||
chunks.append(current.strip())
|
||
current = para + '\n\n'
|
||
|
||
if current:
|
||
chunks.append(current.strip())
|
||
|
||
return chunks
|
||
|
||
def translate_pdf(self, pdf_path, output_path, instruction=None, progress_callback=None):
|
||
"""
|
||
翻译PDF
|
||
|
||
Args:
|
||
pdf_path: 输入PDF路径
|
||
output_path: 输出路径
|
||
instruction: 用户翻译要求
|
||
progress_callback: 进度回调函数
|
||
|
||
Returns:
|
||
翻译统计信息
|
||
"""
|
||
# 先尝试常规提取
|
||
pages = self.extract_pdf_text(pdf_path)
|
||
total_pages = len(pages)
|
||
total_text = sum(len(p['text']) for p in pages)
|
||
|
||
# 如果无法提取文本,尝试使用视觉模型OCR
|
||
if total_pages == 0 or total_text < 10:
|
||
if self.is_vision_model() and PDF_TO_IMAGE_AVAILABLE:
|
||
if progress_callback:
|
||
progress_callback(0, 0, "检测到扫描版PDF,使用视觉模型OCR...")
|
||
|
||
pages, error = self.extract_text_from_scanned_pdf(pdf_path, progress_callback)
|
||
|
||
if error:
|
||
raise ValueError(error)
|
||
|
||
total_pages = len(pages)
|
||
total_text = sum(len(p['text']) for p in pages)
|
||
|
||
if total_text < 10:
|
||
raise ValueError("视觉模型OCR未能提取到有效文字内容")
|
||
|
||
if progress_callback:
|
||
progress_callback(50, total_pages, "OCR完成,开始翻译...")
|
||
else:
|
||
error_msg = "PDF无法提取文本内容。可能原因:\n1. PDF是扫描版(图像形式)\n2. 当前大模型不是视觉模型,无法识别图像文字\n\n如需处理扫描版PDF,请配置视觉大模型(如 glm-4.6v、gpt-4-vision)"
|
||
if progress_callback:
|
||
progress_callback(0, 0, error_msg)
|
||
raise ValueError(error_msg)
|
||
|
||
if progress_callback:
|
||
progress_callback(50, total_pages, "开始翻译...")
|
||
|
||
translated_pages = []
|
||
total_chunks = 0
|
||
|
||
for page_data in pages:
|
||
chunks = self.chunk_text(page_data['text'], self.llm_config['chunk_size'])
|
||
total_chunks += len(chunks)
|
||
|
||
translated_chunks = []
|
||
for i, chunk in enumerate(chunks):
|
||
translated = self.translate_text(chunk, instruction)
|
||
translated_chunks.append(translated)
|
||
|
||
if progress_callback:
|
||
# OCR占50%,翻译占50%
|
||
page_progress = (i + 1) / len(chunks)
|
||
overall_progress = 50 + int(page_progress * 50 / total_pages)
|
||
progress_callback(overall_progress, total_pages, f"翻译第{page_data['page']}页")
|
||
|
||
translated_pages.append({
|
||
'page': page_data['page'],
|
||
'original': page_data['text'],
|
||
'translated': '\n\n'.join(translated_chunks)
|
||
})
|
||
|
||
# 保存结果
|
||
self._save_output(translated_pages, output_path)
|
||
|
||
if progress_callback:
|
||
progress_callback(100, total_pages, "翻译完成")
|
||
|
||
return {
|
||
'total_pages': total_pages,
|
||
'total_chunks': total_chunks,
|
||
'output_path': output_path
|
||
}
|
||
|
||
def _save_output(self, pages, output_path):
|
||
"""保存翻译结果"""
|
||
content = "# 英文PDF中文翻译\n\n> 自动翻译生成\n\n---\n\n"
|
||
for page in pages:
|
||
content += f"## 第 {page['page']} 页\n\n"
|
||
content += page['translated'] + "\n\n---\n\n"
|
||
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
f.write(content)
|
||
|
||
def save_comparison(self, pages, output_path):
|
||
"""保存对比文件(原文+译文)"""
|
||
content = "# 英文PDF翻译对比\n\n---\n\n"
|
||
for page in pages:
|
||
content += f"## 第 {page['page']} 页\n\n"
|
||
content += "### 原文\n\n```\n" + page['original'] + "\n```\n\n"
|
||
content += "### 译文\n\n" + page['translated'] + "\n\n---\n\n"
|
||
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
f.write(content)
|
||
|
||
|
||
# ==================== 缓存服务 ====================
|
||
class CacheService:
|
||
"""翻译缓存服务"""
|
||
|
||
def __init__(self, cache_dir, expire_days=30):
|
||
self.cache_dir = cache_dir
|
||
self.expire_days = expire_days
|
||
|
||
if not os.path.exists(cache_dir):
|
||
os.makedirs(cache_dir)
|
||
|
||
def compute_hash(self, file_content):
|
||
"""计算文件哈希"""
|
||
return hashlib.md5(file_content).hexdigest()
|
||
|
||
def get_cache(self, file_hash, db_model=None):
|
||
"""
|
||
获取缓存
|
||
|
||
Returns:
|
||
缓存路径或None
|
||
"""
|
||
cache_file = os.path.join(self.cache_dir, f"{file_hash}.md")
|
||
|
||
if os.path.exists(cache_file):
|
||
# 检查过期
|
||
file_time = datetime.fromtimestamp(os.path.getmtime(cache_file))
|
||
if datetime.now() - file_time > timedelta(days=self.expire_days):
|
||
os.remove(cache_file)
|
||
return None
|
||
|
||
# 更新命中计数
|
||
if db_model:
|
||
cache_record = db_model.query.filter_by(file_hash=file_hash).first()
|
||
if cache_record:
|
||
cache_record.increment_hit()
|
||
|
||
return cache_file
|
||
|
||
return None
|
||
|
||
def save_cache(self, file_hash, content):
|
||
"""保存缓存"""
|
||
cache_file = os.path.join(self.cache_dir, f"{file_hash}.md")
|
||
with open(cache_file, 'w', encoding='utf-8') as f:
|
||
f.write(content)
|
||
return cache_file
|
||
|
||
def check_cache_exists(self, file_hash):
|
||
"""检查缓存是否存在"""
|
||
cache_file = os.path.join(self.cache_dir, f"{file_hash}.md")
|
||
return os.path.exists(cache_file)
|
||
|
||
|
||
# ==================== 异步翻译任务 ====================
|
||
class TranslationTask:
|
||
"""异步翻译任务"""
|
||
|
||
tasks = {} # 任务存储
|
||
lock = threading.Lock()
|
||
|
||
@classmethod
|
||
def create_task(cls, task_id, pdf_path, output_path, config, instruction=None, translation_id=None, app=None):
|
||
"""创建翻译任务"""
|
||
task = {
|
||
'id': task_id,
|
||
'status': 'pending',
|
||
'progress': 0,
|
||
'message': '等待开始',
|
||
'output_path': output_path,
|
||
'error': None,
|
||
'started_at': None,
|
||
'completed_at': None,
|
||
'translation_id': translation_id,
|
||
}
|
||
|
||
with cls.lock:
|
||
cls.tasks[task_id] = task
|
||
|
||
# 启动翻译线程
|
||
def run_translation():
|
||
# 动态获取LLM配置
|
||
if app:
|
||
with app.app_context():
|
||
from admin import get_llm_config
|
||
llm_config = get_llm_config()
|
||
config = {'LLM_CONFIG': llm_config}
|
||
|
||
service = TranslationService(config)
|
||
task['status'] = 'processing'
|
||
task['started_at'] = datetime.now().isoformat()
|
||
|
||
print(f"[翻译任务] 开始翻译,使用配置: {config.get('LLM_CONFIG', {}).get('api_base', '未知')}")
|
||
|
||
# 更新数据库状态为 processing
|
||
if app and translation_id:
|
||
with app.app_context():
|
||
from models import db, Translation
|
||
trans = Translation.query.get(translation_id)
|
||
if trans:
|
||
trans.status = 'processing'
|
||
db.session.commit()
|
||
|
||
def progress_callback(progress, total, message):
|
||
with cls.lock:
|
||
task['progress'] = progress
|
||
task['message'] = message
|
||
|
||
# 更新数据库进度
|
||
if app and translation_id:
|
||
with app.app_context():
|
||
from models import db, Translation
|
||
trans = Translation.query.get(translation_id)
|
||
if trans:
|
||
trans.progress = progress
|
||
db.session.commit()
|
||
|
||
try:
|
||
result = service.translate_pdf(
|
||
pdf_path, output_path, instruction, progress_callback
|
||
)
|
||
task['status'] = 'completed'
|
||
task['progress'] = 100
|
||
task['message'] = '翻译完成'
|
||
task['completed_at'] = datetime.now().isoformat()
|
||
task['result'] = result
|
||
|
||
# 更新数据库状态为 completed
|
||
if app and translation_id:
|
||
with app.app_context():
|
||
from models import db, Translation
|
||
trans = Translation.query.get(translation_id)
|
||
if trans:
|
||
trans.status = 'completed'
|
||
trans.progress = 100
|
||
trans.completed_at = datetime.now()
|
||
db.session.commit()
|
||
|
||
except Exception as e:
|
||
task['status'] = 'failed'
|
||
task['error'] = str(e)
|
||
task['message'] = f'翻译失败: {e}'
|
||
|
||
# 更新数据库状态为 failed
|
||
if app and translation_id:
|
||
with app.app_context():
|
||
from models import db, Translation
|
||
trans = Translation.query.get(translation_id)
|
||
if trans:
|
||
trans.status = 'failed'
|
||
trans.error_message = str(e)
|
||
db.session.commit()
|
||
|
||
thread = threading.Thread(target=run_translation)
|
||
thread.start()
|
||
|
||
return task_id
|
||
|
||
@classmethod
|
||
def get_task(cls, task_id):
|
||
"""获取任务状态"""
|
||
with cls.lock:
|
||
return cls.tasks.get(task_id)
|
||
|
||
@classmethod
|
||
def update_task(cls, task_id, **kwargs):
|
||
"""更新任务"""
|
||
with cls.lock:
|
||
if task_id in cls.tasks:
|
||
cls.tasks[task_id].update(kwargs) |