train_temp.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import numpy as np
  2. import pandas as pd
  3. from sqlalchemy import create_engine
  4. from Temp_Diag import MSET_Temp
  5. import os
  6. # ——— 配置 ———
  7. windCode = "WOF046400029" # 张崾先:WOF091200030 七台河:WOF046400029
  8. start, end = "2023-10-02 00:00:00", "2024-10-02 00:00:00" # 张崾先:2023-10-20 00:00:00~2024-10-20 00:00:00 七台河:2023-10-02 00:00:00~2024-10-02 00:00:00
  9. model_root = "models"
  10. channels = [
  11. 'main_bearing_temperature',
  12. 'gearbox_oil_temperature',
  13. 'generatordrive_end_bearing_temperature',
  14. 'generatornon_drive_end_bearing_temperature'
  15. ]
  16. engine = create_engine(
  17. "mysql+pymysql://root:admin123456@106.120.102.238:10336/energy_data_prod"
  18. )
  19. # ——————————
  20. # def list_turbines() -> list[str]:
  21. # # 获取风场所有风机编号(有数据的)
  22. # sql = f"""
  23. # SELECT DISTINCT wind_turbine_number
  24. # FROM {windCode}_minute
  25. # WHERE time_stamp BETWEEN '{start}' AND '{end}'
  26. # """
  27. # return pd.read_sql(sql, engine)['wind_turbine_number'].tolist()
  28. def list_turbines() -> list[str]:
  29. # 直接返回待训练的风机编号列表
  30. return ['WOG01344']
  31. def fetch_channel_data(turbine: str, channel: str) -> np.ndarray:
  32. sql = f"""
  33. SELECT {channel}
  34. FROM {windCode}_minute
  35. WHERE wind_turbine_number = '{turbine}'
  36. AND time_stamp BETWEEN '{start}' AND '{end}'
  37. ORDER BY time_stamp ASC
  38. """
  39. df = pd.read_sql(sql, engine).dropna(subset=[channel])
  40. print(f"[TRAIN] 风机 {turbine} 通道 {channel} 拉取 {len(df)} 条")
  41. return df[channel].values.reshape(-1,1)
  42. if __name__ == "__main__":
  43. turbines = list_turbines()
  44. print("[TRAIN] 本次训练风机列表:", turbines)
  45. for turbine in turbines:
  46. for ch in channels:
  47. data = fetch_channel_data(turbine, ch)
  48. if data.shape[0] < 65:
  49. print(f"[TRAIN] {turbine}-{ch} 样本不足,跳过")
  50. continue
  51. model = MSET_Temp(windCode, [turbine], start, end)
  52. model.feature_weight = np.ones((data.shape[1],))
  53. model.alpha = 0.1; model.beta = 0.1
  54. if model.genDLMatrix(data, dataSize4D=60, dataSize4L=5) != 0:
  55. print(f"[TRAIN] {turbine}-{ch} D/L 构建失败")
  56. continue
  57. out_dir = os.path.join(model_root, windCode, turbine)
  58. os.makedirs(out_dir, exist_ok=True)
  59. path = os.path.join(out_dir, f"{ch}.pkl")
  60. model.save_model(path)
  61. print(f"[TRAIN] 已保存模型:{path}")
  62. # 列出所有模型文件
  63. print("\n[TRAIN] 最终模型文件列表:")
  64. for root, _, files in os.walk(model_root):
  65. for fn in files:
  66. print(" ", os.path.join(root, fn))