ConnectMysqlPool.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import numpy as np
  2. import pandas as pd
  3. import pymysql
  4. from dbutils.pooled_db import PooledDB
  5. import os
  6. from pandas import DataFrame
  7. from pymysql.cursors import DictCursor
  8. from conf.db import mysql_config
  9. from utils.log.trans_log import trans_print
  10. class ConnectMysqlPool:
  11. """
  12. 连接MySQL数据库的连接池类。
  13. 属性:
  14. db_account (dict): 数据库账号信息,包括用户名和密码等。
  15. db (str): 数据库名称。
  16. pool (PooledDB): MySQL连接池对象。
  17. 方法:
  18. __init__: 初始化连接池类实例。
  19. _obtaining_data: 从配置文件中获取测试数据。
  20. create_mysql_pool: 创建MySQL连接池。
  21. get_conn: 从连接池中获取一个连接。
  22. close: 关闭数据库连接和游标。
  23. execute: 使用连接执行SQL语句。
  24. """
  25. def __init__(self, connet_name):
  26. """
  27. 初始化连接池类实例。
  28. 参数:
  29. db (str): 测试库名称。
  30. db_account (dict): 包含数据库账号信息的字典。
  31. """
  32. self.yaml_data = mysql_config
  33. self.connet_name = connet_name
  34. self.env = os.environ['env']
  35. print("self.env", self.env)
  36. # 创建连接池
  37. self.pool = self.create_mysql_pool()
  38. # 创建MySQL连接池
  39. def create_mysql_pool(self):
  40. """
  41. 根据配置信息创建MySQL连接池。
  42. 返回:
  43. PooledDB: MySQL连接池对象。
  44. """
  45. pool = PooledDB(
  46. **self.yaml_data[self.connet_name + '_connect_pool_config'],
  47. **self.yaml_data[self.connet_name + "_" + self.env],
  48. ping=2,
  49. creator=pymysql
  50. )
  51. return pool
  52. # 从连接池中获取一个连接
  53. def get_conn(self):
  54. """
  55. 从连接池中获取一个数据库连接。
  56. 返回:
  57. connection: 数据库连接对象。
  58. """
  59. return self.pool.connection()
  60. # 使用连接执行sql
  61. def execute(self, sql, params=tuple()):
  62. """
  63. 使用获取的连接执行SQL语句。
  64. 参数:
  65. sql (str): SQL语句。
  66. params (tuple): SQL参数。
  67. 返回:
  68. list: 执行SQL语句后的结果集,若执行出错则返回None。
  69. """
  70. with self.get_conn() as conn:
  71. with conn.cursor(cursor=DictCursor) as cursor:
  72. try:
  73. cursor.execute(sql, params)
  74. trans_print("开始执行SQL:", cursor._executed)
  75. conn.commit()
  76. result = cursor.fetchall()
  77. return result
  78. except Exception as e:
  79. trans_print(f"执行sql:{sql},报错:{e}")
  80. conn.rollback()
  81. raise e
  82. def save_dict(self, table_name: str, params: dict):
  83. keys = params.keys()
  84. col_str = ",".join(keys)
  85. data_s_str = ",".join(["%s"] * len(keys))
  86. insert_sql = f"replace into {table_name} ({col_str}) values ({data_s_str})"
  87. with self.get_conn() as conn:
  88. with conn.cursor() as cursor:
  89. try:
  90. cursor.execute(insert_sql, tuple(params.values()))
  91. conn.commit()
  92. except Exception as e:
  93. trans_print(f"执行sql:{insert_sql},报错:{e}")
  94. conn.rollback()
  95. raise e
  96. # 使用连接执行sql
  97. def df_batch_save(self, table_name: str, df: DataFrame, batch_count=20000):
  98. col_str = ",".join(df.columns)
  99. data_s_str = ",".join(["%s"] * len(df.columns))
  100. insert_sql = f"INSERT INTO `{table_name}` ({col_str}) values ({data_s_str})"
  101. # 转化nan到null
  102. df.replace(np.nan, None, inplace=True)
  103. total_count = df.shape[0]
  104. for i in range(0, total_count + 1, batch_count):
  105. with self.get_conn() as conn:
  106. with conn.cursor() as cursor:
  107. try:
  108. query_df = df.iloc[i:i + batch_count]
  109. if not query_df.empty:
  110. values = [tuple(data) for data in query_df.values]
  111. cursor.executemany(insert_sql, values)
  112. conn.commit()
  113. result = cursor.fetchall()
  114. trans_print(
  115. "总条数" + str(df.shape[0]) + ",已保存:" + str(i + batch_count))
  116. except Exception as e:
  117. conn.rollback()
  118. raise e
  119. if __name__ == '__main__':
  120. plt = ConnectMysqlPool("plt")
  121. print(plt.execute("select * from data_transfer limit 2"))
  122. trans = ConnectMysqlPool("trans")
  123. df = pd.DataFrame()
  124. df['name'] = ['name' + str(i) for i in range(1000)]
  125. print(trans.df_batch_save('test', df, 33))