79 lines
3.1 KiB
Python
79 lines
3.1 KiB
Python
import logging
|
||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||
from sqlalchemy.orm import sessionmaker
|
||
from sqlalchemy import text, select, inspect
|
||
from app.core.config import settings
|
||
import bcrypt
|
||
from app.models.base import Base
|
||
from app.models.user import User
|
||
from app.models.type import Type
|
||
|
||
# SQLite异步URL需要使用aiosqlite
|
||
DATABASE_URL = settings.DATABASE_URL.replace("sqlite:///", "sqlite+aiosqlite:///")
|
||
|
||
engine = create_async_engine(DATABASE_URL, echo=True)
|
||
AsyncSessionLocal = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||
|
||
# 在创建默认配置类型后添加特权模式配置
|
||
async def init_db():
|
||
"""初始化数据库,创建所有表"""
|
||
try:
|
||
# 检查表是否存在
|
||
async with engine.connect() as conn:
|
||
# 检查users表是否存在
|
||
result = await conn.execute(text("SELECT name FROM sqlite_master WHERE type='table' AND name='users'"))
|
||
table_exists = result.scalar() is not None
|
||
|
||
if not table_exists:
|
||
# 如果表不存在,创建所有表
|
||
async with engine.begin() as conn2:
|
||
await conn2.run_sync(Base.metadata.create_all)
|
||
|
||
# 创建默认数据
|
||
async with AsyncSessionLocal() as session:
|
||
# 创建默认管理员用户
|
||
password = "admin123" # 默认密码
|
||
salt = bcrypt.gensalt()
|
||
hashed_password = bcrypt.hashpw(password.encode(), salt)
|
||
|
||
admin_user = User(
|
||
username="admin",
|
||
hashed_password=hashed_password.decode(),
|
||
role="admin"
|
||
)
|
||
session.add(admin_user)
|
||
|
||
# 创建默认配置类型
|
||
default_type = Type(
|
||
type_name="default",
|
||
description="默认配置类型"
|
||
)
|
||
session.add(default_type)
|
||
|
||
# 添加特权模式配置(默认关闭)
|
||
from app.models.config import Config
|
||
privilege_config = Config(
|
||
type_id=1, # 默认类型ID为1
|
||
key="privilege_mode",
|
||
value="true",
|
||
key_description="特权模式:设置为true时,所有用户无需登录即可使用所有功能"
|
||
)
|
||
session.add(privilege_config)
|
||
|
||
await session.commit()
|
||
|
||
logging.info("数据库初始化成功,创建了默认管理员用户(admin/admin123)和默认配置类型")
|
||
else:
|
||
logging.info("数据库表已存在,跳过初始化")
|
||
|
||
except Exception as e:
|
||
logging.error(f"数据库初始化失败: {e}")
|
||
raise
|
||
|
||
async def get_db():
|
||
"""获取数据库会话"""
|
||
db = AsyncSessionLocal()
|
||
try:
|
||
yield db
|
||
finally:
|
||
await db.close() |