Browse Source

更换unsloth grpo的训练数据集并验证

zhouyang.xie 3 months ago
parent
commit
97fe68c387

File diff suppressed because it is too large
+ 0 - 0
data/backup/train_windturbine_old.jsonl


File diff suppressed because it is too large
+ 0 - 0
data/processed/train.jsonl


File diff suppressed because it is too large
+ 0 - 0
data/processed/train_windturbine_old.jsonl


+ 21 - 4
src/generate_data.py

@@ -24,10 +24,27 @@ class DataGenerator:
             # print("case_data[2]  ->",case_data[2])
             # print("case_data[3]  ->",case_data[3])
             # print("case_data[4]  ->",case_data[4])
-            processed_data.append({
-                "text": f"<human>: {case_data[1]}\n<bot>: {case_data[2]}\n{case_data[3]}\n{case_data[4]}",
-                "metadata": {"source": f"wind_turbine_fault_cases {case_data[0]}"}
-            })
+            # processed_data.append({
+            #     "text": f"<human>: {case_data[1]}\n<bot>: {case_data[2]}\n{case_data[3]}\n{case_data[4]}",
+            #     "metadata": {"source": f"wind_turbine_fault_cases {case_data[0]}"}
+            # })
+
+             # 当 human 和 bot 都取到后,拼装一个新的 JSON
+            target_data = {
+                "question": case_data[1]+" 请予以故障诊断?",
+                "answer":   "".join([case_data[2],"\n",case_data[3],"\n",case_data[4]]),
+                "prompt": [
+                    {
+                        "content": "\nRespond in the following format:\n<reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>\n",
+                        "role": "system"
+                    },
+                    {
+                        "content": case_data[1]+" 请予以故障诊断?",
+                        "role": "user"
+                    }
+                ]
+            }
+            processed_data.append(target_data)
         return processed_data
 
     def split_data(self, data):

+ 12 - 4
src/train_model_grpo.py

@@ -153,11 +153,19 @@ class ModelTrainer:
 
     def load_data(self, train_data_path):
         # 加载训练集和测试集
-        train_dataset = load_dataset("json", data_files={"train": train_data_path}, split="train")
+        data = load_dataset("json", data_files={"train": train_data_path}, split="train")
 
-        train_loader = torch.utils.data.DataLoader(
-            train_dataset, batch_size=1, shuffle=True, pin_memory=True  # 启用 pin_memory  2025年3月7日未能验证通过
-        )
+        train_dataset = data.map(lambda x: { # type: ignore
+        'prompt': [
+            {'role': 'system', 'content': SYSTEM_PROMPT},
+            {'role': 'user', 'content': x['question']}
+            ],
+            'answer': extract_hash_answer(x['answer'])
+        }) # type: ignore
+
+        # train_loader = torch.utils.data.DataLoader(
+        #     train_dataset, batch_size=1, shuffle=True, pin_memory=True  # 启用 pin_memory  2025年3月7日未能验证通过
+        # )
 
         # train_data_path: 训练数据路径,格式为 JSONL
         return train_dataset

Some files were not shown because too many files changed in this diff