# main.py import os, glob import pandas as pd import json import shutil import pandas as pd import uvicorn from fastapi import FastAPI, UploadFile, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel, Field, model_validator from typing import List, Optional, Union from autodiag_class import Auto_diag from Temp_Diag import MSET_Temp app = FastAPI(root_path="/api/diag",title=" Diagnosis API") # 全局:{ windCode: { turbine: { channel: model, … }, … }, … } MODEL_STORE: dict[str, dict[str, dict[str, MSET_Temp]]] = {} # 英文→中文映射 cn_map = { 'main_bearing_temperature': '主轴承温度', 'gearbox_oil_temperature': '齿轮箱油温', 'generatordrive_end_bearing_temperature': '发电机驱动端轴承温度', 'generatornon_drive_end_bearing_temperature': '发电机非驱动端轴承温度' } class TemperatureInput(BaseModel): windCode: str windTurbineNumberList: List[str] startTime: str endTime: str @model_validator(mode='before') def ensure_list(cls, v): raw = v.get('windTurbineNumberList') if isinstance(raw, str): v['windTurbineNumberList'] = [raw] return v class TemperatureThresholdInput(TemperatureInput): pageNo: int pageSize: int @app.on_event("startup") def load_all_models(): for f in glob.glob("models/*/*/*.pkl"): _, wc, turbine, fname = f.split(os.sep) ch = os.path.splitext(fname)[0] MODEL_STORE.setdefault(wc, {}).setdefault(turbine, {})[ch] = MSET_Temp.load_model(f) print("模型加载完成:", {k: list(v.keys()) for k,v in MODEL_STORE.items()}) @app.post("/temperature/threshold") async def route_threshold(inp: TemperatureThresholdInput): """ 输入: { "windCode": "WOF091200030", "windTurbineNumberList": ["WOG01355"], "startTime": "2024-06-01 00:00", "endTime": "2024-06-05 01:00", "pageNo": 1, "pageSize": 10 } 输出: { "data": { "type": "temperature_threshold", "records": [ { "wind_turbine_number": "WOG01355", "time_stamp": "2025-06-01 00:05:00", "temp_channel": "主轴承温度", "SPRT_score": 0.12, "status": "正常" }, ... ], "totalSize": 42 }, "code": 200, "message": "success" } """ # 1) 校验模型是否存在 if inp.windCode not in MODEL_STORE: raise HTTPException(404, f"无模型:{inp.windCode}") # 2) 为每台待分析风机,拉数据并推理 records = [] for turbine in inp.windTurbineNumberList: if turbine not in MODEL_STORE[inp.windCode]: continue analyzer = MSET_Temp(inp.windCode, [turbine], inp.startTime, inp.endTime) df = analyzer._get_data_by_filter() if df.empty: continue df['time_stamp'] = pd.to_datetime(df['time_stamp']) for eng, cn in cn_map.items(): if eng not in df.columns: continue sub = df[['time_stamp', eng]].dropna() arr = sub[eng].values.reshape(-1,1) ts = sub['time_stamp'].dt.strftime("%Y-%m-%d %H:%M:%S").tolist() model = MODEL_STORE[inp.windCode][turbine].get(eng) if not model: continue flags = model.predict_SPRT(arr, decisionGroup=1) for i, sc in enumerate(flags): records.append({ "wind_turbine_number": turbine, "time_stamp": ts[i], "temp_channel": cn, "SPRT_score": sc, "status": "危险" if sc>0.99 else "正常" }) # 分页返回 total = len(records) start = (inp.pageNo-1)*inp.pageSize end = start+inp.pageSize return { "data": { "type": "temperature_threshold", "records": records[start:end], "totalSize": total }, "code": 200, "message": "success" } @app.post("/SPRT/trend") async def route_trend(inp: TemperatureInput): """ 输入: { "windCode": "WOF091200030", "windTurbineNumberList": ["WOG01355"], "startTime": "2024-06-01 00:00", "endTime": "2024-06-05 01:00" } 输出: { "data": { "type": "SPRT_trend", "main_bearing": {"timestamps": [...], "values": [...]}, "gearbox_oil": {"timestamps": [...], "values": [...]}, "generator_drive_end": {"timestamps": [...], "values": [...]}, "generator_nondrive_end": {"timestamps": [...], "values": [...]} }, "code": 200, "message": "success" } """ if inp.windCode not in MODEL_STORE: raise HTTPException(404, f"无模型:{inp.windCode}") turbines_out = [] for turbine in inp.windTurbineNumberList: if turbine not in MODEL_STORE[inp.windCode]: continue analyzer = MSET_Temp(inp.windCode, [turbine], inp.startTime, inp.endTime) df = analyzer._get_data_by_filter() if df.empty: continue df['time_stamp'] = pd.to_datetime(df['time_stamp']) ch_data = {} for eng, key in { 'main_bearing_temperature':'main_bearing', 'gearbox_oil_temperature':'gearbox_oil', 'generatordrive_end_bearing_temperature':'generator_drive_end', 'generatornon_drive_end_bearing_temperature':'generator_nondrive_end' }.items(): if eng not in df.columns: ch_data[key] = {"timestamps": [], "values": []} continue sub = df[['time_stamp', eng]].dropna() ts = sub['time_stamp'].dt.strftime("%Y-%m-%d %H:%M:%S").tolist() arr = sub[eng].values.reshape(-1,1) model = MODEL_STORE[inp.windCode][turbine].get(eng) vals = model.predict_SPRT(arr, decisionGroup=1) if model else [] ch_data[key] = {"timestamps": ts, "values": vals} # turbines_out.append({ # "wind_turbine_number": turbine, # **ch_data # }) return { "data": { "type": "SPRT_trend", **ch_data }, "code": 200, "message": "success" } if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) # 请求模型定义 class AutoDiagInput(BaseModel): ids: List[int] # 数据索引id windCode: str # 风场编号 engine_code: str # 风机编号 autodiagType: str # 诊断类型 @model_validator(mode='before') def convert_ids(cls, values): """将单个id转换为列表形式""" if isinstance(values.get('ids'), int): values['ids'] = [values['ids']] return values class DiagnosisResult(BaseModel): status_codes: List[int] # 每个id对应的状态码列表 max_status: int # 所有状态码中的最大值 count_0: int # 状态码0的个数 count_1: int # 状态码1的个数 count_2: int # 状态码2的个数 @app.post("/autodiag/{autodiagType}") async def perform_diagnosis(autodiagType: str, input_data: AutoDiagInput): """ 执行自动诊断分析 参数: autodiagType: 诊断类型 input_data: 包含ids, windCode, engine_code的输入数据 返回: 诊断结果,包含状态码列表和统计信息 """ autodiag_map = { "Unbalance": "Unbalance_diag", # 不平衡诊断 "Misalignment": "Misalignment_diag", # 不对中诊断 "Looseness": "Looseness_diag", # 松动诊断 "Bearing": "Bearing_diag", # 轴承诊断 "Gear": "Gear_diag" # 齿轮诊断 } if autodiagType not in autodiag_map: raise HTTPException(status_code=400, detail="非可用的诊断类型") try: # 初始化诊断类 autodiag = Auto_diag(input_data.ids, input_data.windCode, input_data.engine_code) # 获取诊断方法 func = getattr(autodiag, autodiag_map[autodiagType]) # 执行诊断 if callable(func): result = func() # 直接返回格式化后的结果 return JSONResponse(content=result) except ValueError as e: # 专门捕获齿轮诊断的错误 if "Can not perform gearbox diagnosis" in str(e): return JSONResponse( status_code=200, content={ "code": 400, "message": str(e) } ) elif "当前采集频率不适合进行诊断分析" in str(e): return JSONResponse( status_code=200, content={ "code": 405, "message": str(e) } ) # 其他ValueError raise HTTPException(status_code=400, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=str(e))