Files
llm-index-rag/services.py

729 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
文档索引服务
使用LLM分析文档并构建索引
"""
import os
import re
import json
import math
from datetime import datetime
from collections import Counter
from openai import OpenAI
from flask import current_app
from config import LLM_CONFIG, DOC_CONFIG, INDEX_CONFIG
from models import db, Document, DocumentChunk, InvertedIndex, IndexStats, QueryLog
def get_llm_config():
"""获取有效的LLM配置支持动态更新"""
user_config_file = os.path.join(os.path.dirname(__file__), 'user_config.json')
if os.path.exists(user_config_file):
with open(user_config_file, 'r', encoding='utf-8') as f:
user_config = json.load(f)
return {**LLM_CONFIG, **user_config.get('llm', {})}
return LLM_CONFIG
def get_doc_config():
"""获取有效的文档配置"""
user_config_file = os.path.join(os.path.dirname(__file__), 'user_config.json')
if os.path.exists(user_config_file):
with open(user_config_file, 'r', encoding='utf-8') as f:
user_config = json.load(f)
return {**DOC_CONFIG, **user_config.get('doc', {})}
return DOC_CONFIG
def get_index_config():
"""获取有效的索引配置"""
user_config_file = os.path.join(os.path.dirname(__file__), 'user_config.json')
if os.path.exists(user_config_file):
with open(user_config_file, 'r', encoding='utf-8') as f:
user_config = json.load(f)
return {**INDEX_CONFIG, **user_config.get('index', {})}
return INDEX_CONFIG
class LLMService:
"""LLM服务封装"""
def __init__(self):
pass # 不再在初始化时设置配置
def _get_client(self):
"""获取LLM客户端"""
config = get_llm_config()
return OpenAI(
api_key=config['api_key'],
base_url=config['api_base'],
)
def _get_config(self):
"""获取当前配置"""
return get_llm_config()
def analyze_document(self, content, title=None):
"""
分析文档,提取关键信息
Returns:
dict: {summary, keywords, topics, entities, category}
"""
prompt = f"""请分析以下文档内容,提取关键信息。
{'文档标题:' + title if title else ''}
文档内容前3000字
{content[:3000]}
请以JSON格式返回以下信息
{{
"summary": "文档摘要100-200字",
"keywords": ["关键词1", "关键词2", ...最多10个],
"topics": ["主题1", "主题2", ...最多5个],
"category": "文档分类",
"entities": {{
"persons": ["人名"],
"organizations": ["组织名"],
"locations": ["地点"],
"dates": ["日期"],
"others": ["其他实体"]
}}
}}
只返回JSON不要其他内容。"""
try:
config = self._get_config()
client = self._get_client()
response = client.chat.completions.create(
model=config['model'],
messages=[{"role": "user", "content": prompt}],
max_tokens=1000,
temperature=0.3,
)
result = response.choices[0].message.content.strip()
# 清理可能的markdown标记
result = re.sub(r'^```json\s*', '', result)
result = re.sub(r'\s*```$', '', result)
return json.loads(result)
except Exception as e:
print(f"LLM分析失败: {e}")
return {
"summary": "",
"keywords": [],
"topics": [],
"category": "",
"entities": {}
}
def analyze_chunk(self, content):
"""
分析文档块,提取关键词
Returns:
dict: {summary, keywords, topics}
"""
prompt = f"""分析以下文本片段,提取关键信息。
文本:
{content[:1500]}
请以JSON格式返回
{{
"summary": "片段摘要50字以内",
"keywords": ["关键词", ...最多8个],
"topics": ["主题", ...最多3个]
}}
只返回JSON。"""
try:
config = self._get_config()
client = self._get_client()
response = client.chat.completions.create(
model=config['model'],
messages=[{"role": "user", "content": prompt}],
max_tokens=500,
temperature=0.3,
)
result = response.choices[0].message.content.strip()
result = re.sub(r'^```json\s*', '', result)
result = re.sub(r'\s*```$', '', result)
return json.loads(result)
except Exception as e:
return {"summary": "", "keywords": [], "topics": []}
def process_query(self, query):
"""
处理查询,提取意图和关键词
Returns:
dict: {intent, keywords, expanded_terms, entities}
"""
prompt = f"""分析以下查询,提取搜索意图和关键词。
查询:{query}
请以JSON格式返回
{{
"intent": "查询意图(如:查找信息、比较、解释、列表等)",
"keywords": ["主要关键词", ...最多10个],
"expanded_terms": ["同义词或相关词", ...最多5个],
"entities": {{
"persons": [],
"organizations": [],
"locations": [],
"dates": [],
"others": []
}}
}}
只返回JSON。"""
try:
config = self._get_config()
client = self._get_client()
response = client.chat.completions.create(
model=config['model'],
messages=[{"role": "user", "content": prompt}],
max_tokens=500,
temperature=0.3,
)
result = response.choices[0].message.content.strip()
result = re.sub(r'^```json\s*', '', result)
result = re.sub(r'\s*```$', '', result)
return json.loads(result)
except Exception as e:
return {
"intent": "search",
"keywords": query.split(),
"expanded_terms": [],
"entities": {}
}
class DocumentIndexer:
"""文档索引器"""
def __init__(self):
self.llm = LLMService()
def _get_chunk_config(self):
"""获取分块配置"""
config = get_doc_config()
return config['chunk_size'], config['chunk_overlap']
def index_document(self, doc_id):
"""
索引单个文档
Args:
doc_id: 文档ID
Returns:
bool: 是否成功
"""
doc = Document.query.get(doc_id)
if not doc:
return False
try:
doc.status = 'processing'
db.session.commit()
# 读取文档内容
content = self._read_document(doc.filepath)
if not content:
raise Exception("无法读取文档内容")
# 存储原文
doc.content = content
doc.word_count = len(content)
# 使用LLM分析整个文档
print(f" 正在分析文档: {doc.filename}")
analysis = self.llm.analyze_document(content, doc.title)
doc.summary = analysis.get('summary', '')
doc.set_keywords(analysis.get('keywords', []))
doc.set_topics(analysis.get('topics', []))
doc.category = analysis.get('category', '')
doc.set_entities(analysis.get('entities', {}))
# 分块处理
chunks = self._split_content(content)
doc.chunk_count = len(chunks)
# 清理旧分块
DocumentChunk.query.filter_by(document_id=doc.id).delete()
# 索引每个分块
for i, chunk_content in enumerate(chunks):
chunk = DocumentChunk(
document_id=doc.id,
chunk_index=i,
content=chunk_content,
start_char=0,
end_char=len(chunk_content)
)
# LLM分析分块
chunk_analysis = self.llm.analyze_chunk(chunk_content)
chunk.summary = chunk_analysis.get('summary', '')
chunk.set_keywords(chunk_analysis.get('keywords', []))
# 计算词频
term_freq = self._compute_term_freq(chunk_content)
chunk.set_term_freq(term_freq)
db.session.add(chunk)
db.session.commit()
# 更新倒排索引
self._update_inverted_index(doc.id)
# 标记完成
doc.status = 'indexed'
doc.indexed_at = datetime.utcnow()
db.session.commit()
# 更新统计
IndexStats.get_stats().update_stats()
print(f" ✓ 文档索引完成: {doc.filename}")
return True
except Exception as e:
doc.status = 'failed'
doc.error_message = str(e)
db.session.commit()
print(f" ✗ 索引失败: {e}")
return False
def _read_document(self, filepath):
"""读取文档内容"""
ext = os.path.splitext(filepath)[1].lower()
# 尝试读取文本文件(包括没有扩展名的)
try:
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
if content.strip(): # 如果能读取到内容
return content
except:
pass
# 按扩展名处理特定格式
if ext in ['.txt', '.md', '.json', '.html']:
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
return f.read()
elif ext == '.pdf':
try:
from pypdf import PdfReader
reader = PdfReader(filepath)
text = ''
for page in reader.pages:
text += page.extract_text() + '\n'
return text
except:
pass
elif ext == '.docx':
try:
from docx import Document as DocxDocument
doc = DocxDocument(filepath)
return '\n'.join([p.text for p in doc.paragraphs])
except:
pass
# 最后尝试以二进制方式读取并解码
try:
with open(filepath, 'rb') as f:
content = f.read()
# 尝试多种编码
for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']:
try:
return content.decode(encoding)
except:
continue
except:
pass
return None
def _split_content(self, content):
"""
分割内容为块
Args:
content: 文档内容
Returns:
list: 内容块列表
"""
chunk_size, _ = self._get_chunk_config()
chunks = []
# 按段落分割
paragraphs = content.split('\n\n')
current_chunk = ""
for para in paragraphs:
if len(current_chunk) + len(para) < chunk_size:
current_chunk += para + '\n\n'
else:
if current_chunk.strip():
chunks.append(current_chunk.strip())
current_chunk = para + '\n\n'
if current_chunk.strip():
chunks.append(current_chunk.strip())
return chunks if chunks else [content[:chunk_size]]
def _compute_term_freq(self, content):
"""计算词频"""
# 简单分词(中英文混合)
# 中文按字符,英文按空格
terms = []
# 提取中文词汇简单按字实际可用jieba
chinese = re.findall(r'[\u4e00-\u9fff]+', content)
for text in chinese:
# 简单的双字词分割
if len(text) >= 2:
for i in range(len(text) - 1):
terms.append(text[i:i+2])
terms.extend(list(text))
# 提取英文单词
english = re.findall(r'[a-zA-Z]+', content.lower())
terms.extend(english)
# 统计词频
return dict(Counter(terms))
def _update_inverted_index(self, doc_id):
"""更新倒排索引"""
chunks = DocumentChunk.query.filter_by(document_id=doc_id).all()
# 收集所有词及其位置
term_postings = {}
for chunk in chunks:
# 从词频获取词
tf = chunk.get_term_freq()
# 从关键词获取词
keywords = chunk.get_keywords()
for kw in keywords:
if kw not in term_postings:
term_postings[kw] = []
term_postings[kw].append({
'doc_id': doc_id,
'chunk_id': chunk.id,
'tf': tf.get(kw, 1),
'weight': INDEX_CONFIG['keyword_weight']
})
for term, freq in tf.items():
if term not in term_postings:
term_postings[term] = []
term_postings[term].append({
'doc_id': doc_id,
'chunk_id': chunk.id,
'tf': freq,
'weight': INDEX_CONFIG['content_weight']
})
# 更新数据库
for term, postings in term_postings.items():
index = InvertedIndex.get_or_create(term)
existing = index.get_postings()
# 合并postings去除旧的同一文档的记录
existing = [p for p in existing if p['doc_id'] != doc_id]
existing.extend(postings)
index.set_postings(existing)
db.session.commit()
class SearchEngine:
"""搜索引擎"""
def __init__(self):
self.llm = LLMService()
def _get_bm25_params(self):
"""获取BM25参数"""
config = get_index_config()
return config['bm25_k1'], config['bm25_b']
def search(self, query, top_k=10):
"""
搜索文档
Args:
query: 查询字符串
top_k: 返回结果数
Returns:
list: 搜索结果 [{doc, score, highlights}]
"""
start_time = datetime.now()
# 1. LLM处理查询
print(f"处理查询: {query}")
query_analysis = self.llm.process_query(query)
keywords = query_analysis.get('keywords', [])
expanded = query_analysis.get('expanded_terms', [])
# 合并关键词
all_terms = keywords + expanded
print(f" 关键词: {keywords}")
print(f" 扩展词: {expanded}")
# 2. 检索
results = self._retrieve(all_terms)
# 3. 计算BM25分数
scored_results = self._score_results(results, all_terms)
# 4. 排序
scored_results.sort(key=lambda x: x['score'], reverse=True)
# 5. 返回top_k
final_results = scored_results[:top_k]
retrieval_time = (datetime.now() - start_time).total_seconds()
# 6. 记录日志
self._log_query(query, query_analysis, final_results, retrieval_time)
return final_results
def _retrieve(self, terms):
"""
检索包含关键词的文档
Args:
terms: 关键词列表
Returns:
dict: {doc_id: {chunks: [], terms: []}}
"""
results = {}
for term in terms:
# 查询倒排索引
index = InvertedIndex.query.filter(
InvertedIndex.term.ilike(f'%{term}%')
).all()
for idx in index:
postings = idx.get_postings()
for p in postings:
doc_id = p['doc_id']
if doc_id not in results:
results[doc_id] = {
'chunks': set(),
'terms': {},
'postings': []
}
results[doc_id]['chunks'].add(p['chunk_id'])
results[doc_id]['terms'][term] = results[doc_id]['terms'].get(term, 0) + p.get('tf', 1)
results[doc_id]['postings'].append({
'term': term,
'chunk_id': p['chunk_id'],
'tf': p.get('tf', 1),
'weight': p.get('weight', 1.0)
})
return results
def _score_results(self, results, query_terms):
"""
使用BM25计算分数
Args:
results: 检索结果
query_terms: 查询词
Returns:
list: [{doc, score, chunks}]
"""
scored = []
# 计算平均文档长度
total_docs = Document.query.filter_by(status='indexed').count()
if total_docs == 0:
return []
avg_doc_len = db.session.query(
db.func.avg(Document.word_count)
).filter(Document.status == 'indexed').scalar() or 1000
for doc_id, data in results.items():
doc = Document.query.get(doc_id)
if not doc or doc.status != 'indexed':
continue
# BM25计算
k1, b = self._get_bm25_params()
score = 0
doc_len = doc.word_count or 1000
for term in query_terms:
# 查询倒排索引获取文档频率
index = InvertedIndex.query.filter_by(term=term).first()
df = index.doc_freq if index else 1
# IDF
idf = math.log((total_docs - df + 0.5) / (df + 0.5) + 1)
# TF
tf = data['terms'].get(term, 0)
# BM25公式
tf_component = (tf * (k1 + 1)) / (
tf + k1 * (1 - b + b * doc_len / avg_doc_len)
)
score += idf * tf_component
# 获取匹配的chunk内容
chunk_ids = list(data['chunks'])[:3] # 最多取3个chunk
chunks = DocumentChunk.query.filter(DocumentChunk.id.in_(chunk_ids)).all()
scored.append({
'doc': doc.to_dict(),
'score': score,
'matched_chunks': [c.to_dict() for c in chunks],
'matched_terms': list(data['terms'].keys())
})
return scored
def _log_query(self, query, analysis, results, retrieval_time):
"""记录查询日志"""
log = QueryLog(
original_query=query,
processed_query=' '.join(analysis.get('keywords', [])),
expanded_terms=json.dumps(analysis.get('expanded_terms', [])),
intent=analysis.get('intent'),
entities=json.dumps(analysis.get('entities', {})),
result_count=len(results),
top_doc_ids=json.dumps([r['doc']['id'] for r in results[:5]]),
retrieval_time=retrieval_time,
total_time=retrieval_time
)
db.session.add(log)
db.session.commit()
class RAGGenerator:
"""RAG生成器"""
def __init__(self):
self.llm = LLMService()
self.search_engine = SearchEngine()
def answer(self, query, top_k=5):
"""
RAG回答
Args:
query: 用户查询
top_k: 检索文档数
Returns:
dict: {answer, sources, confidence}
"""
# 1. 检索相关文档
results = self.search_engine.search(query, top_k)
if not results:
return {
'answer': '抱歉,没有找到相关信息。',
'sources': [],
'confidence': 0
}
# 2. 构建上下文
context_parts = []
sources = []
for i, r in enumerate(results[:3]): # 最多使用3个文档
doc = r['doc']
chunks = r['matched_chunks']
context_parts.append(f"【文档{i+1}{doc.get('title', doc['filename'])}")
for chunk in chunks[:2]: # 每个文档最多2个chunk
context_parts.append(chunk.get('content', '')[:500])
sources.append({
'id': doc['id'],
'title': doc.get('title', doc['filename']),
'score': r['score']
})
context = '\n\n'.join(context_parts)
# 3. LLM生成回答
prompt = f"""基于以下参考信息回答问题。如果参考信息中没有相关内容,请说明。
问题:{query}
参考信息:
{context}
请给出准确、简洁的回答,并标注信息来源。"""
try:
llm_config = get_llm_config()
client = OpenAI(
api_key=llm_config['api_key'],
base_url=llm_config['api_base'],
)
response = client.chat.completions.create(
model=llm_config['model'],
messages=[{"role": "user", "content": prompt}],
max_tokens=1000,
temperature=0.5,
)
answer = response.choices[0].message.content
except Exception as e:
answer = f"生成回答时出错: {e}"
# 4. 返回结果
return {
'answer': answer,
'sources': sources,
'confidence': min(1.0, results[0]['score'] / 10) if results else 0
}