V1.0.0: 基于索引的知识检索系统
核心功能: - 文档索引:使用LLM分析提取关键词/摘要/主题/实体 - 查询处理:LLM分析查询意图并扩展关键词 - BM25检索:基于倒排索引的相关性排序 - RAG问答:检索增强生成 技术栈: - Flask + SQLAlchemy - OpenAI API兼容LLM - BM25算法 特点: 不依赖向量模型和向量库
This commit is contained in:
655
services.py
Normal file
655
services.py
Normal file
@@ -0,0 +1,655 @@
|
||||
"""
|
||||
文档索引服务
|
||||
使用LLM分析文档并构建索引
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import math
|
||||
from datetime import datetime
|
||||
from collections import Counter
|
||||
from openai import OpenAI
|
||||
from flask import current_app
|
||||
|
||||
from config import LLM_CONFIG, DOC_CONFIG, INDEX_CONFIG
|
||||
from models import db, Document, DocumentChunk, InvertedIndex, IndexStats
|
||||
|
||||
|
||||
class LLMService:
|
||||
"""LLM服务封装"""
|
||||
|
||||
def __init__(self):
|
||||
self.client = OpenAI(
|
||||
api_key=LLM_CONFIG['api_key'],
|
||||
base_url=LLM_CONFIG['api_base'],
|
||||
)
|
||||
self.model = LLM_CONFIG['model']
|
||||
self.max_tokens = LLM_CONFIG['max_tokens']
|
||||
self.temperature = LLM_CONFIG['temperature']
|
||||
|
||||
def analyze_document(self, content, title=None):
|
||||
"""
|
||||
分析文档,提取关键信息
|
||||
|
||||
Returns:
|
||||
dict: {summary, keywords, topics, entities, category}
|
||||
"""
|
||||
prompt = f"""请分析以下文档内容,提取关键信息。
|
||||
|
||||
{'文档标题:' + title if title else ''}
|
||||
|
||||
文档内容(前3000字):
|
||||
{content[:3000]}
|
||||
|
||||
请以JSON格式返回以下信息:
|
||||
{{
|
||||
"summary": "文档摘要(100-200字)",
|
||||
"keywords": ["关键词1", "关键词2", ...最多10个],
|
||||
"topics": ["主题1", "主题2", ...最多5个],
|
||||
"category": "文档分类",
|
||||
"entities": {{
|
||||
"persons": ["人名"],
|
||||
"organizations": ["组织名"],
|
||||
"locations": ["地点"],
|
||||
"dates": ["日期"],
|
||||
"others": ["其他实体"]
|
||||
}}
|
||||
}}
|
||||
|
||||
只返回JSON,不要其他内容。"""
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=1000,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
result = response.choices[0].message.content.strip()
|
||||
# 清理可能的markdown标记
|
||||
result = re.sub(r'^```json\s*', '', result)
|
||||
result = re.sub(r'\s*```$', '', result)
|
||||
|
||||
return json.loads(result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"LLM分析失败: {e}")
|
||||
return {
|
||||
"summary": "",
|
||||
"keywords": [],
|
||||
"topics": [],
|
||||
"category": "",
|
||||
"entities": {}
|
||||
}
|
||||
|
||||
def analyze_chunk(self, content):
|
||||
"""
|
||||
分析文档块,提取关键词
|
||||
|
||||
Returns:
|
||||
dict: {summary, keywords, topics}
|
||||
"""
|
||||
prompt = f"""分析以下文本片段,提取关键信息。
|
||||
|
||||
文本:
|
||||
{content[:1500]}
|
||||
|
||||
请以JSON格式返回:
|
||||
{{
|
||||
"summary": "片段摘要(50字以内)",
|
||||
"keywords": ["关键词", ...最多8个],
|
||||
"topics": ["主题", ...最多3个]
|
||||
}}
|
||||
|
||||
只返回JSON。"""
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=500,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
result = response.choices[0].message.content.strip()
|
||||
result = re.sub(r'^```json\s*', '', result)
|
||||
result = re.sub(r'\s*```$', '', result)
|
||||
|
||||
return json.loads(result)
|
||||
|
||||
except Exception as e:
|
||||
return {"summary": "", "keywords": [], "topics": []}
|
||||
|
||||
def process_query(self, query):
|
||||
"""
|
||||
处理查询,提取意图和关键词
|
||||
|
||||
Returns:
|
||||
dict: {intent, keywords, expanded_terms, entities}
|
||||
"""
|
||||
prompt = f"""分析以下查询,提取搜索意图和关键词。
|
||||
|
||||
查询:{query}
|
||||
|
||||
请以JSON格式返回:
|
||||
{{
|
||||
"intent": "查询意图(如:查找信息、比较、解释、列表等)",
|
||||
"keywords": ["主要关键词", ...最多10个],
|
||||
"expanded_terms": ["同义词或相关词", ...最多5个],
|
||||
"entities": {{
|
||||
"persons": [],
|
||||
"organizations": [],
|
||||
"locations": [],
|
||||
"dates": [],
|
||||
"others": []
|
||||
}}
|
||||
}}
|
||||
|
||||
只返回JSON。"""
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=500,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
result = response.choices[0].message.content.strip()
|
||||
result = re.sub(r'^```json\s*', '', result)
|
||||
result = re.sub(r'\s*```$', '', result)
|
||||
|
||||
return json.loads(result)
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"intent": "search",
|
||||
"keywords": query.split(),
|
||||
"expanded_terms": [],
|
||||
"entities": {}
|
||||
}
|
||||
|
||||
|
||||
class DocumentIndexer:
|
||||
"""文档索引器"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm = LLMService()
|
||||
self.chunk_size = DOC_CONFIG['chunk_size']
|
||||
self.chunk_overlap = DOC_CONFIG['chunk_overlap']
|
||||
|
||||
def index_document(self, doc_id):
|
||||
"""
|
||||
索引单个文档
|
||||
|
||||
Args:
|
||||
doc_id: 文档ID
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
doc = Document.query.get(doc_id)
|
||||
if not doc:
|
||||
return False
|
||||
|
||||
try:
|
||||
doc.status = 'processing'
|
||||
db.session.commit()
|
||||
|
||||
# 读取文档内容
|
||||
content = self._read_document(doc.filepath)
|
||||
if not content:
|
||||
raise Exception("无法读取文档内容")
|
||||
|
||||
# 存储原文
|
||||
doc.content = content
|
||||
doc.word_count = len(content)
|
||||
|
||||
# 使用LLM分析整个文档
|
||||
print(f" 正在分析文档: {doc.filename}")
|
||||
analysis = self.llm.analyze_document(content, doc.title)
|
||||
|
||||
doc.summary = analysis.get('summary', '')
|
||||
doc.set_keywords(analysis.get('keywords', []))
|
||||
doc.set_topics(analysis.get('topics', []))
|
||||
doc.category = analysis.get('category', '')
|
||||
doc.set_entities(analysis.get('entities', {}))
|
||||
|
||||
# 分块处理
|
||||
chunks = self._split_content(content)
|
||||
doc.chunk_count = len(chunks)
|
||||
|
||||
# 清理旧分块
|
||||
DocumentChunk.query.filter_by(document_id=doc.id).delete()
|
||||
|
||||
# 索引每个分块
|
||||
for i, chunk_content in enumerate(chunks):
|
||||
chunk = DocumentChunk(
|
||||
document_id=doc.id,
|
||||
chunk_index=i,
|
||||
content=chunk_content,
|
||||
start_char=0,
|
||||
end_char=len(chunk_content)
|
||||
)
|
||||
|
||||
# LLM分析分块
|
||||
chunk_analysis = self.llm.analyze_chunk(chunk_content)
|
||||
chunk.summary = chunk_analysis.get('summary', '')
|
||||
chunk.set_keywords(chunk_analysis.get('keywords', []))
|
||||
|
||||
# 计算词频
|
||||
term_freq = self._compute_term_freq(chunk_content)
|
||||
chunk.set_term_freq(term_freq)
|
||||
|
||||
db.session.add(chunk)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# 更新倒排索引
|
||||
self._update_inverted_index(doc.id)
|
||||
|
||||
# 标记完成
|
||||
doc.status = 'indexed'
|
||||
doc.indexed_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
# 更新统计
|
||||
IndexStats.get_stats().update_stats()
|
||||
|
||||
print(f" ✓ 文档索引完成: {doc.filename}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
doc.status = 'failed'
|
||||
doc.error_message = str(e)
|
||||
db.session.commit()
|
||||
print(f" ✗ 索引失败: {e}")
|
||||
return False
|
||||
|
||||
def _read_document(self, filepath):
|
||||
"""读取文档内容"""
|
||||
ext = os.path.splitext(filepath)[1].lower()
|
||||
|
||||
if ext in ['.txt', '.md', '.json', '.html']:
|
||||
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
return f.read()
|
||||
|
||||
elif ext == '.pdf':
|
||||
try:
|
||||
from pypdf import PdfReader
|
||||
reader = PdfReader(filepath)
|
||||
text = ''
|
||||
for page in reader.pages:
|
||||
text += page.extract_text() + '\n'
|
||||
return text
|
||||
except:
|
||||
pass
|
||||
|
||||
elif ext == '.docx':
|
||||
try:
|
||||
from docx import Document as DocxDocument
|
||||
doc = DocxDocument(filepath)
|
||||
return '\n'.join([p.text for p in doc.paragraphs])
|
||||
except:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _split_content(self, content):
|
||||
"""
|
||||
分割内容为块
|
||||
|
||||
Args:
|
||||
content: 文档内容
|
||||
|
||||
Returns:
|
||||
list: 内容块列表
|
||||
"""
|
||||
chunks = []
|
||||
|
||||
# 按段落分割
|
||||
paragraphs = content.split('\n\n')
|
||||
|
||||
current_chunk = ""
|
||||
for para in paragraphs:
|
||||
if len(current_chunk) + len(para) < self.chunk_size:
|
||||
current_chunk += para + '\n\n'
|
||||
else:
|
||||
if current_chunk.strip():
|
||||
chunks.append(current_chunk.strip())
|
||||
current_chunk = para + '\n\n'
|
||||
|
||||
if current_chunk.strip():
|
||||
chunks.append(current_chunk.strip())
|
||||
|
||||
return chunks if chunks else [content[:self.chunk_size]]
|
||||
|
||||
def _compute_term_freq(self, content):
|
||||
"""计算词频"""
|
||||
# 简单分词(中英文混合)
|
||||
# 中文按字符,英文按空格
|
||||
terms = []
|
||||
|
||||
# 提取中文词汇(简单按字,实际可用jieba)
|
||||
chinese = re.findall(r'[\u4e00-\u9fff]+', content)
|
||||
for text in chinese:
|
||||
# 简单的双字词分割
|
||||
if len(text) >= 2:
|
||||
for i in range(len(text) - 1):
|
||||
terms.append(text[i:i+2])
|
||||
terms.extend(list(text))
|
||||
|
||||
# 提取英文单词
|
||||
english = re.findall(r'[a-zA-Z]+', content.lower())
|
||||
terms.extend(english)
|
||||
|
||||
# 统计词频
|
||||
return dict(Counter(terms))
|
||||
|
||||
def _update_inverted_index(self, doc_id):
|
||||
"""更新倒排索引"""
|
||||
chunks = DocumentChunk.query.filter_by(document_id=doc_id).all()
|
||||
|
||||
# 收集所有词及其位置
|
||||
term_postings = {}
|
||||
|
||||
for chunk in chunks:
|
||||
# 从词频获取词
|
||||
tf = chunk.get_term_freq()
|
||||
|
||||
# 从关键词获取词
|
||||
keywords = chunk.get_keywords()
|
||||
|
||||
for kw in keywords:
|
||||
if kw not in term_postings:
|
||||
term_postings[kw] = []
|
||||
term_postings[kw].append({
|
||||
'doc_id': doc_id,
|
||||
'chunk_id': chunk.id,
|
||||
'tf': tf.get(kw, 1),
|
||||
'weight': INDEX_CONFIG['keyword_weight']
|
||||
})
|
||||
|
||||
for term, freq in tf.items():
|
||||
if term not in term_postings:
|
||||
term_postings[term] = []
|
||||
term_postings[term].append({
|
||||
'doc_id': doc_id,
|
||||
'chunk_id': chunk.id,
|
||||
'tf': freq,
|
||||
'weight': INDEX_CONFIG['content_weight']
|
||||
})
|
||||
|
||||
# 更新数据库
|
||||
for term, postings in term_postings.items():
|
||||
index = InvertedIndex.get_or_create(term)
|
||||
existing = index.get_postings()
|
||||
|
||||
# 合并postings(去除旧的同一文档的记录)
|
||||
existing = [p for p in existing if p['doc_id'] != doc_id]
|
||||
existing.extend(postings)
|
||||
|
||||
index.set_postings(existing)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
|
||||
class SearchEngine:
|
||||
"""搜索引擎"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm = LLMService()
|
||||
self.k1 = INDEX_CONFIG['bm25_k1']
|
||||
self.b = INDEX_CONFIG['bm25_b']
|
||||
|
||||
def search(self, query, top_k=10):
|
||||
"""
|
||||
搜索文档
|
||||
|
||||
Args:
|
||||
query: 查询字符串
|
||||
top_k: 返回结果数
|
||||
|
||||
Returns:
|
||||
list: 搜索结果 [{doc, score, highlights}]
|
||||
"""
|
||||
start_time = datetime.now()
|
||||
|
||||
# 1. LLM处理查询
|
||||
print(f"处理查询: {query}")
|
||||
query_analysis = self.llm.process_query(query)
|
||||
|
||||
keywords = query_analysis.get('keywords', [])
|
||||
expanded = query_analysis.get('expanded_terms', [])
|
||||
|
||||
# 合并关键词
|
||||
all_terms = keywords + expanded
|
||||
|
||||
print(f" 关键词: {keywords}")
|
||||
print(f" 扩展词: {expanded}")
|
||||
|
||||
# 2. 检索
|
||||
results = self._retrieve(all_terms)
|
||||
|
||||
# 3. 计算BM25分数
|
||||
scored_results = self._score_results(results, all_terms)
|
||||
|
||||
# 4. 排序
|
||||
scored_results.sort(key=lambda x: x['score'], reverse=True)
|
||||
|
||||
# 5. 返回top_k
|
||||
final_results = scored_results[:top_k]
|
||||
|
||||
retrieval_time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
# 6. 记录日志
|
||||
self._log_query(query, query_analysis, final_results, retrieval_time)
|
||||
|
||||
return final_results
|
||||
|
||||
def _retrieve(self, terms):
|
||||
"""
|
||||
检索包含关键词的文档
|
||||
|
||||
Args:
|
||||
terms: 关键词列表
|
||||
|
||||
Returns:
|
||||
dict: {doc_id: {chunks: [], terms: []}}
|
||||
"""
|
||||
results = {}
|
||||
|
||||
for term in terms:
|
||||
# 查询倒排索引
|
||||
index = InvertedIndex.query.filter(
|
||||
InvertedIndex.term.ilike(f'%{term}%')
|
||||
).all()
|
||||
|
||||
for idx in index:
|
||||
postings = idx.get_postings()
|
||||
|
||||
for p in postings:
|
||||
doc_id = p['doc_id']
|
||||
|
||||
if doc_id not in results:
|
||||
results[doc_id] = {
|
||||
'chunks': set(),
|
||||
'terms': {},
|
||||
'postings': []
|
||||
}
|
||||
|
||||
results[doc_id]['chunks'].add(p['chunk_id'])
|
||||
results[doc_id]['terms'][term] = results[doc_id]['terms'].get(term, 0) + p.get('tf', 1)
|
||||
results[doc_id]['postings'].append({
|
||||
'term': term,
|
||||
'chunk_id': p['chunk_id'],
|
||||
'tf': p.get('tf', 1),
|
||||
'weight': p.get('weight', 1.0)
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
def _score_results(self, results, query_terms):
|
||||
"""
|
||||
使用BM25计算分数
|
||||
|
||||
Args:
|
||||
results: 检索结果
|
||||
query_terms: 查询词
|
||||
|
||||
Returns:
|
||||
list: [{doc, score, chunks}]
|
||||
"""
|
||||
scored = []
|
||||
|
||||
# 计算平均文档长度
|
||||
total_docs = Document.query.filter_by(status='indexed').count()
|
||||
if total_docs == 0:
|
||||
return []
|
||||
|
||||
avg_doc_len = db.session.query(
|
||||
db.func.avg(Document.word_count)
|
||||
).filter(Document.status == 'indexed').scalar() or 1000
|
||||
|
||||
for doc_id, data in results.items():
|
||||
doc = Document.query.get(doc_id)
|
||||
if not doc or doc.status != 'indexed':
|
||||
continue
|
||||
|
||||
# BM25计算
|
||||
score = 0
|
||||
doc_len = doc.word_count or 1000
|
||||
|
||||
for term in query_terms:
|
||||
# 查询倒排索引获取文档频率
|
||||
index = InvertedIndex.query.filter_by(term=term).first()
|
||||
df = index.doc_freq if index else 1
|
||||
|
||||
# IDF
|
||||
idf = math.log((total_docs - df + 0.5) / (df + 0.5) + 1)
|
||||
|
||||
# TF
|
||||
tf = data['terms'].get(term, 0)
|
||||
|
||||
# BM25公式
|
||||
tf_component = (tf * (self.k1 + 1)) / (
|
||||
tf + self.k1 * (1 - self.b + self.b * doc_len / avg_doc_len)
|
||||
)
|
||||
|
||||
score += idf * tf_component
|
||||
|
||||
# 获取匹配的chunk内容
|
||||
chunk_ids = list(data['chunks'])[:3] # 最多取3个chunk
|
||||
chunks = DocumentChunk.query.filter(DocumentChunk.id.in_(chunk_ids)).all()
|
||||
|
||||
scored.append({
|
||||
'doc': doc.to_dict(),
|
||||
'score': score,
|
||||
'matched_chunks': [c.to_dict() for c in chunks],
|
||||
'matched_terms': list(data['terms'].keys())
|
||||
})
|
||||
|
||||
return scored
|
||||
|
||||
def _log_query(self, query, analysis, results, retrieval_time):
|
||||
"""记录查询日志"""
|
||||
log = QueryLog(
|
||||
original_query=query,
|
||||
processed_query=' '.join(analysis.get('keywords', [])),
|
||||
expanded_terms=json.dumps(analysis.get('expanded_terms', [])),
|
||||
intent=analysis.get('intent'),
|
||||
entities=json.dumps(analysis.get('entities', {})),
|
||||
result_count=len(results),
|
||||
top_doc_ids=json.dumps([r['doc']['id'] for r in results[:5]]),
|
||||
retrieval_time=retrieval_time,
|
||||
total_time=retrieval_time
|
||||
)
|
||||
db.session.add(log)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
class RAGGenerator:
|
||||
"""RAG生成器"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm = LLMService()
|
||||
self.search_engine = SearchEngine()
|
||||
|
||||
def answer(self, query, top_k=5):
|
||||
"""
|
||||
RAG回答
|
||||
|
||||
Args:
|
||||
query: 用户查询
|
||||
top_k: 检索文档数
|
||||
|
||||
Returns:
|
||||
dict: {answer, sources, confidence}
|
||||
"""
|
||||
# 1. 检索相关文档
|
||||
results = self.search_engine.search(query, top_k)
|
||||
|
||||
if not results:
|
||||
return {
|
||||
'answer': '抱歉,没有找到相关信息。',
|
||||
'sources': [],
|
||||
'confidence': 0
|
||||
}
|
||||
|
||||
# 2. 构建上下文
|
||||
context_parts = []
|
||||
sources = []
|
||||
|
||||
for i, r in enumerate(results[:3]): # 最多使用3个文档
|
||||
doc = r['doc']
|
||||
chunks = r['matched_chunks']
|
||||
|
||||
context_parts.append(f"【文档{i+1}】{doc.get('title', doc['filename'])}")
|
||||
|
||||
for chunk in chunks[:2]: # 每个文档最多2个chunk
|
||||
context_parts.append(chunk.get('content', '')[:500])
|
||||
|
||||
sources.append({
|
||||
'id': doc['id'],
|
||||
'title': doc.get('title', doc['filename']),
|
||||
'score': r['score']
|
||||
})
|
||||
|
||||
context = '\n\n'.join(context_parts)
|
||||
|
||||
# 3. LLM生成回答
|
||||
prompt = f"""基于以下参考信息回答问题。如果参考信息中没有相关内容,请说明。
|
||||
|
||||
问题:{query}
|
||||
|
||||
参考信息:
|
||||
{context}
|
||||
|
||||
请给出准确、简洁的回答,并标注信息来源。"""
|
||||
|
||||
try:
|
||||
client = OpenAI(
|
||||
api_key=LLM_CONFIG['api_key'],
|
||||
base_url=LLM_CONFIG['api_base'],
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=LLM_CONFIG['model'],
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=1000,
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
answer = response.choices[0].message.content
|
||||
|
||||
except Exception as e:
|
||||
answer = f"生成回答时出错: {e}"
|
||||
|
||||
# 4. 返回结果
|
||||
return {
|
||||
'answer': answer,
|
||||
'sources': sources,
|
||||
'confidence': min(1.0, results[0]['score'] / 10) if results else 0
|
||||
}
|
||||
Reference in New Issue
Block a user