@@ -56,6 +56,9 @@ def get_gsm8k_questions(split = "train") -> Dataset:
dataset = get_gsm8k_questions()
+# 方法 1:使用 datasets 库的 to_json 方法
+dataset.to_json(os.path.join("..","data","backup", "gsm8k_dataset_for_train.jsonl"), orient="records", lines=True, force_ascii=False)
+
# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]