123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- import sys
- import os
- import logging
- from src.config import Config
- from src.data_processor import DataProcessor
- from src.model_trainer import ModelTrainer
- from src.model_runner import ModelRunner
- # 设置日志
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- class ChatInterface:
- def __init__(self, model_runner):
- self.runner = model_runner
- self.chat_history = [] # 存储对话历史
-
- def start_chat(self):
- """启动对话交互"""
- print("\n========== DeepSeek 对话系统 ==========")
- print("输入 'exit' 或 'quit' 结束对话\n")
-
- while True:
- try:
- user_input = input("用户: ")
- if user_input.lower() in ["exit", "quit"]:
- print("对话结束。")
- break
-
- # 生成回复
- full_prompt = self._build_prompt(user_input)
- response = self.runner.generate_text(
- prompt=full_prompt,
- max_length=Config.MAX_LENGTH, # 适当增加生成长度
- temperature=Config.TEMPERATURE # 提高创造性
- )
-
- # 提取新生成的回复(去除历史重复)
- new_response = response[len(full_prompt):].strip()
- print(f"AI: {new_response}")
-
- # 更新对话历史(保留最近3轮对话避免过长)
- self._update_history(user_input, new_response)
-
- except KeyboardInterrupt:
- print("\n检测到中断,对话结束。")
- break
- except Exception as e:
- logger.error(f"对话出错: {e}")
- print("系统出现错误,请重新输入。")
-
- def _build_prompt(self, new_input):
- """构建包含历史记录的提示"""
- history = "\n".join([f"用户: {u}\nAI: {r}" for u, r in self.chat_history])
- return f"{history}\n用户: {new_input}\nAI:" if history else f"用户: {new_input}\nAI:"
-
- def _update_history(self, user_input, response):
- """维护对话历史(最多保留3轮)"""
- self.chat_history.append((user_input, response))
- if len(self.chat_history) > 3:
- self.chat_history.pop(0)
- def main():
- try:
- # 数据集生成与处理
- logger.info("Generating and processing data...")
- data_processor = DataProcessor()
- data_processor.process_data()
- # 模型训练
- logger.info("Training model...")
- trainer = ModelTrainer()
- trainer.train()
- # 初始化对话系统
- logger.info("Initializing chat system...")
- runner = ModelRunner()
- chat = ChatInterface(runner)
-
- # 启动对话
- chat.start_chat()
- except Exception as e:
- logger.error(f"An error occurred: {e}")
- if __name__ == "__main__":
- main()
|