12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- from sqlalchemy import create_engine
- from sqlalchemy.orm import sessionmaker, declarative_base
- from sqlalchemy import Column, Integer, String
- from sqlalchemy.orm import Session
- from sqlalchemy.sql import text
- from contextlib import contextmanager
- from utils.rdbmsUtil.factoryRegistry import FactoryRegistry
- # 数据库类
- class DatabaseUtil:
- def __init__(self, url, pool_size=5, max_overflow=10, pool_recycle=1000, pool_pre_ping=True, connect_timeout=10,read_timeout=120):
- self.engine = create_engine(
- url,
- pool_size=pool_size, # 设置最大连接池大小
- max_overflow=max_overflow, # 设置最大溢出连接数
- pool_recycle=pool_recycle, # 设置连接回收时间
- pool_pre_ping=pool_pre_ping, # 检查连接的可用性
- connect_args={
- 'connect_timeout': connect_timeout, # 连接超时时间(秒)
- 'read_timeout': read_timeout # 读取超时时间(秒)
- } # 增加连接超时设置
- )
- self.Session = sessionmaker(bind=self.engine)
- def get_session(self):
- return self.Session()
- @contextmanager
- def session_scope(self):
- """Provide a transactional scope around a series of operations."""
- session = self.get_session()
- try:
- yield session
- session.commit()
- except Exception as e:
- session.rollback()
- raise
- finally:
- session.close()
- def close_session(self, session):
- """Explicitly close a given session."""
- if session:
- session.close()
- # # 事务管理
- # @contextmanager
- # def transaction_scope(session:Session):
- # try:
- # yield session
- # session.commit()
- # except Exception as e:
- # session.rollback()
- # raise e
- # finally:
- # session.close()
- # 实现动态 SQL 和批量操作支持的装饰器
- def sql_operation(sql, entity_type, batch=False):
- def decorator(func):
- def wrapper(self, session, *args, **kwargs):
- factory = FactoryRegistry.get_factory(entity_type)
- if callable(sql): # 检查sql是否为可调用对象,即函数
- resolved_sql = sql(*args, **kwargs) # 如果是函数,执行它以获取SQL字符串
- else:
- resolved_sql = sql # 否则直接使用sql字符串
- print(f"Self: {self}")
- print(f"Session: {session}")
- print(f"Session type: {type(session)}")
- print(f"SQL: {resolved_sql}")
- print(f"Args: {args}")
- print(f"Kwargs: {kwargs}")
- results = []
- if batch:
- for entity in args[0]:
- # 使用参数直接执行SQL,防止注入
- params = {key: getattr(
- entity, key) for key in entity.__dict__ if not key.startswith('_')}
- r = session.execute(text(resolved_sql), params)
- print(r)
- else:
- # 单个操作的情况
- result = session.execute(text(resolved_sql), kwargs)
- if result.returns_rows:
- results.extend([factory(**dict(row))
- for row in result.mappings()])
- else:
- # 如果不需要处理返回结果,这里可以直接pass
- pass
- return results
- return wrapper
- return decorator
|