ConnectMysqlPool.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import numpy as np
  2. import pandas as pd
  3. import pymysql
  4. from dbutils.pooled_db import PooledDB
  5. import os
  6. from pandas import DataFrame
  7. from pymysql.cursors import DictCursor
  8. from utils.conf.read_conf import yaml_conf
  9. from utils.log.trans_log import trans_print
  10. class ConnectMysqlPool:
  11. """
  12. 连接MySQL数据库的连接池类。
  13. 属性:
  14. db_account (dict): 数据库账号信息,包括用户名和密码等。
  15. db (str): 数据库名称。
  16. pool (PooledDB): MySQL连接池对象。
  17. 方法:
  18. __init__: 初始化连接池类实例。
  19. _obtaining_data: 从配置文件中获取测试数据。
  20. create_mysql_pool: 创建MySQL连接池。
  21. get_conn: 从连接池中获取一个连接。
  22. close: 关闭数据库连接和游标。
  23. execute: 使用连接执行SQL语句。
  24. """
  25. def __init__(self, connet_name):
  26. """
  27. 初始化连接池类实例。
  28. 参数:
  29. db (str): 测试库名称。
  30. db_account (dict): 包含数据库账号信息的字典。
  31. """
  32. file_path = os.path.join(
  33. os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
  34. "conf",
  35. "db.yaml"
  36. )
  37. self.yaml_data = yaml_conf(file_path)
  38. self.connet_name = connet_name
  39. # 创建连接池
  40. self.pool = self.create_mysql_pool()
  41. # 创建MySQL连接池
  42. def create_mysql_pool(self):
  43. """
  44. 根据配置信息创建MySQL连接池。
  45. 返回:
  46. PooledDB: MySQL连接池对象。
  47. """
  48. pool = PooledDB(
  49. **self.yaml_data[self.connet_name + '_connect_pool_config'],
  50. **self.yaml_data[self.connet_name],
  51. creator=pymysql
  52. )
  53. return pool
  54. # 从连接池中获取一个连接
  55. def get_conn(self):
  56. """
  57. 从连接池中获取一个数据库连接。
  58. 返回:
  59. connection: 数据库连接对象。
  60. """
  61. return self.pool.connection()
  62. # 使用连接执行sql
  63. def execute(self, sql, params=tuple()):
  64. """
  65. 使用获取的连接执行SQL语句。
  66. 参数:
  67. sql (str): SQL语句。
  68. params (tuple): SQL参数。
  69. 返回:
  70. list: 执行SQL语句后的结果集,若执行出错则返回None。
  71. """
  72. with self.get_conn() as conn:
  73. with conn.cursor(cursor=DictCursor) as cursor:
  74. try:
  75. cursor.execute(sql, params)
  76. trans_print("开始执行SQL:", cursor._executed)
  77. conn.commit()
  78. result = cursor.fetchall()
  79. return result
  80. except Exception as e:
  81. print(f"执行sql:{sql},报错:{e}")
  82. conn.rollback()
  83. raise e
  84. def save_dict(self, table_name: str, params: dict):
  85. keys = params.keys()
  86. col_str = ",".join(keys)
  87. data_s_str = ",".join(["%s"] * len(keys))
  88. insert_sql = f"replace into {table_name} ({col_str}) values ({data_s_str})"
  89. with self.get_conn() as conn:
  90. with conn.cursor() as cursor:
  91. try:
  92. cursor.execute(insert_sql, tuple(params.values()))
  93. conn.commit()
  94. except Exception as e:
  95. print(f"执行sql:{insert_sql},报错:{e}")
  96. conn.rollback()
  97. raise e
  98. # 使用连接执行sql
  99. def df_batch_save(self, table_name: str, df: DataFrame, batch_count=20000):
  100. col_str = ",".join(df.columns)
  101. data_s_str = ",".join(["%s"] * len(df.columns))
  102. insert_sql = f"INSERT INTO `{table_name}` ({col_str}) values ({data_s_str})"
  103. # 转化nan到null
  104. df.replace(np.nan, None, inplace=True)
  105. total_count = df.shape[0]
  106. for i in range(0, total_count + 1, batch_count):
  107. with self.get_conn() as conn:
  108. with conn.cursor() as cursor:
  109. try:
  110. query_df = df.iloc[i:i + batch_count]
  111. values = [tuple(data) for data in query_df.values]
  112. cursor.executemany(insert_sql, values)
  113. conn.commit()
  114. result = cursor.fetchall()
  115. print(
  116. "总条数" + str(df.shape[0]) + ",已保存:" + str(i + batch_count))
  117. except Exception as e:
  118. conn.rollback()
  119. raise e
  120. if __name__ == '__main__':
  121. plt = ConnectMysqlPool("plt")
  122. print(plt.execute("select * from data_transfer limit 2"))
  123. trans = ConnectMysqlPool("trans")
  124. df = pd.DataFrame()
  125. df['name'] = ['name' + str(i) for i in range(1000)]
  126. print(trans.df_batch_save('test', df, 33))