# main.py import os, glob import pandas as pd from fastapi import FastAPI from fastapi.responses import JSONResponse from pydantic import BaseModel, model_validator from typing import List from Temp_Diag import MSET_Temp app = FastAPI(title="Temperature Diagnosis API") # 全局模型缓存:{ windCode: { channel_name: MSET_Temp, … }, … } MODEL_STORE: dict[str, dict[str, MSET_Temp]] = {} # 英文→中文映射 cn_map = { 'main_bearing_temperature': '主轴承温度', 'gearbox_oil_temperature': '齿轮箱油温', 'generatordrive_end_bearing_temperature': '发电机驱动端轴承温度', 'generatornon_drive_end_bearing_temperature': '发电机非驱动端轴承温度' } class TemperatureInput(BaseModel): windCode: str windTurbineNumberList: List[str] startTime: str endTime: str @model_validator(mode='before') def ensure_list(cls, v): raw = v.get('windTurbineNumberList') if isinstance(raw, str): v['windTurbineNumberList'] = [raw] return v class TemperatureThresholdInput(TemperatureInput): pageNo: int pageSize: int @app.on_event("startup") def load_all_models(): for f in glob.glob("models/*/*.pkl"): wc = os.path.basename(os.path.dirname(f)) # 取目录名 → windCode ch = os.path.splitext(os.path.basename(f))[0] # 取文件名(去 .pkl)→ channel MODEL_STORE.setdefault(wc, {})[ch] = MSET_Temp.load_model(f) print("模型加载完成:", {k: list(v.keys()) for k,v in MODEL_STORE.items()}) @app.post("/temperature/threshold") async def route_threshold(inp: TemperatureThresholdInput): try: analyzer = MSET_Temp(inp.windCode, inp.windTurbineNumberList, inp.startTime, inp.endTime) df = analyzer._get_data_by_filter() if df.empty: return {"data":{"type":"temperature_threshold","records":[],"totalSize":0},"code":200,"message":"success"} df['time_stamp'] = pd.to_datetime(df['time_stamp']) records = [] for eng, cn in cn_map.items(): if eng not in df.columns: continue sub = df[['time_stamp', eng]].dropna() arr = sub[eng].values.reshape(-1,1) ts = sub['time_stamp'].dt.strftime("%Y-%m-%d %H:%M:%S").tolist() model = MODEL_STORE.get(inp.windCode,{}).get(eng) if not model: continue flags = model.predict_SPRT(arr, decisionGroup=1) for i, sc in enumerate(flags): records.append({ "time_stamp": ts[i], "temp_channel": cn, "SPRT_score": sc, "status": "危险" if sc>0.99 else "正常" }) total = len(records) start = (inp.pageNo-1)*inp.pageSize end = start+inp.pageSize return {"data":{"type":"temperature_threshold", "records":records[start:end], "totalSize":total},"code":200,"message":"success"} except Exception as e: return JSONResponse(status_code=500, content={"code":500,"message":"analysis failed","detail":str(e)}) @app.post("/SPRT/trend") async def route_trend(inp: TemperatureInput): try: analyzer = MSET_Temp(inp.windCode, inp.windTurbineNumberList, inp.startTime, inp.endTime) df = analyzer._get_data_by_filter() df['time_stamp'] = pd.to_datetime(df['time_stamp']) result = {} for eng, key in { 'main_bearing_temperature':'main_bearing', 'gearbox_oil_temperature':'gearbox_oil', 'generatordrive_end_bearing_temperature':'generator_drive_end', 'generatornon_drive_end_bearing_temperature':'generator_nondrive_end' }.items(): if eng not in df.columns: result[key] = {"timestamps":[],"values":[]} continue sub = df[['time_stamp',eng]].dropna() arr = sub[eng].values.reshape(-1,1) ts = sub['time_stamp'].dt.strftime("%Y-%m-%d %H:%M:%S").tolist() model = MODEL_STORE.get(inp.windCode,{}).get(eng) vals = model.predict_SPRT(arr, decisionGroup=1) if model else [] result[key] = {"timestamps": ts, "values": vals} return {"data":{"type":"SPRT_trend",**result},"code":200,"message":"success"} except Exception as e: return JSONResponse(status_code=500, content={"code":500,"message":"analysis failed","detail":str(e)}) if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)