|
@@ -9,8 +9,8 @@ from trl import GRPOConfig, GRPOTrainer
|
|
|
from datasets import load_dataset
|
|
|
from conf_train import Config, load_config # 导入配置文件
|
|
|
import re
|
|
|
-# from transformers import BertTokenizer, BertModel
|
|
|
-from transformers import LongformerTokenizer, LongformerModel
|
|
|
+# from transformers import BertTokenizer, BertModel # 分词模型最大支持 512 个token
|
|
|
+from transformers import LongformerTokenizer, LongformerModel # 分词模型最大支持 4096 个token
|
|
|
import numpy as np
|
|
|
|
|
|
class ModelTrainer:
|
|
@@ -128,8 +128,8 @@ class ModelTrainer:
|
|
|
model=model,
|
|
|
processing_class=tokenizer,
|
|
|
reward_funcs=[
|
|
|
- # self.xmlcount_reward_func,
|
|
|
- # self.soft_format_reward_func,
|
|
|
+ self.xmlcount_reward_func,
|
|
|
+ self.soft_format_reward_func,
|
|
|
# self.strict_format_reward_func,
|
|
|
# self.int_reward_func,
|
|
|
# self.correctness_reward_func,
|