|
@@ -155,14 +155,6 @@ class ModelTrainer:
|
|
|
# 加载训练集和测试集
|
|
|
data = load_dataset("json", data_files={"train": train_data_path}, split="train")
|
|
|
|
|
|
- train_dataset = data.map(lambda x: { # type: ignore
|
|
|
- 'prompt': [
|
|
|
- {'role': 'system', 'content': SYSTEM_PROMPT},
|
|
|
- {'role': 'user', 'content': x['question']}
|
|
|
- ],
|
|
|
- 'answer': extract_hash_answer(x['answer'])
|
|
|
- }) # type: ignore
|
|
|
-
|
|
|
# train_loader = torch.utils.data.DataLoader(
|
|
|
# train_dataset, batch_size=1, shuffle=True, pin_memory=True # 启用 pin_memory 2025年3月7日未能验证通过
|
|
|
# )
|