123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- # main.py
- import os, glob
- import pandas as pd
- from fastapi import FastAPI
- from fastapi.responses import JSONResponse
- from pydantic import BaseModel, model_validator
- from typing import List
- from Temp_Diag import MSET_Temp
- app = FastAPI(title="Temperature Diagnosis API")
- # 全局模型缓存:{ windCode: { channel_name: MSET_Temp, … }, … }
- MODEL_STORE: dict[str, dict[str, MSET_Temp]] = {}
- # 英文→中文映射
- cn_map = {
- 'main_bearing_temperature': '主轴承温度',
- 'gearbox_oil_temperature': '齿轮箱油温',
- 'generatordrive_end_bearing_temperature': '发电机驱动端轴承温度',
- 'generatornon_drive_end_bearing_temperature': '发电机非驱动端轴承温度'
- }
- class TemperatureInput(BaseModel):
- windCode: str
- windTurbineNumberList: List[str]
- startTime: str
- endTime: str
- @model_validator(mode='before')
- def ensure_list(cls, v):
- raw = v.get('windTurbineNumberList')
- if isinstance(raw, str):
- v['windTurbineNumberList'] = [raw]
- return v
- class TemperatureThresholdInput(TemperatureInput):
- pageNo: int
- pageSize: int
- @app.on_event("startup")
- def load_all_models():
- for f in glob.glob("models/*/*.pkl"):
- wc = os.path.basename(os.path.dirname(f)) # 取目录名 → windCode
- ch = os.path.splitext(os.path.basename(f))[0] # 取文件名(去 .pkl)→ channel
- MODEL_STORE.setdefault(wc, {})[ch] = MSET_Temp.load_model(f)
- print("模型加载完成:", {k: list(v.keys()) for k,v in MODEL_STORE.items()})
- @app.post("/temperature/threshold")
- async def route_threshold(inp: TemperatureThresholdInput):
- try:
- analyzer = MSET_Temp(inp.windCode, inp.windTurbineNumberList, inp.startTime, inp.endTime)
- df = analyzer._get_data_by_filter()
- if df.empty:
- return {"data":{"type":"temperature_threshold","records":[],"totalSize":0},"code":200,"message":"success"}
- df['time_stamp'] = pd.to_datetime(df['time_stamp'])
- records = []
- for eng, cn in cn_map.items():
- if eng not in df.columns: continue
- sub = df[['time_stamp', eng]].dropna()
- arr = sub[eng].values.reshape(-1,1)
- ts = sub['time_stamp'].dt.strftime("%Y-%m-%d %H:%M:%S").tolist()
- model = MODEL_STORE.get(inp.windCode,{}).get(eng)
- if not model: continue
- flags = model.predict_SPRT(arr, decisionGroup=1)
- for i, sc in enumerate(flags):
- records.append({
- "time_stamp": ts[i],
- "temp_channel": cn,
- "SPRT_score": sc,
- "status": "危险" if sc>0.99 else "正常"
- })
- total = len(records)
- start = (inp.pageNo-1)*inp.pageSize
- end = start+inp.pageSize
- return {"data":{"type":"temperature_threshold",
- "records":records[start:end],
- "totalSize":total},"code":200,"message":"success"}
- except Exception as e:
- return JSONResponse(status_code=500, content={"code":500,"message":"analysis failed","detail":str(e)})
- @app.post("/SPRT/trend")
- async def route_trend(inp: TemperatureInput):
- try:
- analyzer = MSET_Temp(inp.windCode, inp.windTurbineNumberList, inp.startTime, inp.endTime)
- df = analyzer._get_data_by_filter()
- df['time_stamp'] = pd.to_datetime(df['time_stamp'])
- result = {}
- for eng, key in {
- 'main_bearing_temperature':'main_bearing',
- 'gearbox_oil_temperature':'gearbox_oil',
- 'generatordrive_end_bearing_temperature':'generator_drive_end',
- 'generatornon_drive_end_bearing_temperature':'generator_nondrive_end'
- }.items():
- if eng not in df.columns:
- result[key] = {"timestamps":[],"values":[]}
- continue
- sub = df[['time_stamp',eng]].dropna()
- arr = sub[eng].values.reshape(-1,1)
- ts = sub['time_stamp'].dt.strftime("%Y-%m-%d %H:%M:%S").tolist()
- model = MODEL_STORE.get(inp.windCode,{}).get(eng)
- vals = model.predict_SPRT(arr, decisionGroup=1) if model else []
- result[key] = {"timestamps": ts, "values": vals}
- return {"data":{"type":"SPRT_trend",**result},"code":200,"message":"success"}
- except Exception as e:
- return JSONResponse(status_code=500, content={"code":500,"message":"analysis failed","detail":str(e)})
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|