Sfoglia il codice sorgente

修改train_model_grpo.py代码,验证GRPO训练模型,输出训练用数据集为jsonl文件,保存至data/backup相对路径

zhouyang.xie 3 mesi fa
parent
commit
23704681d0
1 ha cambiato i file con 3 aggiunte e 0 eliminazioni
  1. 3 0
      src/train_model_grpo.py

+ 3 - 0
src/train_model_grpo.py

@@ -56,6 +56,9 @@ def get_gsm8k_questions(split = "train") -> Dataset:
 
 dataset = get_gsm8k_questions()
 
+# 方法 1:使用 datasets 库的 to_json 方法
+dataset.to_json(os.path.join("..","data","backup", "gsm8k_dataset_for_train.jsonl"), orient="records", lines=True, force_ascii=False)
+
 # Reward functions
 def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
     responses = [completion[0]['content'] for completion in completions]