ConnectMysql.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import tempfile
  2. import time
  3. import traceback
  4. from os import *
  5. import pandas as pd
  6. import pymysql
  7. from pymysql.cursors import DictCursor
  8. from sqlalchemy import create_engine, text
  9. from utils.conf.read_conf import yaml_conf
  10. from utils.log.trans_log import logger
  11. class ConnectMysql:
  12. def __init__(self, connet_name):
  13. conf_path = path.abspath(__file__).split("utils")[0] + 'conf' + sep + 'config.yaml'
  14. self.yaml_data = yaml_conf(conf_path)
  15. self.connet_name = connet_name
  16. self.config = self.yaml_data[self.connet_name]
  17. self.database = self.config['database']
  18. # 从连接池中获取一个连接
  19. def get_conn(self):
  20. return pymysql.connect(**self.config, local_infile=True)
  21. # 使用连接执行sql
  22. def execute(self, sql, params=tuple()):
  23. with self.get_conn() as conn:
  24. with conn.cursor(cursor=DictCursor) as cursor:
  25. try:
  26. cursor.execute(sql, params)
  27. logger.info(f"开始执行SQL:{cursor._executed}")
  28. conn.commit()
  29. result = cursor.fetchall()
  30. return result
  31. except Exception as e:
  32. logger.info(f"执行sql:{sql},报错:{e}")
  33. logger.info(traceback.format_exc())
  34. conn.rollback()
  35. raise e
  36. def get_engine(self):
  37. config = self.config
  38. username = config['user']
  39. password = config['password']
  40. host = config['host']
  41. port = config['port']
  42. dbname = config['database']
  43. return create_engine(f'mysql+pymysql://{username}:{password}@{host}:{port}/{dbname}?local_infile=1')
  44. def execute_df_save(self, df, table_name, batch_count=1000):
  45. df.to_sql(table_name, self.get_engine(), index=False, if_exists='append', chunksize=batch_count)
  46. def read_sql_to_df(self, sql):
  47. df = pd.read_sql_query(sql, self.get_engine())
  48. return df
  49. def safe_load_data_local(self, df, table_name, batch_size=30000):
  50. """
  51. 安全加载数据到TiDB,包含以下优化:
  52. 1. 分批处理避免内存溢出
  53. 2. 完善的连接管理
  54. 3. 错误处理和重试机制
  55. """
  56. total_rows = len(df)
  57. success_rows = 0
  58. engine = self.get_engine()
  59. for i in range(0, total_rows, batch_size):
  60. batch = df.iloc[i:i + batch_size]
  61. retry_count = 0
  62. max_retries = 4
  63. while retry_count < max_retries:
  64. try:
  65. with tempfile.NamedTemporaryFile(mode='w') as tmp:
  66. batch.to_csv(tmp, index=False, header=False, sep='\t')
  67. tmp.flush()
  68. with engine.begin() as conn: # 自动提交事务
  69. # 设置当前会话内存配额
  70. conn.execute(text("SET tidb_mem_quota_query = 2147483648")) # 2GB
  71. # 执行LOAD DATA
  72. conn.execute(text(f"""
  73. LOAD DATA LOCAL INFILE '{tmp.name}'
  74. INTO TABLE {table_name}
  75. FIELDS TERMINATED BY '\t'
  76. LINES TERMINATED BY '\n'
  77. """))
  78. success_rows += len(batch)
  79. logger.info(f"成功加载批次 {i // batch_size + 1}: {len(batch)} 行")
  80. break # 成功则跳出重试循环
  81. except Exception as e:
  82. retry_count += 1
  83. logger.info(f"批次 {i // batch_size + 1} 第 {retry_count} 次尝试失败: {str(e)}")
  84. if retry_count >= max_retries:
  85. logger.error(f"批次 {i // batch_size + 1} 达到最大重试次数")
  86. logger.error(traceback.format_exc())
  87. raise
  88. time.sleep(2 ** retry_count) # 指数退避
  89. logger.info(f"数据加载完成: 总计 {success_rows}/{total_rows} 行")
  90. return success_rows