train_temp.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import numpy as np
  2. import pandas as pd
  3. from sqlalchemy import create_engine
  4. from Temp_Diag import MSET_Temp
  5. import os
  6. # ——— 配置 ———
  7. windCode = "WOF046400029" # 张崾先:WOF091200030 七台河:WOF046400029
  8. start, end = "2023-10-02 00:00:00", "2024-10-02 00:00:00" # 张崾先:2023-10-20 00:00:00~2024-10-20 00:00:00 七台河:2023-10-02 00:00:00~2024-10-02 00:00:00
  9. model_root = "models"
  10. channels = [
  11. 'main_bearing_temperature',
  12. 'gearbox_oil_temperature',
  13. 'generatordrive_end_bearing_temperature',
  14. 'generatornon_drive_end_bearing_temperature'
  15. ]
  16. engine = create_engine(
  17. #"mysql+pymysql://root:admin123456@106.120.102.238:10336/energy_data_prod"
  18. "mysql+pymysql://root:admin123456@192.168.50.235:30306/energy_data_prod"
  19. )
  20. # ——————————
  21. # def list_turbines() -> list[str]:
  22. # # 获取风场所有风机编号(有数据的)
  23. # sql = f"""
  24. # SELECT DISTINCT wind_turbine_number
  25. # FROM {windCode}_minute
  26. # WHERE time_stamp BETWEEN '{start}' AND '{end}'
  27. # """
  28. # return pd.read_sql(sql, engine)['wind_turbine_number'].tolist()
  29. def list_turbines() -> list[str]:
  30. # 直接返回待训练的风机编号列表
  31. return ['WOG01344']
  32. def fetch_channel_data(turbine: str, channel: str) -> np.ndarray:
  33. sql = f"""
  34. SELECT {channel}
  35. FROM {windCode}_minute
  36. WHERE wind_turbine_number = '{turbine}'
  37. AND time_stamp BETWEEN '{start}' AND '{end}'
  38. ORDER BY time_stamp ASC
  39. """
  40. df = pd.read_sql(sql, engine).dropna(subset=[channel])
  41. print(f"[TRAIN] 风机 {turbine} 通道 {channel} 拉取 {len(df)} 条")
  42. return df[channel].values.reshape(-1,1)
  43. if __name__ == "__main__":
  44. turbines = list_turbines()
  45. print("[TRAIN] 本次训练风机列表:", turbines)
  46. for turbine in turbines:
  47. for ch in channels:
  48. data = fetch_channel_data(turbine, ch)
  49. if data.shape[0] < 65:
  50. print(f"[TRAIN] {turbine}-{ch} 样本不足,跳过")
  51. continue
  52. model = MSET_Temp(windCode, [turbine], start, end)
  53. model.feature_weight = np.ones((data.shape[1],))
  54. model.alpha = 0.1; model.beta = 0.1
  55. if model.genDLMatrix(data, dataSize4D=60, dataSize4L=5) != 0:
  56. print(f"[TRAIN] {turbine}-{ch} D/L 构建失败")
  57. continue
  58. out_dir = os.path.join(model_root, windCode, turbine)
  59. os.makedirs(out_dir, exist_ok=True)
  60. path = os.path.join(out_dir, f"{ch}.pkl")
  61. model.save_model(path)
  62. print(f"[TRAIN] 已保存模型:{path}")
  63. # 列出所有模型文件
  64. print("\n[TRAIN] 最终模型文件列表:")
  65. for root, _, files in os.walk(model_root):
  66. for fn in files:
  67. print(" ", os.path.join(root, fn))