324 lines
8.0 KiB
Python
324 lines
8.0 KiB
Python
#数据库工具
|
||
|
||
import datetime
|
||
import threading
|
||
import time
|
||
import traceback
|
||
import warnings
|
||
from contextlib import contextmanager
|
||
from functools import wraps
|
||
from urllib.parse import urlparse, parse_qsl
|
||
|
||
import pymysql
|
||
from dbutils.pooled_db import PooledDB
|
||
|
||
warnings.filterwarnings("ignore")
|
||
|
||
from setting import dbUrl
|
||
|
||
|
||
def __parseresult_to_dict(parsed):
|
||
#解析连接字符串
|
||
path_parts = parsed.path[1:].split('?')
|
||
query = parsed.query
|
||
connect_kwargs = {'db': path_parts[0]}
|
||
if parsed.username:
|
||
connect_kwargs['user'] = parsed.username
|
||
if parsed.password:
|
||
connect_kwargs['password'] = parsed.password
|
||
if parsed.hostname:
|
||
connect_kwargs['host'] = parsed.hostname
|
||
if parsed.port:
|
||
connect_kwargs['port'] = parsed.port
|
||
|
||
# Adjust parameters for MySQL.
|
||
if 'password' in connect_kwargs:
|
||
connect_kwargs['passwd'] = connect_kwargs.pop('password')
|
||
|
||
|
||
|
||
# Get additional connection args from the query string
|
||
qs_args = parse_qsl(query, keep_blank_values=True)
|
||
for key, value in qs_args:
|
||
if value.lower() == 'false':
|
||
value = False
|
||
elif value.lower() == 'true':
|
||
value = True
|
||
elif value.isdigit():
|
||
value = int(value)
|
||
elif '.' in value and all(p.isdigit() for p in value.split('.', 1)):
|
||
try:
|
||
value = float(value)
|
||
except ValueError:
|
||
pass
|
||
elif value.lower() in ('null', 'none'):
|
||
value = None
|
||
connect_kwargs[key] = value
|
||
if 'maxsize' in connect_kwargs:
|
||
connect_kwargs['maxconnections'] = connect_kwargs.pop('maxsize')
|
||
return connect_kwargs
|
||
|
||
|
||
def __create_pool(url):
|
||
#创建连接池
|
||
parsed = urlparse(url)
|
||
connect_kwargs = __parseresult_to_dict(parsed)
|
||
return PooledDB(pymysql, 1, **connect_kwargs)
|
||
|
||
|
||
# 数据库连接
|
||
global setting, transaction_map, pool
|
||
|
||
if "transaction_map" not in globals():
|
||
global transaction_map
|
||
transaction_map = {}
|
||
|
||
if "pool" not in globals():
|
||
global pool
|
||
pool = __create_pool(dbUrl)
|
||
|
||
|
||
def __get_connection():
|
||
#获取数据库链接
|
||
tid = threading.get_ident()
|
||
if tid in transaction_map:
|
||
return transaction_map.get(tid)
|
||
else:
|
||
return pool.connection()
|
||
|
||
|
||
def __close_connection(conn):
|
||
#归还数据库链接
|
||
tid = threading.get_ident()
|
||
if tid in transaction_map:
|
||
return
|
||
else:
|
||
conn.close()
|
||
|
||
|
||
@contextmanager
|
||
def dbp():
|
||
#with 数据库方法块
|
||
f = __get_connection()
|
||
yield f
|
||
__close_connection(f)
|
||
|
||
|
||
def execute_sql(sql,params=None):
|
||
#执行sql
|
||
with dbp() as db:
|
||
c = db.cursor()
|
||
c.execute(sql,params)
|
||
db.commit()
|
||
c.close()
|
||
|
||
|
||
def execute_sql_list(sqls):
|
||
#批量执行sql语句
|
||
with dbp() as db:
|
||
c = db.cursor()
|
||
for sql in sqls:
|
||
c.execute(sql)
|
||
db.commit()
|
||
c.close()
|
||
|
||
|
||
def __get_obj_list_sql(obj_list, table, replace=True):
|
||
#获取对象插入sql以及对应参数
|
||
if obj_list:
|
||
obj = obj_list[0]
|
||
keys=list(map(lambda x: f"`{x}`",obj.keys()))
|
||
values = list(map(lambda x:"%s",obj.keys()))
|
||
if replace:
|
||
sql = f"""replace INTO `{table}` ({",".join(keys)}) VALUES ({",".join(values)})"""
|
||
else:
|
||
sql = f"""insert INTO `{table}` ({",".join(keys)}) VALUES ({",".join(values)})"""
|
||
params = []
|
||
for obj in obj_list:
|
||
params.append(tuple(obj.values()))
|
||
return sql, params
|
||
else:
|
||
return "", []
|
||
|
||
|
||
def __get_obj_update_sql(obj, table, key):
|
||
#获取对象插入sql以及对应参数
|
||
key_sql=f"where {key}='{obj[key]}'"
|
||
del obj[key]
|
||
keys=list(map(lambda x: f"`{x}`=%s",obj.keys()))
|
||
sql = f"""update `{table}` set {",".join(keys)} """ + key_sql
|
||
params =tuple(obj.values())
|
||
return sql, params
|
||
|
||
def sql_to_dict(sql,params=None):
|
||
#查询sql,输出dict 列表
|
||
with dbp() as db:
|
||
c = db.cursor()
|
||
c.execute(sql,params)
|
||
ncols = len(c.description)
|
||
colnames = [c.description[i][0] for i in range(ncols)]
|
||
db_list = c.fetchall()
|
||
ret_list = []
|
||
for row in db_list:
|
||
d = Map()
|
||
for i in range(ncols):
|
||
if isinstance(row[i],bytes) and len(row[i])==1:
|
||
d[colnames[i]] = True if row[i] == b'\x01' else False
|
||
else:
|
||
d[colnames[i]] = row[i]
|
||
ret_list.append(d)
|
||
c.close()
|
||
return ret_list
|
||
|
||
|
||
def start_transaction():
|
||
#开始事务
|
||
conn = __get_connection()
|
||
conn.autocommit = False
|
||
tid = threading.get_ident()
|
||
transaction_map[tid] = conn
|
||
return tid
|
||
|
||
|
||
def end_transaction(rockback=False):
|
||
#结束事务
|
||
tid = threading.get_ident()
|
||
conn = transaction_map.pop(tid)
|
||
try:
|
||
if rockback:
|
||
conn.rollback()
|
||
else:
|
||
conn.commit()
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
@contextmanager
|
||
def transaction_code():
|
||
#with 事务方法块
|
||
f = start_transaction()
|
||
try:
|
||
yield f
|
||
end_transaction()
|
||
except Exception:
|
||
traceback.print_exc()
|
||
end_transaction(True)
|
||
|
||
|
||
# 事务
|
||
def transaction(target_function):
|
||
#事务注解
|
||
@wraps(target_function)
|
||
def wrapper(*args, **kwargs):
|
||
start_transaction()
|
||
ret = target_function(*args, **kwargs)
|
||
end_transaction()
|
||
return ret
|
||
|
||
return wrapper
|
||
|
||
|
||
def insert(obj, table):
|
||
#插入对象
|
||
(sql, params) = __get_obj_list_sql([obj], table)
|
||
with dbp() as db:
|
||
c = db.cursor()
|
||
c.execute(sql, params[0])
|
||
db.commit()
|
||
lid=c.lastrowid
|
||
c.close()
|
||
return lid
|
||
|
||
def update(obj, table,key="id"):
|
||
#插入对象
|
||
(sql, params) = __get_obj_update_sql(obj, table,key)
|
||
with dbp() as db:
|
||
c = db.cursor()
|
||
c.execute(sql, params)
|
||
db.commit()
|
||
c.close()
|
||
|
||
def inserts(obj_list, table):
|
||
#批量插入对象
|
||
if obj_list:
|
||
(sql, params) = __get_obj_list_sql(obj_list, table)
|
||
with dbp() as db:
|
||
c = db.cursor()
|
||
c.executemany(sql, params)
|
||
db.commit()
|
||
c.close()
|
||
|
||
def get(table,id,idstr="id"):
|
||
if isinstance(id,str):id=f"'{id}'"
|
||
db_data=sql_to_dict(f"select * from {table} where {idstr}={id}")
|
||
if db_data:
|
||
return db_data[0]
|
||
return None
|
||
|
||
def get_list(table,where=None):
|
||
if not where:return sql_to_dict(f"select * from {table}")
|
||
return sql_to_dict(f"select * from {table} where {where}")
|
||
|
||
class Map(dict):
|
||
"""
|
||
Example:
|
||
m = Map({'first_name': 'Eduardo'}, last_name='Pool', age=24, sports=['Soccer'])
|
||
"""
|
||
|
||
def __init__(self, *args, **kwargs):
|
||
super(Map, self).__init__(*args, **kwargs)
|
||
for arg in args:
|
||
if isinstance(arg, dict):
|
||
for k, v in arg.items():
|
||
self[k] = v
|
||
|
||
if kwargs:
|
||
for k, v in kwargs.items():
|
||
self[k] = v
|
||
|
||
def __getattr__(self, attr):
|
||
return self.get(attr)
|
||
|
||
def __setattr__(self, key, value):
|
||
self.__setitem__(key, value)
|
||
|
||
def __setitem__(self, key, value):
|
||
super(Map, self).__setitem__(key, value)
|
||
self.__dict__.update({key: value})
|
||
|
||
def __delattr__(self, item):
|
||
self.__delitem__(item)
|
||
|
||
def __delitem__(self, key):
|
||
super(Map, self).__delitem__(key)
|
||
del self.__dict__[key]
|
||
|
||
|
||
def __update_setting():
|
||
global setting
|
||
s = sql_to_dict("select name,value from setting")
|
||
for i in s:
|
||
setting[i["name"]] = i["value"]
|
||
|
||
|
||
def __update_setting_thread():
|
||
while True:
|
||
__update_setting()
|
||
time.sleep(5)
|
||
|
||
|
||
# 系统设置
|
||
if "setting" not in vars():
|
||
setting = Map()
|
||
__update_setting()
|
||
threading.Thread(target=__update_setting_thread, daemon=True).start()
|
||
|
||
|
||
def get_table_desc(table):
|
||
datas=sql_to_dict(f"show full fields from `{table}`")
|
||
ret_data=[]
|
||
for v in datas:
|
||
ret_data.append(Map({"name":v.Field,"type":v.Type,"commnet":v.Comment}))
|
||
return ret_data
|
||
|