|
@@ -9,7 +9,8 @@ from trl import GRPOConfig, GRPOTrainer
|
|
|
from datasets import load_dataset
|
|
|
from conf_train import Config, load_config # 导入配置文件
|
|
|
import re
|
|
|
-from transformers import BertTokenizer, BertModel
|
|
|
+# from transformers import BertTokenizer, BertModel
|
|
|
+from transformers import LongformerTokenizer, LongformerModel
|
|
|
import numpy as np
|
|
|
|
|
|
class ModelTrainer:
|
|
@@ -27,8 +28,8 @@ class ModelTrainer:
|
|
|
self.lora_rank = config.lora_rank
|
|
|
self.gpu_memory_utilization = config.gpu_memory_utilization
|
|
|
# 初始化 BERT 模型和分词器
|
|
|
- self.tokenizer = BertTokenizer.from_pretrained(f'../models/allenai/longformer-base-4096')
|
|
|
- self.bert_model = BertModel.from_pretrained(f'../models/allenai/longformer-base-4096')
|
|
|
+ self.tokenizer = LongformerTokenizer.from_pretrained(f'../models/allenai/longformer-base-4096')
|
|
|
+ self.bert_model = LongformerModel.from_pretrained(f'../models/allenai/longformer-base-4096')
|
|
|
|
|
|
def load_model(self):
|
|
|
"""
|