Преглед на файлове

修改train_model_grpo.py文件-开启vLLM 观察能否解决损失率值0并且无变化问题

zhouyang.xie преди 3 месеца
родител
ревизия
ce572ff677
променени са 1 файла, в които са добавени 21 реда и са изтрити 13 реда
  1. 21 13
      src/train_model_grpo.py

+ 21 - 13
src/train_model_grpo.py

@@ -1,5 +1,6 @@
 import os
 import torch
+import torch.distributed as dist
 from unsloth import FastLanguageModel
 from unsloth import is_bfloat16_supported
 from trl import SFTTrainer, GRPOConfig, GRPOTrainer
@@ -236,18 +237,25 @@ if __name__ == "__main__":
     train_data_path = os.path.join('..', 'data', 'processed', 'train.jsonl')
     # train_data_path: 训练数据路径
 
-    # 初始化 ModelTrainer
-    trainer = ModelTrainer(model_name, max_seq_length, dtype, load_in_4bit,lora_rank)
-    
-    # 加载模型和分词器
-    model, tokenizer = trainer.load_model()
-
-    # 加载数据集
-    train_dataset = trainer.load_data(train_data_path)
+    try:
+        # 初始化进程组
+        dist.init_process_group(backend='nccl', init_method='env://')
+        # 初始化 ModelTrainer
+        trainer = ModelTrainer(model_name, max_seq_length, dtype, load_in_4bit,lora_rank)
+        
+        # 加载模型和分词器
+        model, tokenizer = trainer.load_model()
 
-    # 训练模型
-    model = trainer.train(model, tokenizer, train_dataset)
+        # 加载数据集
+        train_dataset = trainer.load_data(train_data_path)
 
-    # 保存模型
-    save_path = os.path.join('..', 'models', 'trained', 'DeepSeek-R1-Distill-Qwen-1.5B-GRPO')
-    trainer.save_model(model, tokenizer, save_path)
+        # 训练模型
+        model = trainer.train(model, tokenizer, train_dataset)
+
+        # 保存模型
+        save_path = os.path.join('..', 'models', 'trained', 'DeepSeek-R1-Distill-Qwen-1.5B-GRPO')
+        trainer.save_model(model, tokenizer, save_path)
+    finally:
+        # 确保进程组被销毁
+        if dist.is_initialized():
+            dist.destroy_process_group()