# -*- coding: utf-8 -*- # @Time : 2024/6/7 # @Author : 魏志亮 import os import traceback from typing import Any, Dict, List, Tuple, Union import pandas as pd import pymysql from pymysql.cursors import DictCursor from sqlalchemy import create_engine from sqlalchemy.engine import Engine from utils.conf.read_conf import load_yaml_config from utils.log.trans_log import error, info, debug class MySQLDatabase: """MySQL数据库连接管理类""" # 类级别的引擎缓存,避免重复创建 _engine_cache = {} def __init__(self, connection_name: str): """ 初始化MySQL数据库连接 Args: connection_name: 配置文件中对应的连接名称 """ # 获取配置文件路径 config_path = os.environ.get('ETL_CONF') if not config_path: raise ValueError("环境变量 ETL_CONF 未设置") # 加载配置 self.yaml_data = load_yaml_config(config_path) self.connection_name = connection_name # 验证配置是否存在 if connection_name not in self.yaml_data: raise KeyError(f"配置中不存在连接名称: {connection_name}") self.config = self.yaml_data[connection_name] self.database = self.config.get('database', '') # 验证必要配置项 required_keys = ['host', 'user', 'password', 'database'] missing_keys = [key for key in required_keys if key not in self.config] if missing_keys: raise KeyError(f"连接配置缺少必要项: {missing_keys}") def get_connection(self) -> pymysql.Connection: """ 从连接池中获取一个连接 Returns: pymysql连接对象 """ # 创建连接配置副本,避免修改原配置 conn_config = self.config.copy() # 移除可能不需要的配置项(如果有) conn_config.pop('charset', None) # pymysql连接时charset参数可能会冲突 return pymysql.connect( cursorclass=DictCursor, charset='utf8mb4', **conn_config ) def execute_query(self, sql: str, params: Union[Tuple, List, Dict] = None) -> List[Dict[str, Any]]: """ 执行SQL查询并返回结果 Args: sql: SQL语句 params: SQL参数,可以是元组、列表或字典 Returns: 查询结果列表,每个元素为字典形式 Raises: Exception: SQL执行错误时抛出 """ params = params or () conn = None cursor = None try: conn = self.get_connection() cursor = conn.cursor() # 执行SQL cursor.execute(sql, params) debug("开始执行SQL:\n", cursor._executed) # 提交事务 conn.commit() # 获取结果 result = cursor.fetchall() return result except Exception as e: error(f"执行SQL出错: {sql}") error(f"错误信息: {e}") error(traceback.format_exc()) if conn: conn.rollback() raise e finally: # 确保资源被释放 if cursor: cursor.close() if conn: conn.close() def execute_update(self, sql: str, params: Union[Tuple, List, Dict] = None) -> int: """ 执行更新操作(INSERT, UPDATE, DELETE) Args: sql: SQL语句 params: SQL参数 Returns: 影响的行数 """ params = params or () conn = None cursor = None try: conn = self.get_connection() cursor = conn.cursor() cursor.execute(sql, params) debug("开始执行SQL:", cursor._executed) conn.commit() return cursor.rowcount except Exception as e: error(f"执行更新SQL出错: {sql}") error(f"错误信息: {e}") error(traceback.format_exc()) if conn: conn.rollback() raise e finally: if cursor: cursor.close() if conn: conn.close() def get_engine(self) -> Engine: """ 获取SQLAlchemy引擎,使用缓存避免重复创建 Returns: SQLAlchemy引擎对象 """ # 构建缓存键 config = self.config cache_key = f"{config['host']}:{config['port']}:{config['user']}:{config['database']}" # 检查缓存中是否已有引擎 if cache_key not in self._engine_cache: username = config['user'] password = config['password'] host = config['host'] port = config['port'] dbname = config['database'] # 构建连接URL connection_url = f'mysql+pymysql://{username}:{password}@{host}:{port}/{dbname}?charset=utf8mb4' # 创建引擎并缓存 self._engine_cache[cache_key] = create_engine( connection_url, pool_size=10, # 增加连接池大小 pool_recycle=3600, pool_pre_ping=True, # 连接池预ping,确保连接有效 echo=False # 设置为True可打印SQL日志 ) return self._engine_cache[cache_key] def save_dataframe(self, df: pd.DataFrame, table_name: str, chunk_size: int = 10000, if_exists: str = 'append') -> None: """ 将DataFrame保存到数据库表 Args: df: pandas DataFrame对象 table_name: 目标表名 chunk_size: 每批写入的行数 if_exists: 表存在时的处理方式:'fail', 'replace', 'append' """ try: df.to_sql( table_name, self.get_engine(), index=False, if_exists=if_exists, chunksize=chunk_size, method='multi' # 使用多值插入提高性能 ) info(f"成功保存 {len(df)} 条数据到表 {table_name}") except Exception as e: error(f"保存DataFrame到表 {table_name} 失败: {e}") error(traceback.format_exc()) raise e def read_sql_to_dataframe(self, sql: str) -> pd.DataFrame: """ 执行SQL查询并返回DataFrame Args: sql: SQL查询语句 Returns: 查询结果的DataFrame """ try: df = pd.read_sql_query(sql, self.get_engine()) debug(f"查询返回 {len(df)} 行数据") return df except Exception as e: error(f"执行SQL查询失败: {sql}") error(f"错误信息: {e}") error(traceback.format_exc()) raise e # 为了保持向后兼容,保留原方法名(可选) get_conn = get_connection execute = execute_query execute_df_save = save_dataframe read_sql_to_df = read_sql_to_dataframe