|
@@ -347,49 +347,52 @@ class ModelTrainer:
|
|
|
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- # 加载配置文件
|
|
|
- config = load_config(f"../conf/conf_train.yaml")
|
|
|
-
|
|
|
- # 设置环境变量
|
|
|
- """
|
|
|
- # 多机多卡
|
|
|
- # export RANK=0 # 第一台机器的 rank
|
|
|
- # export WORLD_SIZE=4 # 总共有 4 台机器
|
|
|
- # export MASTER_ADDR=<主节点 IP>
|
|
|
- # export MASTER_PORT=12345
|
|
|
-
|
|
|
- """
|
|
|
- # 单机多卡
|
|
|
- os.environ['RANK'] = '0'
|
|
|
- os.environ['WORLD_SIZE'] = '1'
|
|
|
- os.environ['MASTER_ADDR'] = 'localhost'
|
|
|
- os.environ['MASTER_PORT'] = '12345'
|
|
|
-
|
|
|
- # 根据操作系统选择后端
|
|
|
- backend = 'gloo' if os.name == 'nt' else 'nccl'
|
|
|
-
|
|
|
- # 使用文件初始化方法 2025-3-11 成功验证支持windows
|
|
|
- init_method = f'env://' # env:// # 文件路径需要所有进程都能访问
|
|
|
- dist.init_process_group(backend=backend, init_method=init_method)
|
|
|
-
|
|
|
- print(f"Initialized distributed training with backend: {backend}")
|
|
|
-
|
|
|
- # 初始化 ModelTrainer
|
|
|
- trainer = ModelTrainer(config)
|
|
|
-
|
|
|
- # 加载模型和分词器
|
|
|
- model, tokenizer = trainer.load_model()
|
|
|
-
|
|
|
- # 加载数据集
|
|
|
- train_dataset = trainer.load_data(config.train_data_path)
|
|
|
-
|
|
|
- # 训练模型
|
|
|
- model = trainer.train(model, tokenizer, train_dataset)
|
|
|
-
|
|
|
- # 保存模型
|
|
|
- trainer.save_model(model, tokenizer, config.save_path)
|
|
|
-
|
|
|
- # 确保进程组被销毁
|
|
|
- if dist.is_initialized():
|
|
|
- dist.destroy_process_group()
|
|
|
- print("Training completed.")
|
|
|
+ try:
|
|
|
+ # 加载配置文件
|
|
|
+ config = load_config(f"../conf/conf_train.yaml")
|
|
|
+
|
|
|
+ # 设置环境变量
|
|
|
+ """
|
|
|
+ # 多机多卡
|
|
|
+ # export RANK=0 # 第一台机器的 rank
|
|
|
+ # export WORLD_SIZE=4 # 总共有 4 台机器
|
|
|
+ # export MASTER_ADDR=<主节点 IP>
|
|
|
+ # export MASTER_PORT=12345
|
|
|
+
|
|
|
+ # 单机多卡
|
|
|
+ os.environ['RANK'] = '0'
|
|
|
+ os.environ['WORLD_SIZE'] = '1'
|
|
|
+ os.environ['MASTER_ADDR'] = 'localhost'
|
|
|
+ os.environ['MASTER_PORT'] = '12345'
|
|
|
+
|
|
|
+ # 根据操作系统选择后端
|
|
|
+ backend = 'gloo' if os.name == 'nt' else 'nccl'
|
|
|
+ # 使用文件初始化方法 2025-3-11 成功验证支持windows
|
|
|
+ init_method = f'env://' # env:// # 文件路径需要所有进程都能访问
|
|
|
+ dist.init_process_group(backend=backend, init_method=init_method)
|
|
|
+ print(f"Initialized distributed training with backend: {backend}")
|
|
|
+ """
|
|
|
+
|
|
|
+ # 初始化 ModelTrainer
|
|
|
+ trainer = ModelTrainer(config)
|
|
|
+
|
|
|
+ # 加载模型和分词器
|
|
|
+ model, tokenizer = trainer.load_model()
|
|
|
+
|
|
|
+ # 加载数据集
|
|
|
+ train_dataset = trainer.load_data(config.train_data_path)
|
|
|
+
|
|
|
+ # 训练模型
|
|
|
+ model = trainer.train(model, tokenizer, train_dataset)
|
|
|
+
|
|
|
+ # 保存模型
|
|
|
+ trainer.save_model(model, tokenizer, config.save_path)
|
|
|
+
|
|
|
+ print("Training completed.")
|
|
|
+ except Exception as e:
|
|
|
+ print("exception \n ",e)
|
|
|
+ finally:
|
|
|
+ # # 确保进程组被销毁
|
|
|
+ # if dist.is_initialized():
|
|
|
+ # dist.destroy_process_group()
|
|
|
+ print("end")
|