# runModel.py import sys import os import logging from src.config import Config from src.model_runner import ModelRunner # 设置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ChatInterface: def __init__(self, model_runner): self.runner = model_runner self.chat_history = [] # 存储对话历史 def start_chat(self): """启动对话交互""" print("\n========== DeepSeek 对话系统 ==========") print("输入 'exit' 或 'quit' 结束对话\n") while True: try: user_input = input("用户: ") if user_input.lower() in ["exit", "quit"]: print("对话结束。") break # 构建包含历史的完整提示 full_prompt = self._build_prompt(user_input) # 生成回复(调用 generate 方法) full_response, metrics = self.runner.generate( prompts=full_prompt, max_new_tokens=Config.MAX_NEW_TOKENS, temperature=Config.TEMPERATURE ) # 提取新生成的回复(去除历史部分) new_response = full_response[len(full_prompt):].strip() print(f"AI: {new_response}") # 更新对话历史(保留最近3轮对话) self._update_history(user_input, new_response) except KeyboardInterrupt: print("\n检测到中断,对话结束。") break except Exception as e: logger.error(f"对话出错: {e}") print("系统出现错误,请重新输入。") def _build_prompt(self, new_input): """构建包含历史记录的提示""" history = "\n".join([f"用户: {u}\nAI: {r}" for u, r in self.chat_history]) return f"{history}\n用户: {new_input}\nAI:" if history else f"用户: {new_input}\nAI:" def _update_history(self, user_input, response): """维护对话历史(最多保留3轮)""" self.chat_history.append((user_input, response)) if len(self.chat_history) > 3: self.chat_history.pop(0) def main(): try: logger.info("Initializing chat system...") runner = ModelRunner() chat = ChatInterface(runner) chat.start_chat() except Exception as e: logger.error(f"An error occurred: {e}") if __name__ == "__main__": main()