from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy import Column, Integer, String from sqlalchemy.orm import Session from sqlalchemy.sql import text from contextlib import contextmanager from utils.rdbmsUtil.factoryRegistry import FactoryRegistry # 数据库类 class DatabaseUtil: def __init__(self, url, pool_size=5, max_overflow=10, pool_recycle=1000, pool_pre_ping=True, connect_timeout=10,read_timeout=120): self.engine = create_engine( url, pool_size=pool_size, # 设置最大连接池大小 max_overflow=max_overflow, # 设置最大溢出连接数 pool_recycle=pool_recycle, # 设置连接回收时间 pool_pre_ping=pool_pre_ping, # 检查连接的可用性 connect_args={ 'connect_timeout': connect_timeout, # 连接超时时间(秒) 'read_timeout': read_timeout # 读取超时时间(秒) } # 增加连接超时设置 ) self.Session = sessionmaker(bind=self.engine) def get_session(self): return self.Session() @contextmanager def session_scope(self): """Provide a transactional scope around a series of operations.""" session = self.get_session() try: yield session session.commit() except Exception as e: session.rollback() raise finally: session.close() def close_session(self, session): """Explicitly close a given session.""" if session: session.close() # # 事务管理 # @contextmanager # def transaction_scope(session:Session): # try: # yield session # session.commit() # except Exception as e: # session.rollback() # raise e # finally: # session.close() # 实现动态 SQL 和批量操作支持的装饰器 def sql_operation(sql, entity_type, batch=False): def decorator(func): def wrapper(self, session, *args, **kwargs): factory = FactoryRegistry.get_factory(entity_type) if callable(sql): # 检查sql是否为可调用对象,即函数 resolved_sql = sql(*args, **kwargs) # 如果是函数,执行它以获取SQL字符串 else: resolved_sql = sql # 否则直接使用sql字符串 print(f"Self: {self}") print(f"Session: {session}") print(f"Session type: {type(session)}") print(f"SQL: {resolved_sql}") print(f"Args: {args}") print(f"Kwargs: {kwargs}") results = [] if batch: for entity in args[0]: # 使用参数直接执行SQL,防止注入 params = {key: getattr( entity, key) for key in entity.__dict__ if not key.startswith('_')} r = session.execute(text(resolved_sql), params) print(r) else: # 单个操作的情况 result = session.execute(text(resolved_sql), kwargs) if result.returns_rows: results.extend([factory(**dict(row)) for row in result.mappings()]) else: # 如果不需要处理返回结果,这里可以直接pass pass return results return wrapper return decorator