runModel.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import sys
  2. import os
  3. import logging
  4. from src.config import Config
  5. from src.data_processor import DataProcessor
  6. from src.model_trainer import ModelTrainer
  7. from src.model_runner import ModelRunner
  8. # 设置日志
  9. logging.basicConfig(level=logging.INFO)
  10. logger = logging.getLogger(__name__)
  11. class ChatInterface:
  12. def __init__(self, model_runner):
  13. self.runner = model_runner
  14. self.chat_history = [] # 存储对话历史
  15. def start_chat(self):
  16. """启动对话交互"""
  17. print("\n========== DeepSeek 对话系统 ==========")
  18. print("输入 'exit' 或 'quit' 结束对话\n")
  19. while True:
  20. try:
  21. user_input = input("用户: ")
  22. if user_input.lower() in ["exit", "quit"]:
  23. print("对话结束。")
  24. break
  25. # 生成回复
  26. full_prompt = self._build_prompt(user_input)
  27. response = self.runner.generate_text(
  28. prompt=full_prompt,
  29. max_length=Config.MAX_LENGTH, # 适当增加生成长度
  30. temperature=Config.TEMPERATURE # 提高创造性
  31. )
  32. # 提取新生成的回复(去除历史重复)
  33. new_response = response[len(full_prompt):].strip()
  34. print(f"AI: {new_response}")
  35. # 更新对话历史(保留最近3轮对话避免过长)
  36. self._update_history(user_input, new_response)
  37. except KeyboardInterrupt:
  38. print("\n检测到中断,对话结束。")
  39. break
  40. except Exception as e:
  41. logger.error(f"对话出错: {e}")
  42. print("系统出现错误,请重新输入。")
  43. def _build_prompt(self, new_input):
  44. """构建包含历史记录的提示"""
  45. history = "\n".join([f"用户: {u}\nAI: {r}" for u, r in self.chat_history])
  46. return f"{history}\n用户: {new_input}\nAI:" if history else f"用户: {new_input}\nAI:"
  47. def _update_history(self, user_input, response):
  48. """维护对话历史(最多保留3轮)"""
  49. self.chat_history.append((user_input, response))
  50. if len(self.chat_history) > 3:
  51. self.chat_history.pop(0)
  52. def main():
  53. try:
  54. # 初始化对话系统
  55. logger.info("Initializing chat system...")
  56. runner = ModelRunner()
  57. chat = ChatInterface(runner)
  58. # 启动对话
  59. chat.start_chat()
  60. except Exception as e:
  61. logger.error(f"An error occurred: {e}")
  62. if __name__ == "__main__":
  63. main()