| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- import os
- import sys
- import pandas as pd
- from datetime import datetime
- from typing import Optional, Dict
- from database import DataFetcher
- from health_pretrain import WindFarmPretrainModel
- import logging
- # 配置日志
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(levelname)s - %(message)s',
- handlers=[
- logging.StreamHandler(sys.stdout),
- logging.FileHandler('train_health.log')
- ]
- )
- logger = logging.getLogger(__name__)
- # 配置 - 使用与database.py一致的风场代码
- WIND_CODE = "7V2xSuma" # 使用demo风场
- START_DATE = "2024-07-01"
- END_DATE = "2025-09-01"
- MODEL_DIR = "health_models"
- MIN_SAMPLES = 100
- def fetch_turbine_data_for_training(fetcher: DataFetcher, wind_code: str, turbine_code: str) -> Optional[pd.DataFrame]:
- """从parquet文件获取单个风机的训练数据"""
- try:
- logger.info(f"正在获取风机 {turbine_code} 的数据...")
-
- # 使用 DataFetcher 的正确方法名和参数
- # 注意:fetch_turbine_data 需要月份参数,但我们有开始和结束日期
- # 这里需要获取所有可用的特征列
- features = fetcher.get_turbine_columns(wind_code)
- if not features:
- logger.warning(f"{turbine_code} 无可用特征列")
- return None
-
- # 由于 fetch_turbine_data 是按月份查询的,我们需要处理日期范围
- # 这里简化处理:获取所有数据,然后在本地过滤
- df = fetcher.fetch_turbine_data(wind_code, turbine_code, "2024-07", features)
-
- if df is None or df.empty:
- logger.warning(f"{turbine_code} 无数据")
- return None
-
- # 在本地进行时间范围过滤
- if 'time_stamp' in df.columns:
- df['time_stamp'] = pd.to_datetime(df['time_stamp'])
- start_dt = pd.Timestamp(START_DATE)
- end_dt = pd.Timestamp(END_DATE)
-
- mask = (df['time_stamp'] >= start_dt) & (df['time_stamp'] <= end_dt)
- df = df.loc[mask]
- logger.info(f"时间范围过滤后数据量: {len(df)} 行")
-
- if df.empty:
- logger.warning(f"{turbine_code} 在指定时间范围内无数据")
- return None
-
- logger.info(f"获取到 {turbine_code} 数据 {len(df)} 条")
- return df
-
- except Exception as e:
- logger.error(f"获取 {turbine_code} 数据失败: {str(e)}")
- return None
-
- def train_windfarm_model(fetcher: DataFetcher, wind_code: str) -> bool:
- """训练风场模型 - 适配parquet数据源"""
- try:
- # 获取风场所有风机
- turbines = fetcher.get_turbines(wind_code)
- if turbines.empty:
- logger.error("无风机数据,无法训练")
- return False
- # 获取所有风机数据
- data_dict = {}
- valid_turbines = []
-
- for idx, turbine_info in turbines.iterrows():
- turbine_code = turbine_info['engine_code']
- logger.info(f"正在处理风机 {turbine_code}...")
-
- # 从parquet获取数据
- data = fetch_turbine_data_for_training(fetcher, wind_code, turbine_code)
- if data is not None and len(data) >= MIN_SAMPLES:
- data_dict[turbine_code] = data
- valid_turbines.append(turbine_info)
- logger.info(f"风机 {turbine_code} 数据有效,样本数: {len(data)}")
- else:
- logger.warning(f"风机 {turbine_code} 数据不足或无效")
-
- if not data_dict:
- logger.error("无有效风机数据,无法训练风场模型")
- return False
- # 确定主要机型
- mill_type_counts = {}
- for turbine_info in valid_turbines:
- mill_type_num = fetcher.get_mill_type(turbine_info['mill_type_code'])
- mill_type = {1: 'dfig', 2: 'direct', 3: 'semi_direct'}.get(mill_type_num, 'unknown')
- if mill_type != 'unknown':
- mill_type_counts[mill_type] = mill_type_counts.get(mill_type, 0) + 1
-
- if not mill_type_counts:
- logger.error("无法确定风场主要机型")
- return False
-
- main_mill_type = max(mill_type_counts.items(), key=lambda x: x[1])[0]
-
- # 训练模型
- logger.info(f"开始训练风场 {wind_code} ({main_mill_type})模型...")
- logger.info(f"有效风机数量: {len(data_dict)}")
- logger.info(f"总数据量: {sum(len(df) for df in data_dict.values())}")
-
- model = WindFarmPretrainModel(wind_code)
- model.train(data_dict, main_mill_type)
-
- # 保存模型 - 确保目录存在
- model_dir = os.path.join(MODEL_DIR, wind_code)
- os.makedirs(model_dir, exist_ok=True) # 确保目录存在
-
- model.save(model_dir)
- logger.info(f"风场 {wind_code} 模型训练完成并保存到 {model_dir}")
- return True
-
- except Exception as e:
- logger.error(f"训练风场模型失败: {str(e)}", exc_info=True)
- return False
- def main():
- """主训练流程"""
- logger.info("=== 开始健康评估模型训练 ===")
- logger.info(f"风场: {WIND_CODE}")
- logger.info(f"时间范围: {START_DATE} 至 {END_DATE}")
-
- # 确保模型目录存在
- os.makedirs(MODEL_DIR, exist_ok=True)
-
- # 初始化数据获取器
- fetcher = DataFetcher()
-
- # 训练风场模型
- if train_windfarm_model(fetcher, WIND_CODE):
- logger.info(f"风场 {WIND_CODE} 模型训练成功")
- else:
- logger.error(f"风场 {WIND_CODE} 模型训练失败")
- if __name__ == "__main__":
- try:
- main()
- except Exception as e:
- logger.error(f"训练流程异常终止: {str(e)}", exc_info=True)
- sys.exit(1)
|