MSET_Temp.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. import math
  2. import numpy as np
  3. import pandas as pd
  4. from sqlalchemy import text
  5. from sklearn.neighbors import BallTree
  6. from app.config import dataBase
  7. from app.database import get_engine
  8. class MSET_Temp:
  9. """
  10. 基于 MSET + SPRT 的温度趋势/阈值分析类。
  11. 查询条件由 wind_turbine_number 列和 time_stamp 范围决定,
  12. SPRT 阈值固定为 0.99,calcSPRT 输出在 [-1,1]。
  13. """
  14. def __init__(self, windCode: str, windTurbineNumberList: list[str], startTime: str, endTime: str):
  15. """
  16. :param windCode: 风机类型或机组代码,用于拼表名。例如 "WOG01312" → 表名 "WOG01312_minute"
  17. :param windTurbineNumberList: 要查询的 wind_turbine_number(风机编号)列表
  18. :param startTime: 起始时间(字符串),格式 "YYYY-MM-DD HH:MM"
  19. :param endTime: 结束时间(字符串),格式 "YYYY-MM-DD HH:MM"
  20. """
  21. self.windCode = windCode.strip()
  22. self.windTurbineNumberList = windTurbineNumberList
  23. # 强制保留到秒
  24. self.startTime = startTime
  25. self.endTime = endTime
  26. # D/L 矩阵相关
  27. self.matrixD = None
  28. self.matrixL = None
  29. self.healthyResidual = None
  30. self.normalDataBallTree = None
  31. def _get_data_by_filter(self) -> pd.DataFrame:
  32. """
  33. 按 wind_turbine_number 列和 time_stamp 时间范围批量查询,
  34. 返回一个完整的 DataFrame(已按 time_stamp 升序排序)。
  35. """
  36. table_name = f"{self.windCode}_minute"
  37. engine = get_engine(dataBase.DATA_DB)
  38. # 准备 wind_turbine_number 列表的 SQL 片段:('WT1','WT2',...)
  39. turbines = ",".join(f"'{wt.strip()}'" for wt in self.windTurbineNumberList)
  40. sql = text(f"""
  41. SELECT *
  42. FROM {table_name}
  43. WHERE wind_turbine_number IN ({turbines})
  44. AND time_stamp BETWEEN :start AND :end
  45. ORDER BY time_stamp ASC
  46. """)
  47. df = pd.read_sql(sql, engine, params={"start": self.startTime, "end": self.endTime})
  48. return df
  49. def calcSimilarity(self, x: np.ndarray, y: np.ndarray, m: str = 'euc') -> float:
  50. """
  51. 计算向量 x 与 y 的相似度,(0,1] 区间:
  52. - m='euc' → 欧氏距离
  53. - m='cbd' → 城市街区距离
  54. """
  55. if len(x) != len(y):
  56. return 0.0
  57. if m == 'cbd':
  58. arr = [1.0 / (1.0 + abs(p - q)) for p, q in zip(x, y)]
  59. return float(np.sum(arr) / len(arr))
  60. else:
  61. diffsq = [(p - q) ** 2 for p, q in zip(x, y)]
  62. return float(1.0 / (1.0 + math.sqrt(np.sum(diffsq))))
  63. def genDLMatrix(self, trainDataset: np.ndarray, dataSize4D=100, dataSize4L=50) -> int:
  64. """
  65. 根据训练集 trainDataset 生成 D/L 矩阵:
  66. - 若样本数 < dataSize4D + dataSize4L,返回 -1
  67. - 否则构造 matrixD、matrixL,并用局部加权回归获得 healthyResidual,返回 0
  68. """
  69. m, n = trainDataset.shape
  70. if m < dataSize4D + dataSize4L:
  71. return -1
  72. # Step1:每个特征的最小/最大样本加入 matrixD
  73. self.matrixD = []
  74. selectIndex4D = []
  75. for i in range(n):
  76. col_i = trainDataset[:, i]
  77. idx_min = np.argmin(col_i)
  78. idx_max = np.argmax(col_i)
  79. self.matrixD.append(trainDataset[idx_min, :].tolist())
  80. selectIndex4D.append(idx_min)
  81. self.matrixD.append(trainDataset[idx_max, :].tolist())
  82. selectIndex4D.append(idx_max)
  83. # Step2:对剩余样本逐步选出“与 matrixD 平均距离最大”的样本,直至 matrixD 行数 = dataSize4D
  84. while len(selectIndex4D) < dataSize4D:
  85. freeList = list(set(range(len(trainDataset))) - set(selectIndex4D))
  86. distAvg = []
  87. for idx in freeList:
  88. tmp = trainDataset[idx, :]
  89. dlist = [1.0 - self.calcSimilarity(x, tmp) for x in self.matrixD]
  90. distAvg.append(np.mean(dlist))
  91. select_id = freeList[int(np.argmax(distAvg))]
  92. self.matrixD.append(trainDataset[select_id, :].tolist())
  93. selectIndex4D.append(select_id)
  94. self.matrixD = np.array(self.matrixD)
  95. # 用 matrixD 建 BallTree,用于局部加权回归
  96. self.normalDataBallTree = BallTree(
  97. self.matrixD,
  98. leaf_size=4,
  99. metric=lambda a, b: 1.0 - self.calcSimilarity(a, b)
  100. )
  101. # Step3:把所有训练样本都作为 matrixL
  102. self.matrixL = trainDataset.copy()
  103. # Step4:用局部加权回归算出健康残差
  104. self.healthyResidual = self.calcResidualByLocallyWeightedLR(self.matrixL)
  105. return 0
  106. def calcResidualByLocallyWeightedLR(self, newStates: np.ndarray) -> np.ndarray:
  107. """
  108. 对 newStates 中每个样本,使用 matrixD 的前 20 个最近邻做局部加权回归,计算残差。
  109. 返回形状 [len(newStates), 特征数] 的残差矩阵。
  110. """
  111. est_list = []
  112. for x in newStates:
  113. dist, idxs = self.normalDataBallTree.query([x], k=20, return_distance=True)
  114. w = 1.0 / (dist[0] + 1e-1)
  115. w = w / np.sum(w)
  116. est = np.sum([w_i * self.matrixD[j] for w_i, j in zip(w, idxs[0])], axis=0)
  117. est_list.append(est)
  118. est_arr = np.reshape(np.array(est_list), (len(est_list), -1))
  119. return est_arr - newStates
  120. def calcSPRT(
  121. self,
  122. newsStates: np.ndarray,
  123. feature_weight: np.ndarray,
  124. alpha: float = 0.1,
  125. beta: float = 0.1,
  126. decisionGroup: int = 5
  127. ) -> list[float]:
  128. """
  129. 对 newsStates 运行 Wald-SPRT,返回得分列表,长度 = len(newsStates) - decisionGroup + 1,
  130. 分数在 [-1, 1]:
  131. - 越接近 1 → 越“异常(危险)”
  132. - 越接近 -1 → 越“正常”
  133. """
  134. # 1) 计算残差并做特征加权
  135. stateRes = self.calcResidualByLocallyWeightedLR(newsStates)
  136. weightedStateResidual = [np.dot(x, feature_weight) for x in stateRes]
  137. weightedHealthyResidual = [np.dot(x, feature_weight) for x in self.healthyResidual]
  138. # 2) 健康残差的分布统计
  139. mu0 = float(np.mean(weightedHealthyResidual))
  140. sigma0 = float(np.std(weightedHealthyResidual))
  141. # 3) 计算 SPRT 的上下阈值
  142. lowThres = np.log(beta / (1.0 - alpha)) # < 0
  143. highThres = np.log((1.0 - beta) / alpha) # > 0
  144. flags: list[float] = []
  145. length = len(weightedStateResidual)
  146. for i in range(0, length - decisionGroup + 1):
  147. segment = weightedStateResidual[i : i + decisionGroup]
  148. mu1 = float(np.mean(segment))
  149. si = (
  150. np.sum(segment) * (mu1 - mu0) / (sigma0**2)
  151. - decisionGroup * ((mu1**2) - (mu0**2)) / (2.0 * (sigma0**2))
  152. )
  153. # 限制 si 在 [lowThres, highThres] 之内
  154. si = max(min(si, highThres), lowThres)
  155. # 正负归一化
  156. if si > 0:
  157. norm_si = float(si / highThres)
  158. else:
  159. norm_si = float(si / lowThres)
  160. flags.append(norm_si)
  161. return flags
  162. def check_threshold(self) -> pd.DataFrame:
  163. """
  164. 阈值分析(阈值 0.99),长格式:
  165. 返回所有存在通道的数据,缺失的通道自动跳过。
  166. """
  167. THRESHOLD = 0.99
  168. df = self._get_data_by_filter()
  169. if df.empty:
  170. return pd.DataFrame(columns=["time_stamp", "temp_channel", "SPRT_score", "status"])
  171. # 四个通道英文名
  172. temp_cols_all = [
  173. 'main_bearing_temperature',
  174. 'gearbox_oil_temperature',
  175. 'generatordrive_end_bearing_temperature',
  176. 'generatornon_drive_end_bearing_temperature'
  177. ]
  178. # 只保留存在的列
  179. temp_cols = [c for c in temp_cols_all if c in df.columns]
  180. if not temp_cols:
  181. return pd.DataFrame(columns=["time_stamp", "temp_channel", "SPRT_score", "status"])
  182. # 转数值 & 时间
  183. df[temp_cols] = df[temp_cols].apply(pd.to_numeric, errors='coerce')
  184. df['time_stamp'] = pd.to_datetime(df['time_stamp'], errors='coerce')
  185. records = []
  186. # 英文→中文映射
  187. cn_map = {
  188. 'main_bearing_temperature': '主轴承温度',
  189. 'gearbox_oil_temperature': '齿轮箱油温',
  190. 'generatordrive_end_bearing_temperature': '发电机驱动端轴承温度',
  191. 'generatornon_drive_end_bearing_temperature': '发电机非驱动端轴承温度'
  192. }
  193. for col in temp_cols:
  194. sub = df[['time_stamp', col]].dropna()
  195. if sub.empty:
  196. continue
  197. arr = sub[col].values.reshape(-1,1)
  198. ts = sub['time_stamp'].dt.strftime("%Y-%m-%d %H:%M:%S").tolist()
  199. half = len(arr) // 2
  200. train = arr[:half]
  201. test = arr[half:]
  202. # 不足则跳过该通道
  203. if self.genDLMatrix(train, dataSize4D=60, dataSize4L=5) != 0:
  204. continue
  205. flags = self.calcSPRT(test, np.array([1.0]), decisionGroup=1)
  206. for i, score in enumerate(flags):
  207. records.append({
  208. "time_stamp": ts[half + i],
  209. "temp_channel": cn_map[col],
  210. "SPRT_score": score,
  211. "status": "危险" if score > THRESHOLD else "正常"
  212. })
  213. return pd.DataFrame(records, columns=["time_stamp", "temp_channel", "SPRT_score", "status"])
  214. def get_trend(self) -> dict:
  215. """
  216. 趋势分析:对每个通道单独计算,缺失或训练不足时输出空结构。
  217. """
  218. df = self._get_data_by_filter()
  219. # 英文→输出字段名
  220. key_map = {
  221. 'main_bearing_temperature': 'main_bearing',
  222. 'gearbox_oil_temperature': 'gearbox_oil',
  223. 'generatordrive_end_bearing_temperature': 'generator_drive_end',
  224. 'generatornon_drive_end_bearing_temperature': 'generator_nondrive_end'
  225. }
  226. # 预置结果
  227. result = {v: {} for v in key_map.values()}
  228. if df.empty:
  229. return {"data": result, "code": 200, "message": "success"}
  230. df['time_stamp'] = pd.to_datetime(df['time_stamp'], errors='coerce')
  231. for col, out_key in key_map.items():
  232. if col not in df.columns:
  233. continue
  234. sub = df[['time_stamp', col]].dropna()
  235. if sub.empty:
  236. continue
  237. vals = pd.to_numeric(sub[col], errors='coerce').values
  238. ts_list = sub['time_stamp'].dt.strftime("%Y-%m-%d %H:%M:%S").tolist()
  239. half = len(vals) // 2
  240. train = vals[:half].reshape(-1,1)
  241. test = vals[half:].reshape(-1,1)
  242. # 训练不足时输出空列表
  243. if self.genDLMatrix(train, dataSize4D=60, dataSize4L=5) != 0:
  244. flags = []
  245. else:
  246. flags = self.calcSPRT(test, np.array([1.0]), decisionGroup=1)
  247. result[out_key] = {
  248. "timestamps": ts_list[half:],
  249. "values": flags
  250. }
  251. return {"data": result, "code": 200, "message": "success"}