Przeglądaj źródła

遵循面向对象思想重构train_model_grpo.py

zhouyang.xie 4 miesięcy temu
rodzic
commit
6529570532
1 zmienionych plików z 4 dodań i 3 usunięć
  1. 4 3
      src/train_model_grpo_v1.1.py

+ 4 - 3
src/train_model_grpo_v1.1.py

@@ -9,7 +9,8 @@ from trl import GRPOConfig, GRPOTrainer
 from datasets import load_dataset
 from conf_train import Config, load_config  # 导入配置文件
 import re
-from transformers import BertTokenizer, BertModel
+# from transformers import BertTokenizer, BertModel
+from transformers import LongformerTokenizer, LongformerModel
 import numpy as np
 
 class ModelTrainer:
@@ -27,8 +28,8 @@ class ModelTrainer:
         self.lora_rank = config.lora_rank
         self.gpu_memory_utilization = config.gpu_memory_utilization
         # 初始化 BERT 模型和分词器
-        self.tokenizer = BertTokenizer.from_pretrained(f'../models/allenai/longformer-base-4096')
-        self.bert_model = BertModel.from_pretrained(f'../models/allenai/longformer-base-4096')
+        self.tokenizer = LongformerTokenizer.from_pretrained(f'../models/allenai/longformer-base-4096')
+        self.bert_model = LongformerModel.from_pretrained(f'../models/allenai/longformer-base-4096')
 
     def load_model(self):
         """