123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- import tempfile
- import time
- import traceback
- from os import *
- import pandas as pd
- import pymysql
- from pymysql.cursors import DictCursor
- from sqlalchemy import create_engine, text
- from utils.conf.read_conf import yaml_conf
- from utils.log.trans_log import logger
- class ConnectMysql:
- def __init__(self, connet_name):
- conf_path = path.abspath(__file__).split("utils")[0] + 'conf' + sep + 'config.yaml'
- self.yaml_data = yaml_conf(conf_path)
- self.connet_name = connet_name
- self.config = self.yaml_data[self.connet_name]
- self.database = self.config['database']
- # 从连接池中获取一个连接
- def get_conn(self):
- return pymysql.connect(**self.config, local_infile=True)
- # 使用连接执行sql
- def execute(self, sql, params=tuple()):
- with self.get_conn() as conn:
- with conn.cursor(cursor=DictCursor) as cursor:
- try:
- cursor.execute(sql, params)
- logger.info(f"开始执行SQL:{cursor._executed}")
- conn.commit()
- result = cursor.fetchall()
- return result
- except Exception as e:
- logger.info(f"执行sql:{sql},报错:{e}")
- logger.info(traceback.format_exc())
- conn.rollback()
- raise e
- def get_engine(self):
- config = self.config
- username = config['user']
- password = config['password']
- host = config['host']
- port = config['port']
- dbname = config['database']
- return create_engine(f'mysql+pymysql://{username}:{password}@{host}:{port}/{dbname}?local_infile=1')
- def execute_df_save(self, df, table_name, batch_count=1000):
- df.to_sql(table_name, self.get_engine(), index=False, if_exists='append', chunksize=batch_count)
- def read_sql_to_df(self, sql):
- df = pd.read_sql_query(sql, self.get_engine())
- return df
- def safe_load_data_local(self, df, table_name, batch_size=30000):
- """
- 安全加载数据到TiDB,包含以下优化:
- 1. 分批处理避免内存溢出
- 2. 完善的连接管理
- 3. 错误处理和重试机制
- """
- total_rows = len(df)
- success_rows = 0
- engine = self.get_engine()
- for i in range(0, total_rows, batch_size):
- batch = df.iloc[i:i + batch_size]
- retry_count = 0
- max_retries = 4
- while retry_count < max_retries:
- try:
- with tempfile.NamedTemporaryFile(mode='w') as tmp:
- batch.to_csv(tmp, index=False, header=False, sep='\t')
- tmp.flush()
- with engine.begin() as conn: # 自动提交事务
- # 设置当前会话内存配额
- conn.execute(text("SET tidb_mem_quota_query = 2147483648")) # 2GB
- # 执行LOAD DATA
- conn.execute(text(f"""
- LOAD DATA LOCAL INFILE '{tmp.name}'
- INTO TABLE {table_name}
- FIELDS TERMINATED BY '\t'
- LINES TERMINATED BY '\n'
- """))
- success_rows += len(batch)
- logger.info(f"成功加载批次 {i // batch_size + 1}: {len(batch)} 行")
- break # 成功则跳出重试循环
- except Exception as e:
- retry_count += 1
- logger.info(f"批次 {i // batch_size + 1} 第 {retry_count} 次尝试失败: {str(e)}")
- if retry_count >= max_retries:
- logger.error(f"批次 {i // batch_size + 1} 达到最大重试次数")
- logger.error(traceback.format_exc())
- raise
- time.sleep(2 ** retry_count) # 指数退避
- logger.info(f"数据加载完成: 总计 {success_rows}/{total_rows} 行")
- return success_rows
|