ソースを参照

修改train_model_grpo.py文件-开启vLLM 观察能否解决损失率值0并且无变化问题

zhouyang.xie 3 ヶ月 前
コミット
e8e87e7975
1 ファイル変更12 行追加0 行削除
  1. 12 0
      src/train_model_grpo.py

+ 12 - 0
src/train_model_grpo.py

@@ -238,6 +238,18 @@ if __name__ == "__main__":
     # train_data_path: 训练数据路径
 
     try:
+        # 设置环境变量
+        # 单机多卡
+        os.environ['RANK'] = '0' # 第一张卡的 rank
+        os.environ['WORLD_SIZE'] = '1'  # 总共有 1 张卡
+        os.environ['MASTER_ADDR'] = 'localhost'
+        os.environ['MASTER_PORT'] = '12345'
+        # 多机多卡
+        # export RANK=0  # 第一台机器的 rank
+        # export WORLD_SIZE=4  # 总共有 4 台机器
+        # export MASTER_ADDR=<主节点 IP>
+        # export MASTER_PORT=12345
+
         # 初始化进程组
         dist.init_process_group(backend='nccl', init_method='env://')
         # 初始化 ModelTrainer