main.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import sys
  2. import os
  3. import logging
  4. from src.config import Config
  5. from src.data_processor import DataProcessor
  6. from src.model_trainer import ModelTrainer
  7. from src.model_runner import ModelRunner
  8. # 设置日志
  9. logging.basicConfig(level=logging.INFO)
  10. logger = logging.getLogger(__name__)
  11. class ChatInterface:
  12. def __init__(self, model_runner):
  13. self.runner = model_runner
  14. self.chat_history = [] # 存储对话历史
  15. def start_chat(self):
  16. """启动对话交互"""
  17. print("\n========== DeepSeek 对话系统 ==========")
  18. print("输入 'exit' 或 'quit' 结束对话\n")
  19. while True:
  20. try:
  21. user_input = input("用户: ")
  22. if user_input.lower() in ["exit", "quit"]:
  23. print("对话结束。")
  24. break
  25. # 生成回复
  26. full_prompt = self._build_prompt(user_input)
  27. response = self.runner.generate_text(
  28. prompt=full_prompt,
  29. max_length=Config.MAX_LENGTH, # 适当增加生成长度
  30. temperature=Config.TEMPERATURE # 提高创造性
  31. )
  32. # 提取新生成的回复(去除历史重复)
  33. new_response = response[len(full_prompt):].strip()
  34. print(f"AI: {new_response}")
  35. # 更新对话历史(保留最近3轮对话避免过长)
  36. self._update_history(user_input, new_response)
  37. except KeyboardInterrupt:
  38. print("\n检测到中断,对话结束。")
  39. break
  40. except Exception as e:
  41. logger.error(f"对话出错: {e}")
  42. print("系统出现错误,请重新输入。")
  43. def _build_prompt(self, new_input):
  44. """构建包含历史记录的提示"""
  45. history = "\n".join([f"用户: {u}\nAI: {r}" for u, r in self.chat_history])
  46. return f"{history}\n用户: {new_input}\nAI:" if history else f"用户: {new_input}\nAI:"
  47. def _update_history(self, user_input, response):
  48. """维护对话历史(最多保留3轮)"""
  49. self.chat_history.append((user_input, response))
  50. if len(self.chat_history) > 3:
  51. self.chat_history.pop(0)
  52. def main():
  53. try:
  54. # 数据集生成与处理
  55. logger.info("Generating and processing data...")
  56. data_processor = DataProcessor()
  57. data_processor.process_data()
  58. # 模型训练
  59. logger.info("Training model...")
  60. trainer = ModelTrainer()
  61. trainer.train()
  62. # 初始化对话系统
  63. logger.info("Initializing chat system...")
  64. runner = ModelRunner()
  65. chat = ChatInterface(runner)
  66. # 启动对话
  67. chat.start_chat()
  68. except Exception as e:
  69. logger.error(f"An error occurred: {e}")
  70. if __name__ == "__main__":
  71. main()