Prechádzať zdrojové kódy

修改 风电机组数据集格式

zhouyang.xie 2 mesiacov pred
rodič
commit
39ae535b0d
2 zmenil súbory, kde vykonal 7 pridanie a 7 odobranie
  1. 2 2
      src/generate_data.py
  2. 5 5
      src/train_model_grpo_v1.1.py

+ 2 - 2
src/generate_data.py

@@ -35,8 +35,8 @@ class DataGenerator:
                 "answer":   "".join([case_data[2],"\n",case_data[3],"\n",case_data[4]]),
                 "prompt": [
                     {
-                        "content": f"\nRespond in the following format:\n<think>\n {case_data[2]} \n</think>\n<answer>\n {case_data[3]}  {case_data[4]} \n</answer>\n",
-                        # "content": f"\nRespond in the following format:\n<think>\n ... \n</think>\n<answer>\n ... \n</answer>\n",
+                        "content": f"\nRespond in the following format:\n<reasoning>\n {case_data[2]} \n</reasoning>\n<answer>\n {case_data[3]}  {case_data[4]} \n</answer>\n",
+                        # "content": f"\nRespond in the following format:\n<reasoning>\n ... \n</reasoning>\n<answer>\n ... \n</answer>\n",
                         "role": "system"
                     },
                     {

+ 5 - 5
src/train_model_grpo_v1.1.py

@@ -252,7 +252,7 @@ class ModelTrainer:
         responses = [completion[0]["content"] for completion in completions]
         scores = []
         for response in responses:
-            reasoning_match = re.search(r"<think>\n(.+?)\n</think>", response, re.DOTALL)
+            reasoning_match = re.search(r"<reasoning>\n(.+?)\n</reasoning>", response, re.DOTALL)
             if reasoning_match:
                 reasoning_content = reasoning_match.group(1).strip()
                 # 简单检查推理内容是否包含关键词
@@ -283,9 +283,9 @@ class ModelTrainer:
         :return: XML 标签的完整性得分
         """
         count = 0.0
-        if text.count("<think>\n") == 1:
+        if text.count("<reasoning>\n") == 1:
             count += 0.125
-        if text.count("\n</think>\n") == 1:
+        if text.count("\n</reasoning>\n") == 1:
             count += 0.125
         if text.count("\n<answer>\n") == 1:
             count += 0.125
@@ -312,7 +312,7 @@ class ModelTrainer:
         :param completions: 模型生成的补全内容
         :return: 符合软格式要求的得分列表
         """
-        pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
+        pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
         responses = [completion[0]["content"] for completion in completions]
         matches = [re.match(pattern, r) for r in responses]
         return [0.5 if match else 0.0 for match in matches]
@@ -324,7 +324,7 @@ class ModelTrainer:
         :param completions: 模型生成的补全内容
         :return: 符合严格格式要求的得分列表
         """
-        pattern = r"^<think>\n(.+?)\n</think>\n<answer>\n(.+?)\n</answer>\n$"
+        pattern = r"^<reasoning>\n(.+?)\n</reasoning>\n<answer>\n(.+?)\n</answer>\n$"
         responses = [completion[0]["content"] for completion in completions]
         scores = []
         for response in responses: