Kaynağa Gözat

修改train_model_grpo.py代码,尝试使用显卡的共享内存

zhouyang.xie 3 ay önce
ebeveyn
işleme
1d5a7faf12
1 değiştirilmiş dosya ile 17 ekleme ve 4 silme
  1. 17 4
      src/train_model_grpo.py

+ 17 - 4
src/train_model_grpo.py

@@ -117,9 +117,9 @@ class ModelTrainer:
             max_seq_length=self.max_seq_length,
             load_in_4bit=self.load_in_4bit, # 值为True 以 4 bit量化进行微调,为False LoRA 16bit。这将内存使用量减少了 4 倍,使我们能够在免费的 16GB 内存 GPU 中实际进行微调。4 位量化本质上将权重转换为一组有限的数字以减少内存使用量。这样做的缺点是准确度会下降 1-2%。如果您想要这种微小的额外准确度,请在较大的 GPU(如 H100)上将其设置为 False。
             dtype=self.dtype,
-            fast_inference = False, # Enable vLLM fast inference
+            fast_inference = True, # Enable vLLM fast inference
             max_lora_rank = lora_rank,
-            gpu_memory_utilization=0.1, # 0.6 # Reduce if out of memory
+            gpu_memory_utilization=0.6, # 0.6 # Reduce if out of memory
         )
         
         # 添加 LoRA 适配器
@@ -143,13 +143,26 @@ class ModelTrainer:
     def load_data(self, train_data_path):
         # 加载训练集和测试集
         train_dataset = load_dataset("json", data_files={"train": train_data_path}, split="train")
+
+        train_loader = torch.utils.data.DataLoader(
+            train_dataset, batch_size=1, shuffle=True, pin_memory=True  # 启用 pin_memory
+        )
+
         # train_data_path: 训练数据路径,格式为 JSONL
         return train_dataset
 
     def train(self, model, tokenizer, train_dataset):
-        print("is_bfloat16_supported()=",is_bfloat16_supported())
+        print("is_bfloat16_supported()=", is_bfloat16_supported())
+        
+        # 监控显存使用情况
+        print(f"Reserved memory: {torch.cuda.memory_reserved()}")
+        print(f"Allocated memory: {torch.cuda.memory_allocated()}")
+        
+        # 释放未使用的显存
+        torch.cuda.empty_cache()
+
         training_args = GRPOConfig(
-            use_vllm = False, # use vLLM for fast inference!
+            use_vllm = True, # use vLLM for fast inference!
             learning_rate = 5e-6,
             adam_beta1 = 0.9,
             adam_beta2 = 0.99,