瀏覽代碼

验证GRPO 训练的模型推理

zhouyang.xie 3 月之前
父節點
當前提交
d3713ffe46
共有 2 個文件被更改,包括 5 次插入5 次删除
  1. 1 1
      src/inference.py
  2. 4 4
      src/train_model_grpo.py

+ 1 - 1
src/inference.py

@@ -49,7 +49,7 @@ class ModelInference:
 
 if __name__ == "__main__":
     # 配置参数
-    model_path = os.path.join('..', 'models', 'trained', 'DeepSeek-R1-Distill-Qwen-1.5B')
+    model_path = os.path.join('..', 'models', 'trained', 'DeepSeek-R1-Distill-Qwen-1.5B-GRPO')
     max_seq_length = 2048
     dtype = torch.float16
     load_in_4bit = True

+ 4 - 4
src/train_model_grpo.py

@@ -162,9 +162,9 @@ class ModelTrainer:
             fp16 = not is_bfloat16_supported(),
             per_device_train_batch_size = 1,
             gradient_accumulation_steps = 1, # Increase to 4 for smoother training
-            num_generations = 4, # Decrease if out of memory
-            max_prompt_length = 256,
-            max_completion_length = 200,
+            num_generations = 4, # 每次生成 4 个输出
+            max_prompt_length = 256, # 输入提示的最大长度
+            max_completion_length = 200, # 生成内容的最大长度
             # num_train_epochs = 1, # Set to 1 for a full training run
             max_steps = 20,  # 250
             save_steps = 20, # 250
@@ -203,7 +203,7 @@ if __name__ == "__main__":
     # 配置参数
     model_name = os.path.join('..', 'models', 'pretrained', 'DeepSeek-R1-Distill-Qwen-1.5B')
     # model_name: 预训练模型的路径
-    max_seq_length = 2048  # 最大序列长度
+    max_seq_length = 6000  # 单次会话(single session) 的最大 token 长度,一个token大约3-4 字节(Byte)
     dtype = torch.float16  # 数据类型
     load_in_4bit = True  # 是否以4位精度加载模型
     lora_rank=16