| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 |
- # 表头读取类
- import pandas as pd
- import numpy as np
- from typing import List, Set, Tuple, Dict, Optional
- from file_scanner import ParquetFileInfo
- class SchemaReader:
- """读取parquet文件的表头信息,识别时间字段"""
-
- def __init__(self, file_infos: List[ParquetFileInfo], time_column_aliases: List[str] = None):
- self.file_infos = file_infos
- self.time_column_aliases = time_column_aliases or [
- "data_time", "time", "timestamp", "datetime", "采集时间",
- "时间", "记录时间", "数据时间", "Time", "Timestamp"
- ]
- self.all_columns: Set[str] = set()
- self.time_columns: Dict[str, Set[str]] = {} # 文件名 -> 时间字段集合
- self.numeric_columns: Set[str] = set()
- self.string_columns: Set[str] = set()
- self.identified_time_column: Optional[str] = None # 识别出的主要时间字段
-
- def read_all_headers(self) -> Set[str]:
- """读取所有文件的表头,返回并集,并识别时间字段"""
- print("正在读取所有parquet文件的表头并识别时间字段...")
-
- time_column_candidates = {}
-
- for i, file_info in enumerate(self.file_infos):
- try:
- # 只读取元数据,不加载数据
- df = pd.read_parquet(file_info.file_path, engine='pyarrow')
-
- # 添加列名
- self.all_columns.update(df.columns.tolist())
-
- # 识别时间字段
- time_cols_in_file = self._identify_time_columns(df.columns.tolist())
- self.time_columns[file_info.file_path] = time_cols_in_file
-
- # 统计时间字段候选
- for col in time_cols_in_file:
- time_column_candidates[col] = time_column_candidates.get(col, 0) + 1
-
- # 分析列类型
- for col in df.columns:
- if df[col].dtype in [np.float64, np.float32, np.int64, np.int32, np.float16, np.int16, np.int8]:
- self.numeric_columns.add(col)
- elif df[col].dtype == 'object':
- self.string_columns.add(col)
- elif df[col].dtype == 'datetime64[ns]':
- # 已经是datetime类型,添加到时间字段
- if col not in self.time_columns[file_info.file_path]:
- self.time_columns[file_info.file_path].add(col)
- time_column_candidates[col] = time_column_candidates.get(col, 0) + 1
-
- if (i + 1) % 10 == 0:
- print(f"已处理 {i+1}/{len(self.file_infos)} 个文件")
-
- except Exception as e:
- print(f"读取文件 {file_info.file_path} 失败: {e}")
- continue
-
- # 确定主要时间字段
- self.identified_time_column = self._determine_primary_time_column(time_column_candidates)
-
- # 为每个文件设置识别到的时间字段名
- for file_info in self.file_infos:
- if file_info.file_path in self.time_columns:
- time_cols = self.time_columns[file_info.file_path]
- # 优先使用识别到的主要时间字段,否则使用文件中的第一个时间字段
- if self.identified_time_column and self.identified_time_column in time_cols:
- file_info.data_time_column = self.identified_time_column
- elif time_cols:
- file_info.data_time_column = next(iter(time_cols))
-
- # 添加额外的字段
- base_columns = {'id_farm', 'name_farm', 'no_model_turbine', 'id_turbine'}
- self.all_columns.update(base_columns)
- self.string_columns.update(base_columns)
-
- # 确保data_time在列集合中
- if self.identified_time_column:
- self.all_columns.add(self.identified_time_column)
-
- print(f"\n表头读取完成,共 {len(self.all_columns)} 个字段")
- print(f"识别到的主要时间字段: {self.identified_time_column}")
- print(f"时间字段候选统计: {time_column_candidates}")
- print(f"数值字段: {len(self.numeric_columns)} 个")
- print(f"字符串字段: {len(self.string_columns)} 个")
-
- return self.all_columns
-
- def _identify_time_columns(self, columns: List[str]) -> Set[str]:
- """识别时间字段"""
- time_cols = set()
-
- for col in columns:
- col_lower = col.lower()
- # 检查是否匹配任何时间字段别名
- for alias in self.time_column_aliases:
- alias_lower = alias.lower()
- if alias_lower in col_lower or col_lower in alias_lower:
- time_cols.add(col)
- break
-
- return time_cols
-
- def _determine_primary_time_column(self, time_column_candidates: Dict[str, int]) -> Optional[str]:
- """确定主要时间字段"""
- if not time_column_candidates:
- return None
-
- # 按出现频率排序
- sorted_candidates = sorted(time_column_candidates.items(), key=lambda x: x[1], reverse=True)
-
- # 优先选择完全匹配"data_time"的字段
- for col, count in sorted_candidates:
- if col.lower() == "data_time":
- return col
-
- # 如果没有data_time,选择出现频率最高的时间字段
- primary_col, count = sorted_candidates[0]
-
- # 如果出现频率超过文件总数的50%,则认为是主要时间字段
- if count >= len(self.file_infos) * 0.5:
- return primary_col
- else:
- print(f"警告: 未找到统一的时间字段,最高频率字段 '{primary_col}' 仅出现在 {count}/{len(self.file_infos)} 个文件中")
- return primary_col
-
- def get_sql_columns(self) -> List[str]:
- """获取SQL列定义,尝试推断数据类型"""
- sql_columns = []
-
- for column in sorted(self.all_columns):
- # 检查是否是时间字段
- is_time_column = False
- if column == self.identified_time_column:
- is_time_column = True
- else:
- col_lower = column.lower()
- for alias in self.time_column_aliases:
- if alias.lower() in col_lower:
- is_time_column = True
- break
-
- if column in ['id_farm', 'name_farm', 'no_model_turbine', 'id_turbine']:
- # 元数据字段
- sql_type = 'VARCHAR(100)'
- elif is_time_column:
- # 时间字段
- sql_type = 'DATETIME'
- elif any(keyword in column.lower() for keyword in ['status', 'code', 'flag']):
- # 状态码字段
- sql_type = 'VARCHAR(50)'
- elif any(keyword in column.lower() for keyword in ['id', 'no', 'num']):
- # ID字段
- sql_type = 'VARCHAR(50)'
- else:
- # 默认数值字段
- sql_type = 'DOUBLE'
-
- sql_columns.append(f"`{column}` {sql_type}")
-
- return sql_columns
-
- def get_unique_key_columns(self, default_keys: List[str] = None) -> List[str]:
- """获取唯一键列"""
- if default_keys is None:
- default_keys = ['id_farm', 'id_turbine', 'data_time']
-
- # 确保data_time在列集合中
- if self.identified_time_column and self.identified_time_column != 'data_time':
- # 如果识别到的时间字段不是'data_time',我们需要在创建表时将其重命名
- print(f"注意: 识别到的时间字段为 '{self.identified_time_column}',将作为 'data_time' 处理")
-
- return default_keys
-
- def get_update_columns(self, exclude_keys: List[str] = None) -> List[str]:
- """获取需要更新的列(排除唯一键)"""
- if exclude_keys is None:
- exclude_keys = ['id_farm', 'id_turbine', 'data_time', 'id']
-
- update_columns = []
- for column in sorted(self.all_columns):
- if column not in exclude_keys:
- update_columns.append(column)
-
- return update_columns
|