data_loader.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import os
  2. from pathlib import Path
  3. from typing import List, Optional
  4. import pandas as pd
  5. import pyarrow.parquet as pq
  6. from config import PARQUET_ROOT
  7. def list_model_types() -> List[str]:
  8. """返回 PARQUET_ROOT 下所有机型文件夹名称。"""
  9. return [p.name for p in PARQUET_ROOT.iterdir() if p.is_dir()]
  10. def list_turbines(model_name: str) -> List[dict]:
  11. """
  12. 返回某机型下所有风机信息列表。
  13. 每条记录: {model_name, farm_name, turbine_name, path}
  14. """
  15. records = []
  16. model_root = PARQUET_ROOT / model_name
  17. if not model_root.exists():
  18. return records
  19. for farm_dir in model_root.iterdir():
  20. if not farm_dir.is_dir():
  21. continue
  22. for pq_file in farm_dir.glob("*.parquet"):
  23. if pq_file.name.startswith("._"):
  24. continue
  25. records.append({
  26. "model_name": model_name,
  27. "farm_name": farm_dir.name,
  28. "turbine_name": pq_file.stem,
  29. "path": pq_file,
  30. })
  31. return records
  32. def load_turbine(
  33. path: Path,
  34. required_cols: List[str],
  35. optional_cols: Optional[List[str]] = None,
  36. ) -> Optional[pd.DataFrame]:
  37. """
  38. 读取单台风机 parquet 文件,带测点缺失容错。
  39. - required_cols: 任意一列缺失则跳过该文件,返回 None
  40. - optional_cols: 存在则读取,不存在则忽略
  41. - 返回仅包含实际存在列的 DataFrame,已对 required_cols 做 dropna
  42. """
  43. try:
  44. pq_file = pq.ParquetFile(path)
  45. schema_names = set(pq_file.schema.names)
  46. except Exception as e:
  47. print(f"[WARN] 无法读取 schema {path.name}: {e}")
  48. return None
  49. missing_required = [c for c in required_cols if c not in schema_names]
  50. if missing_required:
  51. print(f"[SKIP] {path.name} 缺少必要测点: {missing_required}")
  52. return None
  53. cols_to_read = list(required_cols)
  54. if optional_cols:
  55. cols_to_read += [c for c in optional_cols if c in schema_names]
  56. try:
  57. df = pq_file.read(columns=cols_to_read).to_pandas()
  58. except Exception as e:
  59. print(f"[WARN] 读取数据失败 {path.name}: {e}")
  60. return None
  61. # 转换数值类型,跳过时间戳列,过滤必要列的空值
  62. numeric_cols = [c for c in cols_to_read if c != "data_time"]
  63. for col in numeric_cols:
  64. df[col] = pd.to_numeric(df[col], errors='coerce')
  65. df = df.dropna(subset=required_cols)
  66. if df.empty:
  67. print(f"[SKIP] {path.name} 必要测点全为空值")
  68. return None
  69. return df
  70. def load_model_type(
  71. model_name: str,
  72. required_cols: List[str],
  73. optional_cols: Optional[List[str]] = None,
  74. ) -> pd.DataFrame:
  75. """
  76. 聚合某机型下所有风机数据为一个 DataFrame(用于训练)。
  77. 自动附加 farm_name / turbine_name 列便于溯源。
  78. """
  79. frames = []
  80. for rec in list_turbines(model_name):
  81. df = load_turbine(rec["path"], required_cols, optional_cols)
  82. if df is None:
  83. continue
  84. df = df.copy()
  85. df["farm_name"] = rec["farm_name"]
  86. df["turbine_name"] = rec["turbine_name"]
  87. frames.append(df)
  88. if not frames:
  89. print(f"[WARN] 机型 {model_name} 无有效数据(所需列: {required_cols})")
  90. return pd.DataFrame()
  91. result = pd.concat(frames, ignore_index=True)
  92. print(f"[INFO] 机型 {model_name} 加载完成,共 {len(result)} 行,{len(frames)} 台风机")
  93. return result