ソースを参照

遵循面向对象思想重构train_model_grpo.py

zhouyang.xie 2 ヶ月 前
コミット
77168b22a2
3 ファイル変更6 行追加6 行削除
  1. 2 2
      conf/conf_train.yaml
  2. 0 0
      src/train_model_grpo_original.py
  3. 4 4
      src/train_model_grpo_v1.1.py

+ 2 - 2
conf/conf_train.yaml

@@ -1,6 +1,6 @@
 # 模型配置
 model_name: "../models/pretrained/DeepSeek-R1-Distill-Qwen-1.5B"
-max_seq_length: 6144  # 单次会话的最大 token 长度
+max_seq_length: 6144  # 2048 单次会话的最大 token 长度
 dtype: "float16"  # 数据类型,可选 "float16" 或 "bfloat16"
 load_in_4bit: True  # 是否以4位精度加载模型
 fast_inference: False # Enable vLLM fast inference
@@ -19,7 +19,7 @@ optim: "adamw_8bit"  # 优化器类型
 logging_steps: 1  # 日志记录步数
 per_device_train_batch_size: 1  # 每个设备的训练批次大小
 gradient_accumulation_steps: 1  # 梯度累积步数
-num_generations: 8  # 每次生成的输出个数
+num_generations: 8  # 8 每次生成的输出个数
 max_prompt_length: 256  # 输入提示的最大长度
 max_completion_length: 200  # 生成内容的最大长度
 num_train_epochs: 1  # 训练轮数

+ 0 - 0
src/train_model_grpo.py → src/train_model_grpo_original.py


+ 4 - 4
src/train_model_grpo_v1.1.py

@@ -9,8 +9,8 @@ from trl import GRPOConfig, GRPOTrainer
 from datasets import load_dataset
 from conf_train import Config, load_config  # 导入配置文件
 import re
-# from transformers import BertTokenizer, BertModel
-from transformers import LongformerTokenizer, LongformerModel
+# from transformers import BertTokenizer, BertModel  # 分词模型最大支持 512 个token
+from transformers import LongformerTokenizer, LongformerModel # 分词模型最大支持 4096 个token
 import numpy as np
 
 class ModelTrainer:
@@ -128,8 +128,8 @@ class ModelTrainer:
             model=model,
             processing_class=tokenizer,
             reward_funcs=[
-                # self.xmlcount_reward_func,
-                # self.soft_format_reward_func,
+                self.xmlcount_reward_func,
+                self.soft_format_reward_func,
                 # self.strict_format_reward_func,
                 # self.int_reward_func,
                 # self.correctness_reward_func,