import os import sys import pandas as pd from datetime import datetime from typing import Optional from database import DataFetcher from health_pretrain import WindFarmPretrainModel import logging # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(sys.stdout), logging.FileHandler('train_health.log') ] ) logger = logging.getLogger(__name__) # 配置 WIND_CODE = "WOF093400005" # 张崾先:WOF091200030 七台河:WOF046400029 诺木洪:WOF093400005 START_DATE = "2023-12-01 00:00:00 "# 张崾先:2023-10-20 00:00:00~2024-10-20 00:00:00 七台河:2023-10-02 00:00:00~2024-10-02 00:00:00 END_DATE = "2024-05-30 23:59:59" #诺木洪 2023-12-01 00:00:00 ~ 2024-05-30 23:50:00 MODEL_DIR = "health_models" MIN_SAMPLES = 100 # 最小训练样本数 def fetch_turbine_data(fetcher: DataFetcher, wind_code: str, turbine_code: str) -> Optional[pd.DataFrame]: """获取单个风机的完整训练数据""" try: # 获取所有可用列 columns = fetcher.get_turbine_columns(wind_code) if not columns: logger.warning(f"{turbine_code} 无可用数据列") return None special_wind_farms = { "WOF093400005": f"`{wind_code}-WOB000001_minute`" # 加上反引号 } # 根据风场编号获取表名,特殊风场用反引号,其他风场不加反引号 table = special_wind_farms.get(wind_code, f"{wind_code}_minute") # 构建查询 - 使用参数化查询防止SQL注入 query = f""" SELECT * FROM {table} WHERE `wind_turbine_number` = %s AND `time_stamp` BETWEEN %s AND %s """ # 执行查询 logger.info(f"正在获取风机 {turbine_code} 数据...") df = pd.read_sql( query, fetcher.data_engine, params=(turbine_code, START_DATE, END_DATE) ) if df.empty: logger.warning(f"{turbine_code} 无数据") return None print("数据项",df) logger.info(f"获取到 {turbine_code} 数据 {len(df)} 条") return df except Exception as e: logger.error(f"获取 {turbine_code} 数据失败: {str(e)}") return None def train_windfarm_model( fetcher: DataFetcher, wind_code: str, turbines: pd.DataFrame ) -> bool: """训练风场模型""" try: # 获取所有风机数据 data_dict = {} valid_turbines = [] for idx, turbine_info in turbines.iterrows(): turbine_code = turbine_info['engine_code'] data = fetch_turbine_data(fetcher, wind_code, turbine_code) if data is not None and len(data) >= MIN_SAMPLES: data_dict[turbine_code] = data valid_turbines.append(turbine_info) if not data_dict: logger.error("无有效风机数据,无法训练风场模型") return False # 确定主要机型(取出现次数最多的机型) mill_type_counts = {} for turbine_info in valid_turbines: mill_type_num = fetcher.get_mill_type(turbine_info['mill_type_code']) mill_type = {1: 'dfig', 2: 'direct', 3: 'semi_direct'}.get(mill_type_num, 'unknown') if mill_type != 'unknown': mill_type_counts[mill_type] = mill_type_counts.get(mill_type, 0) + 1 if not mill_type_counts: logger.error("无法确定风场主要机型") return False main_mill_type = max(mill_type_counts.items(), key=lambda x: x[1])[0] # 训练模型 logger.info(f"开始训练风场 {wind_code} ({main_mill_type})模型...") model = WindFarmPretrainModel(wind_code) model.train(data_dict, main_mill_type) # 保存模型 model_dir = os.path.join(MODEL_DIR, wind_code) model.save(model_dir) logger.info(f"风场 {wind_code} 模型训练完成并保存") return True except Exception as e: logger.error(f"训练风场模型失败: {str(e)}", exc_info=True) return False def main(): """主训练流程""" logger.info("=== 开始健康评估模型训练 ===") logger.info(f"风场: {WIND_CODE}") logger.info(f"时间范围: {START_DATE} 至 {END_DATE}") # 初始化数据获取器 fetcher = DataFetcher() # 获取风场下所有风机 logger.info("获取风机列表...") turbines = fetcher.get_turbines(WIND_CODE) if turbines.empty: logger.error("无风机数据,终止训练") return logger.info(f"共发现 {len(turbines)} 台风机") # 创建模型目录 model_dir = os.path.join(MODEL_DIR, WIND_CODE) os.makedirs(model_dir, exist_ok=True) # 训练风场模型 if train_windfarm_model(fetcher, WIND_CODE, turbines): logger.info(f"风场 {WIND_CODE} 模型训练成功") else: logger.error(f"风场 {WIND_CODE} 模型训练失败") if __name__ == "__main__": try: main() except Exception as e: logger.error(f"训练流程异常终止: {str(e)}", exc_info=True) sys.exit(1)