123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- import numpy as np
- import pandas as pd
- import pymysql
- from dbutils.pooled_db import PooledDB
- import os
- from pandas import DataFrame
- from pymysql.cursors import DictCursor
- from conf.db import mysql_config
- from utils.log.trans_log import trans_print
- class ConnectMysqlPool:
- """
- 连接MySQL数据库的连接池类。
- 属性:
- db_account (dict): 数据库账号信息,包括用户名和密码等。
- db (str): 数据库名称。
- pool (PooledDB): MySQL连接池对象。
- 方法:
- __init__: 初始化连接池类实例。
- _obtaining_data: 从配置文件中获取测试数据。
- create_mysql_pool: 创建MySQL连接池。
- get_conn: 从连接池中获取一个连接。
- close: 关闭数据库连接和游标。
- execute: 使用连接执行SQL语句。
- """
- def __init__(self, connet_name):
- """
- 初始化连接池类实例。
- 参数:
- db (str): 测试库名称。
- db_account (dict): 包含数据库账号信息的字典。
- """
- file_path = os.path.join(
- os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
- "conf",
- "db.yaml"
- )
- self.yaml_data = mysql_config
- self.connet_name = connet_name
- # 创建连接池
- self.pool = self.create_mysql_pool()
- # 创建MySQL连接池
- def create_mysql_pool(self):
- """
- 根据配置信息创建MySQL连接池。
- 返回:
- PooledDB: MySQL连接池对象。
- """
- pool = PooledDB(
- **self.yaml_data[self.connet_name + '_connect_pool_config'],
- **self.yaml_data[self.connet_name],
- ping=2,
- creator=pymysql
- )
- return pool
- # 从连接池中获取一个连接
- def get_conn(self):
- """
- 从连接池中获取一个数据库连接。
- 返回:
- connection: 数据库连接对象。
- """
- return self.pool.connection()
- # 使用连接执行sql
- def execute(self, sql, params=tuple()):
- """
- 使用获取的连接执行SQL语句。
- 参数:
- sql (str): SQL语句。
- params (tuple): SQL参数。
- 返回:
- list: 执行SQL语句后的结果集,若执行出错则返回None。
- """
- with self.get_conn() as conn:
- with conn.cursor(cursor=DictCursor) as cursor:
- try:
- cursor.execute(sql, params)
- trans_print("开始执行SQL:", cursor._executed)
- conn.commit()
- result = cursor.fetchall()
- return result
- except Exception as e:
- trans_print(f"执行sql:{sql},报错:{e}")
- conn.rollback()
- raise e
- def save_dict(self, table_name: str, params: dict):
- keys = params.keys()
- col_str = ",".join(keys)
- data_s_str = ",".join(["%s"] * len(keys))
- insert_sql = f"replace into {table_name} ({col_str}) values ({data_s_str})"
- with self.get_conn() as conn:
- with conn.cursor() as cursor:
- try:
- cursor.execute(insert_sql, tuple(params.values()))
- conn.commit()
- except Exception as e:
- trans_print(f"执行sql:{insert_sql},报错:{e}")
- conn.rollback()
- raise e
- # 使用连接执行sql
- def df_batch_save(self, table_name: str, df: DataFrame, batch_count=20000):
- col_str = ",".join(df.columns)
- data_s_str = ",".join(["%s"] * len(df.columns))
- insert_sql = f"INSERT INTO `{table_name}` ({col_str}) values ({data_s_str})"
- # 转化nan到null
- df.replace(np.nan, None, inplace=True)
- total_count = df.shape[0]
- for i in range(0, total_count + 1, batch_count):
- with self.get_conn() as conn:
- with conn.cursor() as cursor:
- try:
- query_df = df.iloc[i:i + batch_count]
- if not query_df.empty:
- values = [tuple(data) for data in query_df.values]
- cursor.executemany(insert_sql, values)
- conn.commit()
- result = cursor.fetchall()
- trans_print(
- "总条数" + str(df.shape[0]) + ",已保存:" + str(i + batch_count))
- except Exception as e:
- conn.rollback()
- raise e
- if __name__ == '__main__':
- plt = ConnectMysqlPool("plt")
- print(plt.execute("select * from data_transfer limit 2"))
- trans = ConnectMysqlPool("trans")
- df = pd.DataFrame()
- df['name'] = ['name' + str(i) for i in range(1000)]
- print(trans.df_batch_save('test', df, 33))
|