Przeglądaj źródła

遵循面向对象思想重构train_model_grpo_v1.1.py 去掉分布式及相应环境变量设置代码

zhouyang.xie 2 miesięcy temu
rodzic
commit
91bb641d42

Plik diff jest za duży
+ 0 - 0
data/processed/train.jsonl


+ 1 - 0
src/fine_tune_model.py

@@ -1,3 +1,4 @@
+import os
 import torch
 from unsloth import FastLanguageModel
 

+ 1 - 1
src/generate_data.py

@@ -35,7 +35,7 @@ class DataGenerator:
                 "answer":   "".join([case_data[2],"\n",case_data[3],"\n",case_data[4]]),
                 "prompt": [
                     {
-                        "content": "\nRespond in the following format:\n<reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>\n",
+                        "content": "\nRespond in the following format:\n<thinking>\n...\n</thinking>\n<answer>\n...\n</answer>\n",
                         "role": "system"
                     },
                     {

+ 0 - 0
src/train_model_grpo_v2.py → src/train_model_grpo_v0.py


+ 5 - 5
src/train_model_grpo_v1.1.py

@@ -252,7 +252,7 @@ class ModelTrainer:
         responses = [completion[0]["content"] for completion in completions]
         scores = []
         for response in responses:
-            reasoning_match = re.search(r"<reasoning>\n(.+?)\n</reasoning>", response, re.DOTALL)
+            reasoning_match = re.search(r"<thinking>\n(.+?)\n</thinking>", response, re.DOTALL)
             if reasoning_match:
                 reasoning_content = reasoning_match.group(1).strip()
                 # 简单检查推理内容是否包含关键词
@@ -283,9 +283,9 @@ class ModelTrainer:
         :return: XML 标签的完整性得分
         """
         count = 0.0
-        if text.count("<reasoning>\n") == 1:
+        if text.count("<thinking>\n") == 1:
             count += 0.125
-        if text.count("\n</reasoning>\n") == 1:
+        if text.count("\n</thinking>\n") == 1:
             count += 0.125
         if text.count("\n<answer>\n") == 1:
             count += 0.125
@@ -312,7 +312,7 @@ class ModelTrainer:
         :param completions: 模型生成的补全内容
         :return: 符合软格式要求的得分列表
         """
-        pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
+        pattern = r"<thinking>.*?</thinking>\s*<answer>.*?</answer>"
         responses = [completion[0]["content"] for completion in completions]
         matches = [re.match(pattern, r) for r in responses]
         return [0.5 if match else 0.0 for match in matches]
@@ -324,7 +324,7 @@ class ModelTrainer:
         :param completions: 模型生成的补全内容
         :return: 符合严格格式要求的得分列表
         """
-        pattern = r"^<reasoning>\n(.+?)\n</reasoning>\n<answer>\n(.+?)\n</answer>\n$"
+        pattern = r"^<thinking>\n(.+?)\n</thinking>\n<answer>\n(.+?)\n</answer>\n$"
         responses = [completion[0]["content"] for completion in completions]
         scores = []
         for response in responses:

+ 1 - 1
src/train_model_sft.py

@@ -116,5 +116,5 @@ if __name__ == "__main__":
     model = trainer.train(model, tokenizer, train_dataset)
 
     # 保存模型
-    save_path = os.path.join('..', 'models', 'trained', 'DeepSeek-R1-Distill-Qwen-1.5B')
+    save_path = os.path.join('..', 'models', 'pretrained', 'DeepSeek-R1-Distill-Qwen-1.5B-SFT')
     trainer.save_model(model, tokenizer, save_path)

Niektóre pliki nie zostały wyświetlone z powodu dużej ilości zmienionych plików