train_health.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import os
  2. import sys
  3. import pandas as pd
  4. from datetime import datetime
  5. from typing import Optional
  6. from database import DataFetcher
  7. from health_pretrain import WindFarmPretrainModel
  8. import logging
  9. # 配置日志
  10. logging.basicConfig(
  11. level=logging.INFO,
  12. format='%(asctime)s - %(levelname)s - %(message)s',
  13. handlers=[
  14. logging.StreamHandler(sys.stdout),
  15. logging.FileHandler('train_health.log')
  16. ]
  17. )
  18. logger = logging.getLogger(__name__)
  19. # 配置
  20. WIND_CODE = "WOF093400005" # 张崾先:WOF091200030 七台河:WOF046400029 诺木洪:WOF093400005
  21. START_DATE = "2023-12-01 00:00:00 "# 张崾先:2023-10-20 00:00:00~2024-10-20 00:00:00 七台河:2023-10-02 00:00:00~2024-10-02 00:00:00
  22. END_DATE = "2024-05-30 23:59:59" #诺木洪 2023-12-01 00:00:00 ~ 2024-05-30 23:50:00
  23. MODEL_DIR = "health_models"
  24. MIN_SAMPLES = 100 # 最小训练样本数
  25. def fetch_turbine_data(fetcher: DataFetcher, wind_code: str, turbine_code: str) -> Optional[pd.DataFrame]:
  26. """获取单个风机的完整训练数据"""
  27. try:
  28. # 获取所有可用列
  29. columns = fetcher.get_turbine_columns(wind_code)
  30. if not columns:
  31. logger.warning(f"{turbine_code} 无可用数据列")
  32. return None
  33. special_wind_farms = {
  34. "WOF093400005": f"`{wind_code}-WOB000001_minute`" # 加上反引号
  35. }
  36. # 根据风场编号获取表名,特殊风场用反引号,其他风场不加反引号
  37. table = special_wind_farms.get(wind_code, f"{wind_code}_minute")
  38. # 构建查询 - 使用参数化查询防止SQL注入
  39. query = f"""
  40. SELECT *
  41. FROM {table}
  42. WHERE `wind_turbine_number` = %s
  43. AND `time_stamp` BETWEEN %s AND %s
  44. """
  45. # 执行查询
  46. logger.info(f"正在获取风机 {turbine_code} 数据...")
  47. df = pd.read_sql(
  48. query,
  49. fetcher.data_engine,
  50. params=(turbine_code, START_DATE, END_DATE)
  51. )
  52. if df.empty:
  53. logger.warning(f"{turbine_code} 无数据")
  54. return None
  55. print("数据项",df)
  56. logger.info(f"获取到 {turbine_code} 数据 {len(df)} 条")
  57. return df
  58. except Exception as e:
  59. logger.error(f"获取 {turbine_code} 数据失败: {str(e)}")
  60. return None
  61. def train_windfarm_model(
  62. fetcher: DataFetcher,
  63. wind_code: str,
  64. turbines: pd.DataFrame
  65. ) -> bool:
  66. """训练风场模型"""
  67. try:
  68. # 获取所有风机数据
  69. data_dict = {}
  70. valid_turbines = []
  71. for idx, turbine_info in turbines.iterrows():
  72. turbine_code = turbine_info['engine_code']
  73. data = fetch_turbine_data(fetcher, wind_code, turbine_code)
  74. if data is not None and len(data) >= MIN_SAMPLES:
  75. data_dict[turbine_code] = data
  76. valid_turbines.append(turbine_info)
  77. if not data_dict:
  78. logger.error("无有效风机数据,无法训练风场模型")
  79. return False
  80. # 确定主要机型(取出现次数最多的机型)
  81. mill_type_counts = {}
  82. for turbine_info in valid_turbines:
  83. mill_type_num = fetcher.get_mill_type(turbine_info['mill_type_code'])
  84. mill_type = {1: 'dfig', 2: 'direct', 3: 'semi_direct'}.get(mill_type_num, 'unknown')
  85. if mill_type != 'unknown':
  86. mill_type_counts[mill_type] = mill_type_counts.get(mill_type, 0) + 1
  87. if not mill_type_counts:
  88. logger.error("无法确定风场主要机型")
  89. return False
  90. main_mill_type = max(mill_type_counts.items(), key=lambda x: x[1])[0]
  91. # 训练模型
  92. logger.info(f"开始训练风场 {wind_code} ({main_mill_type})模型...")
  93. model = WindFarmPretrainModel(wind_code)
  94. model.train(data_dict, main_mill_type)
  95. # 保存模型
  96. model_dir = os.path.join(MODEL_DIR, wind_code)
  97. model.save(model_dir)
  98. logger.info(f"风场 {wind_code} 模型训练完成并保存")
  99. return True
  100. except Exception as e:
  101. logger.error(f"训练风场模型失败: {str(e)}", exc_info=True)
  102. return False
  103. def main():
  104. """主训练流程"""
  105. logger.info("=== 开始健康评估模型训练 ===")
  106. logger.info(f"风场: {WIND_CODE}")
  107. logger.info(f"时间范围: {START_DATE} 至 {END_DATE}")
  108. # 初始化数据获取器
  109. fetcher = DataFetcher()
  110. # 获取风场下所有风机
  111. logger.info("获取风机列表...")
  112. turbines = fetcher.get_turbines(WIND_CODE)
  113. if turbines.empty:
  114. logger.error("无风机数据,终止训练")
  115. return
  116. logger.info(f"共发现 {len(turbines)} 台风机")
  117. # 创建模型目录
  118. model_dir = os.path.join(MODEL_DIR, WIND_CODE)
  119. os.makedirs(model_dir, exist_ok=True)
  120. # 训练风场模型
  121. if train_windfarm_model(fetcher, WIND_CODE, turbines):
  122. logger.info(f"风场 {WIND_CODE} 模型训练成功")
  123. else:
  124. logger.error(f"风场 {WIND_CODE} 模型训练失败")
  125. if __name__ == "__main__":
  126. try:
  127. main()
  128. except Exception as e:
  129. logger.error(f"训练流程异常终止: {str(e)}", exc_info=True)
  130. sys.exit(1)