runModel.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # runModel.py
  2. import sys
  3. import os
  4. import logging
  5. from src.config import Config
  6. from src.model_runner import ModelRunner
  7. # 设置日志
  8. logging.basicConfig(level=logging.INFO)
  9. logger = logging.getLogger(__name__)
  10. class ChatInterface:
  11. def __init__(self, model_runner):
  12. self.runner = model_runner
  13. self.chat_history = [] # 存储对话历史
  14. def start_chat(self):
  15. """启动对话交互"""
  16. print("\n========== DeepSeek 对话系统 ==========")
  17. print("输入 'exit' 或 'quit' 结束对话\n")
  18. while True:
  19. try:
  20. user_input = input("用户: ")
  21. if user_input.lower() in ["exit", "quit"]:
  22. print("对话结束。")
  23. break
  24. # 构建包含历史的完整提示
  25. full_prompt = self._build_prompt(user_input)
  26. # 生成回复(调用 generate 方法)
  27. full_response, metrics = self.runner.generate(
  28. prompts=full_prompt,
  29. max_new_tokens=Config.MAX_NEW_TOKENS,
  30. temperature=Config.TEMPERATURE
  31. )
  32. # 提取新生成的回复(去除历史部分)
  33. new_response = full_response[len(full_prompt):].strip()
  34. print(f"AI: {new_response}")
  35. # 更新对话历史(保留最近3轮对话)
  36. self._update_history(user_input, new_response)
  37. except KeyboardInterrupt:
  38. print("\n检测到中断,对话结束。")
  39. break
  40. except Exception as e:
  41. logger.error(f"对话出错: {e}")
  42. print("系统出现错误,请重新输入。")
  43. def _build_prompt(self, new_input):
  44. """构建包含历史记录的提示"""
  45. history = "\n".join([f"用户: {u}\nAI: {r}" for u, r in self.chat_history])
  46. return f"{history}\n用户: {new_input}\nAI:" if history else f"用户: {new_input}\nAI:"
  47. def _update_history(self, user_input, response):
  48. """维护对话历史(最多保留3轮)"""
  49. self.chat_history.append((user_input, response))
  50. if len(self.chat_history) > 3:
  51. self.chat_history.pop(0)
  52. def main():
  53. try:
  54. logger.info("Initializing chat system...")
  55. runner = ModelRunner()
  56. chat = ChatInterface(runner)
  57. chat.start_chat()
  58. except Exception as e:
  59. logger.error(f"An error occurred: {e}")
  60. if __name__ == "__main__":
  61. main()