729 lines
22 KiB
Python
729 lines
22 KiB
Python
"""
|
||
文档索引服务
|
||
使用LLM分析文档并构建索引
|
||
"""
|
||
|
||
import os
|
||
import re
|
||
import json
|
||
import math
|
||
from datetime import datetime
|
||
from collections import Counter
|
||
from openai import OpenAI
|
||
from flask import current_app
|
||
|
||
from config import LLM_CONFIG, DOC_CONFIG, INDEX_CONFIG
|
||
from models import db, Document, DocumentChunk, InvertedIndex, IndexStats, QueryLog
|
||
|
||
|
||
def get_llm_config():
|
||
"""获取有效的LLM配置(支持动态更新)"""
|
||
user_config_file = os.path.join(os.path.dirname(__file__), 'user_config.json')
|
||
if os.path.exists(user_config_file):
|
||
with open(user_config_file, 'r', encoding='utf-8') as f:
|
||
user_config = json.load(f)
|
||
return {**LLM_CONFIG, **user_config.get('llm', {})}
|
||
return LLM_CONFIG
|
||
|
||
|
||
def get_doc_config():
|
||
"""获取有效的文档配置"""
|
||
user_config_file = os.path.join(os.path.dirname(__file__), 'user_config.json')
|
||
if os.path.exists(user_config_file):
|
||
with open(user_config_file, 'r', encoding='utf-8') as f:
|
||
user_config = json.load(f)
|
||
return {**DOC_CONFIG, **user_config.get('doc', {})}
|
||
return DOC_CONFIG
|
||
|
||
|
||
def get_index_config():
|
||
"""获取有效的索引配置"""
|
||
user_config_file = os.path.join(os.path.dirname(__file__), 'user_config.json')
|
||
if os.path.exists(user_config_file):
|
||
with open(user_config_file, 'r', encoding='utf-8') as f:
|
||
user_config = json.load(f)
|
||
return {**INDEX_CONFIG, **user_config.get('index', {})}
|
||
return INDEX_CONFIG
|
||
|
||
|
||
class LLMService:
|
||
"""LLM服务封装"""
|
||
|
||
def __init__(self):
|
||
pass # 不再在初始化时设置配置
|
||
|
||
def _get_client(self):
|
||
"""获取LLM客户端"""
|
||
config = get_llm_config()
|
||
return OpenAI(
|
||
api_key=config['api_key'],
|
||
base_url=config['api_base'],
|
||
)
|
||
|
||
def _get_config(self):
|
||
"""获取当前配置"""
|
||
return get_llm_config()
|
||
|
||
def analyze_document(self, content, title=None):
|
||
"""
|
||
分析文档,提取关键信息
|
||
|
||
Returns:
|
||
dict: {summary, keywords, topics, entities, category}
|
||
"""
|
||
prompt = f"""请分析以下文档内容,提取关键信息。
|
||
|
||
{'文档标题:' + title if title else ''}
|
||
|
||
文档内容(前3000字):
|
||
{content[:3000]}
|
||
|
||
请以JSON格式返回以下信息:
|
||
{{
|
||
"summary": "文档摘要(100-200字)",
|
||
"keywords": ["关键词1", "关键词2", ...最多10个],
|
||
"topics": ["主题1", "主题2", ...最多5个],
|
||
"category": "文档分类",
|
||
"entities": {{
|
||
"persons": ["人名"],
|
||
"organizations": ["组织名"],
|
||
"locations": ["地点"],
|
||
"dates": ["日期"],
|
||
"others": ["其他实体"]
|
||
}}
|
||
}}
|
||
|
||
只返回JSON,不要其他内容。"""
|
||
|
||
try:
|
||
config = self._get_config()
|
||
client = self._get_client()
|
||
response = client.chat.completions.create(
|
||
model=config['model'],
|
||
messages=[{"role": "user", "content": prompt}],
|
||
max_tokens=1000,
|
||
temperature=0.3,
|
||
)
|
||
|
||
result = response.choices[0].message.content.strip()
|
||
# 清理可能的markdown标记
|
||
result = re.sub(r'^```json\s*', '', result)
|
||
result = re.sub(r'\s*```$', '', result)
|
||
|
||
return json.loads(result)
|
||
|
||
except Exception as e:
|
||
print(f"LLM分析失败: {e}")
|
||
return {
|
||
"summary": "",
|
||
"keywords": [],
|
||
"topics": [],
|
||
"category": "",
|
||
"entities": {}
|
||
}
|
||
|
||
def analyze_chunk(self, content):
|
||
"""
|
||
分析文档块,提取关键词
|
||
|
||
Returns:
|
||
dict: {summary, keywords, topics}
|
||
"""
|
||
prompt = f"""分析以下文本片段,提取关键信息。
|
||
|
||
文本:
|
||
{content[:1500]}
|
||
|
||
请以JSON格式返回:
|
||
{{
|
||
"summary": "片段摘要(50字以内)",
|
||
"keywords": ["关键词", ...最多8个],
|
||
"topics": ["主题", ...最多3个]
|
||
}}
|
||
|
||
只返回JSON。"""
|
||
|
||
try:
|
||
config = self._get_config()
|
||
client = self._get_client()
|
||
response = client.chat.completions.create(
|
||
model=config['model'],
|
||
messages=[{"role": "user", "content": prompt}],
|
||
max_tokens=500,
|
||
temperature=0.3,
|
||
)
|
||
|
||
result = response.choices[0].message.content.strip()
|
||
result = re.sub(r'^```json\s*', '', result)
|
||
result = re.sub(r'\s*```$', '', result)
|
||
|
||
return json.loads(result)
|
||
|
||
except Exception as e:
|
||
return {"summary": "", "keywords": [], "topics": []}
|
||
|
||
def process_query(self, query):
|
||
"""
|
||
处理查询,提取意图和关键词
|
||
|
||
Returns:
|
||
dict: {intent, keywords, expanded_terms, entities}
|
||
"""
|
||
prompt = f"""分析以下查询,提取搜索意图和关键词。
|
||
|
||
查询:{query}
|
||
|
||
请以JSON格式返回:
|
||
{{
|
||
"intent": "查询意图(如:查找信息、比较、解释、列表等)",
|
||
"keywords": ["主要关键词", ...最多10个],
|
||
"expanded_terms": ["同义词或相关词", ...最多5个],
|
||
"entities": {{
|
||
"persons": [],
|
||
"organizations": [],
|
||
"locations": [],
|
||
"dates": [],
|
||
"others": []
|
||
}}
|
||
}}
|
||
|
||
只返回JSON。"""
|
||
|
||
try:
|
||
config = self._get_config()
|
||
client = self._get_client()
|
||
response = client.chat.completions.create(
|
||
model=config['model'],
|
||
messages=[{"role": "user", "content": prompt}],
|
||
max_tokens=500,
|
||
temperature=0.3,
|
||
)
|
||
|
||
result = response.choices[0].message.content.strip()
|
||
result = re.sub(r'^```json\s*', '', result)
|
||
result = re.sub(r'\s*```$', '', result)
|
||
|
||
return json.loads(result)
|
||
|
||
except Exception as e:
|
||
return {
|
||
"intent": "search",
|
||
"keywords": query.split(),
|
||
"expanded_terms": [],
|
||
"entities": {}
|
||
}
|
||
|
||
|
||
class DocumentIndexer:
|
||
"""文档索引器"""
|
||
|
||
def __init__(self):
|
||
self.llm = LLMService()
|
||
|
||
def _get_chunk_config(self):
|
||
"""获取分块配置"""
|
||
config = get_doc_config()
|
||
return config['chunk_size'], config['chunk_overlap']
|
||
|
||
def index_document(self, doc_id):
|
||
"""
|
||
索引单个文档
|
||
|
||
Args:
|
||
doc_id: 文档ID
|
||
|
||
Returns:
|
||
bool: 是否成功
|
||
"""
|
||
doc = Document.query.get(doc_id)
|
||
if not doc:
|
||
return False
|
||
|
||
try:
|
||
doc.status = 'processing'
|
||
db.session.commit()
|
||
|
||
# 读取文档内容
|
||
content = self._read_document(doc.filepath)
|
||
if not content:
|
||
raise Exception("无法读取文档内容")
|
||
|
||
# 存储原文
|
||
doc.content = content
|
||
doc.word_count = len(content)
|
||
|
||
# 使用LLM分析整个文档
|
||
print(f" 正在分析文档: {doc.filename}")
|
||
analysis = self.llm.analyze_document(content, doc.title)
|
||
|
||
doc.summary = analysis.get('summary', '')
|
||
doc.set_keywords(analysis.get('keywords', []))
|
||
doc.set_topics(analysis.get('topics', []))
|
||
doc.category = analysis.get('category', '')
|
||
doc.set_entities(analysis.get('entities', {}))
|
||
|
||
# 分块处理
|
||
chunks = self._split_content(content)
|
||
doc.chunk_count = len(chunks)
|
||
|
||
# 清理旧分块
|
||
DocumentChunk.query.filter_by(document_id=doc.id).delete()
|
||
|
||
# 索引每个分块
|
||
for i, chunk_content in enumerate(chunks):
|
||
chunk = DocumentChunk(
|
||
document_id=doc.id,
|
||
chunk_index=i,
|
||
content=chunk_content,
|
||
start_char=0,
|
||
end_char=len(chunk_content)
|
||
)
|
||
|
||
# LLM分析分块
|
||
chunk_analysis = self.llm.analyze_chunk(chunk_content)
|
||
chunk.summary = chunk_analysis.get('summary', '')
|
||
chunk.set_keywords(chunk_analysis.get('keywords', []))
|
||
|
||
# 计算词频
|
||
term_freq = self._compute_term_freq(chunk_content)
|
||
chunk.set_term_freq(term_freq)
|
||
|
||
db.session.add(chunk)
|
||
|
||
db.session.commit()
|
||
|
||
# 更新倒排索引
|
||
self._update_inverted_index(doc.id)
|
||
|
||
# 标记完成
|
||
doc.status = 'indexed'
|
||
doc.indexed_at = datetime.utcnow()
|
||
db.session.commit()
|
||
|
||
# 更新统计
|
||
IndexStats.get_stats().update_stats()
|
||
|
||
print(f" ✓ 文档索引完成: {doc.filename}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
doc.status = 'failed'
|
||
doc.error_message = str(e)
|
||
db.session.commit()
|
||
print(f" ✗ 索引失败: {e}")
|
||
return False
|
||
|
||
def _read_document(self, filepath):
|
||
"""读取文档内容"""
|
||
ext = os.path.splitext(filepath)[1].lower()
|
||
|
||
# 尝试读取文本文件(包括没有扩展名的)
|
||
try:
|
||
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
||
content = f.read()
|
||
if content.strip(): # 如果能读取到内容
|
||
return content
|
||
except:
|
||
pass
|
||
|
||
# 按扩展名处理特定格式
|
||
if ext in ['.txt', '.md', '.json', '.html']:
|
||
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
||
return f.read()
|
||
|
||
elif ext == '.pdf':
|
||
try:
|
||
from pypdf import PdfReader
|
||
reader = PdfReader(filepath)
|
||
text = ''
|
||
for page in reader.pages:
|
||
text += page.extract_text() + '\n'
|
||
return text
|
||
except:
|
||
pass
|
||
|
||
elif ext == '.docx':
|
||
try:
|
||
from docx import Document as DocxDocument
|
||
doc = DocxDocument(filepath)
|
||
return '\n'.join([p.text for p in doc.paragraphs])
|
||
except:
|
||
pass
|
||
|
||
# 最后尝试以二进制方式读取并解码
|
||
try:
|
||
with open(filepath, 'rb') as f:
|
||
content = f.read()
|
||
# 尝试多种编码
|
||
for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']:
|
||
try:
|
||
return content.decode(encoding)
|
||
except:
|
||
continue
|
||
except:
|
||
pass
|
||
|
||
return None
|
||
|
||
def _split_content(self, content):
|
||
"""
|
||
分割内容为块
|
||
|
||
Args:
|
||
content: 文档内容
|
||
|
||
Returns:
|
||
list: 内容块列表
|
||
"""
|
||
chunk_size, _ = self._get_chunk_config()
|
||
chunks = []
|
||
|
||
# 按段落分割
|
||
paragraphs = content.split('\n\n')
|
||
|
||
current_chunk = ""
|
||
for para in paragraphs:
|
||
if len(current_chunk) + len(para) < chunk_size:
|
||
current_chunk += para + '\n\n'
|
||
else:
|
||
if current_chunk.strip():
|
||
chunks.append(current_chunk.strip())
|
||
current_chunk = para + '\n\n'
|
||
|
||
if current_chunk.strip():
|
||
chunks.append(current_chunk.strip())
|
||
|
||
return chunks if chunks else [content[:chunk_size]]
|
||
|
||
def _compute_term_freq(self, content):
|
||
"""计算词频"""
|
||
# 简单分词(中英文混合)
|
||
# 中文按字符,英文按空格
|
||
terms = []
|
||
|
||
# 提取中文词汇(简单按字,实际可用jieba)
|
||
chinese = re.findall(r'[\u4e00-\u9fff]+', content)
|
||
for text in chinese:
|
||
# 简单的双字词分割
|
||
if len(text) >= 2:
|
||
for i in range(len(text) - 1):
|
||
terms.append(text[i:i+2])
|
||
terms.extend(list(text))
|
||
|
||
# 提取英文单词
|
||
english = re.findall(r'[a-zA-Z]+', content.lower())
|
||
terms.extend(english)
|
||
|
||
# 统计词频
|
||
return dict(Counter(terms))
|
||
|
||
def _update_inverted_index(self, doc_id):
|
||
"""更新倒排索引"""
|
||
chunks = DocumentChunk.query.filter_by(document_id=doc_id).all()
|
||
|
||
# 收集所有词及其位置
|
||
term_postings = {}
|
||
|
||
for chunk in chunks:
|
||
# 从词频获取词
|
||
tf = chunk.get_term_freq()
|
||
|
||
# 从关键词获取词
|
||
keywords = chunk.get_keywords()
|
||
|
||
for kw in keywords:
|
||
if kw not in term_postings:
|
||
term_postings[kw] = []
|
||
term_postings[kw].append({
|
||
'doc_id': doc_id,
|
||
'chunk_id': chunk.id,
|
||
'tf': tf.get(kw, 1),
|
||
'weight': INDEX_CONFIG['keyword_weight']
|
||
})
|
||
|
||
for term, freq in tf.items():
|
||
if term not in term_postings:
|
||
term_postings[term] = []
|
||
term_postings[term].append({
|
||
'doc_id': doc_id,
|
||
'chunk_id': chunk.id,
|
||
'tf': freq,
|
||
'weight': INDEX_CONFIG['content_weight']
|
||
})
|
||
|
||
# 更新数据库
|
||
for term, postings in term_postings.items():
|
||
index = InvertedIndex.get_or_create(term)
|
||
existing = index.get_postings()
|
||
|
||
# 合并postings(去除旧的同一文档的记录)
|
||
existing = [p for p in existing if p['doc_id'] != doc_id]
|
||
existing.extend(postings)
|
||
|
||
index.set_postings(existing)
|
||
|
||
db.session.commit()
|
||
|
||
|
||
class SearchEngine:
|
||
"""搜索引擎"""
|
||
|
||
def __init__(self):
|
||
self.llm = LLMService()
|
||
|
||
def _get_bm25_params(self):
|
||
"""获取BM25参数"""
|
||
config = get_index_config()
|
||
return config['bm25_k1'], config['bm25_b']
|
||
|
||
def search(self, query, top_k=10):
|
||
"""
|
||
搜索文档
|
||
|
||
Args:
|
||
query: 查询字符串
|
||
top_k: 返回结果数
|
||
|
||
Returns:
|
||
list: 搜索结果 [{doc, score, highlights}]
|
||
"""
|
||
start_time = datetime.now()
|
||
|
||
# 1. LLM处理查询
|
||
print(f"处理查询: {query}")
|
||
query_analysis = self.llm.process_query(query)
|
||
|
||
keywords = query_analysis.get('keywords', [])
|
||
expanded = query_analysis.get('expanded_terms', [])
|
||
|
||
# 合并关键词
|
||
all_terms = keywords + expanded
|
||
|
||
print(f" 关键词: {keywords}")
|
||
print(f" 扩展词: {expanded}")
|
||
|
||
# 2. 检索
|
||
results = self._retrieve(all_terms)
|
||
|
||
# 3. 计算BM25分数
|
||
scored_results = self._score_results(results, all_terms)
|
||
|
||
# 4. 排序
|
||
scored_results.sort(key=lambda x: x['score'], reverse=True)
|
||
|
||
# 5. 返回top_k
|
||
final_results = scored_results[:top_k]
|
||
|
||
retrieval_time = (datetime.now() - start_time).total_seconds()
|
||
|
||
# 6. 记录日志
|
||
self._log_query(query, query_analysis, final_results, retrieval_time)
|
||
|
||
return final_results
|
||
|
||
def _retrieve(self, terms):
|
||
"""
|
||
检索包含关键词的文档
|
||
|
||
Args:
|
||
terms: 关键词列表
|
||
|
||
Returns:
|
||
dict: {doc_id: {chunks: [], terms: []}}
|
||
"""
|
||
results = {}
|
||
|
||
for term in terms:
|
||
# 查询倒排索引
|
||
index = InvertedIndex.query.filter(
|
||
InvertedIndex.term.ilike(f'%{term}%')
|
||
).all()
|
||
|
||
for idx in index:
|
||
postings = idx.get_postings()
|
||
|
||
for p in postings:
|
||
doc_id = p['doc_id']
|
||
|
||
if doc_id not in results:
|
||
results[doc_id] = {
|
||
'chunks': set(),
|
||
'terms': {},
|
||
'postings': []
|
||
}
|
||
|
||
results[doc_id]['chunks'].add(p['chunk_id'])
|
||
results[doc_id]['terms'][term] = results[doc_id]['terms'].get(term, 0) + p.get('tf', 1)
|
||
results[doc_id]['postings'].append({
|
||
'term': term,
|
||
'chunk_id': p['chunk_id'],
|
||
'tf': p.get('tf', 1),
|
||
'weight': p.get('weight', 1.0)
|
||
})
|
||
|
||
return results
|
||
|
||
def _score_results(self, results, query_terms):
|
||
"""
|
||
使用BM25计算分数
|
||
|
||
Args:
|
||
results: 检索结果
|
||
query_terms: 查询词
|
||
|
||
Returns:
|
||
list: [{doc, score, chunks}]
|
||
"""
|
||
scored = []
|
||
|
||
# 计算平均文档长度
|
||
total_docs = Document.query.filter_by(status='indexed').count()
|
||
if total_docs == 0:
|
||
return []
|
||
|
||
avg_doc_len = db.session.query(
|
||
db.func.avg(Document.word_count)
|
||
).filter(Document.status == 'indexed').scalar() or 1000
|
||
|
||
for doc_id, data in results.items():
|
||
doc = Document.query.get(doc_id)
|
||
if not doc or doc.status != 'indexed':
|
||
continue
|
||
|
||
# BM25计算
|
||
k1, b = self._get_bm25_params()
|
||
score = 0
|
||
doc_len = doc.word_count or 1000
|
||
|
||
for term in query_terms:
|
||
# 查询倒排索引获取文档频率
|
||
index = InvertedIndex.query.filter_by(term=term).first()
|
||
df = index.doc_freq if index else 1
|
||
|
||
# IDF
|
||
idf = math.log((total_docs - df + 0.5) / (df + 0.5) + 1)
|
||
|
||
# TF
|
||
tf = data['terms'].get(term, 0)
|
||
|
||
# BM25公式
|
||
tf_component = (tf * (k1 + 1)) / (
|
||
tf + k1 * (1 - b + b * doc_len / avg_doc_len)
|
||
)
|
||
|
||
score += idf * tf_component
|
||
|
||
# 获取匹配的chunk内容
|
||
chunk_ids = list(data['chunks'])[:3] # 最多取3个chunk
|
||
chunks = DocumentChunk.query.filter(DocumentChunk.id.in_(chunk_ids)).all()
|
||
|
||
scored.append({
|
||
'doc': doc.to_dict(),
|
||
'score': score,
|
||
'matched_chunks': [c.to_dict() for c in chunks],
|
||
'matched_terms': list(data['terms'].keys())
|
||
})
|
||
|
||
return scored
|
||
|
||
def _log_query(self, query, analysis, results, retrieval_time):
|
||
"""记录查询日志"""
|
||
log = QueryLog(
|
||
original_query=query,
|
||
processed_query=' '.join(analysis.get('keywords', [])),
|
||
expanded_terms=json.dumps(analysis.get('expanded_terms', [])),
|
||
intent=analysis.get('intent'),
|
||
entities=json.dumps(analysis.get('entities', {})),
|
||
result_count=len(results),
|
||
top_doc_ids=json.dumps([r['doc']['id'] for r in results[:5]]),
|
||
retrieval_time=retrieval_time,
|
||
total_time=retrieval_time
|
||
)
|
||
db.session.add(log)
|
||
db.session.commit()
|
||
|
||
|
||
class RAGGenerator:
|
||
"""RAG生成器"""
|
||
|
||
def __init__(self):
|
||
self.llm = LLMService()
|
||
self.search_engine = SearchEngine()
|
||
|
||
def answer(self, query, top_k=5):
|
||
"""
|
||
RAG回答
|
||
|
||
Args:
|
||
query: 用户查询
|
||
top_k: 检索文档数
|
||
|
||
Returns:
|
||
dict: {answer, sources, confidence}
|
||
"""
|
||
# 1. 检索相关文档
|
||
results = self.search_engine.search(query, top_k)
|
||
|
||
if not results:
|
||
return {
|
||
'answer': '抱歉,没有找到相关信息。',
|
||
'sources': [],
|
||
'confidence': 0
|
||
}
|
||
|
||
# 2. 构建上下文
|
||
context_parts = []
|
||
sources = []
|
||
|
||
for i, r in enumerate(results[:3]): # 最多使用3个文档
|
||
doc = r['doc']
|
||
chunks = r['matched_chunks']
|
||
|
||
context_parts.append(f"【文档{i+1}】{doc.get('title', doc['filename'])}")
|
||
|
||
for chunk in chunks[:2]: # 每个文档最多2个chunk
|
||
context_parts.append(chunk.get('content', '')[:500])
|
||
|
||
sources.append({
|
||
'id': doc['id'],
|
||
'title': doc.get('title', doc['filename']),
|
||
'score': r['score']
|
||
})
|
||
|
||
context = '\n\n'.join(context_parts)
|
||
|
||
# 3. LLM生成回答
|
||
prompt = f"""基于以下参考信息回答问题。如果参考信息中没有相关内容,请说明。
|
||
|
||
问题:{query}
|
||
|
||
参考信息:
|
||
{context}
|
||
|
||
请给出准确、简洁的回答,并标注信息来源。"""
|
||
|
||
try:
|
||
llm_config = get_llm_config()
|
||
client = OpenAI(
|
||
api_key=llm_config['api_key'],
|
||
base_url=llm_config['api_base'],
|
||
)
|
||
|
||
response = client.chat.completions.create(
|
||
model=llm_config['model'],
|
||
messages=[{"role": "user", "content": prompt}],
|
||
max_tokens=1000,
|
||
temperature=0.5,
|
||
)
|
||
|
||
answer = response.choices[0].message.content
|
||
|
||
except Exception as e:
|
||
answer = f"生成回答时出错: {e}"
|
||
|
||
# 4. 返回结果
|
||
return {
|
||
'answer': answer,
|
||
'sources': sources,
|
||
'confidence': min(1.0, results[0]['score'] / 10) if results else 0
|
||
} |