databaseUtil.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from sqlalchemy import create_engine
  2. from sqlalchemy.orm import sessionmaker, declarative_base
  3. from sqlalchemy import Column, Integer, String
  4. from sqlalchemy.orm import Session
  5. from sqlalchemy.sql import text
  6. from contextlib import contextmanager
  7. from utils.rdbmsUtil.factoryRegistry import FactoryRegistry
  8. # 数据库类
  9. class DatabaseUtil:
  10. def __init__(self, url, pool_size=5, max_overflow=10, pool_recycle=1000, pool_pre_ping=True, connect_timeout=10,read_timeout=120):
  11. self.engine = create_engine(
  12. url,
  13. pool_size=pool_size, # 设置最大连接池大小
  14. max_overflow=max_overflow, # 设置最大溢出连接数
  15. pool_recycle=pool_recycle, # 设置连接回收时间
  16. pool_pre_ping=pool_pre_ping, # 检查连接的可用性
  17. connect_args={
  18. 'connect_timeout': connect_timeout, # 连接超时时间(秒)
  19. 'read_timeout': read_timeout # 读取超时时间(秒)
  20. } # 增加连接超时设置
  21. )
  22. self.Session = sessionmaker(bind=self.engine)
  23. def get_session(self):
  24. return self.Session()
  25. @contextmanager
  26. def session_scope(self):
  27. """Provide a transactional scope around a series of operations."""
  28. session = self.get_session()
  29. try:
  30. yield session
  31. session.commit()
  32. except Exception as e:
  33. session.rollback()
  34. raise
  35. finally:
  36. session.close()
  37. def close_session(self, session):
  38. """Explicitly close a given session."""
  39. if session:
  40. session.close()
  41. # # 事务管理
  42. # @contextmanager
  43. # def transaction_scope(session:Session):
  44. # try:
  45. # yield session
  46. # session.commit()
  47. # except Exception as e:
  48. # session.rollback()
  49. # raise e
  50. # finally:
  51. # session.close()
  52. # 实现动态 SQL 和批量操作支持的装饰器
  53. def sql_operation(sql, entity_type, batch=False):
  54. def decorator(func):
  55. def wrapper(self, session, *args, **kwargs):
  56. factory = FactoryRegistry.get_factory(entity_type)
  57. if callable(sql): # 检查sql是否为可调用对象,即函数
  58. resolved_sql = sql(*args, **kwargs) # 如果是函数,执行它以获取SQL字符串
  59. else:
  60. resolved_sql = sql # 否则直接使用sql字符串
  61. print(f"Self: {self}")
  62. print(f"Session: {session}")
  63. print(f"Session type: {type(session)}")
  64. print(f"SQL: {resolved_sql}")
  65. print(f"Args: {args}")
  66. print(f"Kwargs: {kwargs}")
  67. results = []
  68. if batch:
  69. for entity in args[0]:
  70. # 使用参数直接执行SQL,防止注入
  71. params = {key: getattr(
  72. entity, key) for key in entity.__dict__ if not key.startswith('_')}
  73. r = session.execute(text(resolved_sql), params)
  74. print(r)
  75. else:
  76. # 单个操作的情况
  77. result = session.execute(text(resolved_sql), kwargs)
  78. if result.returns_rows:
  79. results.extend([factory(**dict(row))
  80. for row in result.mappings()])
  81. else:
  82. # 如果不需要处理返回结果,这里可以直接pass
  83. pass
  84. return results
  85. return wrapper
  86. return decorator