瀏覽代碼

换用github jwjohns/unsloth-GRPO-qwen2.5 验证GRPO训练模型

zhouyang.xie 4 月之前
父節點
當前提交
8603d51a1c
共有 3 個文件被更改,包括 4 次插入3 次删除
  1. 2 2
      conf/conf_train.yaml
  2. 2 1
      src/model_downloader.py
  3. 0 0
      src/qwen_notebook_clone.py

+ 2 - 2
conf/conf_train.yaml

@@ -1,11 +1,11 @@
 # 模型配置
 model_name: "../models/pretrained/DeepSeek-R1-Distill-Qwen-1.5B"
-max_seq_length: 768  # 2048 单次会话的最大 token 长度
+max_seq_length: 512  # 2048 单次会话的最大 token 长度
 dtype: "float16"  # 数据类型,可选 "float16" 或 "bfloat16"
 load_in_4bit: True  # 是否以4位精度加载模型
 fast_inference: False # Enable vLLM fast inference
 lora_rank: 128  # LoRA 的 rank 值 Choose any number>0!suggested 8,16,32,64,128
-gpu_memory_utilization: 0.6 # GPU VRAM 占用率
+gpu_memory_utilization: 0.85 # GPU VRAM 占用率
 
 # 训练配置
 use_vllm: False # use vLLM for fast inference!

+ 2 - 1
src/model_downloader.py

@@ -11,7 +11,8 @@ from modelscope import snapshot_download
 # model_dir = snapshot_download('deepseek-ai/Janus-Pro-7B', cache_dir="../models/")
 
 # model_dir = snapshot_download('AI-ModelScope/bert-base-uncased', cache_dir="../models/")
-model_dir = snapshot_download('allenai/longformer-base-4096', cache_dir="../models/")
+# model_dir = snapshot_download('allenai/longformer-base-4096', cache_dir="../models/")
+model_dir = snapshot_download('Qwen/Qwen2.5-3B-Instruct', cache_dir="../models/")
 
 # 验证SDK token
 # 据模型源上传人说,模型支持华为 昇腾(Ascend) 910

+ 0 - 0
src/qwen_notebook_clone.py