|
@@ -144,9 +144,9 @@ class ModelTrainer:
|
|
|
# 加载训练集和测试集
|
|
|
train_dataset = load_dataset("json", data_files={"train": train_data_path}, split="train")
|
|
|
|
|
|
- train_loader = torch.utils.data.DataLoader(
|
|
|
- train_dataset, batch_size=1, shuffle=True, pin_memory=True # 启用 pin_memory
|
|
|
- )
|
|
|
+ # train_loader = torch.utils.data.DataLoader(
|
|
|
+ # train_dataset, batch_size=1, shuffle=True, pin_memory=True # 启用 pin_memory 2025年3月7日未能验证通过
|
|
|
+ # )
|
|
|
|
|
|
# train_data_path: 训练数据路径,格式为 JSONL
|
|
|
return train_dataset
|