Jelajahi Sumber

遵循面向对象思想重构train_model_grpo.py

zhouyang.xie 3 bulan lalu
induk
melakukan
71cb307569
2 mengubah file dengan 14 tambahan dan 18 penghapusan
  1. 1 1
      src/model_downloader.py
  2. 13 17
      src/train_model_grpo_v1.1.py

+ 1 - 1
src/model_downloader.py

@@ -11,7 +11,7 @@ from modelscope import snapshot_download
 # model_dir = snapshot_download('deepseek-ai/Janus-Pro-7B', cache_dir="../models/")
 
 # model_dir = snapshot_download('AI-ModelScope/bert-base-uncased', cache_dir="../models/")
-model_dir = snapshot_download('allenai/longformer-base-4096', cache_dir="../models/)
+model_dir = snapshot_download('allenai/longformer-base-4096', cache_dir="../models/")
 
 # 验证SDK token
 # 据模型源上传人说,模型支持华为 昇腾(Ascend) 910

+ 13 - 17
src/train_model_grpo_v1.1.py

@@ -172,28 +172,24 @@ class ModelTrainer:
 
     def semantic_correctness_reward_func(self, prompts, completions, answer, **kwargs):
         """
-        使用 BERT 计算生成答案与标准答案的语义相似度。
+        使用 Longformer 计算生成答案与标准答案的语义相似度。
         """
         responses = [completion[0]['content'] for completion in completions]
         extracted_responses = [self.extract_xml_answer(r) for r in responses]
         scores = []
         for resp, ans in zip(extracted_responses, answer):
-            # 分块处理长文本
-            resp_chunks = [resp[i:i + 500] for i in range(0, len(resp), 500)]  # 每块 500 个字符
-            ans_chunks = [ans[i:i + 500] for i in range(0, len(ans), 500)]  # 每块 500 个字符
-            chunk_similarities = []
-            for resp_chunk, ans_chunk in zip(resp_chunks, ans_chunks):
-                # 编码生成答案和标准答案
-                inputs_resp = self.tokenizer(resp_chunk, return_tensors='pt', padding=True, truncation=True, max_length=512)
-                inputs_ans = self.tokenizer(ans_chunk, return_tensors='pt', padding=True, truncation=True, max_length=512)
-                with torch.no_grad():
-                    outputs_resp = self.bert_model(**inputs_resp).last_hidden_state.mean(dim=1)  # 形状为 (1, 768)
-                    outputs_ans = self.bert_model(**inputs_ans).last_hidden_state.mean(dim=1)  # 形状为 (1, 768)
-                # 计算余弦相似度
-                similarity = self.cosine_similarity(outputs_resp.numpy(), outputs_ans.numpy())
-                chunk_similarities.append(similarity)
-            # 取所有块的平均相似度
-            scores.append(np.mean(chunk_similarities))
+            # 截断文本,确保长度不超过 4096
+            resp = self.tokenizer.decode(self.tokenizer.encode(resp, truncation=True, max_length=4096))
+            ans = self.tokenizer.decode(self.tokenizer.encode(ans, truncation=True, max_length=4096))
+            # 编码生成答案和标准答案
+            inputs_resp = self.tokenizer(resp, return_tensors='pt', padding=True, truncation=True, max_length=4096)
+            inputs_ans = self.tokenizer(ans, return_tensors='pt', padding=True, truncation=True, max_length=4096)
+            with torch.no_grad():
+                outputs_resp = self.longformer_model(**inputs_resp).last_hidden_state.mean(dim=1)  # 形状为 (1, 768)
+                outputs_ans = self.longformer_model(**inputs_ans).last_hidden_state.mean(dim=1)  # 形状为 (1, 768)
+            # 计算余弦相似度
+            similarity = self.cosine_similarity(outputs_resp.numpy(), outputs_ans.numpy())
+            scores.append(similarity)
         return scores
     
     def combined_reward_func(self, prompts, completions, answer, **kwargs):