ConnectMysqlPool.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import numpy as np
  2. import pandas as pd
  3. import pymysql
  4. from dbutils.pooled_db import PooledDB
  5. import os
  6. from pandas import DataFrame
  7. from pymysql.cursors import DictCursor
  8. from conf.db import mysql_config
  9. from utils.log.trans_log import trans_print
  10. class ConnectMysqlPool:
  11. """
  12. 连接MySQL数据库的连接池类。
  13. 属性:
  14. db_account (dict): 数据库账号信息,包括用户名和密码等。
  15. db (str): 数据库名称。
  16. pool (PooledDB): MySQL连接池对象。
  17. 方法:
  18. __init__: 初始化连接池类实例。
  19. _obtaining_data: 从配置文件中获取测试数据。
  20. create_mysql_pool: 创建MySQL连接池。
  21. get_conn: 从连接池中获取一个连接。
  22. close: 关闭数据库连接和游标。
  23. execute: 使用连接执行SQL语句。
  24. """
  25. def __init__(self, connet_name):
  26. """
  27. 初始化连接池类实例。
  28. 参数:
  29. db (str): 测试库名称。
  30. db_account (dict): 包含数据库账号信息的字典。
  31. """
  32. file_path = os.path.join(
  33. os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
  34. "conf",
  35. "db.yaml"
  36. )
  37. self.yaml_data = mysql_config
  38. self.connet_name = connet_name
  39. # 创建连接池
  40. self.pool = self.create_mysql_pool()
  41. # 创建MySQL连接池
  42. def create_mysql_pool(self):
  43. """
  44. 根据配置信息创建MySQL连接池。
  45. 返回:
  46. PooledDB: MySQL连接池对象。
  47. """
  48. pool = PooledDB(
  49. **self.yaml_data[self.connet_name + '_connect_pool_config'],
  50. **self.yaml_data[self.connet_name],
  51. ping=2,
  52. creator=pymysql
  53. )
  54. return pool
  55. # 从连接池中获取一个连接
  56. def get_conn(self):
  57. """
  58. 从连接池中获取一个数据库连接。
  59. 返回:
  60. connection: 数据库连接对象。
  61. """
  62. return self.pool.connection()
  63. # 使用连接执行sql
  64. def execute(self, sql, params=tuple()):
  65. """
  66. 使用获取的连接执行SQL语句。
  67. 参数:
  68. sql (str): SQL语句。
  69. params (tuple): SQL参数。
  70. 返回:
  71. list: 执行SQL语句后的结果集,若执行出错则返回None。
  72. """
  73. with self.get_conn() as conn:
  74. with conn.cursor(cursor=DictCursor) as cursor:
  75. try:
  76. cursor.execute(sql, params)
  77. trans_print("开始执行SQL:", cursor._executed)
  78. conn.commit()
  79. result = cursor.fetchall()
  80. return result
  81. except Exception as e:
  82. trans_print(f"执行sql:{sql},报错:{e}")
  83. conn.rollback()
  84. raise e
  85. def save_dict(self, table_name: str, params: dict):
  86. keys = params.keys()
  87. col_str = ",".join(keys)
  88. data_s_str = ",".join(["%s"] * len(keys))
  89. insert_sql = f"replace into {table_name} ({col_str}) values ({data_s_str})"
  90. with self.get_conn() as conn:
  91. with conn.cursor() as cursor:
  92. try:
  93. cursor.execute(insert_sql, tuple(params.values()))
  94. conn.commit()
  95. except Exception as e:
  96. trans_print(f"执行sql:{insert_sql},报错:{e}")
  97. conn.rollback()
  98. raise e
  99. # 使用连接执行sql
  100. def df_batch_save(self, table_name: str, df: DataFrame, batch_count=20000):
  101. col_str = ",".join(df.columns)
  102. data_s_str = ",".join(["%s"] * len(df.columns))
  103. insert_sql = f"INSERT INTO `{table_name}` ({col_str}) values ({data_s_str})"
  104. # 转化nan到null
  105. df.replace(np.nan, None, inplace=True)
  106. total_count = df.shape[0]
  107. for i in range(0, total_count + 1, batch_count):
  108. with self.get_conn() as conn:
  109. with conn.cursor() as cursor:
  110. try:
  111. query_df = df.iloc[i:i + batch_count]
  112. if not query_df.empty:
  113. values = [tuple(data) for data in query_df.values]
  114. cursor.executemany(insert_sql, values)
  115. conn.commit()
  116. result = cursor.fetchall()
  117. trans_print(
  118. "总条数" + str(df.shape[0]) + ",已保存:" + str(i + batch_count))
  119. except Exception as e:
  120. conn.rollback()
  121. raise e
  122. if __name__ == '__main__':
  123. plt = ConnectMysqlPool("plt")
  124. print(plt.execute("select * from data_transfer limit 2"))
  125. trans = ConnectMysqlPool("trans")
  126. df = pd.DataFrame()
  127. df['name'] = ['name' + str(i) for i in range(1000)]
  128. print(trans.df_batch_save('test', df, 33))