|
@@ -153,7 +153,7 @@ class ModelTrainer:
|
|
|
|
|
|
def load_data(self, train_data_path):
|
|
|
# 加载训练集和测试集
|
|
|
- data = load_dataset("json", data_files={"train": train_data_path}, split="train")
|
|
|
+ train_dataset = load_dataset("json", data_files={"train": train_data_path}, split="train")
|
|
|
|
|
|
# train_loader = torch.utils.data.DataLoader(
|
|
|
# train_dataset, batch_size=1, shuffle=True, pin_memory=True # 启用 pin_memory 2025年3月7日未能验证通过
|