Quellcode durchsuchen

修改train_model_grpo.py代码,验证GRPO训练模型,使用显卡共享内存 pin_memory=True 验证失败——内存

zhouyang.xie vor 3 Monaten
Ursprung
Commit
e67953b407
1 geänderte Dateien mit 11 neuen und 3 gelöschten Zeilen
  1. 11 3
      src/train_model_grpo.py

+ 11 - 3
src/train_model_grpo.py

@@ -122,6 +122,14 @@ class ModelTrainer:
             gpu_memory_utilization=0.6, # 0.6 # Reduce if out of memory
         )
         
+        # 将模型移动到设备上
+        model = model.to_empty(device='cuda')  # 使用 to_empty 而不是 to
+
+        # 初始化模型的权重
+        for param in model.parameters():
+            if param.is_meta:
+                param.data = torch.randn_like(param)  # 随机初始化
+
         # 添加 LoRA 适配器
         model = FastLanguageModel.get_peft_model(
             model,
@@ -144,9 +152,9 @@ class ModelTrainer:
         # 加载训练集和测试集
         train_dataset = load_dataset("json", data_files={"train": train_data_path}, split="train")
 
-        # train_loader = torch.utils.data.DataLoader(
-        #     train_dataset, batch_size=1, shuffle=True, pin_memory=True  # 启用 pin_memory  2025年3月7日未能验证通过
-        # )
+        train_loader = torch.utils.data.DataLoader(
+            train_dataset, batch_size=1, shuffle=True, pin_memory=True  # 启用 pin_memory  2025年3月7日未能验证通过
+        )
 
         # train_data_path: 训练数据路径,格式为 JSONL
         return train_dataset