Jelajahi Sumber

修改train_model_grpo.py文件-开启vLLM 观察能否解决损失率值0并且无变化问题

zhouyang.xie 3 bulan lalu
induk
melakukan
a8baf58623
1 mengubah file dengan 9 tambahan dan 8 penghapusan
  1. 9 8
      src/train_model_grpo.py

+ 9 - 8
src/train_model_grpo.py

@@ -241,10 +241,10 @@ if __name__ == "__main__":
     try:
         # 设置环境变量
         # 单机多卡
-        os.environ['RANK'] = '0' # 第一张卡的 rank
-        os.environ['WORLD_SIZE'] = '1'  # 总共有 1 张卡
-        os.environ['MASTER_ADDR'] = 'localhost'
-        os.environ['MASTER_PORT'] = '12345'
+        # os.environ['RANK'] = '0' # 第一张卡的 rank
+        # os.environ['WORLD_SIZE'] = '1'  # 总共有 1 张卡
+        # os.environ['MASTER_ADDR'] = 'localhost'
+        # os.environ['MASTER_PORT'] = '12345'
         # 多机多卡
         # export RANK=0  # 第一台机器的 rank
         # export WORLD_SIZE=4  # 总共有 4 台机器
@@ -252,7 +252,7 @@ if __name__ == "__main__":
         # export MASTER_PORT=12345
 
         # 初始化进程组
-        dist.init_process_group(backend='nccl', init_method='env://')
+        # dist.init_process_group(backend='nccl', init_method='env://')
         # 初始化 ModelTrainer
         trainer = ModelTrainer(model_name, max_seq_length, dtype, load_in_4bit,lora_rank)
         
@@ -269,6 +269,7 @@ if __name__ == "__main__":
         save_path = os.path.join('..', 'models', 'trained', 'DeepSeek-R1-Distill-Qwen-1.5B-GRPO')
         trainer.save_model(model, tokenizer, save_path)
     finally:
-        # 确保进程组被销毁
-        if dist.is_initialized():
-            dist.destroy_process_group()
+        # # 确保进程组被销毁
+        # if dist.is_initialized():
+        #     dist.destroy_process_group()
+        print("train finally")