You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
217 lines
6.9 KiB
217 lines
6.9 KiB
import traceback
|
|
import paginate_sqlalchemy
|
|
from sqlalchemy import create_engine, desc
|
|
from app.conf.Setting import settings
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy.sql import func
|
|
from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList
|
|
from app.lib.Log import logger
|
|
|
|
class OrmBase(object):
|
|
# 初始化数据库连接
|
|
engine = create_engine(settings.SQLALCHEMY_DATABASE_URI,connect_args={'check_same_thread': False}, pool_size=100)
|
|
# 创建session工厂
|
|
DBSession = sessionmaker(bind=engine, autoflush=False, autocommit=False, expire_on_commit=True)
|
|
# 创建session对象
|
|
session = DBSession()
|
|
# 事务标识
|
|
transFlag = False
|
|
|
|
def __init__(self, entityType):
|
|
super().__init__()
|
|
# 对象类型
|
|
self.entityType = entityType
|
|
|
|
# 开始事务
|
|
def beginTrans(self):
|
|
self.session = self.DBSession()
|
|
self.transFlag = True
|
|
|
|
# 提交事务
|
|
def commitTrans(self):
|
|
try:
|
|
self.transFlag = False
|
|
self.session.commit()
|
|
except Exception as e:
|
|
logger.error(traceback.format_exc())
|
|
self.session.rollback()
|
|
raise e
|
|
finally:
|
|
self.session.close()
|
|
|
|
# 新增数据
|
|
def insert(self, entity):
|
|
try:
|
|
if (not self.transFlag):
|
|
self.session = self.DBSession()
|
|
self.session.add(entity)
|
|
if (not self.transFlag):
|
|
self.session.commit()
|
|
return True
|
|
except Exception as e:
|
|
logger.error(traceback.format_exc())
|
|
self.session.rollback()
|
|
raise e
|
|
finally:
|
|
self.session.close()
|
|
|
|
# 新增多行数据
|
|
def insert_many(self, entity_list):
|
|
try:
|
|
if (not self.transFlag):
|
|
self.session = self.DBSession()
|
|
self.session.bulk_save_objects(entity_list)
|
|
if (not self.transFlag):
|
|
self.session.commit()
|
|
return True
|
|
except Exception as e:
|
|
logger.error(traceback.format_exc())
|
|
self.session.rollback()
|
|
raise e
|
|
finally:
|
|
self.session.close()
|
|
|
|
# 更新数据
|
|
def update(self, entity):
|
|
try:
|
|
if (not self.transFlag):
|
|
self.session = self.DBSession()
|
|
self.session.merge(entity)
|
|
if (not self.transFlag):
|
|
self.session.commit()
|
|
except Exception as e:
|
|
logger.error(traceback.format_exc())
|
|
self.session.rollback()
|
|
raise e
|
|
finally:
|
|
self.session.close()
|
|
|
|
# 删除操作
|
|
def delete(self, where):
|
|
try:
|
|
if (not self.transFlag):
|
|
self.session = self.DBSession()
|
|
self.session.query(self.entityType).filter(where).delete()
|
|
if (not self.transFlag):
|
|
self.session.commit()
|
|
except Exception as e:
|
|
logger.error(traceback.format_exc())
|
|
self.session.rollback()
|
|
raise e
|
|
finally:
|
|
self.session.close()
|
|
|
|
# 查询单个实体
|
|
def findEntity(self, *where):
|
|
try:
|
|
self.session = self.DBSession()
|
|
if (type(*where) is BinaryExpression or type(*where) is BooleanClauseList):
|
|
return self.session.query(self.entityType).filter(*where).first()
|
|
else:
|
|
return self.session.query(self.entityType).get(where)
|
|
except Exception as e:
|
|
logger.error(traceback.format_exc())
|
|
self.session.rollback()
|
|
raise e
|
|
finally:
|
|
self.session.close()
|
|
|
|
def findLastEntity(self, order_by_field, *where):
|
|
try:
|
|
self.session = self.DBSession()
|
|
if (type(*where) is BinaryExpression or type(*where) is BooleanClauseList):
|
|
return self.session.query(self.entityType).filter(*where).order_by(desc(order_by_field)).first()
|
|
else:
|
|
return self.session.query(self.entityType).order_by(desc(order_by_field)).get(where)
|
|
except Exception as e:
|
|
logger.error(traceback.format_exc())
|
|
self.session.rollback()
|
|
raise e
|
|
finally:
|
|
self.session.close()
|
|
|
|
# 查询实体列表
|
|
def findList(self, *where):
|
|
try:
|
|
self.session = self.DBSession()
|
|
return self.session.query(self.entityType).filter(*where)
|
|
except Exception as e:
|
|
logger.error(traceback.format_exc())
|
|
self.session.rollback()
|
|
raise e
|
|
finally:
|
|
self.session.close()
|
|
|
|
# 查询所有实体列表
|
|
def findAllList(self):
|
|
try:
|
|
self.session = self.DBSession()
|
|
return self.session.query(self.entityType).all()
|
|
except Exception as e:
|
|
logger.error(traceback.format_exc())
|
|
self.session.rollback()
|
|
raise e
|
|
finally:
|
|
self.session.close()
|
|
|
|
# 查询分页
|
|
def queryPage(self, orm_query, pageParam):
|
|
try:
|
|
page = paginate_sqlalchemy.SqlalchemyOrmPage(
|
|
orm_query, page=pageParam.curPage, items_per_page=pageParam.pageRows, db_session=self.DBSession)
|
|
pageParam.totalRecords = page.item_count
|
|
return page.items
|
|
except Exception as e:
|
|
logger.error(traceback.format_exc())
|
|
raise e
|
|
|
|
# 查询数量
|
|
def findCount(self, *where):
|
|
try:
|
|
self.session = self.DBSession()
|
|
return self.session.query(func.count('*')).select_from(self.entityType).filter(*where).scalar()
|
|
except Exception as e:
|
|
logger.error(traceback.format_exc())
|
|
self.session.rollback()
|
|
raise e
|
|
finally:
|
|
self.session.close()
|
|
|
|
# 查询最大数
|
|
def findMax(self, prop, *where):
|
|
try:
|
|
self.session = self.DBSession()
|
|
return self.session.query(func.max(prop)).select_from(self.entityType).filter(*where).scalar()
|
|
except Exception as e:
|
|
logger.error(traceback.format_exc())
|
|
self.session.rollback()
|
|
raise e
|
|
finally:
|
|
self.session.close()
|
|
|
|
# 执行Sql语句
|
|
def execute(self, sql, *agrs):
|
|
try:
|
|
self.session = self.DBSession()
|
|
return self.session.execute(sql, *agrs)
|
|
except Exception as e:
|
|
logger.error(traceback.format_exc())
|
|
self.session.rollback()
|
|
raise e
|
|
finally:
|
|
self.session.close()
|
|
|
|
# 执行Sql语句
|
|
def executeNoParam(self, sql):
|
|
try:
|
|
self.session = self.DBSession()
|
|
self.session.execute(sql)
|
|
if (not self.transFlag):
|
|
self.session.commit()
|
|
except Exception as e:
|
|
logger.error(traceback.format_exc())
|
|
self.session.rollback()
|
|
raise e
|
|
finally:
|
|
self.session.close()
|