import os import sys import pandas as pd from datetime import datetime from typing import Optional, Dict from database import DataFetcher from health_pretrain import WindFarmPretrainModel import logging # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(sys.stdout), logging.FileHandler('train_health.log') ] ) logger = logging.getLogger(__name__) # 配置 - 使用与database.py一致的风场代码 WIND_CODE = "7V2xSuma" # 使用demo风场 START_DATE = "2024-07-01" END_DATE = "2025-09-01" MODEL_DIR = "health_models" MIN_SAMPLES = 100 def fetch_turbine_data_for_training(fetcher: DataFetcher, wind_code: str, turbine_code: str) -> Optional[pd.DataFrame]: """从parquet文件获取单个风机的训练数据""" try: logger.info(f"正在获取风机 {turbine_code} 的数据...") # 使用 DataFetcher 的正确方法名和参数 # 注意:fetch_turbine_data 需要月份参数,但我们有开始和结束日期 # 这里需要获取所有可用的特征列 features = fetcher.get_turbine_columns(wind_code) if not features: logger.warning(f"{turbine_code} 无可用特征列") return None # 由于 fetch_turbine_data 是按月份查询的,我们需要处理日期范围 # 这里简化处理:获取所有数据,然后在本地过滤 df = fetcher.fetch_turbine_data(wind_code, turbine_code, "2024-07", features) if df is None or df.empty: logger.warning(f"{turbine_code} 无数据") return None # 在本地进行时间范围过滤 if 'time_stamp' in df.columns: df['time_stamp'] = pd.to_datetime(df['time_stamp']) start_dt = pd.Timestamp(START_DATE) end_dt = pd.Timestamp(END_DATE) mask = (df['time_stamp'] >= start_dt) & (df['time_stamp'] <= end_dt) df = df.loc[mask] logger.info(f"时间范围过滤后数据量: {len(df)} 行") if df.empty: logger.warning(f"{turbine_code} 在指定时间范围内无数据") return None logger.info(f"获取到 {turbine_code} 数据 {len(df)} 条") return df except Exception as e: logger.error(f"获取 {turbine_code} 数据失败: {str(e)}") return None def train_windfarm_model(fetcher: DataFetcher, wind_code: str) -> bool: """训练风场模型 - 适配parquet数据源""" try: # 获取风场所有风机 turbines = fetcher.get_turbines(wind_code) if turbines.empty: logger.error("无风机数据,无法训练") return False # 获取所有风机数据 data_dict = {} valid_turbines = [] for idx, turbine_info in turbines.iterrows(): turbine_code = turbine_info['engine_code'] logger.info(f"正在处理风机 {turbine_code}...") # 从parquet获取数据 data = fetch_turbine_data_for_training(fetcher, wind_code, turbine_code) if data is not None and len(data) >= MIN_SAMPLES: data_dict[turbine_code] = data valid_turbines.append(turbine_info) logger.info(f"风机 {turbine_code} 数据有效,样本数: {len(data)}") else: logger.warning(f"风机 {turbine_code} 数据不足或无效") if not data_dict: logger.error("无有效风机数据,无法训练风场模型") return False # 确定主要机型 mill_type_counts = {} for turbine_info in valid_turbines: mill_type_num = fetcher.get_mill_type(turbine_info['mill_type_code']) mill_type = {1: 'dfig', 2: 'direct', 3: 'semi_direct'}.get(mill_type_num, 'unknown') if mill_type != 'unknown': mill_type_counts[mill_type] = mill_type_counts.get(mill_type, 0) + 1 if not mill_type_counts: logger.error("无法确定风场主要机型") return False main_mill_type = max(mill_type_counts.items(), key=lambda x: x[1])[0] # 训练模型 logger.info(f"开始训练风场 {wind_code} ({main_mill_type})模型...") logger.info(f"有效风机数量: {len(data_dict)}") logger.info(f"总数据量: {sum(len(df) for df in data_dict.values())}") model = WindFarmPretrainModel(wind_code) model.train(data_dict, main_mill_type) # 保存模型 - 确保目录存在 model_dir = os.path.join(MODEL_DIR, wind_code) os.makedirs(model_dir, exist_ok=True) # 确保目录存在 model.save(model_dir) logger.info(f"风场 {wind_code} 模型训练完成并保存到 {model_dir}") return True except Exception as e: logger.error(f"训练风场模型失败: {str(e)}", exc_info=True) return False def main(): """主训练流程""" logger.info("=== 开始健康评估模型训练 ===") logger.info(f"风场: {WIND_CODE}") logger.info(f"时间范围: {START_DATE} 至 {END_DATE}") # 确保模型目录存在 os.makedirs(MODEL_DIR, exist_ok=True) # 初始化数据获取器 fetcher = DataFetcher() # 训练风场模型 if train_windfarm_model(fetcher, WIND_CODE): logger.info(f"风场 {WIND_CODE} 模型训练成功") else: logger.error(f"风场 {WIND_CODE} 模型训练失败") if __name__ == "__main__": try: main() except Exception as e: logger.error(f"训练流程异常终止: {str(e)}", exc_info=True) sys.exit(1)