فهرست منبع

遵循面向对象思想重构train_model_grpo_v1.1.py 去掉分布式及相应环境变量设置代码

zhouyang.xie 3 ماه پیش
والد
کامیت
1bef66edba
1فایلهای تغییر یافته به همراه49 افزوده شده و 46 حذف شده
  1. 49 46
      src/train_model_grpo_v1.1.py

+ 49 - 46
src/train_model_grpo_v1.1.py

@@ -347,49 +347,52 @@ class ModelTrainer:
         return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
 
 if __name__ == "__main__":
-    # 加载配置文件
-    config = load_config(f"../conf/conf_train.yaml")
-
-    # 设置环境变量
-    """
-        # 多机多卡
-        # export RANK=0  # 第一台机器的 rank
-        # export WORLD_SIZE=4  # 总共有 4 台机器
-        # export MASTER_ADDR=<主节点 IP>
-        # export MASTER_PORT=12345
-
-    """
-    # 单机多卡
-    os.environ['RANK'] = '0'
-    os.environ['WORLD_SIZE'] = '1'
-    os.environ['MASTER_ADDR'] = 'localhost'
-    os.environ['MASTER_PORT'] = '12345'
-
-    # 根据操作系统选择后端
-    backend = 'gloo' if os.name == 'nt' else 'nccl'
-
-    # 使用文件初始化方法  2025-3-11 成功验证支持windows
-    init_method = f'env://'  # env://  # 文件路径需要所有进程都能访问
-    dist.init_process_group(backend=backend, init_method=init_method)
-
-    print(f"Initialized distributed training with backend: {backend}")
-
-    # 初始化 ModelTrainer
-    trainer = ModelTrainer(config)
-
-    # 加载模型和分词器
-    model, tokenizer = trainer.load_model()
-
-    # 加载数据集
-    train_dataset = trainer.load_data(config.train_data_path)
-
-    # 训练模型
-    model = trainer.train(model, tokenizer, train_dataset)
-
-    # 保存模型
-    trainer.save_model(model, tokenizer, config.save_path)
-
-    # 确保进程组被销毁
-    if dist.is_initialized():
-        dist.destroy_process_group()
-    print("Training completed.")
+    try:
+        # 加载配置文件
+        config = load_config(f"../conf/conf_train.yaml")
+
+        # 设置环境变量
+        """
+            # 多机多卡
+            # export RANK=0  # 第一台机器的 rank
+            # export WORLD_SIZE=4  # 总共有 4 台机器
+            # export MASTER_ADDR=<主节点 IP>
+            # export MASTER_PORT=12345
+
+            # 单机多卡
+            os.environ['RANK'] = '0'
+            os.environ['WORLD_SIZE'] = '1'
+            os.environ['MASTER_ADDR'] = 'localhost'
+            os.environ['MASTER_PORT'] = '12345'
+
+            # 根据操作系统选择后端
+            backend = 'gloo' if os.name == 'nt' else 'nccl'
+            # 使用文件初始化方法  2025-3-11 成功验证支持windows
+            init_method = f'env://'  # env://  # 文件路径需要所有进程都能访问
+            dist.init_process_group(backend=backend, init_method=init_method)
+            print(f"Initialized distributed training with backend: {backend}")
+        """
+
+        # 初始化 ModelTrainer
+        trainer = ModelTrainer(config)
+
+        # 加载模型和分词器
+        model, tokenizer = trainer.load_model()
+
+        # 加载数据集
+        train_dataset = trainer.load_data(config.train_data_path)
+
+        # 训练模型
+        model = trainer.train(model, tokenizer, train_dataset)
+
+        # 保存模型
+        trainer.save_model(model, tokenizer, config.save_path)
+        
+        print("Training completed.")
+    except Exception as e:
+        print("exception \n ",e)
+    finally:
+        # # 确保进程组被销毁
+        # if dist.is_initialized():
+        #     dist.destroy_process_group()
+        print("end")