schema_reader.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # 表头读取类
  2. import pandas as pd
  3. import numpy as np
  4. from typing import List, Set, Tuple, Dict, Optional
  5. from file_scanner import ParquetFileInfo
  6. class SchemaReader:
  7. """读取parquet文件的表头信息,识别时间字段"""
  8. def __init__(self, file_infos: List[ParquetFileInfo], time_column_aliases: List[str] = None):
  9. self.file_infos = file_infos
  10. self.time_column_aliases = time_column_aliases or [
  11. "data_time", "time", "timestamp", "datetime", "采集时间",
  12. "时间", "记录时间", "数据时间", "Time", "Timestamp"
  13. ]
  14. self.all_columns: Set[str] = set()
  15. self.time_columns: Dict[str, Set[str]] = {} # 文件名 -> 时间字段集合
  16. self.numeric_columns: Set[str] = set()
  17. self.string_columns: Set[str] = set()
  18. self.identified_time_column: Optional[str] = None # 识别出的主要时间字段
  19. def read_all_headers(self) -> Set[str]:
  20. """读取所有文件的表头,返回并集,并识别时间字段"""
  21. print("正在读取所有parquet文件的表头并识别时间字段...")
  22. time_column_candidates = {}
  23. for i, file_info in enumerate(self.file_infos):
  24. try:
  25. # 只读取元数据,不加载数据
  26. df = pd.read_parquet(file_info.file_path, engine='pyarrow')
  27. # 添加列名
  28. self.all_columns.update(df.columns.tolist())
  29. # 识别时间字段
  30. time_cols_in_file = self._identify_time_columns(df.columns.tolist())
  31. self.time_columns[file_info.file_path] = time_cols_in_file
  32. # 统计时间字段候选
  33. for col in time_cols_in_file:
  34. time_column_candidates[col] = time_column_candidates.get(col, 0) + 1
  35. # 分析列类型
  36. for col in df.columns:
  37. if df[col].dtype in [np.float64, np.float32, np.int64, np.int32, np.float16, np.int16, np.int8]:
  38. self.numeric_columns.add(col)
  39. elif df[col].dtype == 'object':
  40. self.string_columns.add(col)
  41. elif df[col].dtype == 'datetime64[ns]':
  42. # 已经是datetime类型,添加到时间字段
  43. if col not in self.time_columns[file_info.file_path]:
  44. self.time_columns[file_info.file_path].add(col)
  45. time_column_candidates[col] = time_column_candidates.get(col, 0) + 1
  46. if (i + 1) % 10 == 0:
  47. print(f"已处理 {i+1}/{len(self.file_infos)} 个文件")
  48. except Exception as e:
  49. print(f"读取文件 {file_info.file_path} 失败: {e}")
  50. continue
  51. # 确定主要时间字段
  52. self.identified_time_column = self._determine_primary_time_column(time_column_candidates)
  53. # 为每个文件设置识别到的时间字段名
  54. for file_info in self.file_infos:
  55. if file_info.file_path in self.time_columns:
  56. time_cols = self.time_columns[file_info.file_path]
  57. # 优先使用识别到的主要时间字段,否则使用文件中的第一个时间字段
  58. if self.identified_time_column and self.identified_time_column in time_cols:
  59. file_info.data_time_column = self.identified_time_column
  60. elif time_cols:
  61. file_info.data_time_column = next(iter(time_cols))
  62. # 添加额外的字段
  63. base_columns = {'id_farm', 'name_farm', 'no_model_turbine', 'id_turbine'}
  64. self.all_columns.update(base_columns)
  65. self.string_columns.update(base_columns)
  66. # 确保data_time在列集合中
  67. if self.identified_time_column:
  68. self.all_columns.add(self.identified_time_column)
  69. print(f"\n表头读取完成,共 {len(self.all_columns)} 个字段")
  70. print(f"识别到的主要时间字段: {self.identified_time_column}")
  71. print(f"时间字段候选统计: {time_column_candidates}")
  72. print(f"数值字段: {len(self.numeric_columns)} 个")
  73. print(f"字符串字段: {len(self.string_columns)} 个")
  74. return self.all_columns
  75. def _identify_time_columns(self, columns: List[str]) -> Set[str]:
  76. """识别时间字段"""
  77. time_cols = set()
  78. for col in columns:
  79. col_lower = col.lower()
  80. # 检查是否匹配任何时间字段别名
  81. for alias in self.time_column_aliases:
  82. alias_lower = alias.lower()
  83. if alias_lower in col_lower or col_lower in alias_lower:
  84. time_cols.add(col)
  85. break
  86. return time_cols
  87. def _determine_primary_time_column(self, time_column_candidates: Dict[str, int]) -> Optional[str]:
  88. """确定主要时间字段"""
  89. if not time_column_candidates:
  90. return None
  91. # 按出现频率排序
  92. sorted_candidates = sorted(time_column_candidates.items(), key=lambda x: x[1], reverse=True)
  93. # 优先选择完全匹配"data_time"的字段
  94. for col, count in sorted_candidates:
  95. if col.lower() == "data_time":
  96. return col
  97. # 如果没有data_time,选择出现频率最高的时间字段
  98. primary_col, count = sorted_candidates[0]
  99. # 如果出现频率超过文件总数的50%,则认为是主要时间字段
  100. if count >= len(self.file_infos) * 0.5:
  101. return primary_col
  102. else:
  103. print(f"警告: 未找到统一的时间字段,最高频率字段 '{primary_col}' 仅出现在 {count}/{len(self.file_infos)} 个文件中")
  104. return primary_col
  105. def get_sql_columns(self) -> List[str]:
  106. """获取SQL列定义,尝试推断数据类型"""
  107. sql_columns = []
  108. for column in sorted(self.all_columns):
  109. # 检查是否是时间字段
  110. is_time_column = False
  111. if column == self.identified_time_column:
  112. is_time_column = True
  113. else:
  114. col_lower = column.lower()
  115. for alias in self.time_column_aliases:
  116. if alias.lower() in col_lower:
  117. is_time_column = True
  118. break
  119. if column in ['id_farm', 'name_farm', 'no_model_turbine', 'id_turbine']:
  120. # 元数据字段
  121. sql_type = 'VARCHAR(100)'
  122. elif is_time_column:
  123. # 时间字段
  124. sql_type = 'DATETIME'
  125. elif any(keyword in column.lower() for keyword in ['status', 'code', 'flag']):
  126. # 状态码字段
  127. sql_type = 'VARCHAR(50)'
  128. elif any(keyword in column.lower() for keyword in ['id', 'no', 'num']):
  129. # ID字段
  130. sql_type = 'VARCHAR(50)'
  131. else:
  132. # 默认数值字段
  133. sql_type = 'DOUBLE'
  134. sql_columns.append(f"`{column}` {sql_type}")
  135. return sql_columns
  136. def get_unique_key_columns(self, default_keys: List[str] = None) -> List[str]:
  137. """获取唯一键列"""
  138. if default_keys is None:
  139. default_keys = ['id_farm', 'id_turbine', 'data_time']
  140. # 确保data_time在列集合中
  141. if self.identified_time_column and self.identified_time_column != 'data_time':
  142. # 如果识别到的时间字段不是'data_time',我们需要在创建表时将其重命名
  143. print(f"注意: 识别到的时间字段为 '{self.identified_time_column}',将作为 'data_time' 处理")
  144. return default_keys
  145. def get_update_columns(self, exclude_keys: List[str] = None) -> List[str]:
  146. """获取需要更新的列(排除唯一键)"""
  147. if exclude_keys is None:
  148. exclude_keys = ['id_farm', 'id_turbine', 'data_time', 'id']
  149. update_columns = []
  150. for column in sorted(self.all_columns):
  151. if column not in exclude_keys:
  152. update_columns.append(column)
  153. return update_columns