Explorar el Código

预训练-全场

Xmia hace 1 semana
padre
commit
450ca24334

+ 122 - 304
Temp_Diag.PY

@@ -1,362 +1,180 @@
-# temp_diag.py
+# Temp_Diag.py
 
 import numpy as np
 import pandas as pd
 from sklearn.neighbors import BallTree
 from sqlalchemy import create_engine, text
-import math
+import math, joblib, os
 
 class MSET_Temp:
     """
-    基于 MSET + SPRT 温度趋势/阈值分析类。
-    查询条件由 wind_turbine_number 列和 time_stamp 范围决定,
-    SPRT 阈值固定为 0.99,calcSPRT 输出在 [-1,1]。
+    MSET + SPRT 温度趋势/阈值分析类。
+    - 离线:用全场数据训练 genDLMatrix → save_model
+    - 在线:根据前端传来的 windTurbineNumberList & 时间区间,_get_data_by_filter 拿数据 → predict_SPRT
     """
 
-    def __init__(self, windCode: str, windTurbineNumberList: list[str], startTime: str, endTime: str):
-        """
-        :param windCode: 风机类型或机组代码,用于拼表名。例如 "WOG01312" → 表名 "WOG01312_minute"
-        :param windTurbineNumberList: 要查询的 wind_turbine_number(风机编号)列表
-        :param startTime: 起始时间(字符串),格式 "YYYY-MM-DD HH:MM" 
-        :param endTime: 结束时间(字符串),格式 "YYYY-MM-DD HH:MM"
-        """
+    def __init__(self,
+                 windCode: str,
+                 windTurbineNumberList: list[str],
+                 startTime: str,
+                 endTime: str):
         self.windCode = windCode.strip()
-        self.windTurbineNumberList = windTurbineNumberList
-        # 强制保留到秒
+        self.windTurbineNumberList = windTurbineNumberList or []
         self.startTime = startTime
         self.endTime   = endTime
 
-        # D/L 矩阵相关
+        # 离线或加载后会赋值
         self.matrixD = None
         self.matrixL = None
         self.healthyResidual = None
         self.normalDataBallTree = None
 
-    # def _truncate_to_seconds(self, dt_str: str) -> str:
-    #     """
-    #     将用户可能传进来的 ISO 时间字符串或包含毫秒的字符串
-    #     截断到秒,返回 "YYYY-MM-DD HH:MM:SS" 格式。
-    #     例如: "2025-06-01T12:34:56.789Z" → "2025-06-01 12:34:56"
-    #     """
-    #     # 先将 'T' 替换成空格,去掉尾部可能的 "Z"
-    #     s = dt_str.replace("T", " ").rstrip("Z")
-    #     # 如果含有小数秒,截断
-    #     if "." in s:
-    #         s = s.split(".")[0]
-    #     # 如果还有 "+xx:xx" 时区后缀,也截断
-    #     if "+" in s:
-    #         s = s.split("+")[0]
-    #     return s.strip()
+        # SPRT 参数(离线训练时赋值)
+        self.feature_weight: np.ndarray | None = None
+        self.alpha: float = 0.1
+        self.beta:  float = 0.1
 
     def _get_data_by_filter(self) -> pd.DataFrame:
         """
-        按 wind_turbine_number 列和 time_stamp 时间范围批量查询
-        返回一个完整的 DataFrame(已按 time_stamp 升序排序)
+        在线推理用:按前端给的风机列表 & 时间范围拉数据,
+        如果风机列表为空,则只按时间拉全场数据。
         """
-        table_name = f"{self.windCode}_minute"
+        table = f"{self.windCode}_minute"
         engine = create_engine(
-          #  "mysql+pymysql://root:admin123456@106.120.102.238:10336/energy_data_prod"
-            "mysql+pymysql://root:admin123456@192.168.50.235:30306/energy_data_prod"
+            "mysql+pymysql://root:admin123456@106.120.102.238:10336/energy_data_prod"
         )
+        if self.windTurbineNumberList:
+            turbines = ",".join(f"'{wt.strip()}'" for wt in self.windTurbineNumberList)
+            where = f"wind_turbine_number IN ({turbines}) AND time_stamp BETWEEN :start AND :end"
+        else:
+            where = "time_stamp BETWEEN :start AND :end"
 
-        # 准备 wind_turbine_number 列表的 SQL 片段:('WT1','WT2',...)
-        turbines = ",".join(f"'{wt.strip()}'" for wt in self.windTurbineNumberList)
         sql = text(f"""
             SELECT *
-            FROM {table_name}
-            WHERE wind_turbine_number IN ({turbines})
-              AND time_stamp BETWEEN :start AND :end
+            FROM {table}
+            WHERE {where}
             ORDER BY time_stamp ASC
         """)
-
-        df = pd.read_sql(sql, engine, params={"start": self.startTime, "end": self.endTime})
+        df = pd.read_sql(sql, engine,
+                         params={"start": self.startTime, "end": self.endTime})
         return df
 
     def calcSimilarity(self, x: np.ndarray, y: np.ndarray, m: str = 'euc') -> float:
-        """
-        计算向量 x 与 y 的相似度,(0,1] 区间:
-          - m='euc' → 欧氏距离
-          - m='cbd' → 城市街区距离
-        """
         if len(x) != len(y):
             return 0.0
-
         if m == 'cbd':
-            arr = [1.0 / (1.0 + abs(p - q)) for p, q in zip(x, y)]
-            return float(np.sum(arr) / len(arr))
-        else:
-            diffsq = [(p - q) ** 2 for p, q in zip(x, y)]
-            return float(1.0 / (1.0 + math.sqrt(np.sum(diffsq))))
+            return float(np.mean([1.0/(1.0+abs(p-q)) for p,q in zip(x,y)]))
+        diffsq = np.sum((x-y)**2)
+        return float(1.0/(1.0+math.sqrt(diffsq)))
 
-    def genDLMatrix(self, trainDataset: np.ndarray, dataSize4D=100, dataSize4L=50) -> int:
+    def genDLMatrix(self, trainDataset: np.ndarray,
+                    dataSize4D=100, dataSize4L=50) -> int:
         """
-        根据训练集 trainDataset 生成 D/L 矩阵:
-          - 若样本数 < dataSize4D + dataSize4L,返回 -1
-          - 否则构造 matrixD、matrixL,并用局部加权回归获得 healthyResidual,返回 0
+        离线训练用:构造 matrixD、matrixL、健康残差、BallTree
         """
         m, n = trainDataset.shape
         if m < dataSize4D + dataSize4L:
             return -1
 
-        # Step1:每个特征的最小/最大样本加入 matrixD
-        self.matrixD = []
-        selectIndex4D = []
+        # Step1: 每维最小/最大入 D
+        D_idx = []
+        D = []
         for i in range(n):
-            col_i = trainDataset[:, i]
-            idx_min = np.argmin(col_i)
-            idx_max = np.argmax(col_i)
-            self.matrixD.append(trainDataset[idx_min, :].tolist())
-            selectIndex4D.append(idx_min)
-            self.matrixD.append(trainDataset[idx_max, :].tolist())
-            selectIndex4D.append(idx_max)
-
-        # Step2:对剩余样本逐步选出“与 matrixD 平均距离最大”的样本,直至 matrixD 行数 = dataSize4D
-        while len(selectIndex4D) < dataSize4D:
-            freeList = list(set(range(len(trainDataset))) - set(selectIndex4D))
-            distAvg = []
-            for idx in freeList:
-                tmp = trainDataset[idx, :]
-                dlist = [1.0 - self.calcSimilarity(x, tmp) for x in self.matrixD]
-                distAvg.append(np.mean(dlist))
-            select_id = freeList[int(np.argmax(distAvg))]
-            self.matrixD.append(trainDataset[select_id, :].tolist())
-            selectIndex4D.append(select_id)
-
-        self.matrixD = np.array(self.matrixD)
-
-        # 用 matrixD 建 BallTree,用于局部加权回归
+            col = trainDataset[:, i]
+            imin, imax = np.argmin(col), np.argmax(col)
+            for idx in (imin, imax):
+                D.append(trainDataset[idx].tolist())
+                D_idx.append(idx)
+        # Step2: 迭代挑偏样本到 dataSize4D
+        while len(D_idx) < dataSize4D:
+            free = list(set(range(m)) - set(D_idx))
+            scores = []
+            for idx in free:
+                dists = [1-self.calcSimilarity(trainDataset[idx], d) for d in D]
+                scores.append((np.mean(dists), idx))
+            _, pick = max(scores)
+            D.append(trainDataset[pick].tolist()); D_idx.append(pick)
+        self.matrixD = np.array(D)
+
+        # BallTree + matrixL + healthyResidual
         self.normalDataBallTree = BallTree(
             self.matrixD,
             leaf_size=4,
             metric=lambda a, b: 1.0 - self.calcSimilarity(a, b)
         )
-
-        # Step3:把所有训练样本都作为 matrixL
         self.matrixL = trainDataset.copy()
-
-        # Step4:用局部加权回归算出健康残差
-        self.healthyResidual = self.calcResidualByLocallyWeightedLR(self.matrixL)
+        self.healthyResidual = self._calcResidual(self.matrixL)
         return 0
 
-    def calcResidualByLocallyWeightedLR(self, newStates: np.ndarray) -> np.ndarray:
-        """
-        对 newStates 中每个样本,使用 matrixD 的前 20 个最近邻做局部加权回归,计算残差。
-        返回形状 [len(newStates), 特征数] 的残差矩阵。
-        """
-        est_list = []
-        for x in newStates:
+    def _calcResidual(self, states: np.ndarray) -> np.ndarray:
+        ests = []
+        for x in states:
             dist, idxs = self.normalDataBallTree.query([x], k=20, return_distance=True)
-            w = 1.0 / (dist[0] + 1e-1)
-            w = w / np.sum(w)
-            est = np.sum([w_i * self.matrixD[j] for w_i, j in zip(w, idxs[0])], axis=0)
-            est_list.append(est)
-        est_arr = np.reshape(np.array(est_list), (len(est_list), -1))
-        return est_arr - newStates
-
-    def calcSPRT(
-        self,
-        newsStates: np.ndarray,
-        feature_weight: np.ndarray,
-        alpha: float = 0.1,
-        beta: float = 0.1,
-        decisionGroup: int = 5
-    ) -> list[float]:
-        """
-        对 newsStates 运行 Wald-SPRT,返回得分列表,长度 = len(newsStates) - decisionGroup + 1,
-        分数在 [-1, 1]:
-          - 越接近 1 → 越“异常(危险)”
-          - 越接近 -1 → 越“正常”
-        """
-        # 1) 计算残差并做特征加权
-        stateRes = self.calcResidualByLocallyWeightedLR(newsStates)
-        weightedStateResidual = [np.dot(x, feature_weight) for x in stateRes]
-        weightedHealthyResidual = [np.dot(x, feature_weight) for x in self.healthyResidual]
-
-        # 2) 健康残差的分布统计
-        mu0 = float(np.mean(weightedHealthyResidual))
-        sigma0 = float(np.std(weightedHealthyResidual))
-
-        # 3) 计算 SPRT 的上下阈值
-        lowThres = np.log(beta / (1.0 - alpha))    # < 0
-        highThres = np.log((1.0 - beta) / alpha)   # > 0
-
-        flags: list[float] = []
-        length = len(weightedStateResidual)
-        for i in range(0, length - decisionGroup + 1):
-            segment = weightedStateResidual[i : i + decisionGroup]
-            mu1 = float(np.mean(segment))
-            si = (
-                np.sum(segment) * (mu1 - mu0) / (sigma0**2)
-                - decisionGroup * ((mu1**2) - (mu0**2)) / (2.0 * (sigma0**2))
-            )
-
-            # 限制 si 在 [lowThres, highThres] 之内
-            si = max(min(si, highThres), lowThres)
-            # 正负归一化
-            if si > 0:
-                norm_si = float(si / highThres)
-            else:
-                norm_si = float(si / lowThres)
-            flags.append(norm_si)
-
+            w = 1.0/(dist[0]+1e-1)
+            w = w/ w.sum()
+            ests.append(np.sum([wi*self.matrixD[j] for wi,j in zip(w, idxs[0])],axis=0))
+        est = np.array(ests).reshape(len(ests), -1)
+        return est - states
+
+    def calcSPRT(self,
+                 newsStates: np.ndarray,
+                 feature_weight: np.ndarray,
+                 alpha: float = 0.1,
+                 beta: float = 0.1,
+                 decisionGroup: int = 5) -> list[float]:
+        # 1) 残差+加权
+        resN = self._calcResidual(newsStates)
+        wN = [np.dot(r, feature_weight) for r in resN]
+        wH = [np.dot(r, feature_weight) for r in self.healthyResidual]
+        mu0, sigma0 = np.mean(wH), np.std(wH)
+        low  = math.log(beta/(1-alpha))
+        high = math.log((1-beta)/alpha)
+
+        flags = []
+        for i in range(len(wN)-decisionGroup+1):
+            seg = wN[i:i+decisionGroup]; mu1 = np.mean(seg)
+            si = (sum(seg)*(mu1-mu0)/sigma0**2
+                  - decisionGroup*((mu1**2-mu0**2)/(2*sigma0**2)))
+            si = max(min(si, high), low)
+            flags.append(si/high if si>0 else si/low)
         return flags
 
-    def check_threshold(self) -> pd.DataFrame:
-        """
-        阈值分析(阈值固定 0.99)。返回长格式 DataFrame,列:
-          ["time_stamp", "temp_channel", "SPRT_score", "status"]
-        status = "危险" if SPRT_score > 0.99 else "正常"。
-        """
-        THRESHOLD = 0.99
-
-        # 1) 按风机编号 + 时间范围查询原始数据
-        df_concat = self._get_data_by_filter()
-        if df_concat.empty:
-            return pd.DataFrame(columns=["time_stamp", "temp_channel", "SPRT_score", "status"])
-
-        # 2) 筛选存在的温度列
-        temp_cols_all = [
-            'main_bearing_temperature',
-            'gearbox_oil_temperature',
-            'generatordrive_end_bearing_temperature',
-            'generatornon_drive_end_bearing_temperature'
-        ]
-        temp_cols = [c for c in temp_cols_all if c in df_concat.columns]
-        if not temp_cols:
-            return pd.DataFrame(columns=["time_stamp", "temp_channel", "SPRT_score", "status"])
-
-        # 3) 转数值 & 删除 NaN
-        df_concat[temp_cols] = df_concat[temp_cols].apply(pd.to_numeric, errors='coerce')
-        df_concat = df_concat.dropna(subset=temp_cols + ['time_stamp'])
-        if df_concat.empty:
-            return pd.DataFrame(columns=["time_stamp", "temp_channel", "SPRT_score", "status"])
-
-        # 4) time_stamp 转 datetime
-        df_concat['time_stamp'] = pd.to_datetime(df_concat['time_stamp'])
-        x_date = df_concat['time_stamp']
-
-        # 5) 抽取温度列到 NumPy 数组
-        arr = df_concat[temp_cols].values  # shape = [总记录数, 通道数]
-        m, n = arr.shape
-        half = m // 2
-
-        all_flags: list[list[float]] = []
-        for i in range(n):
-            channel = arr[:, i]
-            train = channel[:half].reshape(-1, 1)
-            test  = channel[half:].reshape(-1, 1)
-
-            # 用训练集构造 D/L 矩阵
-            if self.genDLMatrix(train, dataSize4D=60, dataSize4L=5) != 0:
-                # 如果训练集样本不足,直接返回空表
-                return pd.DataFrame(columns=["time_stamp", "temp_channel", "SPRT_score", "status"])
-
-            feature_w = np.array([1.0])
-            flags = self.calcSPRT(test, feature_w, decisionGroup=1)
-            all_flags.append(flags)
-
-        # 6) 合并为宽表,再 melt 成长表
-        flags_arr = np.array(all_flags)  # shape = [通道数, 测试样本数]
-        num_test = flags_arr.shape[1]
-        ts = x_date.iloc[half : half + num_test].reset_index(drop=True)
-
-        wide = pd.DataFrame({"time_stamp": ts})
-        for idx, col in enumerate(temp_cols):
-            wide[col] = flags_arr[idx, :]
-
-        df_long = wide.melt(
-            id_vars=["time_stamp"],
-            value_vars=temp_cols,
-            var_name="temp_channel",
-            value_name="SPRT_score"
+    def predict_SPRT(self,
+                     newsStates: np.ndarray,
+                     decisionGroup: int = 5) -> list[float]:
+        """在线推理:用已加载的 matrixD、healthyResidual、feature_weight、alpha、beta"""
+        return self.calcSPRT(
+            newsStates,
+            self.feature_weight,
+            alpha=self.alpha,
+            beta=self.beta,
+            decisionGroup=decisionGroup
         )
-        # 把 time_stamp 从 datetime 转成字符串,格式 "YYYY-MM-DD HH:MM:SS" 
-        df_long['time_stamp'] = pd.to_datetime(df_long['time_stamp']).dt.strftime("%Y-%m-%d %H:%M:%S")
 
-        # 7) 添加状态列:SPRT_score > 0.99 → “危险”,否则 “正常”
-        df_long['status'] = df_long['SPRT_score'].apply(
-            lambda x: "危险" if x > THRESHOLD else "正常"
+    def save_model(self, path: str):
+        """离线训练后持久化:matrixD, healthyResidual, feature_weight, alpha, beta"""
+        os.makedirs(os.path.dirname(path), exist_ok=True)
+        joblib.dump({
+            'matrixD': self.matrixD,
+            'healthyResidual': self.healthyResidual,
+            'feature_weight': self.feature_weight,
+            'alpha': self.alpha,
+            'beta': self.beta,
+        }, path)
+
+    @classmethod
+    def load_model(cls, path: str) -> 'MSET_Temp':
+        """在线启动时反序列化并重建 BallTree"""
+        data = joblib.load(path)
+        inst = cls('', [], '', '')
+        inst.matrixD = data['matrixD']
+        inst.healthyResidual = data['healthyResidual']
+        inst.feature_weight = data['feature_weight']
+        inst.alpha = data['alpha']
+        inst.beta  = data['beta']
+        inst.normalDataBallTree = BallTree(
+            inst.matrixD,
+            leaf_size=4,
+            metric=lambda a, b: 1.0 - inst.calcSimilarity(a, b)
         )
-
-        # 8) 将 temp_channel 列的英文名称改为中文
-        temp_channel_mapping = {
-            'main_bearing_temperature': '主轴承温度',
-            'gearbox_oil_temperature': '齿轮箱油温',
-            'generatordrive_end_bearing_temperature': '发电机驱动端轴承温度',
-            'generatornon_drive_end_bearing_temperature': '发电机非驱动端轴承温度'
-        }
-
-        df_long['temp_channel'] = df_long['temp_channel'].map(temp_channel_mapping)
-
-        return df_long
-
-    def get_trend(self) -> dict:
-        """
-        趋势分析
-        获取温度趋势:将温度数据按时间返回。
-        返回格式:{
-            "timestamps": [ISO8601 字符串列表],
-            "channels": [
-                {"temp_channel": "main_bearing_temperature", "values": [浮点列表]},
-                {"temp_channel": "gearbox_oil_temperature", "values": [...]},
-                ...
-            ],
-            "unit": "°C"
-        }
-        """
-        df = self._get_data_by_filter()
-
-        if df.empty:
-            return {"timestamps": [], "channels": [], "unit": "°C"}
-
-        # 定义所有需要检查的温度列
-        temp_cols_all = [
-            'main_bearing_temperature',
-            'gearbox_oil_temperature',
-            'generatordrive_end_bearing_temperature',
-            'generatornon_drive_end_bearing_temperature'
-        ]
-        # 选择实际存在的列
-        temp_cols = [c for c in temp_cols_all if c in df.columns]
-        
-        # 如果没有温度数据列,返回空数据
-        if not temp_cols:
-            return {"timestamps": [], "channels": [], "unit": "°C"}
-
-        # 转数值,并删除 NaN
-        df[temp_cols] = df[temp_cols].apply(pd.to_numeric, errors='coerce')
-        df = df.dropna(subset=temp_cols + ['time_stamp'])
-
-        # 转时间戳为 `YYYY-MM-DD HH:MM:SS` 格式
-        df['time_stamp'] = pd.to_datetime(df['time_stamp']).dt.strftime("%Y-%m-%d %H:%M:%S")
-        df = df.sort_values('time_stamp').reset_index(drop=True)
-
-        # 时间戳格式化为 ISO 8601 字符串
-        timestamps = df['time_stamp'].tolist()
-
-        # 对每个通道,收集它在相应行的数值
-        channels_data = []
-        for col in temp_cols:
-            channels_data.append({
-                "temp_channel": col,
-                "values": df[col].tolist()
-            })
-            
-        # 将 temp_channel 列的英文名称改为中文
-        temp_channel_mapping = {
-            'main_bearing_temperature': '主轴承温度',
-            'gearbox_oil_temperature': '齿轮箱油温',
-            'generatordrive_end_bearing_temperature': '发电机驱动端轴承温度',
-            'generatornon_drive_end_bearing_temperature': '发电机非驱动端轴承温度'
-        }
-
-        for channel in channels_data:
-            channel['temp_channel'] = temp_channel_mapping.get(channel['temp_channel'], channel['temp_channel'])
-
-
-        return {
-            "timestamps": timestamps,
-            "channels": channels_data,
-            "unit": "°C"
-        }
-
+        return inst

+ 90 - 122
api_tempdiag.py

@@ -1,146 +1,114 @@
 # main.py
 
-from fastapi import FastAPI, HTTPException
+import os, glob
+import pandas as pd
+from fastapi import FastAPI
 from fastapi.responses import JSONResponse
-from typing import List
 from pydantic import BaseModel, model_validator
-import uvicorn
+from typing import List
 
-from Temp_Diag import MSET_Temp  
+from Temp_Diag import MSET_Temp
 
 app = FastAPI(title="Temperature Diagnosis API")
 
+# 全局模型缓存:{ windCode: { channel_name: MSET_Temp, … }, … }
+MODEL_STORE: dict[str, dict[str, MSET_Temp]] = {}
+
+# 英文→中文映射
+cn_map = {
+    'main_bearing_temperature': '主轴承温度',
+    'gearbox_oil_temperature': '齿轮箱油温',
+    'generatordrive_end_bearing_temperature': '发电机驱动端轴承温度',
+    'generatornon_drive_end_bearing_temperature': '发电机非驱动端轴承温度'
+}
 
 class TemperatureInput(BaseModel):
     windCode: str
     windTurbineNumberList: List[str]
-    startTime: str  # e.g. "2024-06-08 00:00"
-    endTime: str    # e.g. "2024-06-08 01:00"
+    startTime: str
+    endTime: str
 
     @model_validator(mode='before')
-    def normalize_fields(cls, values):
-        # 确保 windTurbineNumberList 是列表
-        raw = values.get('windTurbineNumberList')
+    def ensure_list(cls, v):
+        raw = v.get('windTurbineNumberList')
         if isinstance(raw, str):
-            values['windTurbineNumberList'] = [raw]
-        return values
+            v['windTurbineNumberList'] = [raw]
+        return v
+
+class TemperatureThresholdInput(TemperatureInput):
+    pageNo:   int
+    pageSize: int
+
+@app.on_event("startup")
+def load_all_models():
+    for f in glob.glob("models/*/*.pkl"):
+        wc = os.path.basename(os.path.dirname(f))  # 取目录名 → windCode
+        ch = os.path.splitext(os.path.basename(f))[0]  # 取文件名(去 .pkl)→ channel
+        MODEL_STORE.setdefault(wc, {})[ch] = MSET_Temp.load_model(f)
+    print("模型加载完成:", {k: list(v.keys()) for k,v in MODEL_STORE.items()})
 
-# 阈值分析
 @app.post("/temperature/threshold")
-async def route_threshold(input_data: TemperatureInput):
-    """
-    阈值分析接口(阈值固定 0.99):
-      - 输入:
-        {
-          "windCode": "WOF01000010",
-          "windTurbineNumberList": ["WOG00542"],
-          "startTime": "2023-01-01 00:00",
-          "endTime": "2023-01-05 12:00"
-        }
-      - 返回:
-        {
-          "data": {
-            "type": "temperature_threshold",
-            "records": [
-              {
-                "time_stamp": "2024-06-08 00:05:00",
-                "temp_channel": "main_bearing_temperature",
-                "SPRT_score": 0.123,
-                "status": "正常"
-              },
-              ...
-            ]
-          },
-          "code": 200,
-          "message": "success"
-        }
-    """
+async def route_threshold(inp: TemperatureThresholdInput):
     try:
-        analyzer = MSET_Temp(
-            windCode=input_data.windCode,
-            windTurbineNumberList=input_data.windTurbineNumberList,
-            startTime=input_data.startTime,
-            endTime=input_data.endTime
-        )
-        df_result = analyzer.check_threshold()  # DataFrame 长格式
-        records = df_result.to_dict(orient="records")
-        return {
-            "data": {
-                "type": "temperature_threshold",
-                "records": records
-            },
-            "code": 200,
-            "message": "success"
-        }
-    except Exception as e:
-        return JSONResponse(
-            status_code=500,
-            content={
-                "code": 500,
-                "message": "analysis failed",
-                "detail": str(e)
-            }
-        )
+        analyzer = MSET_Temp(inp.windCode, inp.windTurbineNumberList, inp.startTime, inp.endTime)
+        df = analyzer._get_data_by_filter()
+        if df.empty:
+            return {"data":{"type":"temperature_threshold","records":[],"totalSize":0},"code":200,"message":"success"}
 
-# 趋势分析(暂未调用)
-@app.post("/temperature/trend")
-async def route_trend(input_data: TemperatureInput):
-    """
-    趋势分析接口:
-      - 输入:
-        {
-          "windCode": "WOF01000010",
-          "windTurbineNumberList": ["WOG00542"],
-          "startTime": "2023-01-01 00:00",
-          "endTime": "2023-01-05 12:00"
-        }
-      - 返回:
-        {
-          "data": {
-            "type": "temperature_trend",
-            "timestamps": [ "2024-06-08 00:00:00", ... ],
-            "channels": [
-               { "temp_channel": "main_bearing_temperature", "values": [24.5, 24.7, ...] },
-               ...
-            ],
-            "unit": "°C"
-          },
-          "code": 200,
-          "message": "success"
-        }
-    """
-    try:
-        analyzer = MSET_Temp(
-            windCode=input_data.windCode,
-            windTurbineNumberList=input_data.windTurbineNumberList,
-            startTime=input_data.startTime,
-            endTime=input_data.endTime
-        )
-        result = analyzer.get_trend()
-        return {
-            "data": {
-                "type": "temperature_trend",
-                **result
-            },
-            "code": 200,
-            "message": "success"
-        }
+        df['time_stamp'] = pd.to_datetime(df['time_stamp'])
+        records = []
+        for eng, cn in cn_map.items():
+            if eng not in df.columns: continue
+            sub = df[['time_stamp', eng]].dropna()
+            arr = sub[eng].values.reshape(-1,1)
+            ts  = sub['time_stamp'].dt.strftime("%Y-%m-%d %H:%M:%S").tolist()
+            model = MODEL_STORE.get(inp.windCode,{}).get(eng)
+            if not model: continue
+            flags = model.predict_SPRT(arr, decisionGroup=1)
+            for i, sc in enumerate(flags):
+                records.append({
+                    "time_stamp": ts[i],
+                    "temp_channel": cn,
+                    "SPRT_score": sc,
+                    "status": "危险" if sc>0.99 else "正常"
+                })
+
+        total = len(records)
+        start = (inp.pageNo-1)*inp.pageSize
+        end   = start+inp.pageSize
+        return {"data":{"type":"temperature_threshold",
+                        "records":records[start:end],
+                        "totalSize":total},"code":200,"message":"success"}
     except Exception as e:
-        return JSONResponse(
-            status_code=500,
-            content={
-                "code": 500,
-                "message": "analysis failed",
-                "detail": str(e)
-            }
-        )
+        return JSONResponse(status_code=500, content={"code":500,"message":"analysis failed","detail":str(e)})
 
+@app.post("/SPRT/trend")
+async def route_trend(inp: TemperatureInput):
+    try:
+        analyzer = MSET_Temp(inp.windCode, inp.windTurbineNumberList, inp.startTime, inp.endTime)
+        df = analyzer._get_data_by_filter()
+        df['time_stamp'] = pd.to_datetime(df['time_stamp'])
+        result = {}
+        for eng, key in {
+            'main_bearing_temperature':'main_bearing',
+            'gearbox_oil_temperature':'gearbox_oil',
+            'generatordrive_end_bearing_temperature':'generator_drive_end',
+            'generatornon_drive_end_bearing_temperature':'generator_nondrive_end'
+        }.items():
+            if eng not in df.columns:
+                result[key] = {"timestamps":[],"values":[]}
+                continue
+            sub = df[['time_stamp',eng]].dropna()
+            arr = sub[eng].values.reshape(-1,1)
+            ts  = sub['time_stamp'].dt.strftime("%Y-%m-%d %H:%M:%S").tolist()
+            model = MODEL_STORE.get(inp.windCode,{}).get(eng)
+            vals  = model.predict_SPRT(arr, decisionGroup=1) if model else []
+            result[key] = {"timestamps": ts, "values": vals}
 
+        return {"data":{"type":"SPRT_trend",**result},"code":200,"message":"success"}
+    except Exception as e:
+        return JSONResponse(status_code=500, content={"code":500,"message":"analysis failed","detail":str(e)})
 
 if __name__ == "__main__":
-    uvicorn.run(
-        "main:app",
-        host="0.0.0.0",
-        port=8000,
-        reload=True
-    )
+    import uvicorn
+    uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)

BIN
models/WOF091200030/generatordrive_end_bearing_temperature.pkl


BIN
models/WOF091200030/main_bearing_temperature.pkl


+ 61 - 0
train_temp.py

@@ -0,0 +1,61 @@
+# train_models.py
+
+import numpy as np
+import pandas as pd
+from sqlalchemy import create_engine
+from Temp_Diag import MSET_Temp
+import os
+
+# ——— 配置 ———
+windCode   = "WOF091200030"
+# turbines = ["WOG01355"]
+start, end = "2024-06-12 00:00:00", "2024-07-12 00:00:00"
+model_root = "models"
+channels   = [
+    'main_bearing_temperature',
+    'gearbox_oil_temperature',
+    'generatordrive_end_bearing_temperature',
+    'generatornon_drive_end_bearing_temperature'
+]
+engine = create_engine(
+    "mysql+pymysql://root:admin123456@106.120.102.238:10336/energy_data_prod"
+)
+# ——————————
+
+def fetch_channel_data(channel: str) -> np.ndarray:
+    sql = f"""
+      SELECT {channel}
+      FROM {windCode}_minute
+      WHERE time_stamp BETWEEN '{start}' AND '{end}'
+    #   AND wind_turbine_number IN ({','.join(f"'{t}'" for t in turbines)})
+      ORDER BY time_stamp ASC
+    """
+    df = pd.read_sql(sql, engine).dropna(subset=[channel])
+    print(f"[TRAIN] {channel} 共 {len(df)} 条")
+    return df[channel].values.reshape(-1,1)
+
+if __name__ == "__main__":
+    for ch in channels:
+        data = fetch_channel_data(ch)
+        if data.shape[0] < 65:
+            print(f"[TRAIN] {ch} 样本不足,跳过")
+            continue
+
+        model = MSET_Temp(windCode, [], start, end)
+        model.feature_weight = np.ones(data.shape[1],)
+        model.alpha = 0.1; model.beta = 0.1
+
+        if model.genDLMatrix(data, dataSize4D=60, dataSize4L=5) != 0:
+            print(f"[TRAIN] {ch} D/L 构建失败")
+            continue
+
+        out_dir = os.path.join(model_root, windCode)
+        os.makedirs(out_dir, exist_ok=True)
+        path = os.path.join(out_dir, f"{ch}.pkl")
+        model.save_model(path)
+        print(f"[TRAIN] 已保存 {path}")
+
+    print("\n[TRAIN] 模型文件列表:")
+    for root, _, files in os.walk(model_root):
+        for fn in files:
+            print(" ", os.path.join(root, fn))