api_tempdiag.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # main.py
  2. import os, glob
  3. import pandas as pd
  4. from fastapi import FastAPI
  5. from fastapi.responses import JSONResponse
  6. from pydantic import BaseModel, model_validator
  7. from typing import List
  8. from Temp_Diag import MSET_Temp
  9. app = FastAPI(title="Temperature Diagnosis API")
  10. # 全局模型缓存:{ windCode: { channel_name: MSET_Temp, … }, … }
  11. MODEL_STORE: dict[str, dict[str, MSET_Temp]] = {}
  12. # 英文→中文映射
  13. cn_map = {
  14. 'main_bearing_temperature': '主轴承温度',
  15. 'gearbox_oil_temperature': '齿轮箱油温',
  16. 'generatordrive_end_bearing_temperature': '发电机驱动端轴承温度',
  17. 'generatornon_drive_end_bearing_temperature': '发电机非驱动端轴承温度'
  18. }
  19. class TemperatureInput(BaseModel):
  20. windCode: str
  21. windTurbineNumberList: List[str]
  22. startTime: str
  23. endTime: str
  24. @model_validator(mode='before')
  25. def ensure_list(cls, v):
  26. raw = v.get('windTurbineNumberList')
  27. if isinstance(raw, str):
  28. v['windTurbineNumberList'] = [raw]
  29. return v
  30. class TemperatureThresholdInput(TemperatureInput):
  31. pageNo: int
  32. pageSize: int
  33. @app.on_event("startup")
  34. def load_all_models():
  35. for f in glob.glob("models/*/*.pkl"):
  36. wc = os.path.basename(os.path.dirname(f)) # 取目录名 → windCode
  37. ch = os.path.splitext(os.path.basename(f))[0] # 取文件名(去 .pkl)→ channel
  38. MODEL_STORE.setdefault(wc, {})[ch] = MSET_Temp.load_model(f)
  39. print("模型加载完成:", {k: list(v.keys()) for k,v in MODEL_STORE.items()})
  40. @app.post("/temperature/threshold")
  41. async def route_threshold(inp: TemperatureThresholdInput):
  42. try:
  43. analyzer = MSET_Temp(inp.windCode, inp.windTurbineNumberList, inp.startTime, inp.endTime)
  44. df = analyzer._get_data_by_filter()
  45. if df.empty:
  46. return {"data":{"type":"temperature_threshold","records":[],"totalSize":0},"code":200,"message":"success"}
  47. df['time_stamp'] = pd.to_datetime(df['time_stamp'])
  48. records = []
  49. for eng, cn in cn_map.items():
  50. if eng not in df.columns: continue
  51. sub = df[['time_stamp', eng]].dropna()
  52. arr = sub[eng].values.reshape(-1,1)
  53. ts = sub['time_stamp'].dt.strftime("%Y-%m-%d %H:%M:%S").tolist()
  54. model = MODEL_STORE.get(inp.windCode,{}).get(eng)
  55. if not model: continue
  56. flags = model.predict_SPRT(arr, decisionGroup=1)
  57. for i, sc in enumerate(flags):
  58. records.append({
  59. "time_stamp": ts[i],
  60. "temp_channel": cn,
  61. "SPRT_score": sc,
  62. "status": "危险" if sc>0.99 else "正常"
  63. })
  64. total = len(records)
  65. start = (inp.pageNo-1)*inp.pageSize
  66. end = start+inp.pageSize
  67. return {"data":{"type":"temperature_threshold",
  68. "records":records[start:end],
  69. "totalSize":total},"code":200,"message":"success"}
  70. except Exception as e:
  71. return JSONResponse(status_code=500, content={"code":500,"message":"analysis failed","detail":str(e)})
  72. @app.post("/SPRT/trend")
  73. async def route_trend(inp: TemperatureInput):
  74. try:
  75. analyzer = MSET_Temp(inp.windCode, inp.windTurbineNumberList, inp.startTime, inp.endTime)
  76. df = analyzer._get_data_by_filter()
  77. df['time_stamp'] = pd.to_datetime(df['time_stamp'])
  78. result = {}
  79. for eng, key in {
  80. 'main_bearing_temperature':'main_bearing',
  81. 'gearbox_oil_temperature':'gearbox_oil',
  82. 'generatordrive_end_bearing_temperature':'generator_drive_end',
  83. 'generatornon_drive_end_bearing_temperature':'generator_nondrive_end'
  84. }.items():
  85. if eng not in df.columns:
  86. result[key] = {"timestamps":[],"values":[]}
  87. continue
  88. sub = df[['time_stamp',eng]].dropna()
  89. arr = sub[eng].values.reshape(-1,1)
  90. ts = sub['time_stamp'].dt.strftime("%Y-%m-%d %H:%M:%S").tolist()
  91. model = MODEL_STORE.get(inp.windCode,{}).get(eng)
  92. vals = model.predict_SPRT(arr, decisionGroup=1) if model else []
  93. result[key] = {"timestamps": ts, "values": vals}
  94. return {"data":{"type":"SPRT_trend",**result},"code":200,"message":"success"}
  95. except Exception as e:
  96. return JSONResponse(status_code=500, content={"code":500,"message":"analysis failed","detail":str(e)})
  97. if __name__ == "__main__":
  98. import uvicorn
  99. uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)