thread_pool.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import concurrent.futures
  2. from typing import List, Callable, Any, Tuple, Optional, Set
  3. import logging
  4. import os
  5. import json
  6. import threading
  7. from datetime import datetime
  8. from file_scanner import ParquetFileInfo
  9. logger = logging.getLogger(__name__)
  10. class ProcessedRecordManager:
  11. """已处理文件记录管理器(线程安全)"""
  12. def __init__(self, record_file: str = "record_processed.json"):
  13. """
  14. 初始化记录管理器
  15. Args:
  16. record_file: 记录文件路径
  17. """
  18. self.record_file = record_file
  19. self.processed_files: Set[str] = set()
  20. self._lock = threading.Lock() # 线程锁,确保线程安全
  21. self._load_records()
  22. def _load_records(self):
  23. """加载已处理文件记录"""
  24. try:
  25. with self._lock:
  26. if os.path.exists(self.record_file):
  27. with open(self.record_file, 'r', encoding='utf-8') as f:
  28. records = json.load(f)
  29. if isinstance(records, list):
  30. self.processed_files = set(records)
  31. logger.info(f"📁 已加载 {len(self.processed_files)} 个已处理文件记录")
  32. else:
  33. logger.info(f"📁 记录文件不存在,将创建新文件: {self.record_file}")
  34. except Exception as e:
  35. logger.warning(f"❌ 加载记录文件失败: {e}")
  36. self.processed_files = set()
  37. def is_processed(self, file_path: str) -> bool:
  38. """检查文件是否已处理过"""
  39. with self._lock:
  40. return file_path in self.processed_files
  41. def add_record(self, file_path: str, metadata: dict = None):
  42. """添加处理记录并立即保存(线程安全)"""
  43. with self._lock:
  44. if file_path not in self.processed_files:
  45. self.processed_files.add(file_path)
  46. self._save_records_internal(file_path, metadata)
  47. def _save_records_internal(self, file_path: str = None, metadata: dict = None):
  48. """内部方法:保存记录到文件(线程安全)"""
  49. try:
  50. # 转换为列表并排序,便于阅读
  51. records_list = sorted(list(self.processed_files))
  52. # 准备要保存的数据
  53. save_data = records_list
  54. with open(self.record_file, 'w', encoding='utf-8') as f:
  55. json.dump(save_data, f, ensure_ascii=False, indent=2)
  56. logger.info(f"💾 已保存 {len(records_list)} 个处理记录到 {self.record_file}")
  57. return True
  58. except Exception as e:
  59. logger.error(f"❌ 保存记录文件失败: {e}")
  60. return False
  61. def get_record_count(self) -> int:
  62. """获取记录数量"""
  63. with self._lock:
  64. return len(self.processed_files)
  65. def clear_records(self):
  66. """清空记录"""
  67. with self._lock:
  68. self.processed_files.clear()
  69. try:
  70. if os.path.exists(self.record_file):
  71. os.remove(self.record_file)
  72. logger.info(f"🗑️ 已删除记录文件: {self.record_file}")
  73. except Exception as e:
  74. logger.warning(f"删除记录文件失败: {e}")
  75. class ThreadPoolManager:
  76. """线程池管理器,集成记录管理功能"""
  77. def __init__(self, max_workers: int = 20, record_file: str = "record_processed.json"):
  78. self.max_workers = max_workers
  79. self.executor = None
  80. self.record_manager = ProcessedRecordManager(record_file)
  81. def process_files(self, file_infos: List[ParquetFileInfo],
  82. process_func: Callable[[ParquetFileInfo], Any]) -> List[Tuple[ParquetFileInfo, Any]]:
  83. """
  84. 使用线程池处理文件,自动记录成功处理的文件
  85. Args:
  86. file_infos: 文件信息列表
  87. process_func: 处理函数
  88. Returns:
  89. 处理结果列表
  90. """
  91. # 先过滤掉已处理过的文件
  92. unprocessed_files = []
  93. for file_info in file_infos:
  94. if self.record_manager.is_processed(file_info.file_path):
  95. logger.debug(f"⏭️ 跳过已处理文件: {os.path.basename(file_info.file_path)}")
  96. else:
  97. unprocessed_files.append(file_info)
  98. if not unprocessed_files:
  99. logger.info("🎉 所有文件都已处理过,无需处理新文件")
  100. return []
  101. logger.info(f"📝 需要处理 {len(unprocessed_files)} 个新文件(跳过了 {len(file_infos)-len(unprocessed_files)} 个已处理文件)")
  102. results = []
  103. with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
  104. self.executor = executor
  105. # 提交所有任务
  106. future_to_file = {
  107. executor.submit(process_func, file_info): file_info
  108. for file_info in unprocessed_files
  109. }
  110. # 处理完成的任务
  111. completed = 0
  112. total = len(unprocessed_files)
  113. for future in concurrent.futures.as_completed(future_to_file):
  114. file_info = future_to_file[future]
  115. completed += 1
  116. try:
  117. result = future.result()
  118. if isinstance(result, tuple) and len(result) == 3:
  119. # 成功处理,记录文件
  120. total_rows, inserted_rows, updated_rows = result
  121. results.append((file_info, result))
  122. # ✅ 立即记录成功处理的文件
  123. self.record_manager.add_record(file_info.file_path, {
  124. 'total_rows': total_rows,
  125. 'inserted_rows': inserted_rows,
  126. 'updated_rows': updated_rows,
  127. 'processed_time': datetime.now().isoformat()
  128. })
  129. logger.info(f"✅ 进度: {completed}/{total} - 文件 {file_info.turbine_id}.parquet "
  130. f"处理完成, 总行: {total_rows}, 插入: {inserted_rows}, 更新: {updated_rows}")
  131. else:
  132. # 处理失败,不记录
  133. results.append((file_info, result))
  134. logger.error(f"❌ 进度: {completed}/{total} - 文件 {file_info.turbine_id}.parquet 处理失败")
  135. except Exception as e:
  136. # 处理异常,不记录
  137. logger.error(f"❌ 处理文件 {file_info.file_path} 时出错: {e}")
  138. results.append((file_info, e))
  139. return results
  140. def process_with_data_loader(self, file_infos: List[ParquetFileInfo],
  141. data_loader: Any) -> List[Tuple[ParquetFileInfo, Any]]:
  142. """
  143. 使用DataLoader处理文件,自动记录成功处理的文件
  144. Args:
  145. file_infos: 文件信息列表
  146. data_loader: DataLoader实例
  147. Returns:
  148. 处理结果列表
  149. """
  150. return self.process_files(file_infos, data_loader.load_file)
  151. def get_record_count(self) -> int:
  152. """获取已处理文件数量"""
  153. return self.record_manager.get_record_count()
  154. def get_unprocessed_count(self, all_files: List[ParquetFileInfo]) -> int:
  155. """获取未处理文件数量"""
  156. unprocessed = 0
  157. for file_info in all_files:
  158. if not self.record_manager.is_processed(file_info.file_path):
  159. unprocessed += 1
  160. return unprocessed