# 表头读取类 import pandas as pd import numpy as np from typing import List, Set, Tuple, Dict, Optional from file_scanner import ParquetFileInfo class SchemaReader: """读取parquet文件的表头信息,识别时间字段""" def __init__(self, file_infos: List[ParquetFileInfo], time_column_aliases: List[str] = None): self.file_infos = file_infos self.time_column_aliases = time_column_aliases or [ "data_time", "time", "timestamp", "datetime", "采集时间", "时间", "记录时间", "数据时间", "Time", "Timestamp" ] self.all_columns: Set[str] = set() self.time_columns: Dict[str, Set[str]] = {} # 文件名 -> 时间字段集合 self.numeric_columns: Set[str] = set() self.string_columns: Set[str] = set() self.identified_time_column: Optional[str] = None # 识别出的主要时间字段 def read_all_headers(self) -> Set[str]: """读取所有文件的表头,返回并集,并识别时间字段""" print("正在读取所有parquet文件的表头并识别时间字段...") time_column_candidates = {} for i, file_info in enumerate(self.file_infos): try: # 只读取元数据,不加载数据 df = pd.read_parquet(file_info.file_path, engine='pyarrow') # 添加列名 self.all_columns.update(df.columns.tolist()) # 识别时间字段 time_cols_in_file = self._identify_time_columns(df.columns.tolist()) self.time_columns[file_info.file_path] = time_cols_in_file # 统计时间字段候选 for col in time_cols_in_file: time_column_candidates[col] = time_column_candidates.get(col, 0) + 1 # 分析列类型 for col in df.columns: if df[col].dtype in [np.float64, np.float32, np.int64, np.int32, np.float16, np.int16, np.int8]: self.numeric_columns.add(col) elif df[col].dtype == 'object': self.string_columns.add(col) elif df[col].dtype == 'datetime64[ns]': # 已经是datetime类型,添加到时间字段 if col not in self.time_columns[file_info.file_path]: self.time_columns[file_info.file_path].add(col) time_column_candidates[col] = time_column_candidates.get(col, 0) + 1 if (i + 1) % 10 == 0: print(f"已处理 {i+1}/{len(self.file_infos)} 个文件") except Exception as e: print(f"读取文件 {file_info.file_path} 失败: {e}") continue # 确定主要时间字段 self.identified_time_column = self._determine_primary_time_column(time_column_candidates) # 为每个文件设置识别到的时间字段名 for file_info in self.file_infos: if file_info.file_path in self.time_columns: time_cols = self.time_columns[file_info.file_path] # 优先使用识别到的主要时间字段,否则使用文件中的第一个时间字段 if self.identified_time_column and self.identified_time_column in time_cols: file_info.data_time_column = self.identified_time_column elif time_cols: file_info.data_time_column = next(iter(time_cols)) # 添加额外的字段 base_columns = {'id_farm', 'name_farm', 'no_model_turbine', 'id_turbine'} self.all_columns.update(base_columns) self.string_columns.update(base_columns) # 确保data_time在列集合中 if self.identified_time_column: self.all_columns.add(self.identified_time_column) print(f"\n表头读取完成,共 {len(self.all_columns)} 个字段") print(f"识别到的主要时间字段: {self.identified_time_column}") print(f"时间字段候选统计: {time_column_candidates}") print(f"数值字段: {len(self.numeric_columns)} 个") print(f"字符串字段: {len(self.string_columns)} 个") return self.all_columns def _identify_time_columns(self, columns: List[str]) -> Set[str]: """识别时间字段""" time_cols = set() for col in columns: col_lower = col.lower() # 检查是否匹配任何时间字段别名 for alias in self.time_column_aliases: alias_lower = alias.lower() if alias_lower in col_lower or col_lower in alias_lower: time_cols.add(col) break return time_cols def _determine_primary_time_column(self, time_column_candidates: Dict[str, int]) -> Optional[str]: """确定主要时间字段""" if not time_column_candidates: return None # 按出现频率排序 sorted_candidates = sorted(time_column_candidates.items(), key=lambda x: x[1], reverse=True) # 优先选择完全匹配"data_time"的字段 for col, count in sorted_candidates: if col.lower() == "data_time": return col # 如果没有data_time,选择出现频率最高的时间字段 primary_col, count = sorted_candidates[0] # 如果出现频率超过文件总数的50%,则认为是主要时间字段 if count >= len(self.file_infos) * 0.5: return primary_col else: print(f"警告: 未找到统一的时间字段,最高频率字段 '{primary_col}' 仅出现在 {count}/{len(self.file_infos)} 个文件中") return primary_col def get_sql_columns(self) -> List[str]: """获取SQL列定义,尝试推断数据类型""" sql_columns = [] for column in sorted(self.all_columns): # 检查是否是时间字段 is_time_column = False if column == self.identified_time_column: is_time_column = True else: col_lower = column.lower() for alias in self.time_column_aliases: if alias.lower() in col_lower: is_time_column = True break if column in ['id_farm', 'name_farm', 'no_model_turbine', 'id_turbine']: # 元数据字段 sql_type = 'VARCHAR(100)' elif is_time_column: # 时间字段 sql_type = 'DATETIME' elif any(keyword in column.lower() for keyword in ['status', 'code', 'flag']): # 状态码字段 sql_type = 'VARCHAR(50)' elif any(keyword in column.lower() for keyword in ['id', 'no', 'num']): # ID字段 sql_type = 'VARCHAR(50)' else: # 默认数值字段 sql_type = 'DOUBLE' sql_columns.append(f"`{column}` {sql_type}") return sql_columns def get_unique_key_columns(self, default_keys: List[str] = None) -> List[str]: """获取唯一键列""" if default_keys is None: default_keys = ['id_farm', 'id_turbine', 'data_time'] # 确保data_time在列集合中 if self.identified_time_column and self.identified_time_column != 'data_time': # 如果识别到的时间字段不是'data_time',我们需要在创建表时将其重命名 print(f"注意: 识别到的时间字段为 '{self.identified_time_column}',将作为 'data_time' 处理") return default_keys def get_update_columns(self, exclude_keys: List[str] = None) -> List[str]: """获取需要更新的列(排除唯一键)""" if exclude_keys is None: exclude_keys = ['id_farm', 'id_turbine', 'data_time', 'id'] update_columns = [] for column in sorted(self.all_columns): if column not in exclude_keys: update_columns.append(column) return update_columns