123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260 |
- # main.py
- import sys
- import os
- import json
- import logging
- import torch
- from typing import List, Tuple, Optional
- from src.config import Config
- from src.data_processor import DataProcessor
- from src.model_trainer import ModelTrainer
- from src.model_runner import ModelRunner
- # 设置日志
- logging.basicConfig(
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- level=logging.INFO
- )
- logger = logging.getLogger(__name__)
- class ChatInterface:
- def __init__(self, model_runner: ModelRunner):
- self.runner = model_runner
- self.chat_history: List[Tuple[str, str]] = [] # 对话历史
- self.history_file = "chat_history.json" # 历史记录保存路径
- def start_chat(self):
- """启动增强版对话交互"""
- self._print_welcome()
-
- # 加载历史记录(如果存在)
- if os.path.exists(self.history_file):
- self.load_history(self.history_file)
- print(f"已加载 {len(self.chat_history)} 条历史记录\n")
- while True:
- try:
- user_input = self._get_user_input()
- if not user_input:
- continue
-
- # 处理多模态输入(如图片)
- if user_input.startswith("/image"):
- image_path = user_input.split(maxsplit=1)[1]
- response = self._handle_image_input(image_path)
- self._display_response(response, {})
- continue
-
- # 构建上下文感知提示
- full_prompt = self._build_context_aware_prompt(user_input)
-
- # 执行生成并处理指标
- generated_text, metrics = self._generate_response(full_prompt)
-
- # 提取并显示新生成内容
- new_response = self._extract_new_response(full_prompt, generated_text)
- self._display_response(new_response, metrics)
-
- # 智能管理对话历史
- self._manage_history(user_input, new_response)
- except KeyboardInterrupt:
- self._handle_interrupt()
- break
- except Exception as e:
- self._handle_error(e)
- def _print_welcome(self):
- """打印欢迎信息"""
- print("\n========== DeepSeek 智能对话系统 ==========")
- print("输入指令:")
- print(" /clear - 清空对话历史")
- print(" /exit - 退出系统")
- print(" /hist - 查看历史记录")
- print(" /image <path> - 处理图片输入")
- print(" /save - 保存当前对话历史\n")
- def _get_user_input(self) -> str:
- """获取并预处理用户输入"""
- try:
- user_input = input("用户: ").strip()
- if user_input.lower() in ["/exit", "/quit"]:
- self.save_history(self.history_file) # 退出前保存历史
- print("对话结束。")
- sys.exit(0)
- elif user_input.lower() == "/clear":
- self.chat_history.clear()
- print("已清空对话历史")
- return ""
- elif user_input.lower() == "/hist":
- self._show_history()
- return ""
- elif user_input.lower() == "/save":
- self.save_history(self.history_file)
- print(f"历史记录已保存到 {self.history_file}")
- return ""
- return user_input
- except EOFError:
- sys.exit(0)
- def _build_context_aware_prompt(self, new_input: str) -> str:
- """构建上下文感知提示(带自动截断)"""
- context_tokens = 0
- context_lines = []
-
- # 逆向遍历历史记录,构建不超过最大上下文长度的提示
- for user, resp in reversed(self.chat_history):
- line = f"用户: {user}\nAI: {resp}"
- line_tokens = len(self.runner.tokenizer.tokenize(line))
-
- if context_tokens + line_tokens > Config.MAX_CONTEXT_TOKENS:
- break
-
- context_lines.insert(0, line) # 保持时间顺序
- context_tokens += line_tokens
-
- context = "\n".join(context_lines)
- return f"{context}\n用户: {new_input}\nAI:" if context else f"用户: {new_input}\nAI:"
- def _generate_response(self, prompt: str) -> Tuple[str, dict]:
- """执行文本生成并收集指标"""
- try:
- # 调用修改后的 generate 方法
- generated_text, metrics = self.runner.generate(
- prompts=prompt,
- max_new_tokens=Config.MAX_NEW_TOKENS,
- temperature=Config.TEMPERATURE,
- top_p=Config.TOP_P
- )
-
- # 处理批量返回结果
- if isinstance(generated_text, list):
- return generated_text[0], metrics
- return generated_text, metrics
-
- except Exception as e:
- logger.error(f"生成失败: {str(e)}")
- return "抱歉,我暂时无法处理这个请求。", {"error": str(e)}
- def _extract_new_response(self, prompt: str, generated: str) -> str:
- """精确提取新生成内容(处理tokenizer差异)"""
- try:
- # 使用tokenizer对齐方式处理
- prompt_tokens = self.runner.tokenizer.encode(prompt, add_special_tokens=False)
- all_tokens = self.runner.tokenizer.encode(generated, add_special_tokens=False)
- new_tokens = all_tokens[len(prompt_tokens):]
- return self.runner.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
- except:
- return generated[len(prompt):].strip()
- def _display_response(self, response: str, metrics: dict):
- """增强型结果显示"""
- if Config.STREAMING:
- # 流式输出已实时显示,此处仅打印指标
- print(f"\n生成指标: {metrics['tokens_per_sec']:.1f}tok/s | 耗时: {metrics['total_time']:.2f}s\n")
- else:
- # 非流式完整显示
- print(f"AI: {response}")
- print(f"[指标] Tokens: {metrics['generated_tokens']} | 速度: {metrics['tokens_per_sec']:.1f}tok/s")
-
- # 调试模式下显示资源使用情况
- if Config.DEBUG_MODE:
- print(f"[资源] GPU Mem: {metrics.get('gpu_mem', 0):.1f}GB | CPU Mem: {metrics.get('cpu_mem', 0):.1f}GB")
- def _manage_history(self, user_input: str, response: str):
- """智能历史管理(基于Token数)"""
- self.chat_history.append((user_input, response))
-
- # 计算总token数
- total_tokens = sum(
- len(self.runner.tokenizer.tokenize(f"用户: {u} AI: {r}"))
- for u, r in self.chat_history
- )
-
- # 动态保留历史(至少保留1轮,最多保留配置上限)
- while total_tokens > Config.MAX_HISTORY_TOKENS and len(self.chat_history) > 1:
- removed = self.chat_history.pop(0)
- total_tokens -= len(self.runner.tokenizer.tokenize(f"用户: {removed[0]} AI: {removed[1]}"))
- def _show_history(self):
- """显示优化后的历史记录"""
- print("\n当前对话历史:")
- for idx, (user, resp) in enumerate(self.chat_history[-3:], 1):
- print(f"[{idx}] 用户: {user}")
- print(f" AI: {resp[:80]}{'...' if len(resp)>80 else ''}")
- print()
- def save_history(self, path: str):
- """保存对话历史到文件"""
- with open(path, 'w', encoding='utf-8') as f:
- json.dump(self.chat_history, f, ensure_ascii=False, indent=2)
- def load_history(self, path: str):
- """从文件加载对话历史"""
- with open(path, 'r', encoding='utf-8') as f:
- self.chat_history = json.load(f)
- def _handle_image_input(self, image_path: str) -> str:
- """处理图片输入"""
- if not Config.MULTIMODAL:
- return "多模态功能未启用,请检查配置。"
-
- try:
- from src.vision_model import VisionModel # 假设有独立的视觉模型模块
- vision_model = VisionModel()
- vision_output = vision_model.process(image_path)
- return f"检测到图片内容:{vision_output}"
- except ImportError:
- return "多模态模块未安装,请安装相关依赖。"
- except Exception as e:
- logger.error(f"图片处理失败: {str(e)}")
- return "图片处理失败,请重试。"
- def _handle_interrupt(self):
- """处理中断信号"""
- print("\n检测到中断信号,正在安全退出...")
- self.save_history(self.history_file) # 保存历史
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- def _handle_error(self, error: Exception):
- """增强错误处理"""
- logger.error(f"对话出错: {str(error)}")
- print("系统遇到意外错误,正在恢复...")
- self.chat_history = self.chat_history[:-1] # 移除最后一轮问题
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- def initialize_system():
- # 数据集生成与处理
- logger.info("Generating and processing data...")
- data_processor = DataProcessor()
- # data_processor.generate_raw_data()
- data_processor.process_data()
- # 模型训练
- logger.info("Training model...")
- trainer = ModelTrainer()
- trainer.train()
- """系统初始化流程"""
- logger.info("初始化模型运行器...")
- runner = ModelRunner()
-
- # 预热模型
- if Config.RUN_WARMUP:
- logger.info("执行模型预热...")
- runner.generate("模型预热", max_new_tokens=10)
-
- return ChatInterface(runner)
- def main():
- try:
- chat = initialize_system()
- chat.start_chat()
- except Exception as e:
- logger.critical(f"系统启动失败: {str(e)}")
- sys.exit(1)
- if __name__ == "__main__":
- main()
|