1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- # runModel.py
- import sys
- import os
- import logging
- from src.config import Config
- from src.model_runner import ModelRunner
- # 设置日志
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- class ChatInterface:
- def __init__(self, model_runner):
- self.runner = model_runner
- self.chat_history = [] # 存储对话历史
-
- def start_chat(self):
- """启动对话交互"""
- print("\n========== DeepSeek 对话系统 ==========")
- print("输入 'exit' 或 'quit' 结束对话\n")
-
- while True:
- try:
- user_input = input("用户: ")
- if user_input.lower() in ["exit", "quit"]:
- print("对话结束。")
- break
-
- # 构建包含历史的完整提示
- full_prompt = self._build_prompt(user_input)
-
- # 生成回复(调用 generate 方法)
- full_response, metrics = self.runner.generate(
- prompts=full_prompt,
- max_new_tokens=Config.MAX_NEW_TOKENS,
- temperature=Config.TEMPERATURE
- )
-
- # 提取新生成的回复(去除历史部分)
- new_response = full_response[len(full_prompt):].strip()
- print(f"AI: {new_response}")
-
- # 更新对话历史(保留最近3轮对话)
- self._update_history(user_input, new_response)
-
- except KeyboardInterrupt:
- print("\n检测到中断,对话结束。")
- break
- except Exception as e:
- logger.error(f"对话出错: {e}")
- print("系统出现错误,请重新输入。")
-
- def _build_prompt(self, new_input):
- """构建包含历史记录的提示"""
- history = "\n".join([f"用户: {u}\nAI: {r}" for u, r in self.chat_history])
- return f"{history}\n用户: {new_input}\nAI:" if history else f"用户: {new_input}\nAI:"
-
- def _update_history(self, user_input, response):
- """维护对话历史(最多保留3轮)"""
- self.chat_history.append((user_input, response))
- if len(self.chat_history) > 3:
- self.chat_history.pop(0)
- def main():
- try:
- logger.info("Initializing chat system...")
- runner = ModelRunner()
- chat = ChatInterface(runner)
- chat.start_chat()
- except Exception as e:
- logger.error(f"An error occurred: {e}")
- if __name__ == "__main__":
- main()
|