瀏覽代碼

遵循面向对象思想重构train_model_grpo.py

zhouyang.xie 2 月之前
父節點
當前提交
cd5cc43486
共有 2 個文件被更改,包括 5 次插入3 次删除
  1. 3 1
      src/model_downloader.py
  2. 2 2
      src/train_model_grpo_v1.1.py

+ 3 - 1
src/model_downloader.py

@@ -6,10 +6,12 @@ from modelscope import snapshot_download
 # model_dir = snapshot_download('deepseek-ai/DeepSeek-R1-Distill-Llama-70B', cache_dir="../models/")
 # model_dir = snapshot_download('deepseek-ai/DeepSeek-R1-Distill-Qwen-32B', cache_dir="../models/")
 # model_dir = snapshot_download('deepseek-ai/DeepSeek-R1-Distill-Llama-8B', cache_dir="../models/")
-model_dir = snapshot_download('deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', cache_dir="../models/")
+# model_dir = snapshot_download('deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', cache_dir="../models/")
 # model_dir = snapshot_download('deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', cache_dir="../models/")
 # model_dir = snapshot_download('deepseek-ai/Janus-Pro-7B', cache_dir="../models/")
 
+model_dir = snapshot_download('AI-ModelScope/bert-base-uncased', cache_dir="../models/")
+
 # 验证SDK token
 # 据模型源上传人说,模型支持华为 昇腾(Ascend) 910
 # from modelscope.hub.api import HubApi

+ 2 - 2
src/train_model_grpo_v1.1.py

@@ -27,8 +27,8 @@ class ModelTrainer:
         self.lora_rank = config.lora_rank
         self.gpu_memory_utilization = config.gpu_memory_utilization
         # 初始化 BERT 模型和分词器
-        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
-        self.bert_model = BertModel.from_pretrained('bert-base-uncased')
+        self.tokenizer = BertTokenizer.from_pretrained(f'../models/AI-ModelScope/bert-base-uncased')
+        self.bert_model = BertModel.from_pretrained(f'../models/AI-ModelScope/bert-base-uncased')
 
     def load_model(self):
         """