import numpy as np import pandas as pd import pymysql from dbutils.pooled_db import PooledDB import os from pandas import DataFrame from pymysql.cursors import DictCursor from conf.db import mysql_config from utils.log.trans_log import trans_print class ConnectMysqlPool: """ 连接MySQL数据库的连接池类。 属性: db_account (dict): 数据库账号信息,包括用户名和密码等。 db (str): 数据库名称。 pool (PooledDB): MySQL连接池对象。 方法: __init__: 初始化连接池类实例。 _obtaining_data: 从配置文件中获取测试数据。 create_mysql_pool: 创建MySQL连接池。 get_conn: 从连接池中获取一个连接。 close: 关闭数据库连接和游标。 execute: 使用连接执行SQL语句。 """ def __init__(self, connet_name): """ 初始化连接池类实例。 参数: db (str): 测试库名称。 db_account (dict): 包含数据库账号信息的字典。 """ file_path = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "conf", "db.yaml" ) self.yaml_data = mysql_config self.connet_name = connet_name # 创建连接池 self.pool = self.create_mysql_pool() # 创建MySQL连接池 def create_mysql_pool(self): """ 根据配置信息创建MySQL连接池。 返回: PooledDB: MySQL连接池对象。 """ pool = PooledDB( **self.yaml_data[self.connet_name + '_connect_pool_config'], **self.yaml_data[self.connet_name], ping=2, creator=pymysql ) return pool # 从连接池中获取一个连接 def get_conn(self): """ 从连接池中获取一个数据库连接。 返回: connection: 数据库连接对象。 """ return self.pool.connection() # 使用连接执行sql def execute(self, sql, params=tuple()): """ 使用获取的连接执行SQL语句。 参数: sql (str): SQL语句。 params (tuple): SQL参数。 返回: list: 执行SQL语句后的结果集,若执行出错则返回None。 """ with self.get_conn() as conn: with conn.cursor(cursor=DictCursor) as cursor: try: cursor.execute(sql, params) trans_print("开始执行SQL:", cursor._executed) conn.commit() result = cursor.fetchall() return result except Exception as e: trans_print(f"执行sql:{sql},报错:{e}") conn.rollback() raise e def save_dict(self, table_name: str, params: dict): keys = params.keys() col_str = ",".join(keys) data_s_str = ",".join(["%s"] * len(keys)) insert_sql = f"replace into {table_name} ({col_str}) values ({data_s_str})" with self.get_conn() as conn: with conn.cursor() as cursor: try: cursor.execute(insert_sql, tuple(params.values())) conn.commit() except Exception as e: trans_print(f"执行sql:{insert_sql},报错:{e}") conn.rollback() raise e # 使用连接执行sql def df_batch_save(self, table_name: str, df: DataFrame, batch_count=20000): col_str = ",".join(df.columns) data_s_str = ",".join(["%s"] * len(df.columns)) insert_sql = f"INSERT INTO `{table_name}` ({col_str}) values ({data_s_str})" # 转化nan到null df.replace(np.nan, None, inplace=True) total_count = df.shape[0] for i in range(0, total_count + 1, batch_count): with self.get_conn() as conn: with conn.cursor() as cursor: try: query_df = df.iloc[i:i + batch_count] if not query_df.empty: values = [tuple(data) for data in query_df.values] cursor.executemany(insert_sql, values) conn.commit() result = cursor.fetchall() trans_print( "总条数" + str(df.shape[0]) + ",已保存:" + str(i + batch_count)) except Exception as e: conn.rollback() raise e if __name__ == '__main__': plt = ConnectMysqlPool("plt") print(plt.execute("select * from data_transfer limit 2")) trans = ConnectMysqlPool("trans") df = pd.DataFrame() df['name'] = ['name' + str(i) for i in range(1000)] print(trans.df_batch_save('test', df, 33))