|
@@ -154,7 +154,8 @@ class ModelTrainer:
|
|
|
|
|
|
def load_data(self, train_data_path):
|
|
|
# 加载训练集和测试集
|
|
|
- train_dataset = load_dataset("json", data_files={"train": train_data_path}, split="train")
|
|
|
+ with open(train_data_path, 'r') as f:
|
|
|
+ train_dataset = load_dataset("json", data_files={"train": train_data_path}, split="train")
|
|
|
|
|
|
# train_data_path: 训练数据路径,格式为 JSONL
|
|
|
return train_dataset
|