train_health.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import os
  2. import sys
  3. import pandas as pd
  4. from datetime import datetime
  5. from typing import Optional, Dict
  6. from database import DataFetcher
  7. from health_pretrain import WindFarmPretrainModel
  8. import logging
  9. # 配置日志
  10. logging.basicConfig(
  11. level=logging.INFO,
  12. format='%(asctime)s - %(levelname)s - %(message)s',
  13. handlers=[
  14. logging.StreamHandler(sys.stdout),
  15. logging.FileHandler('train_health.log')
  16. ]
  17. )
  18. logger = logging.getLogger(__name__)
  19. # 配置 - 使用与database.py一致的风场代码
  20. WIND_CODE = "7V2xSuma" # 使用demo风场
  21. START_DATE = "2024-07-01"
  22. END_DATE = "2025-09-01"
  23. MODEL_DIR = "health_models"
  24. MIN_SAMPLES = 100
  25. def fetch_turbine_data_for_training(fetcher: DataFetcher, wind_code: str, turbine_code: str) -> Optional[pd.DataFrame]:
  26. """从parquet文件获取单个风机的训练数据"""
  27. try:
  28. logger.info(f"正在获取风机 {turbine_code} 的数据...")
  29. # 使用 DataFetcher 的正确方法名和参数
  30. # 注意:fetch_turbine_data 需要月份参数,但我们有开始和结束日期
  31. # 这里需要获取所有可用的特征列
  32. features = fetcher.get_turbine_columns(wind_code)
  33. if not features:
  34. logger.warning(f"{turbine_code} 无可用特征列")
  35. return None
  36. # 由于 fetch_turbine_data 是按月份查询的,我们需要处理日期范围
  37. # 这里简化处理:获取所有数据,然后在本地过滤
  38. df = fetcher.fetch_turbine_data(wind_code, turbine_code, "2024-07", features)
  39. if df is None or df.empty:
  40. logger.warning(f"{turbine_code} 无数据")
  41. return None
  42. # 在本地进行时间范围过滤
  43. if 'time_stamp' in df.columns:
  44. df['time_stamp'] = pd.to_datetime(df['time_stamp'])
  45. start_dt = pd.Timestamp(START_DATE)
  46. end_dt = pd.Timestamp(END_DATE)
  47. mask = (df['time_stamp'] >= start_dt) & (df['time_stamp'] <= end_dt)
  48. df = df.loc[mask]
  49. logger.info(f"时间范围过滤后数据量: {len(df)} 行")
  50. if df.empty:
  51. logger.warning(f"{turbine_code} 在指定时间范围内无数据")
  52. return None
  53. logger.info(f"获取到 {turbine_code} 数据 {len(df)} 条")
  54. return df
  55. except Exception as e:
  56. logger.error(f"获取 {turbine_code} 数据失败: {str(e)}")
  57. return None
  58. def train_windfarm_model(fetcher: DataFetcher, wind_code: str) -> bool:
  59. """训练风场模型 - 适配parquet数据源"""
  60. try:
  61. # 获取风场所有风机
  62. turbines = fetcher.get_turbines(wind_code)
  63. if turbines.empty:
  64. logger.error("无风机数据,无法训练")
  65. return False
  66. # 获取所有风机数据
  67. data_dict = {}
  68. valid_turbines = []
  69. for idx, turbine_info in turbines.iterrows():
  70. turbine_code = turbine_info['engine_code']
  71. logger.info(f"正在处理风机 {turbine_code}...")
  72. # 从parquet获取数据
  73. data = fetch_turbine_data_for_training(fetcher, wind_code, turbine_code)
  74. if data is not None and len(data) >= MIN_SAMPLES:
  75. data_dict[turbine_code] = data
  76. valid_turbines.append(turbine_info)
  77. logger.info(f"风机 {turbine_code} 数据有效,样本数: {len(data)}")
  78. else:
  79. logger.warning(f"风机 {turbine_code} 数据不足或无效")
  80. if not data_dict:
  81. logger.error("无有效风机数据,无法训练风场模型")
  82. return False
  83. # 确定主要机型
  84. mill_type_counts = {}
  85. for turbine_info in valid_turbines:
  86. mill_type_num = fetcher.get_mill_type(turbine_info['mill_type_code'])
  87. mill_type = {1: 'dfig', 2: 'direct', 3: 'semi_direct'}.get(mill_type_num, 'unknown')
  88. if mill_type != 'unknown':
  89. mill_type_counts[mill_type] = mill_type_counts.get(mill_type, 0) + 1
  90. if not mill_type_counts:
  91. logger.error("无法确定风场主要机型")
  92. return False
  93. main_mill_type = max(mill_type_counts.items(), key=lambda x: x[1])[0]
  94. # 训练模型
  95. logger.info(f"开始训练风场 {wind_code} ({main_mill_type})模型...")
  96. logger.info(f"有效风机数量: {len(data_dict)}")
  97. logger.info(f"总数据量: {sum(len(df) for df in data_dict.values())}")
  98. model = WindFarmPretrainModel(wind_code)
  99. model.train(data_dict, main_mill_type)
  100. # 保存模型 - 确保目录存在
  101. model_dir = os.path.join(MODEL_DIR, wind_code)
  102. os.makedirs(model_dir, exist_ok=True) # 确保目录存在
  103. model.save(model_dir)
  104. logger.info(f"风场 {wind_code} 模型训练完成并保存到 {model_dir}")
  105. return True
  106. except Exception as e:
  107. logger.error(f"训练风场模型失败: {str(e)}", exc_info=True)
  108. return False
  109. def main():
  110. """主训练流程"""
  111. logger.info("=== 开始健康评估模型训练 ===")
  112. logger.info(f"风场: {WIND_CODE}")
  113. logger.info(f"时间范围: {START_DATE} 至 {END_DATE}")
  114. # 确保模型目录存在
  115. os.makedirs(MODEL_DIR, exist_ok=True)
  116. # 初始化数据获取器
  117. fetcher = DataFetcher()
  118. # 训练风场模型
  119. if train_windfarm_model(fetcher, WIND_CODE):
  120. logger.info(f"风场 {WIND_CODE} 模型训练成功")
  121. else:
  122. logger.error(f"风场 {WIND_CODE} 模型训练失败")
  123. if __name__ == "__main__":
  124. try:
  125. main()
  126. except Exception as e:
  127. logger.error(f"训练流程异常终止: {str(e)}", exc_info=True)
  128. sys.exit(1)