Explorar el Código

调整train_model_grpo.py 提高vLLM 对VRAM的占用

zhouyang.xie hace 4 meses
padre
commit
1cfcf6522b
Se han modificado 2 ficheros con 4 adiciones y 4 borrados
  1. 1 1
      README.MD
  2. 3 3
      src/train_model_grpo.py

+ 1 - 1
README.MD

@@ -1232,7 +1232,7 @@ logging.basicConfig(level=logging.INFO)
    **Linux环境**
    支持pip install triton的安装
    
-5. pip install scikit-learn vllm  其中 vllm 不支持windows
+5. pip install scikit-learn vllm addict 其中 vllm 不支持windows
  
 
 

+ 3 - 3
src/train_model_grpo.py

@@ -119,7 +119,7 @@ class ModelTrainer:
             dtype=self.dtype,
             fast_inference = True, # Enable vLLM fast inference
             max_lora_rank = lora_rank,
-            gpu_memory_utilization=0.6,# Reduce if out of memory
+            gpu_memory_utilization=0.8,# Reduce if out of memory
         )
         
         # 添加 LoRA 适配器
@@ -162,7 +162,7 @@ class ModelTrainer:
             fp16 = not is_bfloat16_supported(),
             per_device_train_batch_size = 1,
             gradient_accumulation_steps = 1, # Increase to 4 for smoother training
-            num_generations = 8, # Decrease if out of memory
+            num_generations = 4, # Decrease if out of memory
             max_prompt_length = 256,
             max_completion_length = 200,
             # num_train_epochs = 1, # Set to 1 for a full training run
@@ -203,7 +203,7 @@ if __name__ == "__main__":
     # 配置参数
     model_name = os.path.join('..', 'models', 'pretrained', 'DeepSeek-R1-Distill-Qwen-1.5B')
     # model_name: 预训练模型的路径
-    max_seq_length = 2048  # 最大序列长度
+    max_seq_length = 1024 # 2048  # 最大序列长度
     dtype = torch.float16  # 数据类型
     load_in_4bit = True  # 是否以4位精度加载模型
     lora_rank=32