|
@@ -1,5 +1,6 @@
|
|
|
import os
|
|
|
import torch
|
|
|
+import torch.distributed as dist
|
|
|
from unsloth import FastLanguageModel
|
|
|
from unsloth import is_bfloat16_supported
|
|
|
from trl import SFTTrainer, GRPOConfig, GRPOTrainer
|
|
@@ -236,18 +237,25 @@ if __name__ == "__main__":
|
|
|
train_data_path = os.path.join('..', 'data', 'processed', 'train.jsonl')
|
|
|
# train_data_path: 训练数据路径
|
|
|
|
|
|
- # 初始化 ModelTrainer
|
|
|
- trainer = ModelTrainer(model_name, max_seq_length, dtype, load_in_4bit,lora_rank)
|
|
|
-
|
|
|
- # 加载模型和分词器
|
|
|
- model, tokenizer = trainer.load_model()
|
|
|
-
|
|
|
- # 加载数据集
|
|
|
- train_dataset = trainer.load_data(train_data_path)
|
|
|
+ try:
|
|
|
+ # 初始化进程组
|
|
|
+ dist.init_process_group(backend='nccl', init_method='env://')
|
|
|
+ # 初始化 ModelTrainer
|
|
|
+ trainer = ModelTrainer(model_name, max_seq_length, dtype, load_in_4bit,lora_rank)
|
|
|
+
|
|
|
+ # 加载模型和分词器
|
|
|
+ model, tokenizer = trainer.load_model()
|
|
|
|
|
|
- # 训练模型
|
|
|
- model = trainer.train(model, tokenizer, train_dataset)
|
|
|
+ # 加载数据集
|
|
|
+ train_dataset = trainer.load_data(train_data_path)
|
|
|
|
|
|
- # 保存模型
|
|
|
- save_path = os.path.join('..', 'models', 'trained', 'DeepSeek-R1-Distill-Qwen-1.5B-GRPO')
|
|
|
- trainer.save_model(model, tokenizer, save_path)
|
|
|
+ # 训练模型
|
|
|
+ model = trainer.train(model, tokenizer, train_dataset)
|
|
|
+
|
|
|
+ # 保存模型
|
|
|
+ save_path = os.path.join('..', 'models', 'trained', 'DeepSeek-R1-Distill-Qwen-1.5B-GRPO')
|
|
|
+ trainer.save_model(model, tokenizer, save_path)
|
|
|
+ finally:
|
|
|
+ # 确保进程组被销毁
|
|
|
+ if dist.is_initialized():
|
|
|
+ dist.destroy_process_group()
|