thread_pool.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. # 线程池类
  2. import concurrent.futures
  3. from typing import List, Callable, Any, Tuple, Optional
  4. import logging
  5. from file_scanner import ParquetFileInfo
  6. logger = logging.getLogger(__name__)
  7. class ThreadPoolManager:
  8. """线程池管理器"""
  9. def __init__(self, max_workers: int = 20):
  10. self.max_workers = max_workers
  11. self.executor = None
  12. def process_files(self, file_infos: List[ParquetFileInfo],
  13. process_func: Callable[[ParquetFileInfo], Any]) -> List[Tuple[ParquetFileInfo, Any]]:
  14. """使用线程池处理文件"""
  15. results = []
  16. with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
  17. self.executor = executor
  18. # 提交所有任务
  19. future_to_file = {
  20. executor.submit(process_func, file_info): file_info
  21. for file_info in file_infos
  22. }
  23. # 处理完成的任务
  24. completed = 0
  25. total = len(file_infos)
  26. for future in concurrent.futures.as_completed(future_to_file):
  27. file_info = future_to_file[future]
  28. completed += 1
  29. try:
  30. result = future.result()
  31. results.append((file_info, result))
  32. if isinstance(result, tuple) and len(result) == 3:
  33. total_rows, inserted_rows, updated_rows = result
  34. logger.info(f"进度: {completed}/{total} - 文件 {file_info.turbine_id}.parquet "
  35. f"处理完成, 总行: {total_rows}, 插入: {inserted_rows}, 更新: {updated_rows}")
  36. else:
  37. logger.info(f"进度: {completed}/{total} - 文件 {file_info.turbine_id}.parquet 处理完成")
  38. except Exception as e:
  39. logger.error(f"处理文件 {file_info.file_path} 时出错: {e}")
  40. results.append((file_info, e))
  41. return results
  42. def process_with_data_loader(self, file_infos: List[ParquetFileInfo],
  43. data_loader: Any) -> List[Tuple[ParquetFileInfo, Any]]:
  44. """使用DataLoader处理文件"""
  45. return self.process_files(file_infos, data_loader.load_file)