common_utils.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import os
  2. from datetime import datetime
  3. import chardet
  4. import pandas as pd
  5. from sqlalchemy import create_engine, text, inspect
  6. from utils.log.trans_log import logger
  7. def get_engine(host='192.168.50.235', port=30306, user='root', pwd='admin123456', db='datang'):
  8. engine = create_engine(f'mysql+pymysql://{user}:{pwd}@{host}:{port}/{db}')
  9. return engine
  10. 1
  11. def print_log(message: str = '', bool_print_log=False, pre_time=None):
  12. log_now = datetime.now()
  13. if bool_print_log:
  14. if pre_time is None:
  15. logger.info(message)
  16. else:
  17. logger.info('%s,耗时:%s', message, log_now - pre_time)
  18. return log_now
  19. # 获取文件编码
  20. def detect_file_encoding(filename):
  21. with open(filename, 'rb') as f:
  22. rawdata = f.read(1000)
  23. result = chardet.detect(rawdata)
  24. encoding = result['encoding']
  25. if encoding is None:
  26. encoding = 'gb18030'
  27. if encoding.lower() in ['utf-8', 'ascii', 'utf8', 'utf-8-sig']:
  28. return 'utf-8'
  29. return 'gb18030'
  30. def df_save_file(df: pd.DataFrame = None, file_name: str = '', bool_add_time: bool = True, bool_print_log: bool = True):
  31. if df is None:
  32. raise Exception('df不能为None')
  33. res_file_name = file_name
  34. if bool_add_time:
  35. now = datetime.now()
  36. now_str = now.strftime('%Y%m%d%H%M%S')
  37. res_file_name = os.path.join(os.path.dirname(file_name), now_str + '-' + os.path.basename(file_name))
  38. log_now = print_log(f'开始保存df:{df.shape}到{res_file_name}', bool_print_log)
  39. os.makedirs(os.path.dirname(file_name), exist_ok=True)
  40. if file_name.endswith('csv'):
  41. df.to_csv(res_file_name, index=False, encoding='utf8')
  42. elif file_name.endswith('xls') or file_name.endswith('xlsx'):
  43. df.to_excel(res_file_name, index=False)
  44. elif file_name.endswith('parquet'):
  45. df.to_parquet(res_file_name, index=False)
  46. else:
  47. raise Exception('需要添加映射')
  48. print_log(f'完成保存df:{df.shape}到{res_file_name}', bool_print_log, pre_time=log_now)
  49. return res_file_name
  50. def read_file(file_path, read_cols: list = None, header: int = 0,
  51. nrows: int = None, bool_print_log: bool = True) -> pd.DataFrame:
  52. # 参数验证
  53. if not file_path or not isinstance(file_path, str):
  54. raise ValueError("文件路径不能为空且必须是字符串")
  55. if not os.path.exists(file_path):
  56. raise FileNotFoundError(f"文件不存在: {file_path}")
  57. file_ext = file_path.lower()
  58. # 构建读取参数的通用字典
  59. read_params = {'header': header}
  60. if nrows is not None:
  61. read_params['nrows'] = nrows
  62. try:
  63. log_now = print_log(f'开始读取文件: {file_path}', bool_print_log)
  64. # 根据文件扩展名选择读取方式
  65. if file_ext.endswith('.csv'):
  66. encoding = detect_file_encoding(file_path)
  67. read_params['encoding'] = encoding
  68. if read_cols:
  69. df = pd.read_csv(file_path, usecols=read_cols, **read_params)
  70. else:
  71. df = pd.read_csv(file_path, **read_params)
  72. elif file_ext.endswith(('.xls', '.xlsx')):
  73. if read_cols:
  74. df = pd.read_excel(file_path, usecols=read_cols, **read_params)
  75. else:
  76. df = pd.read_excel(file_path, **read_params)
  77. elif file_ext.endswith('.parquet'):
  78. # parquet文件使用不同的参数名
  79. parquet_params = {'columns': read_cols} if read_cols else {}
  80. df = pd.read_parquet(file_path, **parquet_params)
  81. elif file_ext.endswith('.json'):
  82. if read_cols:
  83. df = pd.read_json(file_path, **read_params)[read_cols]
  84. else:
  85. df = pd.read_json(file_path, **read_params)
  86. else:
  87. supported_formats = ['.csv', '.xls', '.xlsx', '.parquet', '.json']
  88. raise ValueError(
  89. f"不支持的文件格式: {os.path.splitext(file_path)[1]}\n"
  90. f"支持的文件格式: {', '.join(supported_formats)}"
  91. )
  92. print_log(f'文件读取成功: {file_path}, 数据量: {df.shape}', bool_print_log, pre_time=log_now)
  93. return df
  94. except pd.errors.EmptyDataError:
  95. print_log(f'文件为空: {file_path}', bool_print_log)
  96. return pd.DataFrame()
  97. except Exception as e:
  98. print_log(f'读取文件失败: {file_path}, 错误: {str(e)}', bool_print_log)
  99. raise
  100. def df_save_table(engine, df: pd.DataFrame, table_name, pre_save: str = None):
  101. """
  102. engine: 数据库引擎
  103. df: DataFrame
  104. table_name: 表名
  105. pre_save: 删除表(DROP),清空表(TRUNCATE)或者不处理(None)
  106. """
  107. log_time = print_log(f'开始保存到表{table_name},数据量:{df.shape},前置处理表:{pre_save}', bool_print_log=True)
  108. if not pre_save is None:
  109. with engine.connect() as conn:
  110. # 检查表是否存在
  111. inspector = inspect(engine)
  112. if table_name in inspector.get_table_names():
  113. # 删除表
  114. conn.execute(text(f"{pre_save} TABLE {table_name}"))
  115. conn.commit()
  116. df.to_sql(table_name, con=engine, if_exists='append', index=False)
  117. print_log(f'开始保存到表{table_name},数据量:{df.shape},前置处理表:{pre_save}', bool_print_log=True,
  118. pre_time=log_time)
  119. from pathlib import Path
  120. def get_all_files(read_path: str, pre_filter: tuple = None):
  121. result = list()
  122. if os.path.isfile(read_path):
  123. if pre_filter is None:
  124. result.append(str(Path(read_path)))
  125. else:
  126. if read_path.split('.')[-1].endswith(pre_filter):
  127. result.append(str(Path(read_path)))
  128. else:
  129. for root, dir, files in os.walk(read_path):
  130. for file in files:
  131. whole_path = os.path.join(root, file)
  132. if pre_filter is None:
  133. result.append(str(Path(whole_path)))
  134. else:
  135. if whole_path.split('.')[-1].endswith(pre_filter):
  136. result.append(str(Path(whole_path)))
  137. return result