Axis1DataImpl.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. import multiprocessing
  2. import os
  3. import traceback
  4. import pandas as pd
  5. from service.import_data_service import update_transfer_progress
  6. from trans.ExecParam import ExecParam
  7. from utils.file.trans_methods import split_array, read_excel_files, read_file_to_df, find_header
  8. from utils.log.import_data_log import log_print
  9. from utils.systeminfo.sysinfo import use_files_get_max_cpu_count, get_dir_size, max_file_size_get_max_cpu_count
  10. class Axis1DataImpl(object):
  11. def __init__(self, id, process_count, now_count, exec_param: ExecParam, save_db=True):
  12. self.id = id
  13. self.process_count = process_count
  14. self.now_count = now_count
  15. self.exec_param = exec_param
  16. self.save_db = save_db
  17. self.lock_map = dict()
  18. # for i in range(4):
  19. # self.lock_map[i] = multiprocessing.Manager().Lock()
  20. self.lock = multiprocessing.Manager().Lock()
  21. self.field_dict = multiprocessing.Manager().dict()
  22. def get_lock(self, split_col_value, col_name):
  23. boolean_first_time = False
  24. if split_col_value:
  25. filed_name = f'{split_col_value}_{col_name}'
  26. else:
  27. filed_name = col_name
  28. exists_count = len(self.field_dict.keys())
  29. # if exists_count >= 4:
  30. # return self.lock
  31. if filed_name not in self.field_dict:
  32. boolean_first_time = True
  33. self.field_dict[filed_name] = len(self.field_dict.keys()) + 1
  34. # return boolean_first_time, self.lock_map[self.field_dict[filed_name]]
  35. return boolean_first_time, self.lock
  36. def read_and_save_file(self, file_path):
  37. if self.exec_param.has_header:
  38. header = find_header(file_path, self.exec_param.use_cols)
  39. if header is None:
  40. raise Exception(f"文件{os.path.basename(file_path)}没有找到列名")
  41. else:
  42. header = None
  43. df = read_file_to_df(file_path, header=header)
  44. col_map = {'file_name': 'file_name'}
  45. if 'sheet_name' in df.columns:
  46. col_map['sheet_name'] = 'sheet_name'
  47. for col in df.columns:
  48. if col in self.exec_param.use_cols and col not in ['file_name', 'sheet_name']:
  49. col_map[col] = self.exec_param.mapping_cols[col]
  50. df = df.rename(columns=col_map)
  51. # 如果数据不包含索引列报错
  52. for col in self.exec_param.index_cols:
  53. if col not in df.columns:
  54. log_print(f"{file_path}没有索引列{col}")
  55. raise Exception(f"{file_path}没有索引列{col}")
  56. if self.exec_param.split_cols:
  57. df['split_col'] = df[self.exec_param.split_cols].apply(
  58. lambda x: '_'.join([str(i).replace(' ', '_').replace(':', '_') for i in x.values]),
  59. axis=1)
  60. else:
  61. df['split_col'] = 'All'
  62. split_col = df['split_col'].unique()
  63. if len(split_col) >= 1000:
  64. log_print(f"{file_path}切割文件太多,大于等于1000个")
  65. raise Exception(f"{file_path}切割文件太多,大于等于1000个")
  66. general_fields = list(df.columns)
  67. general_fields.remove('split_col')
  68. general_fields.remove('file_name')
  69. if 'sheet_name' in general_fields:
  70. general_fields.remove('sheet_name')
  71. for col in self.exec_param.index_cols:
  72. general_fields.remove(col)
  73. for split_col_value in split_col:
  74. for col in general_fields:
  75. now_cols = [i for i in self.exec_param.index_cols]
  76. now_cols.append(col)
  77. now_df = df[df['split_col'] == split_col_value][now_cols]
  78. boolean_first_time, lock = self.get_lock(split_col_value, col)
  79. with lock:
  80. path = os.path.join(self.exec_param.path_param.get_merge_tmp_path(), split_col_value)
  81. os.makedirs(path, exist_ok=True)
  82. if boolean_first_time:
  83. now_df.to_csv(os.path.join(path, f'{col}.csv'), index=False, encoding='utf-8')
  84. else:
  85. now_df.to_csv(os.path.join(path, f'{col}.csv'), index=False, mode='a', header=False,
  86. encoding='utf-8')
  87. def read_merge_df_to_process(self, base_name):
  88. path = os.path.join(self.exec_param.path_param.get_merge_tmp_path(), base_name)
  89. all_files = os.listdir(path)
  90. dfs = [pd.read_csv(os.path.join(path, i), encoding='utf-8', index_col=self.exec_param.index_cols) for i in
  91. all_files]
  92. df = pd.concat(dfs, axis=1)
  93. df.reset_index(inplace=True)
  94. df.to_csv(os.path.join(self.exec_param.path_param.get_process_tmp_path(), base_name + '.csv'), index=False,
  95. encoding='utf-8')
  96. def run(self):
  97. if len(self.exec_param.index_cols) == 0:
  98. log_print("合并表需要闯将索引列")
  99. log_print(traceback.format_exc())
  100. raise Exception("合并表需要闯将索引列")
  101. all_files = read_excel_files(self.exec_param.path_param.get_unzip_tmp_path())
  102. split_count = use_files_get_max_cpu_count(all_files)
  103. all_arrays = split_array(all_files, split_count)
  104. log_print("开始读取横向合并文件,文件总数:", len(all_files), ",文件分片数:", split_count)
  105. for index, now_array in enumerate(all_arrays):
  106. with multiprocessing.Pool(split_count) as pool:
  107. pool.map(self.read_and_save_file, now_array)
  108. update_transfer_progress(self.id, round(20 + 50 * (index + 1) / len(all_arrays)), self.process_count,
  109. self.now_count, self.save_db)
  110. all_dirs = os.listdir(self.exec_param.path_param.get_merge_tmp_path())
  111. dir_size = get_dir_size(os.path.join(self.exec_param.path_param.get_merge_tmp_path(), all_dirs[0]))
  112. pool_count = max_file_size_get_max_cpu_count(dir_size)
  113. pool_count = pool_count if pool_count <= len(all_files) else len(all_files)
  114. with multiprocessing.Pool(pool_count) as pool:
  115. pool.map(self.read_merge_df_to_process, all_dirs)
  116. update_transfer_progress(self.id, 80, self.process_count, self.now_count, self.save_db)