|
@@ -0,0 +1,277 @@
|
|
|
+# model_rag.py
|
|
|
+
|
|
|
+import os
|
|
|
+from abc import ABC, abstractmethod
|
|
|
+from typing import List, Dict, Any, Optional
|
|
|
+from dataclasses import dataclass
|
|
|
+import numpy as np
|
|
|
+from PIL import Image
|
|
|
+import docx
|
|
|
+from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
+from sentence_transformers import SentenceTransformer
|
|
|
+from safetensors import safe_open
|
|
|
+import torch
|
|
|
+from torch import nn
|
|
|
+import faiss
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class DocumentChunk:
|
|
|
+ """文档片段数据类"""
|
|
|
+ content: str
|
|
|
+ metadata: Dict[str, Any]
|
|
|
+ embedding: Optional[np.ndarray] = None
|
|
|
+
|
|
|
+class DocumentProcessor(ABC):
|
|
|
+ """文档处理器抽象基类"""
|
|
|
+
|
|
|
+ @abstractmethod
|
|
|
+ def extract_text(self, file_path: str) -> List[DocumentChunk]:
|
|
|
+ pass
|
|
|
+
|
|
|
+ @abstractmethod
|
|
|
+ def extract_images(self, file_path: str) -> List[Image.Image]:
|
|
|
+ pass
|
|
|
+
|
|
|
+class WordDocumentProcessor(DocumentProcessor):
|
|
|
+ """Word 文档处理器"""
|
|
|
+
|
|
|
+ def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50):
|
|
|
+ self.chunk_size = chunk_size
|
|
|
+ self.chunk_overlap = chunk_overlap
|
|
|
+
|
|
|
+ def extract_text(self, file_path: str) -> List[DocumentChunk]:
|
|
|
+ """从Word文档提取文本并分块"""
|
|
|
+ doc = docx.Document(file_path)
|
|
|
+ full_text = []
|
|
|
+
|
|
|
+ # 提取所有段落文本
|
|
|
+ for para in doc.paragraphs:
|
|
|
+ if para.text.strip():
|
|
|
+ full_text.append(para.text)
|
|
|
+
|
|
|
+ # 合并为完整文本
|
|
|
+ full_text = "\n".join(full_text)
|
|
|
+
|
|
|
+ # 文本分块
|
|
|
+ chunks = self._split_text(full_text)
|
|
|
+ return [DocumentChunk(content=chunk, metadata={"source": file_path}) for chunk in chunks]
|
|
|
+
|
|
|
+ def extract_images(self, file_path: str) -> List[Image.Image]:
|
|
|
+ """从Word文档提取图片(简化实现,实际需要更复杂处理)"""
|
|
|
+ # 注意:python-docx 不直接支持图片提取,这里简化处理
|
|
|
+ # 实际项目中可能需要使用其他库如 docx2python 或直接解析zip
|
|
|
+ return []
|
|
|
+
|
|
|
+ def _split_text(self, text: str) -> List[str]:
|
|
|
+ """简单的文本分割实现"""
|
|
|
+ words = text.split()
|
|
|
+ chunks = []
|
|
|
+ current_chunk = []
|
|
|
+ current_length = 0
|
|
|
+
|
|
|
+ for word in words:
|
|
|
+ if current_length + len(word) + 1 > self.chunk_size and current_chunk:
|
|
|
+ chunks.append(" ".join(current_chunk))
|
|
|
+ current_chunk = current_chunk[-self.chunk_overlap:]
|
|
|
+ current_length = sum(len(w) + 1 for w in current_chunk)
|
|
|
+
|
|
|
+ current_chunk.append(word)
|
|
|
+ current_length += len(word) + 1
|
|
|
+
|
|
|
+ if current_chunk:
|
|
|
+ chunks.append(" ".join(current_chunk))
|
|
|
+
|
|
|
+ return chunks
|
|
|
+
|
|
|
+class EmbeddingModel:
|
|
|
+ """文本嵌入模型封装"""
|
|
|
+
|
|
|
+ def __init__(self, model_name: str = f"../models/AI-ModelScope/all-mpnet-base-v2"):
|
|
|
+ # 使用更小且多语言的模型
|
|
|
+ try:
|
|
|
+ self.model = SentenceTransformer(model_name)
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error loading model: {e}")
|
|
|
+ print("Trying smaller model...")
|
|
|
+ self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
|
|
+
|
|
|
+ def embed_text(self, texts: List[str]) -> np.ndarray:
|
|
|
+ """生成文本嵌入向量"""
|
|
|
+ return self.model.encode(texts, convert_to_numpy=True)
|
|
|
+
|
|
|
+class VectorStore:
|
|
|
+ """向量存储实现"""
|
|
|
+
|
|
|
+ def __init__(self, dimension: int = 768):
|
|
|
+ self.dimension = dimension
|
|
|
+ self.index = faiss.IndexFlatL2(dimension)
|
|
|
+ self.chunks = []
|
|
|
+
|
|
|
+ def add_chunks(self, chunks: List[DocumentChunk], embeddings: np.ndarray):
|
|
|
+ """添加文档片段及其嵌入向量"""
|
|
|
+ if len(self.chunks) == 0:
|
|
|
+ self.index.add(embeddings)
|
|
|
+ else:
|
|
|
+ # 确保维度匹配
|
|
|
+ assert embeddings.shape[1] == self.dimension
|
|
|
+ self.index.add(embeddings)
|
|
|
+
|
|
|
+ self.chunks.extend(chunks)
|
|
|
+
|
|
|
+ def search(self, query_embedding: np.ndarray, k: int = 5) -> List[DocumentChunk]:
|
|
|
+ """搜索最相似的k个文档片段"""
|
|
|
+ distances, indices = self.index.search(query_embedding, k)
|
|
|
+ return [self.chunks[i] for i in indices[0]]
|
|
|
+
|
|
|
+class Retriever:
|
|
|
+ """检索器"""
|
|
|
+
|
|
|
+ def __init__(self, vector_store: VectorStore, embedding_model: EmbeddingModel, max_chunks: int = 3, max_chunk_length: int = 1000):
|
|
|
+ self.vector_store = vector_store
|
|
|
+ self.embedding_model = embedding_model
|
|
|
+ self.max_chunks = max_chunks
|
|
|
+ self.max_chunk_length = max_chunk_length
|
|
|
+
|
|
|
+ def retrieve(self, query: str) -> List[DocumentChunk]:
|
|
|
+ """检索相关文档片段"""
|
|
|
+ query_embedding = self.embedding_model.embed_text([query])
|
|
|
+ chunks = self.vector_store.search(query_embedding, k=self.max_chunks)
|
|
|
+
|
|
|
+ # 限制每个片段的长度
|
|
|
+ result = []
|
|
|
+ for chunk in chunks:
|
|
|
+ if len(chunk.content) > self.max_chunk_length:
|
|
|
+ chunk.content = chunk.content[:self.max_chunk_length] + "..."
|
|
|
+ result.append(chunk)
|
|
|
+
|
|
|
+ return result
|
|
|
+
|
|
|
+class Generator:
|
|
|
+ """回答生成器"""
|
|
|
+
|
|
|
+ def __init__(self, model_path: str):
|
|
|
+ # 加载本地模型
|
|
|
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
|
+ self.model = AutoModelForCausalLM.from_pretrained(
|
|
|
+ model_path,
|
|
|
+ torch_dtype=torch.float16,
|
|
|
+ device_map="auto"
|
|
|
+ )
|
|
|
+
|
|
|
+ def generate(self, prompt: str, max_new_tokens: int = 200) -> str:
|
|
|
+ """生成回答"""
|
|
|
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
|
|
+
|
|
|
+ # 计算输入长度
|
|
|
+ input_length = inputs.input_ids.shape[1]
|
|
|
+
|
|
|
+ # 确保总长度不超过模型限制
|
|
|
+ max_possible_tokens = self.model.config.max_position_embeddings
|
|
|
+ max_new_tokens = min(max_new_tokens, max_possible_tokens - input_length)
|
|
|
+
|
|
|
+ if max_new_tokens <= 0:
|
|
|
+ return "输入过长,无法生成回答。请缩短您的问题或上下文。"
|
|
|
+
|
|
|
+ outputs = self.model.generate(
|
|
|
+ **inputs,
|
|
|
+ max_new_tokens=max_new_tokens,
|
|
|
+ temperature=0.7,
|
|
|
+ top_p=0.9,
|
|
|
+ do_sample=True
|
|
|
+ )
|
|
|
+ return self.tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
|
|
|
+
|
|
|
+class RAGPipeline:
|
|
|
+ """RAG管道"""
|
|
|
+
|
|
|
+ def __init__(self, model_path: str, max_context_length: int = 4000):
|
|
|
+ self.document_processor = WordDocumentProcessor()
|
|
|
+ try:
|
|
|
+ # 尝试加载嵌入模型
|
|
|
+ self.embedding_model = EmbeddingModel()
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Failed to load embedding model: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ self.vector_store = VectorStore()
|
|
|
+ self.retriever = Retriever(self.vector_store, self.embedding_model)
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 加载本地生成模型
|
|
|
+ self.generator = Generator(model_path)
|
|
|
+ self.max_context_length = max_context_length # 设置最大上下文长度
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Failed to load generator model: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ def index_documents(self, document_paths: List[str]):
|
|
|
+ """索引文档"""
|
|
|
+ all_chunks = []
|
|
|
+
|
|
|
+ for path in document_paths:
|
|
|
+ if path.endswith('.docx'):
|
|
|
+ chunks = self.document_processor.extract_text(path)
|
|
|
+ all_chunks.extend(chunks)
|
|
|
+
|
|
|
+ # 生成嵌入向量
|
|
|
+ texts = [chunk.content for chunk in all_chunks]
|
|
|
+ embeddings = self.embedding_model.embed_text(texts)
|
|
|
+
|
|
|
+ # 添加到向量存储
|
|
|
+ for chunk, embedding in zip(all_chunks, embeddings):
|
|
|
+ chunk.embedding = embedding
|
|
|
+
|
|
|
+ self.vector_store.add_chunks(all_chunks, embeddings)
|
|
|
+
|
|
|
+ def query(self, question: str, max_new_tokens: int = 512) -> str:
|
|
|
+ """查询知识库并生成回答"""
|
|
|
+ # 检索相关文档片段
|
|
|
+ relevant_chunks = self.retriever.retrieve(question)
|
|
|
+
|
|
|
+ # 构建上下文,确保不超过最大长度
|
|
|
+ context_parts = []
|
|
|
+ current_length = 0
|
|
|
+
|
|
|
+ for chunk in relevant_chunks:
|
|
|
+ chunk_length = len(self.generator.tokenizer.tokenize(chunk.content))
|
|
|
+ if current_length + chunk_length > self.max_context_length:
|
|
|
+ break
|
|
|
+ context_parts.append(chunk.content)
|
|
|
+ current_length += chunk_length
|
|
|
+
|
|
|
+ context = "\n\n".join(context_parts)
|
|
|
+
|
|
|
+ prompt = f"""基于以下上下文回答问题。如果你不知道答案,就说你不知道。
|
|
|
+
|
|
|
+上下文:
|
|
|
+{context}
|
|
|
+
|
|
|
+问题: {question}
|
|
|
+答案:"""
|
|
|
+
|
|
|
+ # 生成回答
|
|
|
+ return self.generator.generate(prompt, max_new_tokens=max_new_tokens)
|
|
|
+
|
|
|
+# 使用示例
|
|
|
+if __name__ == "__main__":
|
|
|
+ # 配置模型路径和文档路径
|
|
|
+ MODEL_PATH = f"../models/trained/DeepSeek-R1-Distill-Qwen-1.5B-GRPO"
|
|
|
+ DOCUMENT_PATHS = [f"../data/knowledgebase/岭门风电场机组功率曲线治理与发电性能改善服务—数据分析报告-20240711V.docx"] # 替换为实际文档路径
|
|
|
+
|
|
|
+ # 初始化RAG系统
|
|
|
+ rag = RAGPipeline(MODEL_PATH, max_context_length=3000)
|
|
|
+
|
|
|
+ # 索引文档
|
|
|
+ print("Indexing documents...")
|
|
|
+ rag.index_documents(DOCUMENT_PATHS)
|
|
|
+ print("Documents indexed successfully.")
|
|
|
+
|
|
|
+ # 交互式问答
|
|
|
+ print("RAG system ready. Type 'exit' to quit.")
|
|
|
+ while True:
|
|
|
+ question = input("\nQuestion: ")
|
|
|
+ if question.lower() == 'exit':
|
|
|
+ break
|
|
|
+
|
|
|
+ answer = rag.query(question, max_new_tokens=512)
|
|
|
+ print("\nAnswer:", answer)
|