Przeglądaj źródła

遵循面向对象思想重构train_model_grpo.py

zhouyang.xie 2 miesięcy temu
rodzic
commit
95988ed459
1 zmienionych plików z 103 dodań i 8 usunięć
  1. 103 8
      src/train_model_grpo_v1.1.py

+ 103 - 8
src/train_model_grpo_v1.1.py

@@ -9,6 +9,8 @@ from trl import GRPOConfig, GRPOTrainer
 from datasets import load_dataset
 from conf_train import Config, load_config  # 导入配置文件
 import re
+from transformers import BertTokenizer, BertModel
+import numpy as np
 
 class ModelTrainer:
     def __init__(self, config: Config):
@@ -24,6 +26,9 @@ class ModelTrainer:
         self.fast_inference = config.fast_inference
         self.lora_rank = config.lora_rank
         self.gpu_memory_utilization = config.gpu_memory_utilization
+        # 初始化 BERT 模型和分词器
+        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+        self.bert_model = BertModel.from_pretrained('bert-base-uncased')
 
     def load_model(self):
         """
@@ -122,11 +127,15 @@ class ModelTrainer:
             model=model,
             processing_class=tokenizer,
             reward_funcs=[
-                self.xmlcount_reward_func,
-                self.soft_format_reward_func,
+                # self.xmlcount_reward_func,
+                # self.soft_format_reward_func,
+                # self.strict_format_reward_func,
+                # self.int_reward_func,
+                # self.correctness_reward_func,
                 self.strict_format_reward_func,
-                self.int_reward_func,
-                self.correctness_reward_func,
+                self.semantic_correctness_reward_func,
+                self.reasoning_quality_reward_func,
+                self.combined_reward_func,
             ],
             args=training_args,
             train_dataset=train_dataset,
@@ -145,6 +154,80 @@ class ModelTrainer:
         model.save_pretrained(save_path)
         tokenizer.save_pretrained(save_path)
         print(f"Model saved to {save_path}")
+    
+    @staticmethod
+    def cosine_similarity(vec1, vec2):
+        """
+        计算两个向量的余弦相似度。
+        """
+        return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
+
+    def semantic_correctness_reward_func(self, prompts, completions, answer, **kwargs):
+        """
+        使用 BERT 计算生成答案与标准答案的语义相似度。
+        :param prompts: 输入提示
+        :param completions: 模型生成的补全内容
+        :param answer: 标准答案
+        :return: 语义相似度得分列表
+        """
+        responses = [completion[0]['content'] for completion in completions]
+        extracted_responses = [self.extract_xml_answer(r) for r in responses]
+        scores = []
+        for resp, ans in zip(extracted_responses, answer):
+            # 编码生成答案和标准答案
+            inputs_resp = self.tokenizer(resp, return_tensors='pt', padding=True, truncation=True)
+            inputs_ans = self.tokenizer(ans, return_tensors='pt', padding=True, truncation=True)
+            with torch.no_grad():
+                outputs_resp = self.bert_model(**inputs_resp).last_hidden_state.mean(dim=1)
+                outputs_ans = self.bert_model(**inputs_ans).last_hidden_state.mean(dim=1)
+            # 计算余弦相似度
+            similarity = self.cosine_similarity(outputs_resp.numpy(), outputs_ans.numpy())
+            scores.append(similarity)
+        return scores
+    
+    def combined_reward_func(self, prompts, completions, answer, **kwargs):
+        """
+        综合多个奖励函数,动态调整权重。
+        :param prompts: 输入提示
+        :param completions: 模型生成的补全内容
+        :param answer: 标准答案
+        :return: 综合得分列表
+        """
+        # 计算各奖励函数的得分
+        format_score = self.strict_format_reward_func(completions)
+        semantic_score = self.semantic_correctness_reward_func(prompts, completions, answer)
+        correctness_score = self.correctness_reward_func(prompts, completions, answer)
+
+        # 动态调整权重
+        combined_scores = []
+        for fs, ss, cs in zip(format_score, semantic_score, correctness_score):
+            if cs == 2.0:  # 答案完全正确
+                combined_scores.append(fs * 0.2 + ss * 0.3 + cs * 0.5)
+            else:  # 答案不完全正确
+                combined_scores.append(fs * 0.4 + ss * 0.4 + cs * 0.2)
+        return combined_scores
+    
+    @staticmethod
+    def reasoning_quality_reward_func(completions, **kwargs):
+        """
+        检查推理过程的质量。
+        :param completions: 模型生成的补全内容
+        :return: 推理过程质量得分列表
+        """
+        responses = [completion[0]["content"] for completion in completions]
+        scores = []
+        for response in responses:
+            reasoning_match = re.search(r"<reasoning>\n(.+?)\n</reasoning>", response, re.DOTALL)
+            if reasoning_match:
+                reasoning_content = reasoning_match.group(1).strip()
+                # 简单检查推理内容是否包含关键词
+                if "诊断" in reasoning_content and "原因" in reasoning_content:
+                    scores.append(1.0)
+                else:
+                    scores.append(0.5)
+            else:
+                scores.append(0.0)
+        return scores
 
     @staticmethod
     def extract_xml_answer(text: str) -> str:
@@ -202,14 +285,26 @@ class ModelTrainer:
     @staticmethod
     def strict_format_reward_func(completions, **kwargs):
         """
-        检查补全内容是否符合严格格式要求
+        检查响应是否符合严格的 XML 格式要求,并确保标签内容非空
         :param completions: 模型生成的补全内容
         :return: 符合严格格式要求的得分列表
         """
-        pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
+        pattern = r"^<reasoning>\n(.+?)\n</reasoning>\n<answer>\n(.+?)\n</answer>\n$"
         responses = [completion[0]["content"] for completion in completions]
-        matches = [re.match(pattern, r) for r in responses]
-        return [0.5 if match else 0.0 for match in matches]
+        scores = []
+        for response in responses:
+            match = re.match(pattern, response, re.DOTALL)
+            if match:
+                reasoning_content = match.group(1).strip()
+                answer_content = match.group(2).strip()
+                # 检查内容是否非空
+                if reasoning_content and answer_content:
+                    scores.append(1.0)  # 格式和内容均符合要求
+                else:
+                    scores.append(0.5)  # 格式符合但内容为空
+            else:
+                scores.append(0.0)  # 格式不符合
+        return scores
 
     @staticmethod
     def int_reward_func(completions, **kwargs):