瀏覽代碼

遵循面向对象思想重构train_model_grpo.py

zhouyang.xie 4 月之前
父節點
當前提交
d97b2ed3a6
共有 1 個文件被更改,包括 11 次插入7 次删除
  1. 11 7
      src/train_model_grpo_v1.1.py

+ 11 - 7
src/train_model_grpo_v1.1.py

@@ -159,16 +159,20 @@ class ModelTrainer:
     def cosine_similarity(vec1, vec2):
         """
         计算两个向量的余弦相似度。
-        """
+        :param vec1: 第一个向量,形状为 (1, 768)
+        :param vec2: 第二个向量,形状为 (1, 768)
+        :return: 余弦相似度
+        """
+        # 将 (1, 768) 的矩阵转换为 (768,) 的一维向量
+        vec1 = vec1.squeeze()  # 形状从 (1, 768) 变为 (768,)
+        vec2 = vec2.squeeze()  # 形状从 (1, 768) 变为 (768,)
+        print(f"vec1 shape: {vec1.shape}, vec2 shape: {vec2.shape}")
+        # 计算余弦相似度
         return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
 
     def semantic_correctness_reward_func(self, prompts, completions, answer, **kwargs):
         """
         使用 BERT 计算生成答案与标准答案的语义相似度。
-        :param prompts: 输入提示
-        :param completions: 模型生成的补全内容
-        :param answer: 标准答案
-        :return: 语义相似度得分列表
         """
         responses = [completion[0]['content'] for completion in completions]
         extracted_responses = [self.extract_xml_answer(r) for r in responses]
@@ -178,8 +182,8 @@ class ModelTrainer:
             inputs_resp = self.tokenizer(resp, return_tensors='pt', padding=True, truncation=True)
             inputs_ans = self.tokenizer(ans, return_tensors='pt', padding=True, truncation=True)
             with torch.no_grad():
-                outputs_resp = self.bert_model(**inputs_resp).last_hidden_state.mean(dim=1)
-                outputs_ans = self.bert_model(**inputs_ans).last_hidden_state.mean(dim=1)
+                outputs_resp = self.bert_model(**inputs_resp).last_hidden_state.mean(dim=1)  # 形状为 (1, 768)
+                outputs_ans = self.bert_model(**inputs_ans).last_hidden_state.mean(dim=1)  # 形状为 (1, 768)
             # 计算余弦相似度
             similarity = self.cosine_similarity(outputs_resp.numpy(), outputs_ans.numpy())
             scores.append(similarity)