ConnectMysql.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2024/6/7
  3. # @Author : 魏志亮
  4. import os
  5. import traceback
  6. from typing import Any, Dict, List, Tuple, Union
  7. import pandas as pd
  8. import pymysql
  9. from pymysql.cursors import DictCursor
  10. from sqlalchemy import create_engine
  11. from sqlalchemy.engine import Engine
  12. from utils.conf.read_conf import load_yaml_config
  13. from utils.log.trans_log import error, info, debug
  14. class MySQLDatabase:
  15. """MySQL数据库连接管理类"""
  16. # 类级别的引擎缓存,避免重复创建
  17. _engine_cache = {}
  18. def __init__(self, connection_name: str):
  19. """
  20. 初始化MySQL数据库连接
  21. Args:
  22. connection_name: 配置文件中对应的连接名称
  23. """
  24. # 获取配置文件路径
  25. config_path = os.environ.get('ETL_CONF')
  26. if not config_path:
  27. raise ValueError("环境变量 ETL_CONF 未设置")
  28. # 加载配置
  29. self.yaml_data = load_yaml_config(config_path)
  30. self.connection_name = connection_name
  31. # 验证配置是否存在
  32. if connection_name not in self.yaml_data:
  33. raise KeyError(f"配置中不存在连接名称: {connection_name}")
  34. self.config = self.yaml_data[connection_name]
  35. self.database = self.config.get('database', '')
  36. # 验证必要配置项
  37. required_keys = ['host', 'user', 'password', 'database']
  38. missing_keys = [key for key in required_keys if key not in self.config]
  39. if missing_keys:
  40. raise KeyError(f"连接配置缺少必要项: {missing_keys}")
  41. def get_connection(self) -> pymysql.Connection:
  42. """
  43. 从连接池中获取一个连接
  44. Returns:
  45. pymysql连接对象
  46. """
  47. # 创建连接配置副本,避免修改原配置
  48. conn_config = self.config.copy()
  49. # 移除可能不需要的配置项(如果有)
  50. conn_config.pop('charset', None) # pymysql连接时charset参数可能会冲突
  51. return pymysql.connect(
  52. cursorclass=DictCursor,
  53. charset='utf8mb4',
  54. **conn_config
  55. )
  56. def execute_query(self, sql: str, params: Union[Tuple, List, Dict] = None) -> List[Dict[str, Any]]:
  57. """
  58. 执行SQL查询并返回结果
  59. Args:
  60. sql: SQL语句
  61. params: SQL参数,可以是元组、列表或字典
  62. Returns:
  63. 查询结果列表,每个元素为字典形式
  64. Raises:
  65. Exception: SQL执行错误时抛出
  66. """
  67. params = params or ()
  68. conn = None
  69. cursor = None
  70. try:
  71. conn = self.get_connection()
  72. cursor = conn.cursor()
  73. # 执行SQL
  74. cursor.execute(sql, params)
  75. debug("开始执行SQL:\n", cursor._executed)
  76. # 提交事务
  77. conn.commit()
  78. # 获取结果
  79. result = cursor.fetchall()
  80. return result
  81. except Exception as e:
  82. error(f"执行SQL出错: {sql}")
  83. error(f"错误信息: {e}")
  84. error(traceback.format_exc())
  85. if conn:
  86. conn.rollback()
  87. raise e
  88. finally:
  89. # 确保资源被释放
  90. if cursor:
  91. cursor.close()
  92. if conn:
  93. conn.close()
  94. def execute_update(self, sql: str, params: Union[Tuple, List, Dict] = None) -> int:
  95. """
  96. 执行更新操作(INSERT, UPDATE, DELETE)
  97. Args:
  98. sql: SQL语句
  99. params: SQL参数
  100. Returns:
  101. 影响的行数
  102. """
  103. params = params or ()
  104. conn = None
  105. cursor = None
  106. try:
  107. conn = self.get_connection()
  108. cursor = conn.cursor()
  109. cursor.execute(sql, params)
  110. debug("开始执行SQL:", cursor._executed)
  111. conn.commit()
  112. return cursor.rowcount
  113. except Exception as e:
  114. error(f"执行更新SQL出错: {sql}")
  115. error(f"错误信息: {e}")
  116. error(traceback.format_exc())
  117. if conn:
  118. conn.rollback()
  119. raise e
  120. finally:
  121. if cursor:
  122. cursor.close()
  123. if conn:
  124. conn.close()
  125. def get_engine(self) -> Engine:
  126. """
  127. 获取SQLAlchemy引擎,使用缓存避免重复创建
  128. Returns:
  129. SQLAlchemy引擎对象
  130. """
  131. # 构建缓存键
  132. config = self.config
  133. cache_key = f"{config['host']}:{config['port']}:{config['user']}:{config['database']}"
  134. # 检查缓存中是否已有引擎
  135. if cache_key not in self._engine_cache:
  136. username = config['user']
  137. password = config['password']
  138. host = config['host']
  139. port = config['port']
  140. dbname = config['database']
  141. # 构建连接URL
  142. connection_url = f'mysql+pymysql://{username}:{password}@{host}:{port}/{dbname}?charset=utf8mb4'
  143. # 创建引擎并缓存
  144. self._engine_cache[cache_key] = create_engine(
  145. connection_url,
  146. pool_size=10, # 增加连接池大小
  147. pool_recycle=3600,
  148. pool_pre_ping=True, # 连接池预ping,确保连接有效
  149. echo=False # 设置为True可打印SQL日志
  150. )
  151. return self._engine_cache[cache_key]
  152. def save_dataframe(self, df: pd.DataFrame, table_name: str, chunk_size: int = 10000,
  153. if_exists: str = 'append') -> None:
  154. """
  155. 将DataFrame保存到数据库表
  156. Args:
  157. df: pandas DataFrame对象
  158. table_name: 目标表名
  159. chunk_size: 每批写入的行数
  160. if_exists: 表存在时的处理方式:'fail', 'replace', 'append'
  161. """
  162. try:
  163. df.to_sql(
  164. table_name,
  165. self.get_engine(),
  166. index=False,
  167. if_exists=if_exists,
  168. chunksize=chunk_size,
  169. method='multi' # 使用多值插入提高性能
  170. )
  171. info(f"成功保存 {len(df)} 条数据到表 {table_name}")
  172. except Exception as e:
  173. error(f"保存DataFrame到表 {table_name} 失败: {e}")
  174. error(traceback.format_exc())
  175. raise e
  176. def read_sql_to_dataframe(self, sql: str) -> pd.DataFrame:
  177. """
  178. 执行SQL查询并返回DataFrame
  179. Args:
  180. sql: SQL查询语句
  181. Returns:
  182. 查询结果的DataFrame
  183. """
  184. try:
  185. df = pd.read_sql_query(sql, self.get_engine())
  186. debug(f"查询返回 {len(df)} 行数据")
  187. return df
  188. except Exception as e:
  189. error(f"执行SQL查询失败: {sql}")
  190. error(f"错误信息: {e}")
  191. error(traceback.format_exc())
  192. raise e
  193. # 为了保持向后兼容,保留原方法名(可选)
  194. get_conn = get_connection
  195. execute = execute_query
  196. execute_df_save = save_dataframe
  197. read_sql_to_df = read_sql_to_dataframe