# main.py import sys import os import json import logging import torch from typing import List, Tuple, Optional from src.config import Config from src.data_processor import DataProcessor from src.model_trainer import ModelTrainer from src.model_runner import ModelRunner # 设置日志 logging.basicConfig( format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO ) logger = logging.getLogger(__name__) class ChatInterface: def __init__(self, model_runner: ModelRunner): self.runner = model_runner self.chat_history: List[Tuple[str, str]] = [] # 对话历史 self.history_file = "chat_history.json" # 历史记录保存路径 def start_chat(self): """启动增强版对话交互""" self._print_welcome() # 加载历史记录(如果存在) if os.path.exists(self.history_file): self.load_history(self.history_file) print(f"已加载 {len(self.chat_history)} 条历史记录\n") while True: try: user_input = self._get_user_input() if not user_input: continue # 处理多模态输入(如图片) if user_input.startswith("/image"): image_path = user_input.split(maxsplit=1)[1] response = self._handle_image_input(image_path) self._display_response(response, {}) continue # 构建上下文感知提示 full_prompt = self._build_context_aware_prompt(user_input) # 执行生成并处理指标 generated_text, metrics = self._generate_response(full_prompt) # 提取并显示新生成内容 new_response = self._extract_new_response(full_prompt, generated_text) self._display_response(new_response, metrics) # 智能管理对话历史 self._manage_history(user_input, new_response) except KeyboardInterrupt: self._handle_interrupt() break except Exception as e: self._handle_error(e) def _print_welcome(self): """打印欢迎信息""" print("\n========== DeepSeek 智能对话系统 ==========") print("输入指令:") print(" /clear - 清空对话历史") print(" /exit - 退出系统") print(" /hist - 查看历史记录") print(" /image - 处理图片输入") print(" /save - 保存当前对话历史\n") def _get_user_input(self) -> str: """获取并预处理用户输入""" try: user_input = input("用户: ").strip() if user_input.lower() in ["/exit", "/quit"]: self.save_history(self.history_file) # 退出前保存历史 print("对话结束。") sys.exit(0) elif user_input.lower() == "/clear": self.chat_history.clear() print("已清空对话历史") return "" elif user_input.lower() == "/hist": self._show_history() return "" elif user_input.lower() == "/save": self.save_history(self.history_file) print(f"历史记录已保存到 {self.history_file}") return "" return user_input except EOFError: sys.exit(0) def _build_context_aware_prompt(self, new_input: str) -> str: """构建上下文感知提示(带自动截断)""" context_tokens = 0 context_lines = [] # 逆向遍历历史记录,构建不超过最大上下文长度的提示 for user, resp in reversed(self.chat_history): line = f"用户: {user}\nAI: {resp}" line_tokens = len(self.runner.tokenizer.tokenize(line)) if context_tokens + line_tokens > Config.MAX_CONTEXT_TOKENS: break context_lines.insert(0, line) # 保持时间顺序 context_tokens += line_tokens context = "\n".join(context_lines) return f"{context}\n用户: {new_input}\nAI:" if context else f"用户: {new_input}\nAI:" def _generate_response(self, prompt: str) -> Tuple[str, dict]: """执行文本生成并收集指标""" try: # 调用修改后的 generate 方法 generated_text, metrics = self.runner.generate( prompts=prompt, max_new_tokens=Config.MAX_NEW_TOKENS, temperature=Config.TEMPERATURE, top_p=Config.TOP_P ) # 处理批量返回结果 if isinstance(generated_text, list): return generated_text[0], metrics return generated_text, metrics except Exception as e: logger.error(f"生成失败: {str(e)}") return "抱歉,我暂时无法处理这个请求。", {"error": str(e)} def _extract_new_response(self, prompt: str, generated: str) -> str: """精确提取新生成内容(处理tokenizer差异)""" try: # 使用tokenizer对齐方式处理 prompt_tokens = self.runner.tokenizer.encode(prompt, add_special_tokens=False) all_tokens = self.runner.tokenizer.encode(generated, add_special_tokens=False) new_tokens = all_tokens[len(prompt_tokens):] return self.runner.tokenizer.decode(new_tokens, skip_special_tokens=True).strip() except: return generated[len(prompt):].strip() def _display_response(self, response: str, metrics: dict): """增强型结果显示""" if Config.STREAMING: # 流式输出已实时显示,此处仅打印指标 print(f"\n生成指标: {metrics['tokens_per_sec']:.1f}tok/s | 耗时: {metrics['total_time']:.2f}s\n") else: # 非流式完整显示 print(f"AI: {response}") print(f"[指标] Tokens: {metrics['generated_tokens']} | 速度: {metrics['tokens_per_sec']:.1f}tok/s") # 调试模式下显示资源使用情况 if Config.DEBUG_MODE: print(f"[资源] GPU Mem: {metrics.get('gpu_mem', 0):.1f}GB | CPU Mem: {metrics.get('cpu_mem', 0):.1f}GB") def _manage_history(self, user_input: str, response: str): """智能历史管理(基于Token数)""" self.chat_history.append((user_input, response)) # 计算总token数 total_tokens = sum( len(self.runner.tokenizer.tokenize(f"用户: {u} AI: {r}")) for u, r in self.chat_history ) # 动态保留历史(至少保留1轮,最多保留配置上限) while total_tokens > Config.MAX_HISTORY_TOKENS and len(self.chat_history) > 1: removed = self.chat_history.pop(0) total_tokens -= len(self.runner.tokenizer.tokenize(f"用户: {removed[0]} AI: {removed[1]}")) def _show_history(self): """显示优化后的历史记录""" print("\n当前对话历史:") for idx, (user, resp) in enumerate(self.chat_history[-3:], 1): print(f"[{idx}] 用户: {user}") print(f" AI: {resp[:80]}{'...' if len(resp)>80 else ''}") print() def save_history(self, path: str): """保存对话历史到文件""" with open(path, 'w', encoding='utf-8') as f: json.dump(self.chat_history, f, ensure_ascii=False, indent=2) def load_history(self, path: str): """从文件加载对话历史""" with open(path, 'r', encoding='utf-8') as f: self.chat_history = json.load(f) def _handle_image_input(self, image_path: str) -> str: """处理图片输入""" if not Config.MULTIMODAL: return "多模态功能未启用,请检查配置。" try: from src.vision_model import VisionModel # 假设有独立的视觉模型模块 vision_model = VisionModel() vision_output = vision_model.process(image_path) return f"检测到图片内容:{vision_output}" except ImportError: return "多模态模块未安装,请安装相关依赖。" except Exception as e: logger.error(f"图片处理失败: {str(e)}") return "图片处理失败,请重试。" def _handle_interrupt(self): """处理中断信号""" print("\n检测到中断信号,正在安全退出...") self.save_history(self.history_file) # 保存历史 if torch.cuda.is_available(): torch.cuda.empty_cache() def _handle_error(self, error: Exception): """增强错误处理""" logger.error(f"对话出错: {str(error)}") print("系统遇到意外错误,正在恢复...") self.chat_history = self.chat_history[:-1] # 移除最后一轮问题 if torch.cuda.is_available(): torch.cuda.empty_cache() def initialize_system(): # 数据集生成与处理 logger.info("Generating and processing data...") data_processor = DataProcessor() # data_processor.generate_raw_data() data_processor.process_data() # 模型训练 logger.info("Training model...") trainer = ModelTrainer() trainer.train() """系统初始化流程""" logger.info("初始化模型运行器...") runner = ModelRunner() # 预热模型 if Config.RUN_WARMUP: logger.info("执行模型预热...") runner.generate("模型预热", max_new_tokens=10) return ChatInterface(runner) def main(): try: chat = initialize_system() chat.start_chat() except Exception as e: logger.critical(f"系统启动失败: {str(e)}") sys.exit(1) if __name__ == "__main__": main()