| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246 |
- # -*- coding: utf-8 -*-
- # @Time : 2024/6/7
- # @Author : 魏志亮
- import os
- import traceback
- from typing import Any, Dict, List, Tuple, Union
- import pandas as pd
- import pymysql
- from pymysql.cursors import DictCursor
- from sqlalchemy import create_engine
- from sqlalchemy.engine import Engine
- from utils.conf.read_conf import load_yaml_config
- from utils.log.trans_log import error, info, debug
- class MySQLDatabase:
- """MySQL数据库连接管理类"""
- # 类级别的引擎缓存,避免重复创建
- _engine_cache = {}
- def __init__(self, connection_name: str):
- """
- 初始化MySQL数据库连接
-
- Args:
- connection_name: 配置文件中对应的连接名称
- """
- # 获取配置文件路径
- config_path = os.environ.get('ETL_CONF')
- if not config_path:
- raise ValueError("环境变量 ETL_CONF 未设置")
- # 加载配置
- self.yaml_data = load_yaml_config(config_path)
- self.connection_name = connection_name
- # 验证配置是否存在
- if connection_name not in self.yaml_data:
- raise KeyError(f"配置中不存在连接名称: {connection_name}")
- self.config = self.yaml_data[connection_name]
- self.database = self.config.get('database', '')
- # 验证必要配置项
- required_keys = ['host', 'user', 'password', 'database']
- missing_keys = [key for key in required_keys if key not in self.config]
- if missing_keys:
- raise KeyError(f"连接配置缺少必要项: {missing_keys}")
- def get_connection(self) -> pymysql.Connection:
- """
- 从连接池中获取一个连接
-
- Returns:
- pymysql连接对象
- """
- # 创建连接配置副本,避免修改原配置
- conn_config = self.config.copy()
- # 移除可能不需要的配置项(如果有)
- conn_config.pop('charset', None) # pymysql连接时charset参数可能会冲突
- return pymysql.connect(
- cursorclass=DictCursor,
- charset='utf8mb4',
- **conn_config
- )
- def execute_query(self, sql: str, params: Union[Tuple, List, Dict] = None) -> List[Dict[str, Any]]:
- """
- 执行SQL查询并返回结果
-
- Args:
- sql: SQL语句
- params: SQL参数,可以是元组、列表或字典
-
- Returns:
- 查询结果列表,每个元素为字典形式
-
- Raises:
- Exception: SQL执行错误时抛出
- """
- params = params or ()
- conn = None
- cursor = None
- try:
- conn = self.get_connection()
- cursor = conn.cursor()
- # 执行SQL
- cursor.execute(sql, params)
- debug("开始执行SQL:\n", cursor._executed)
- # 提交事务
- conn.commit()
- # 获取结果
- result = cursor.fetchall()
- return result
- except Exception as e:
- error(f"执行SQL出错: {sql}")
- error(f"错误信息: {e}")
- error(traceback.format_exc())
- if conn:
- conn.rollback()
- raise e
- finally:
- # 确保资源被释放
- if cursor:
- cursor.close()
- if conn:
- conn.close()
- def execute_update(self, sql: str, params: Union[Tuple, List, Dict] = None) -> int:
- """
- 执行更新操作(INSERT, UPDATE, DELETE)
-
- Args:
- sql: SQL语句
- params: SQL参数
-
- Returns:
- 影响的行数
- """
- params = params or ()
- conn = None
- cursor = None
- try:
- conn = self.get_connection()
- cursor = conn.cursor()
- cursor.execute(sql, params)
- debug("开始执行SQL:", cursor._executed)
- conn.commit()
- return cursor.rowcount
- except Exception as e:
- error(f"执行更新SQL出错: {sql}")
- error(f"错误信息: {e}")
- error(traceback.format_exc())
- if conn:
- conn.rollback()
- raise e
- finally:
- if cursor:
- cursor.close()
- if conn:
- conn.close()
- def get_engine(self) -> Engine:
- """
- 获取SQLAlchemy引擎,使用缓存避免重复创建
-
- Returns:
- SQLAlchemy引擎对象
- """
- # 构建缓存键
- config = self.config
- cache_key = f"{config['host']}:{config['port']}:{config['user']}:{config['database']}"
- # 检查缓存中是否已有引擎
- if cache_key not in self._engine_cache:
- username = config['user']
- password = config['password']
- host = config['host']
- port = config['port']
- dbname = config['database']
- # 构建连接URL
- connection_url = f'mysql+pymysql://{username}:{password}@{host}:{port}/{dbname}?charset=utf8mb4'
- # 创建引擎并缓存
- self._engine_cache[cache_key] = create_engine(
- connection_url,
- pool_size=10, # 增加连接池大小
- pool_recycle=3600,
- pool_pre_ping=True, # 连接池预ping,确保连接有效
- echo=False # 设置为True可打印SQL日志
- )
- return self._engine_cache[cache_key]
- def save_dataframe(self, df: pd.DataFrame, table_name: str, chunk_size: int = 10000,
- if_exists: str = 'append') -> None:
- """
- 将DataFrame保存到数据库表
-
- Args:
- df: pandas DataFrame对象
- table_name: 目标表名
- chunk_size: 每批写入的行数
- if_exists: 表存在时的处理方式:'fail', 'replace', 'append'
- """
- try:
- df.to_sql(
- table_name,
- self.get_engine(),
- index=False,
- if_exists=if_exists,
- chunksize=chunk_size,
- method='multi' # 使用多值插入提高性能
- )
- info(f"成功保存 {len(df)} 条数据到表 {table_name}")
- except Exception as e:
- error(f"保存DataFrame到表 {table_name} 失败: {e}")
- error(traceback.format_exc())
- raise e
- def read_sql_to_dataframe(self, sql: str) -> pd.DataFrame:
- """
- 执行SQL查询并返回DataFrame
-
- Args:
- sql: SQL查询语句
-
- Returns:
- 查询结果的DataFrame
- """
- try:
- df = pd.read_sql_query(sql, self.get_engine())
- debug(f"查询返回 {len(df)} 行数据")
- return df
- except Exception as e:
- error(f"执行SQL查询失败: {sql}")
- error(f"错误信息: {e}")
- error(traceback.format_exc())
- raise e
- # 为了保持向后兼容,保留原方法名(可选)
- get_conn = get_connection
- execute = execute_query
- execute_df_save = save_dataframe
- read_sql_to_df = read_sql_to_dataframe
|