|
@@ -56,6 +56,9 @@ def get_gsm8k_questions(split = "train") -> Dataset:
|
|
|
|
|
|
dataset = get_gsm8k_questions()
|
|
dataset = get_gsm8k_questions()
|
|
|
|
|
|
|
|
+# 方法 1:使用 datasets 库的 to_json 方法
|
|
|
|
+dataset.to_json(os.path.join("..","data","backup", "gsm8k_dataset_for_train.jsonl"), orient="records", lines=True, force_ascii=False)
|
|
|
|
+
|
|
# Reward functions
|
|
# Reward functions
|
|
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
|
|
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
|
|
responses = [completion[0]['content'] for completion in completions]
|
|
responses = [completion[0]['content'] for completion in completions]
|