Переглянути джерело

修改运行模型推理代码-增加输出速度和耗时统计

zhouyang.xie 3 місяців тому
батько
коміт
4ffa330fe2

+ 298 - 69
README.MD

@@ -3,6 +3,191 @@
 <div align=center><img src="./resource/人工智能知识库-电力行业风力发电领域.png"></div>
 
 #  大语言模型介绍
+## 基础概念
+
+### 蒸馏
+在机器学习的语境中,**蒸馏(Knowledge Distillation)** 是一种模型压缩技术,其核心思想是将一个复杂的大模型(称为**教师模型**)中的“知识”迁移到一个更轻量的小模型(称为**学生模型**)中,使得小模型能以更低的计算成本实现接近甚至超越原大模型的性能。
+
+---
+
+#### **蒸馏的核心含义**
+1. **知识迁移**  
+   教师模型的“知识”不仅指其参数,更包括:
+   - **输出层知识**:模型对输入数据的预测概率分布(软标签,Soft Labels)。
+   - **中间层知识**:隐藏层的特征表示(如注意力权重、激活值)。
+   - **推理逻辑**:复杂任务的解题步骤或生成策略(如代码生成、数学推导)。
+
+2. **训练方式**  
+   学生模型通过以下方式学习教师模型的知识:
+   - **模仿输出**:在训练时,学生模型不仅学习真实标签(硬标签),还学习教师模型输出的概率分布(软标签),后者包含更丰富的类别间关系信息。  
+     *例如:教师模型判断“猫”的概率为0.8,“狗”为0.15,“狐狸”为0.05,学生模型会学习这种更细粒度的概率分布。*
+   - **特征对齐**:通过匹配教师模型和学生模型的中间层特征(如隐藏状态),强制学生模型学习相似的内部表征。
+
+3. **目的**  
+   - **压缩模型**:将大模型(如千亿参数)压缩为小模型(如百亿参数),便于部署到资源受限的环境(手机、嵌入式设备等)。
+   - **提升效率**:减少推理时的计算量和内存占用,同时保持高性能。
+   - **知识泛化**:通过教师模型的指导,学生模型可能泛化到未见过的新任务或数据。
+
+---
+
+#### **蒸馏 vs 传统训练**
+| **对比维度**       | **传统训练**               | **知识蒸馏**                     |
+|--------------------|---------------------------|---------------------------------|
+| **训练目标**       | 直接拟合真实标签(硬标签) | 同时拟合真实标签 + 教师模型的软标签 |
+| **知识来源**       | 仅训练数据                | 训练数据 + 教师模型的隐性知识      |
+| **模型复杂度**     | 学生模型独立训练           | 学生模型受教师模型指导            |
+| **效果**           | 依赖数据量和模型容量       | 小模型可逼近大模型性能            |
+
+---
+
+#### **蒸馏的典型应用场景**
+1. **模型轻量化**:将 GPT-4、Qwen-72B 等超大模型压缩为适合手机或边缘设备的版本(如 DeepSeek-R1)。
+2. **加速推理**:减少服务端模型的响应延迟和计算成本。
+3. **隐私保护**:用蒸馏模型替代原始大模型,避免敏感数据直接输入大模型。
+4. **领域适配**:通过教师模型的领域知识(如医疗、法律),快速训练专用小模型。
+
+---
+
+#### **DeepSeek-R1 中的蒸馏意义**
+在 DeepSeek-R1-Distill-Qwen-32B 中,蒸馏技术可能用于:
+- 将 Qwen 系列大模型的复杂能力(如代码生成、多轮对话)迁移到更小的模型架构中。
+- 通过软标签和中间层对齐,保留原模型的逻辑推理和泛化能力。
+- 实现低资源环境下的高效部署(如企业级私有化部署)。
+
+简而言之,**蒸馏是通过“师生传承”让轻量模型继承重量级模型智慧的技术**,是平衡性能与效率的关键手段。
+
+**示例:**
+
+---
+
+DeepSeek-R1-Distill-Qwen-32B 是基于知识蒸馏技术对原模型(如 Qwen-32B 或其他大规模模型)进行优化的产物。虽然具体细节需以官方文档为准,但根据知识蒸馏的常见方法和模型优化目标,可以推测其蒸馏内容可能包括以下方面:
+
+---
+
+#### 1. **模型架构精简**
+   - **参数压缩**:通过减少模型层数、隐藏层维度或注意力头数,降低参数量(如从千亿级压缩到百亿级)。
+   - **结构简化**:移除冗余模块或替换为更高效的组件(如简化注意力机制)。
+
+---
+
+#### 2. **知识迁移策略**
+   - **输出层蒸馏(Logits Distillation)**:对齐学生模型与教师模型的输出概率分布(软标签),保留语义理解和生成能力。
+   - **中间层特征匹配**:通过匹配教师模型和学生模型的中间层特征(如注意力权重、隐藏状态),提升学生模型的表征能力。
+   - **多任务蒸馏**:在通用文本生成、代码理解、数学推理等多任务上迁移教师模型的能力。
+
+---
+
+#### 3. **训练数据优化**
+   - **合成数据增强**:利用教师模型生成高质量合成数据(如问答对、解题步骤),补充训练集。
+   - **数据筛选**:基于教师模型的置信度或复杂度评分,筛选高质量、多样化的训练样本。
+
+---
+
+#### 4. **推理效率提升**
+   - **计算加速**:通过降低计算精度(如 FP16 量化)、优化矩阵运算,减少推理延迟。
+   - **内存优化**:采用动态内存管理或缓存策略,降低显存占用。
+
+---
+
+#### 5. **领域知识保留**
+   - **垂直领域适配**:针对代码生成、数学推理、多语言理解等场景,保留教师模型的领域专精能力。
+   - **少样本学习**:通过蒸馏增强模型在少样本或零样本任务中的泛化性。
+
+---
+
+#### 6. **训练策略改进**
+   - **渐进式蒸馏**:分阶段迁移知识(如先通用能力后领域能力),避免性能损失。
+   - **对抗训练**:引入对抗样本提升鲁棒性,或结合强化学习优化生成质量。
+
+---
+
+#### 7. **模型对齐与安全**
+   - **价值观对齐**:通过蒸馏传递符合伦理的内容生成约束,减少有害输出。
+   - **安全护栏(Safety Guardrails)**:继承教师模型的安全过滤机制,增强可控性。
+
+---
+
+#### 典型应用场景(推测)
+- **端侧部署**:优化后的模型可能适用于边缘设备或低资源环境。
+- **低成本推理**:减少云计算依赖,适合大规模商业化应用。
+- **多任务服务**:保留原模型的多领域能力,支持问答、代码、数学等综合场景。
+
+---
+
+如需更准确的信息,建议参考深度求索(DeepSeek)的官方技术报告或开源文档,以获取蒸馏方法、实验对比等细节。
+
+
+## 1. **Timer**
+   - **含义**:Timer 是一种用于优化模型推理效率的技术,通常用于减少模型的计算时间或延迟。
+   - **核心思想**:
+     - 通过动态调整模型的计算资源(如跳过某些层或模块),在保证性能的同时加速推理。
+     - 适用于实时性要求高的场景(如对话系统、推荐系统)。
+   - **应用场景**:
+     - 减少大模型的推理延迟。
+     - 在边缘设备上部署轻量级模型。
+   - **示例**:
+     - 在 Transformer 模型中,根据输入复杂度动态跳过某些注意力头或层。
+
+---
+
+## 2. **COT(Chain-of-Thought)**
+   - **含义**:COT 是一种提示(Prompting)技术,通过引导模型生成中间推理步骤,提升复杂任务(如数学推理、逻辑推理)的性能。
+   - **核心思想**:
+     - 让模型像人类一样“逐步思考”,生成中间推理过程,而不是直接输出最终答案。
+     - 特别适合需要多步推理的任务。
+   - **应用场景**:
+     - 数学问题求解。
+     - 逻辑推理任务。
+     - 复杂问答系统。
+   - **示例**:
+     - 输入:“如果小明有5个苹果,吃了2个,又买了3个,他现在有多少个苹果?”
+     - 模型输出:“小明原来有5个苹果,吃了2个,剩下3个。又买了3个,所以现在有6个苹果。”
+
+---
+
+## 3. **RAG(Retrieval-Augmented Generation)**
+   - **含义**:RAG 是一种结合检索(Retrieval)和生成(Generation)的技术,通过从外部知识库中检索相关信息来增强生成模型的能力。
+   - **核心思想**:
+     - 在生成答案之前,先从大规模知识库(如维基百科)中检索相关文档或段落。
+     - 将检索到的信息与输入问题结合,生成更准确、可靠的答案。
+   - **应用场景**:
+     - 开放域问答(Open-Domain QA)。
+     - 知识密集型任务(如事实核查、文档生成)。
+   - **示例**:
+     - 输入:“谁发明了电话?”
+     - 模型检索到相关文档:“电话是由亚历山大·格拉汉姆·贝尔发明的。”
+     - 模型生成答案:“电话是由亚历山大·格拉汉姆·贝尔发明的。”
+
+---
+
+## 4. **Fine-tuning(微调)**
+   - **含义**:Fine-tuning 是一种迁移学习技术,通过在大规模预训练模型的基础上,使用特定任务的数据进行进一步训练,使模型适应特定任务。
+   - **核心思想**:
+     - 预训练模型(如 GPT、BERT)在大规模通用数据上学习通用语言表示。
+     - 微调阶段使用特定任务的数据(如情感分类、命名实体识别)调整模型参数,使其在特定任务上表现更好。
+   - **应用场景**:
+     - 领域适配(如医疗、法律)。
+     - 特定任务优化(如文本分类、机器翻译)。
+   - **示例**:
+     - 使用 BERT 模型在情感分析数据集(如 IMDb)上进行微调,用于电影评论的情感分类。
+
+**技术对比**
+| **技术**       | **目标**                          | **核心方法**                              | **典型应用场景**                  |
+|----------------|-----------------------------------|------------------------------------------|----------------------------------|
+| **Timer**      | 优化推理效率                      | 动态调整计算资源                          | 实时对话、边缘计算                |
+| **COT**        | 提升复杂任务推理能力              | 生成中间推理步骤                          | 数学推理、逻辑推理                |
+| **RAG**        | 增强生成模型的准确性              | 结合检索和生成                            | 开放域问答、知识密集型任务        |
+| **Fine-tuning**| 适应特定任务                      | 在预训练模型基础上进行任务特定训练          | 领域适配、文本分类                |
+
+---
+
+**结合使用示例**
+- **RAG + COT**:在开放域问答中,先检索相关知识,再通过 Chain-of-Thought 生成推理步骤和最终答案。
+- **Fine-tuning + Timer**:在特定领域(如医疗)微调模型后,使用 Timer 技术优化推理效率,便于实时应用。
+
+这些技术各有侧重,但可以结合使用,以构建更强大、高效的 NLP 系统。
+
+
 ## 架构
     MoE(Mixture of Experts)架构
     一、原理
@@ -104,93 +289,137 @@ trainer.train()
 
 ---
 
-### 2. **GRPO(Gradient Reversal-based Policy Optimization,基于梯度反转的策略优化)**
+### 2. **GRPO(Group Relative Policy Optimization,组相对策略优化)**
 
-#### **定义**
-GRPO 是一种用于强化学习(RL)或策略优化的技术。它通过反转梯度方向来优化策略,从而在复杂任务中实现更好的性能。
+关于 **DeepSeek 的 GRPO(Group Relative Policy Optimization)**,目前公开的技术细节有限(截至2023年10月,DeepSeek 尚未发布官方论文或完整代码)。但根据命名推测,GRPO 可能是一种**基于分组的相对策略优化方法**,结合了强化学习(RL)中的策略优化思想,并引入“分组”与“相对比较”机制以提升训练效率。以下是对其可能设计逻辑的推测及一个简化示例代码框架。
 
-#### **作用**
-- **策略优化**:在强化学习中优化策略模型。
-- **稳定性提升**:通过梯度反转,避免训练过程中的不稳定性。
+---
 
-#### **实现步骤**
-1. **定义策略模型**:
-   - 使用 DeepSeek-R1 作为策略模型。
-2. **定义奖励函数**:
-   - 根据任务设计奖励函数。
-3. **梯度反转**:
-   - 在反向传播时反转梯度方向,以优化策略。
-4. **训练模型**:
-   - 使用强化学习算法(如 PPO、A2C)训练模型。
+#### **GRPO 的核心概念(推测)**
+1. **分组(Grouping)**  
+   - 将智能体的策略划分为多个组(例如不同策略参数或行为模式),通过组间交互或竞争提升探索效率。
+   - 可能应用场景:多智能体协作、策略多样性增强。
+
+2. **相对策略优化(Relative Policy Optimization)**  
+   - 通过组间策略表现的**相对比较**(而非绝对奖励)调整策略参数,类似竞赛机制。
+   - 优势:减少对绝对奖励值的依赖,增强鲁棒性。
+
+3. **优化目标**  
+   - 最大化组间相对优势,同时控制策略更新的稳定性(类似 PPO 的剪切机制)。
+
+---
+
+#### **示例代码框架(Python + PyTorch)**
+以下是一个简化的 GRPO 实现示例,基于单智能体环境(如 CartPole)的假设性设计:
 
-#### **代码示例**
 ```python
 import torch
 import torch.nn as nn
 import torch.optim as optim
-
-# 定义策略模型
-class PolicyModel(nn.Module):
-    def __init__(self):
-        super(PolicyModel, self).__init__()
-        self.fc = nn.Linear(128, 2)  # 假设输入维度为 128,输出为 2 个动作
-
-    def forward(self, x):
-        return torch.softmax(self.fc(x), dim=-1)
-
-# 定义 GRPO 优化器
-class GRPOptimizer:
-    def __init__(self, model, lr=1e-3):
-        self.model = model
-        self.optimizer = optim.Adam(model.parameters(), lr=lr)
-
-    def step(self, loss):
-        # 反向传播
-        loss.backward()
-        
-        # 梯度反转
-        for param in self.model.parameters():
-            param.grad = -param.grad  # 反转梯度
-        
-        # 更新参数
-        self.optimizer.step()
-        self.optimizer.zero_grad()
-
-# 训练循环
-policy_model = PolicyModel()
-grp_optimizer = GRPOptimizer(policy_model)
-
-for epoch in range(10):
-    state = torch.randn(1, 128)  # 假设状态维度为 128
-    action_probs, _ = policy_model(state)
-    action = torch.argmax(action_probs).item()
+from torch.distributions import Categorical
+import gym
+
+# 定义策略网络
+class Policy(nn.Module):
+    def __init__(self, state_dim, action_dim):
+        super().__init__()
+        self.fc = nn.Sequential(
+            nn.Linear(state_dim, 64),
+            nn.ReLU(),
+            nn.Linear(64, action_dim)
+        )
+    
+    def forward(self, state):
+        return self.fc(state)
+
+# GRPO 算法核心
+class GRPO:
+    def __init__(self, state_dim, action_dim, n_groups=3, lr=1e-3, clip_epsilon=0.2):
+        self.n_groups = n_groups
+        self.policies = [Policy(state_dim, action_dim) for _ in range(n_groups)]
+        self.optimizers = [optim.Adam(policy.parameters(), lr=lr) for policy in self.policies]
+        self.clip_epsilon = clip_epsilon
+
+    def update(self, group_data):
+        # group_data: 每个组的轨迹数据 (states, actions, advantages)
+        for group_id in range(self.n_groups):
+            states, actions, advantages = group_data[group_id]
+            probs_old = Categorical(logits=self.policies[group_id](states)).log_prob(actions).detach()
+            
+            # 计算当前策略概率
+            probs_new = Categorical(logits=self.policies[group_id](states)).log_prob(actions)
+            
+            # 计算相对优势(假设其他组的平均优势为基线)
+            other_avg_advantage = torch.mean(torch.stack(
+                [torch.mean(adv) for adv in advantages if adv is not advantages]
+            ))
+            relative_advantages = advantages - other_avg_advantage
+            
+            # 策略损失(类似 PPO 的剪切目标)
+            ratio = torch.exp(probs_new - probs_old)
+            clipped_ratio = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon)
+            loss = -torch.min(ratio * relative_advantages, clipped_ratio * relative_advantages).mean()
+            
+            # 更新策略
+            self.optimizers[group_id].zero_grad()
+            loss.backward()
+            self.optimizers[group_id].step()
+
+# 训练循环(示例)
+env = gym.make('CartPole-v1')
+state_dim = env.observation_space.shape[0]
+action_dim = env.action_space.n
+grpo = GRPO(state_dim, action_dim, n_groups=3)
+
+for episode in range(1000):
+    group_data = {i: {"states": [], "actions": [], "advantages": []} for i in range(grpo.n_groups)}
     
-    # 假设奖励为 1(实际任务中需根据环境计算)
-    reward = 1
-    loss = -torch.log(action_probs[0, action]) * reward  # 策略梯度损失
+    # 并行收集各组轨迹数据
+    for group_id in range(grpo.n_groups):
+        state = env.reset()
+        done = False
+        while not done:
+            action_logits = grpo.policies[group_id](torch.FloatTensor(state))
+            action = Categorical(logits=action_logits).sample().item()
+            next_state, reward, done, _ = env.step(action)
+            
+            group_data[group_id]["states"].append(state)
+            group_data[group_id]["actions"].append(action)
+            group_data[group_id]["advantages"].append(reward)  # 简化优势计算
+            
+            state = next_state
     
-    # GRPO 优化
-    grp_optimizer.step(loss)
+    # 更新策略
+    grpo.update(group_data)
 ```
 
 ---
 
-### 3. **SFT 和 GRPO 的区别**
-| **特性**       | **SFT(监督微调)**                  | **GRPO(基于梯度反转的策略优化)** |
-|----------------|------------------------------------|------------------------------------|
-| **应用场景**   | 监督学习任务(如分类、生成)         | 强化学习任务(如策略优化)         |
-| **数据需求**   | 需要标注数据                        | 需要环境和奖励信号                 |
-| **优化目标**   | 最小化监督损失(如交叉熵)           | 最大化累积奖励                     |
-| **技术核心**   | 微调预训练模型                      | 梯度反转和策略优化                 |
+#### **关键设计解释**
+1. **分组策略**  
+   - 维护多个策略组(`self.policies`),每组独立与环境交互并收集数据。
+   - 通过组间优势的**相对比较**(如其他组的平均优势作为基线)调整更新幅度。
+
+2. **相对优势计算**  
+   - 在损失函数中,使用组间相对优势替代传统绝对优势值,鼓励策略间的竞争或协作。
+
+3. **PPO 剪切机制**  
+   - 保留 PPO 的剪切目标(`torch.clamp`),确保策略更新稳定。
+
+---
+
+#### **应用场景(假设)**
+- **多策略探索**:通过组间差异增强探索能力,避免局部最优。
+- **多智能体系统**:扩展为多智能体协作/竞争场景(如游戏 AI)。
+- **鲁棒性训练**:相对优势减少对奖励绝对值的敏感度。
 
 ---
 
-### 4. **总结**
-- **SFT** 是一种监督微调技术,用于将预训练模型适配到特定任务。
-- **GRPO** 是一种基于梯度反转的策略优化技术,常用于强化学习任务。
-- 在 DeepSeek-R1 中,SFT 和 GRPO 可以结合使用,以在不同任务中实现最佳性能。
+#### **注意事项**
+- 上述代码为**假设性实现**,真实 GRPO 可能涉及更复杂的组间交互机制(如知识共享、动态分组)。
+- 实际应用需结合具体任务调整优势计算、分组策略等模块。
 
-如果你有更多关于 DeepSeek-R1 或相关技术的问题,欢迎继续提问!
+如需准确实现,建议等待 DeepSeek 官方技术公开后参考其论文或代码库。
 
 
 #  大语言模型(LLM)训练与推理

+ 226 - 53
main.py

@@ -1,87 +1,260 @@
+# main.py
 import sys
 import os
+import json
 import logging
+import torch
+from typing import List, Tuple, Optional
 from src.config import Config
 from src.data_processor import DataProcessor
 from src.model_trainer import ModelTrainer
 from src.model_runner import ModelRunner
 
 # 设置日志
-logging.basicConfig(level=logging.INFO)
+logging.basicConfig(
+    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
+    level=logging.INFO
+)
 logger = logging.getLogger(__name__)
 
 class ChatInterface:
-    def __init__(self, model_runner):
+    def __init__(self, model_runner: ModelRunner):
         self.runner = model_runner
-        self.chat_history = []  # 存储对话历史
-    
+        self.chat_history: List[Tuple[str, str]] = []  # 对话历史
+        self.history_file = "chat_history.json"  # 历史记录保存路径
+
     def start_chat(self):
-        """启动对话交互"""
-        print("\n========== DeepSeek 对话系统 ==========")
-        print("输入 'exit' 或 'quit' 结束对话\n")
+        """启动增强版对话交互"""
+        self._print_welcome()
         
+        # 加载历史记录(如果存在)
+        if os.path.exists(self.history_file):
+            self.load_history(self.history_file)
+            print(f"已加载 {len(self.chat_history)} 条历史记录\n")
+
         while True:
             try:
-                user_input = input("用户: ")
-                if user_input.lower() in ["exit", "quit"]:
-                    print("对话结束。")
-                    break
+                user_input = self._get_user_input()
+                if not user_input:
+                    continue
                 
-                # 生成回复
-                full_prompt = self._build_prompt(user_input)
-                response = self.runner.generate_text(
-                    prompt=full_prompt,
-                    max_length=Config.MAX_LENGTH,  # 适当增加生成长度
-                    temperature=Config.TEMPERATURE  # 提高创造性
-                )
+                # 处理多模态输入(如图片)
+                if user_input.startswith("/image"):
+                    image_path = user_input.split(maxsplit=1)[1]
+                    response = self._handle_image_input(image_path)
+                    self._display_response(response, {})
+                    continue
                 
-                # 提取新生成的回复(去除历史重复)
-                new_response = response[len(full_prompt):].strip()
-                print(f"AI: {new_response}")
+                # 构建上下文感知提示
+                full_prompt = self._build_context_aware_prompt(user_input)
                 
-                # 更新对话历史(保留最近3轮对话避免过长)
-                self._update_history(user_input, new_response)
+                # 执行生成并处理指标
+                generated_text, metrics = self._generate_response(full_prompt)
                 
+                # 提取并显示新生成内容
+                new_response = self._extract_new_response(full_prompt, generated_text)
+                self._display_response(new_response, metrics)
+                
+                # 智能管理对话历史
+                self._manage_history(user_input, new_response)
+
             except KeyboardInterrupt:
-                print("\n检测到中断,对话结束。")
+                self._handle_interrupt()
                 break
             except Exception as e:
-                logger.error(f"对话出错: {e}")
-                print("系统出现错误,请重新输入。")
+                self._handle_error(e)
+
+    def _print_welcome(self):
+        """打印欢迎信息"""
+        print("\n========== DeepSeek 智能对话系统 ==========")
+        print("输入指令:")
+        print("  /clear  - 清空对话历史")
+        print("  /exit   - 退出系统")
+        print("  /hist   - 查看历史记录")
+        print("  /image <path> - 处理图片输入")
+        print("  /save   - 保存当前对话历史\n")
+
+    def _get_user_input(self) -> str:
+        """获取并预处理用户输入"""
+        try:
+            user_input = input("用户: ").strip()
+            if user_input.lower() in ["/exit", "/quit"]:
+                self.save_history(self.history_file)  # 退出前保存历史
+                print("对话结束。")
+                sys.exit(0)
+            elif user_input.lower() == "/clear":
+                self.chat_history.clear()
+                print("已清空对话历史")
+                return ""
+            elif user_input.lower() == "/hist":
+                self._show_history()
+                return ""
+            elif user_input.lower() == "/save":
+                self.save_history(self.history_file)
+                print(f"历史记录已保存到 {self.history_file}")
+                return ""
+            return user_input
+        except EOFError:
+            sys.exit(0)
+
+    def _build_context_aware_prompt(self, new_input: str) -> str:
+        """构建上下文感知提示(带自动截断)"""
+        context_tokens = 0
+        context_lines = []
+        
+        # 逆向遍历历史记录,构建不超过最大上下文长度的提示
+        for user, resp in reversed(self.chat_history):
+            line = f"用户: {user}\nAI: {resp}"
+            line_tokens = len(self.runner.tokenizer.tokenize(line))
+            
+            if context_tokens + line_tokens > Config.MAX_CONTEXT_TOKENS:
+                break
+                
+            context_lines.insert(0, line)  # 保持时间顺序
+            context_tokens += line_tokens
+        
+        context = "\n".join(context_lines)
+        return f"{context}\n用户: {new_input}\nAI:" if context else f"用户: {new_input}\nAI:"
+
+    def _generate_response(self, prompt: str) -> Tuple[str, dict]:
+        """执行文本生成并收集指标"""
+        try:
+            # 调用修改后的 generate 方法
+            generated_text, metrics = self.runner.generate(
+                prompts=prompt,
+                max_new_tokens=Config.MAX_NEW_TOKENS,
+                temperature=Config.TEMPERATURE,
+                top_p=Config.TOP_P
+            )
+            
+            # 处理批量返回结果
+            if isinstance(generated_text, list):
+                return generated_text[0], metrics
+            return generated_text, metrics
+        
+        except Exception as e:
+            logger.error(f"生成失败: {str(e)}")
+            return "抱歉,我暂时无法处理这个请求。", {"error": str(e)}
+
+    def _extract_new_response(self, prompt: str, generated: str) -> str:
+        """精确提取新生成内容(处理tokenizer差异)"""
+        try:
+            # 使用tokenizer对齐方式处理
+            prompt_tokens = self.runner.tokenizer.encode(prompt, add_special_tokens=False)
+            all_tokens = self.runner.tokenizer.encode(generated, add_special_tokens=False)
+            new_tokens = all_tokens[len(prompt_tokens):]
+            return self.runner.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
+        except:
+            return generated[len(prompt):].strip()
+
+    def _display_response(self, response: str, metrics: dict):
+        """增强型结果显示"""
+        if Config.STREAMING:
+            # 流式输出已实时显示,此处仅打印指标
+            print(f"\n生成指标: {metrics['tokens_per_sec']:.1f}tok/s | 耗时: {metrics['total_time']:.2f}s\n")
+        else:
+            # 非流式完整显示
+            print(f"AI: {response}")
+            print(f"[指标] Tokens: {metrics['generated_tokens']} | 速度: {metrics['tokens_per_sec']:.1f}tok/s")
+            
+            # 调试模式下显示资源使用情况
+            if Config.DEBUG_MODE:
+                print(f"[资源] GPU Mem: {metrics.get('gpu_mem', 0):.1f}GB | CPU Mem: {metrics.get('cpu_mem', 0):.1f}GB")
+
+    def _manage_history(self, user_input: str, response: str):
+        """智能历史管理(基于Token数)"""
+        self.chat_history.append((user_input, response))
+        
+        # 计算总token数
+        total_tokens = sum(
+            len(self.runner.tokenizer.tokenize(f"用户: {u} AI: {r}"))
+            for u, r in self.chat_history
+        )
+        
+        # 动态保留历史(至少保留1轮,最多保留配置上限)
+        while total_tokens > Config.MAX_HISTORY_TOKENS and len(self.chat_history) > 1:
+            removed = self.chat_history.pop(0)
+            total_tokens -= len(self.runner.tokenizer.tokenize(f"用户: {removed[0]} AI: {removed[1]}"))
+
+    def _show_history(self):
+        """显示优化后的历史记录"""
+        print("\n当前对话历史:")
+        for idx, (user, resp) in enumerate(self.chat_history[-3:], 1):
+            print(f"[{idx}] 用户: {user}")
+            print(f"    AI: {resp[:80]}{'...' if len(resp)>80 else ''}")
+        print()
+
+    def save_history(self, path: str):
+        """保存对话历史到文件"""
+        with open(path, 'w', encoding='utf-8') as f:
+            json.dump(self.chat_history, f, ensure_ascii=False, indent=2)
+
+    def load_history(self, path: str):
+        """从文件加载对话历史"""
+        with open(path, 'r', encoding='utf-8') as f:
+            self.chat_history = json.load(f)
+
+    def _handle_image_input(self, image_path: str) -> str:
+        """处理图片输入"""
+        if not Config.MULTIMODAL:
+            return "多模态功能未启用,请检查配置。"
+        
+        try:
+            from src.vision_model import VisionModel  # 假设有独立的视觉模型模块
+            vision_model = VisionModel()
+            vision_output = vision_model.process(image_path)
+            return f"检测到图片内容:{vision_output}"
+        except ImportError:
+            return "多模态模块未安装,请安装相关依赖。"
+        except Exception as e:
+            logger.error(f"图片处理失败: {str(e)}")
+            return "图片处理失败,请重试。"
+
+    def _handle_interrupt(self):
+        """处理中断信号"""
+        print("\n检测到中断信号,正在安全退出...")
+        self.save_history(self.history_file)  # 保存历史
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+
+    def _handle_error(self, error: Exception):
+        """增强错误处理"""
+        logger.error(f"对话出错: {str(error)}")
+        print("系统遇到意外错误,正在恢复...")
+        self.chat_history = self.chat_history[:-1]  # 移除最后一轮问题
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+
+def initialize_system():
+    # 数据集生成与处理
+    logger.info("Generating and processing data...")
+    data_processor = DataProcessor()
+    # data_processor.generate_raw_data()
+    data_processor.process_data()
+
+    # 模型训练
+    logger.info("Training model...")
+    trainer = ModelTrainer()
+    trainer.train()
+
+    """系统初始化流程"""
+    logger.info("初始化模型运行器...")
+    runner = ModelRunner()
     
-    def _build_prompt(self, new_input):
-        """构建包含历史记录的提示"""
-        history = "\n".join([f"用户: {u}\nAI: {r}" for u, r in self.chat_history])
-        return f"{history}\n用户: {new_input}\nAI:" if history else f"用户: {new_input}\nAI:"
+    # 预热模型
+    if Config.RUN_WARMUP:
+        logger.info("执行模型预热...")
+        runner.generate("模型预热", max_new_tokens=10)
     
-    def _update_history(self, user_input, response):
-        """维护对话历史(最多保留3轮)"""
-        self.chat_history.append((user_input, response))
-        if len(self.chat_history) > 3:
-            self.chat_history.pop(0)
+    return ChatInterface(runner)
 
 def main():
     try:
-        # 数据集生成与处理
-        logger.info("Generating and processing data...")
-        data_processor = DataProcessor()
-        data_processor.process_data()
-
-        # 模型训练
-        logger.info("Training model...")
-        trainer = ModelTrainer()
-        trainer.train()
-
-        # 初始化对话系统
-        logger.info("Initializing chat system...")
-        runner = ModelRunner()
-        chat = ChatInterface(runner)
-        
-        # 启动对话
+        chat = initialize_system()
         chat.start_chat()
-
     except Exception as e:
-        logger.error(f"An error occurred: {e}")
+        logger.critical(f"系统启动失败: {str(e)}")
+        sys.exit(1)
 
 if __name__ == "__main__":
     main()

+ 11 - 14
runModel.py

@@ -1,9 +1,8 @@
+# runModel.py
 import sys
 import os
 import logging
 from src.config import Config
-from src.data_processor import DataProcessor
-from src.model_trainer import ModelTrainer
 from src.model_runner import ModelRunner
 
 # 设置日志
@@ -27,19 +26,21 @@ class ChatInterface:
                     print("对话结束。")
                     break
                 
-                # 生成回复
+                # 构建包含历史的完整提示
                 full_prompt = self._build_prompt(user_input)
-                response = self.runner.generate_text(
-                    prompt=full_prompt,
-                    max_length=Config.MAX_LENGTH,  # 适当增加生成长度
-                    temperature=Config.TEMPERATURE  # 提高创造性
+                
+                # 生成回复(调用 generate 方法)
+                full_response, metrics = self.runner.generate(
+                    prompts=full_prompt,
+                    max_new_tokens=Config.MAX_NEW_TOKENS,
+                    temperature=Config.TEMPERATURE
                 )
                 
-                # 提取新生成的回复(去除历史重复
-                new_response = response[len(full_prompt):].strip()
+                # 提取新生成的回复(去除历史部分
+                new_response = full_response[len(full_prompt):].strip()
                 print(f"AI: {new_response}")
                 
-                # 更新对话历史(保留最近3轮对话避免过长
+                # 更新对话历史(保留最近3轮对话)
                 self._update_history(user_input, new_response)
                 
             except KeyboardInterrupt:
@@ -62,14 +63,10 @@ class ChatInterface:
 
 def main():
     try:
-        # 初始化对话系统
         logger.info("Initializing chat system...")
         runner = ModelRunner()
         chat = ChatInterface(runner)
-        
-        # 启动对话
         chat.start_chat()
-
     except Exception as e:
         logger.error(f"An error occurred: {e}")
 

BIN
src/__pycache__/config.cpython-39.pyc


BIN
src/__pycache__/model_runner.cpython-39.pyc


BIN
src/__pycache__/model_trainer.cpython-39.pyc


+ 18 - 4
src/config.py

@@ -9,8 +9,8 @@ class Config:
     
     # 模型路径
     PRETRAINED_MODEL_DIR = os.path.join("models", "pretrained","DeepSeek-R1-Distill-Qwen-1dot5B")  # 预训练模型目录
-    # TRAINED_MODEL_DIR = os.path.join("models", "trained","DeepSeek-R1-Distill-Qwen-1dot5B")        # 训练后模型保存目录
-    TRAINED_MODEL_DIR = os.path.join("models", "trained","DeepSeek-R1-Distill-Qwen-1dot5B-WindTurbine")        # 训练后模型保存目录
+    TRAINED_MODEL_DIR = os.path.join("models", "trained","DeepSeek-R1-Distill-Qwen-1dot5B")        # 训练后模型保存目录
+    # TRAINED_MODEL_DIR = os.path.join("models", "trained","DeepSeek-R1-Distill-Qwen-1dot5B-WindTurbine")        # 训练后模型保存目录
 
     # PRETRAINED_MODEL_DIR = os.path.join("models", "pretrained","DeepSeek-R1-Distill-Qwen-7B")  # 预训练模型目录
     # TRAINED_MODEL_DIR = os.path.join("models", "trained","DeepSeek-R1-Distill-Qwen-7B")        # 训练后模型保存目录
@@ -22,8 +22,11 @@ class Config:
     # TRAINED_MODEL_DIR = os.path.join("models", "trained","DeepSeek-R1-Distill-Llama-8B")        # 训练后模型保存目录
     RUN_MODEL_DIR=TRAINED_MODEL_DIR  # TRAINED_MODEL_DIR
 
-    DEVICE_STR="cpu"  # "cuda" if torch.cuda.is_available() else "cpu"
+    # 设备配置
+    DEVICE_STR="cpu"  # 
+    # DEVICE_STR="cuda" if torch.cuda.is_available() else "cpu"
     DEVICE=torch.device(DEVICE_STR)
+    TORCH_DTYPE = torch.bfloat16 if DEVICE_STR == "cuda" else torch.float32
     
     # 训练参数
     USE_FP16 = True if DEVICE_STR=="cuda" else False # 是否使用混合精度训练,在CPU上无法有效使用FP16,且可能引发梯度问题,须关闭混合精度训练,即将USE_FP16设为False
@@ -44,4 +47,15 @@ class Config:
     MAX_INPUT_LENGTH = 2048  # 输入序列的最大长度
     MAX_NEW_TOKENS = 200  # 生成文本的最大长度(新增部分)范围200~512
     MAX_LENGTH = MAX_INPUT_LENGTH + MAX_NEW_TOKENS  # 总长度
-    TEMPERATURE = 0.7 # 0.7  # 控制生成文本的随机性。较低的温度(如 0.2)会使生成结果更确定,较高的温度(如 1.0)会使生成结果更多样化。
+    # 上下文管理
+    MAX_CONTEXT_TOKENS = 2048    # 最大上下文token数
+    MAX_HISTORY_TOKENS = 1024    # 历史记录最大token数
+
+    TEMPERATURE = 0.7 # 0.7  # 控制生成文本的随机性。较低的温度(如 0.2)会使生成结果更确定,较高的温度(如 1.0)会使生成结果更多样化。
+    DO_SAMPLE = True
+
+    # 系统功能
+    RUN_WARMUP = True           # 是否启用预热
+    STREAMING = False  # 是否启用流式输出
+    BATCH_ENABLED = True  # 是否启用批处理
+    MODEL_LOAD_KWARGS = {}  # 额外加载参数

+ 199 - 53
src/model_runner.py

@@ -1,76 +1,222 @@
-#  model_runner.py
+# model_runner.py
 import torch
-from torch.quantization import quantize_dynamic
-from transformers import AutoTokenizer, AutoModelForCausalLM
+import time
+import psutil  # 新增系统监控
+from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
 from src.config import Config
 import logging
-
+from typing import Union, List, Tuple
 
 logging.basicConfig(level=logging.INFO)
 logger = logging.getLogger(__name__)
 
 class ModelRunner:
     def __init__(self):
-        # 修正设备设置
-        self.device =Config.DEVICE # torch.device("cuda" if torch.cuda.is_available() else "cpu")
-        logger.info(f"Using device: {self.device}")
-        
-        self.modelId=Config.RUN_MODEL_DIR
-        logger.info(f"model id: {self.modelId}")
+        self.device = Config.DEVICE
+        self.modelId = Config.RUN_MODEL_DIR
+        self._init_resources()
+        self._load_model()
+        self._init_streamer()
+
+    def _init_resources(self):
+        """初始化资源监控基线"""
+        self.process = psutil.Process()
+        self.initial_cpu_mem = self.process.memory_info().rss / 1024**3  # GB
+        if torch.cuda.is_available():
+            torch.cuda.reset_peak_memory_stats()
+            self.initial_gpu_mem = torch.cuda.memory_allocated() / 1024**3  # GB
 
+    def _load_model(self):
+        """加载模型和分词器"""
         try:
-            # 加载训练后的模型和分词器
-            self.tokenizer = AutoTokenizer.from_pretrained(self.modelId)
-            self.model = AutoModelForCausalLM.from_pretrained(self.modelId,torch_dtype= Config.TORCH_DTYPE)
-            # self.model = AutoModelForCausalLM.from_pretrained(self.modelId,torch_dtype= Config.TORCH_DTYPE,device_map="auto")
+            logger.info(f"Loading model from {self.modelId} on {self.device}")
+            
+            self.tokenizer = AutoTokenizer.from_pretrained(
+                self.modelId,
+                padding_side="left" if Config.BATCH_ENABLED else "right"
+            )
             
-            if Config.DEVICE_STR=="cuda" and torch.cuda.device_count() > 1:
-               self.model = torch.nn.DataParallel(self.model)  # 多 GPU 并行
+            # 自动设备映射优化
+            self.model = AutoModelForCausalLM.from_pretrained(
+                self.modelId,
+                torch_dtype=Config.TORCH_DTYPE,
+                device_map="auto" if Config.DEVICE_STR == "cuda" else None,
+                **Config.MODEL_LOAD_KWARGS
+            )
+
+            # 多GPU并行
+            if Config.DEVICE_STR == "cuda" and torch.cuda.device_count() > 1:
+                self.model = torch.nn.DataParallel(self.model)
+                logger.info(f"Parallelized across {torch.cuda.device_count()} GPUs")
 
-            # 确保分词器有 pad_token
+            # 分词器配置
             if self.tokenizer.pad_token is None:
                 self.tokenizer.pad_token = self.tokenizer.eos_token
-                        
-            # 将模型移动到设备
-            self.model.to(self.device)
-            logger.info("Trained model loaded successfully.")
+
+            self.model.eval()
+            logger.info("Model loaded in evaluation mode")
+
         except Exception as e:
-            logger.error(f"Failed to load model: {e}")
+            logger.error(f"Model loading failed: {str(e)}")
             raise
 
-    def generate_text(self, prompt, max_length=None, temperature=None, do_sample=True):
+    def _init_streamer(self):
+        """初始化流式输出组件"""
+        self.streamer = TextStreamer(
+            self.tokenizer,
+            skip_prompt=True,
+            skip_special_tokens=True
+        ) if Config.STREAMING else None
+
+    def _log_resource_usage(self):
+        """记录资源使用情况"""
+        metrics = {}
+        # CPU内存
+        metrics["cpu_mem"] = self.process.memory_info().rss / 1024**3 - self.initial_cpu_mem
+        
+        # GPU内存
+        if torch.cuda.is_available():
+            metrics["gpu_mem"] = (
+                torch.cuda.max_memory_allocated() / 1024**3 - self.initial_gpu_mem
+            )
+        
+        logger.debug(f"Resource Usage: {metrics}")
+
+    def generate(
+        self, 
+        prompts: Union[str, List[str]],
+        **generation_kwargs
+    ) -> Tuple[Union[str, List[str]], dict]:
+        """
+        支持单条/批量文本生成
+        :param prompts: 输入文本或文本列表
+        :return: (生成文本, 性能指标)
+        """
+        metrics = {"total_tokens": 0}
         try:
-            self.model.eval()
+            total_start = time.perf_counter()
             
-            # 编码输入并生成 attention_mask
-            inputs = self.tokenizer(
-                prompt, 
-                return_tensors="pt", 
-                padding=True, 
-                truncation=True
-            ).to(self.device)
+            # 批处理预处理
+            is_batch = isinstance(prompts, list)
+            inputs = self._preprocess(prompts, metrics)
             
-            # 设置生成参数
-            max_length = max_length or Config.MAX_TOKENS
-            max_new_tokens = max_length or Config.MAX_NEW_TOKENS
-            temperature = temperature or Config.TEMPERATURE
-
-            # 生成文本
-            outputs = self.model.generate(
-                inputs.input_ids,
-                attention_mask=inputs.attention_mask,  # 显式传递 attention_mask
-                max_new_tokens=max_new_tokens,
-                temperature=temperature,
-                top_k=Config.TOP_K,
-                top_p=Config.TOP_P,
-                num_beams=Config.NUM_BEAMS,
-                do_sample=do_sample,
-                pad_token_id=self.tokenizer.pad_token_id  # 显式设置 pad_token_id
-            )
+            # 生成参数
+            gen_params = self._build_generation_params(inputs, generation_kwargs)
+            
+            # 执行生成
+            outputs = self._execute_generation(gen_params, metrics)
             
-            # 解码生成结果
-            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
-            return generated_text
+            # 后处理
+            results = self._postprocess(outputs, inputs, is_batch, metrics)
+            
+            # 最终指标计算
+            metrics["total_time"] = time.perf_counter() - total_start
+            metrics["throughput"] = metrics["total_tokens"] / metrics["total_time"]
+            
+            self._log_performance(metrics)
+            self._log_resource_usage()
+            
+            return results, metrics
+
         except Exception as e:
-            logger.error(f"Generation failed: {e}")
-            raise
+            logger.error(f"Generation error: {str(e)}")
+            self._log_resource_usage()
+            raise
+
+    def _preprocess(self, prompts: Union[str, List[str]], metrics: dict):
+        """预处理输入"""
+        start = time.perf_counter()
+        
+        inputs = self.tokenizer(
+            prompts,
+            return_tensors="pt",
+            padding=Config.BATCH_ENABLED,
+            truncation=True,
+            max_length=Config.MAX_INPUT_LENGTH
+        ).to(self.device)
+        
+        metrics["preprocess_time"] = time.perf_counter() - start
+        return inputs
+
+    def _build_generation_params(self, inputs, generation_kwargs):
+        """构建生成参数"""
+        params = {
+            "input_ids": inputs.input_ids,
+            "attention_mask": inputs.attention_mask,
+            "max_new_tokens": generation_kwargs.get("max_new_tokens", Config.MAX_NEW_TOKENS),
+            "temperature": generation_kwargs.get("temperature", Config.TEMPERATURE),
+            "top_p": generation_kwargs.get("top_p", Config.TOP_P),
+            "num_beams": generation_kwargs.get("num_beams", Config.NUM_BEAMS),
+            "do_sample": generation_kwargs.get("do_sample", Config.DO_SAMPLE),
+            "pad_token_id": self.tokenizer.pad_token_id,
+            "streamer": self.streamer if Config.STREAMING else None,
+        }
+        return {k: v for k, v in params.items() if v is not None}
+
+    def _execute_generation(self, gen_params, metrics):
+        """执行生成并记录指标"""
+        start = time.perf_counter()
+        
+        with torch.inference_mode():
+            outputs = self.model.generate(**gen_params)
+        
+        metrics["inference_time"] = time.perf_counter() - start
+        metrics["total_tokens"] = outputs.shape[-1] - gen_params["input_ids"].shape[-1]
+        return outputs
+
+    def _postprocess(self, outputs, inputs, is_batch, metrics):
+        """后处理输出"""
+        start = time.perf_counter()
+        
+        # 批量解码
+        results = []
+        for i in range(outputs.shape[0]):
+            # 跳过输入部分
+            generated = outputs[i, inputs.input_ids.shape[1]:]
+            try:
+                text = self.tokenizer.decode(
+                    generated,
+                    skip_special_tokens=True
+                )
+                results.append(text.strip())
+            except Exception as e:
+                logger.warning(f"Decoding error for sample {i}: {str(e)}")
+                results.append("")
+        
+        metrics["postprocess_time"] = time.perf_counter() - start
+        return results if is_batch else results[0]
+
+    def _log_performance(self, metrics):
+        """记录性能日志"""
+        log_msg = (
+            f"Performance Summary || "
+            f"Total: {metrics['total_time']:.2f}s | "
+            f"Inference: {metrics['inference_time']:.2f}s | "
+            f"Tokens: {metrics['total_tokens']} | "
+            f"Speed: {metrics['throughput']:.1f} tok/s"
+        )
+        if Config.BATCH_ENABLED:
+            log_msg += f" | Batch size: {metrics.get('batch_size', 1)}"
+        logger.info(log_msg)
+
+# 配置文件示例(src/config.py)
+    
+if __name__ == "__main__":
+    # 使用示例
+    runner = ModelRunner()
+    
+    # 单条生成
+    single_prompt = "DeepSeek-R1的主要技术优势是"
+    text, metrics = runner.generate(single_prompt)
+    print(f"\nSingle Generation:\n{text}")
+    
+    # 批量生成
+    batch_prompts = [
+        "解释量子计算的基本原理",
+        "写一首关于人工智能的诗",
+        "如何快速学习PyTorch?"
+    ]
+    texts, metrics = runner.generate(batch_prompts)
+    print("\nBatch Results:")
+    for i, text in enumerate(texts):
+        print(f"{i+1}. {text[:100]}...")

+ 76 - 0
src/model_runner_v1.py

@@ -0,0 +1,76 @@
+#  model_runner.py
+import torch
+from torch.quantization import quantize_dynamic
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from src.config import Config
+import logging
+
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+class ModelRunner:
+    def __init__(self):
+        # 修正设备设置
+        self.device =Config.DEVICE # torch.device("cuda" if torch.cuda.is_available() else "cpu")
+        logger.info(f"Using device: {self.device}")
+        
+        self.modelId=Config.RUN_MODEL_DIR
+        logger.info(f"model id: {self.modelId}")
+
+        try:
+            # 加载训练后的模型和分词器
+            self.tokenizer = AutoTokenizer.from_pretrained(self.modelId)
+            self.model = AutoModelForCausalLM.from_pretrained(self.modelId,torch_dtype= Config.TORCH_DTYPE)
+            # self.model = AutoModelForCausalLM.from_pretrained(self.modelId,torch_dtype= Config.TORCH_DTYPE,device_map="auto")
+            
+            if Config.DEVICE_STR=="cuda" and torch.cuda.device_count() > 1:
+               self.model = torch.nn.DataParallel(self.model)  # 多 GPU 并行
+
+            # 确保分词器有 pad_token
+            if self.tokenizer.pad_token is None:
+                self.tokenizer.pad_token = self.tokenizer.eos_token
+                        
+            # 将模型移动到设备
+            self.model.to(self.device)
+            logger.info("Trained model loaded successfully.")
+        except Exception as e:
+            logger.error(f"Failed to load model: {e}")
+            raise
+
+    def generate_text(self, prompt, max_length=None, temperature=None, do_sample=True):
+        try:
+            self.model.eval()
+            
+            # 编码输入并生成 attention_mask
+            inputs = self.tokenizer(
+                prompt, 
+                return_tensors="pt", 
+                padding=True, 
+                truncation=True
+            ).to(self.device)
+            
+            # 设置生成参数
+            max_length = max_length or Config.MAX_TOKENS
+            max_new_tokens = max_length or Config.MAX_NEW_TOKENS
+            temperature = temperature or Config.TEMPERATURE
+
+            # 生成文本
+            outputs = self.model.generate(
+                inputs.input_ids,
+                attention_mask=inputs.attention_mask,  # 显式传递 attention_mask
+                max_new_tokens=max_new_tokens,
+                temperature=temperature,
+                top_k=Config.TOP_K,
+                top_p=Config.TOP_P,
+                num_beams=Config.NUM_BEAMS,
+                do_sample=do_sample,
+                pad_token_id=self.tokenizer.pad_token_id  # 显式设置 pad_token_id
+            )
+            
+            # 解码生成结果
+            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
+            return generated_text
+        except Exception as e:
+            logger.error(f"Generation failed: {e}")
+            raise

+ 1 - 2
trainModel.py

@@ -3,7 +3,6 @@ import os
 import logging
 from src.data_processor import DataProcessor
 from src.model_trainer import ModelTrainer
-from src.model_runner import ModelRunner
 
 # 设置日志
 logging.basicConfig(level=logging.INFO)
@@ -14,7 +13,7 @@ def main():
         # 数据集生成与处理
         logger.info("Generating and processing data...")
         data_processor = DataProcessor()
-        data_processor.generate_raw_data()
+        # data_processor.generate_raw_data()
         data_processor.process_data()
 
         # 模型训练