129 lines
4.1 KiB
Python
129 lines
4.1 KiB
Python
from fastapi import Depends, HTTPException, status
|
||
from fastapi.security import OAuth2PasswordBearer
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select
|
||
from jose import JWTError, jwt
|
||
from typing import Optional
|
||
from datetime import datetime, timedelta
|
||
|
||
from app.core.config import settings
|
||
from app.models.database import get_db
|
||
from app.models.user import User
|
||
from app.models.config import Config
|
||
|
||
# 设置 auto_error=False 使其在没有令牌时不抛出异常
|
||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token", auto_error=False)
|
||
|
||
# 修改 is_privilege_mode 函数
|
||
from app.core.config import settings
|
||
|
||
async def is_privilege_mode(db: AsyncSession) -> bool:
|
||
"""检查是否启用了特权模式"""
|
||
# 首先检查环境变量中的特权模式设置
|
||
if settings.PRIVILEGE_MODE:
|
||
return True
|
||
|
||
# 如果环境变量中没有启用特权模式,则检查数据库中的配置
|
||
try:
|
||
result = await db.execute(
|
||
select(Config).where(
|
||
Config.key == "privilege_mode"
|
||
)
|
||
)
|
||
config = result.scalar_one_or_none()
|
||
return config is not None and config.value.lower() == "true"
|
||
except Exception as e:
|
||
print(f"检查特权模式出错: {e}")
|
||
return False
|
||
|
||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
||
"""创建访问令牌"""
|
||
to_encode = data.copy()
|
||
if expires_delta:
|
||
expire = datetime.utcnow() + expires_delta
|
||
else:
|
||
expire = datetime.utcnow() + timedelta(minutes=15)
|
||
to_encode.update({"exp": expire})
|
||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm="HS256")
|
||
return encoded_jwt
|
||
|
||
async def get_current_user(
|
||
token: str = Depends(oauth2_scheme),
|
||
db: AsyncSession = Depends(get_db)
|
||
) -> User:
|
||
"""获取当前用户(必须已登录)"""
|
||
# 检查特权模式
|
||
if await is_privilege_mode(db):
|
||
# 特权模式下,返回管理员用户
|
||
result = await db.execute(select(User).where(User.role == "admin").limit(1))
|
||
admin_user = result.scalar_one_or_none()
|
||
if admin_user:
|
||
return admin_user
|
||
|
||
credentials_exception = HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="无法验证凭据",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
|
||
if not token:
|
||
raise credentials_exception
|
||
|
||
try:
|
||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
||
username: str = payload.get("sub")
|
||
if username is None:
|
||
raise credentials_exception
|
||
except JWTError:
|
||
raise credentials_exception
|
||
|
||
result = await db.execute(select(User).where(User.username == username))
|
||
user = result.scalar_one_or_none()
|
||
|
||
if user is None:
|
||
raise credentials_exception
|
||
|
||
return user
|
||
|
||
async def get_current_user_optional(
|
||
token: str = Depends(oauth2_scheme),
|
||
db: AsyncSession = Depends(get_db)
|
||
) -> Optional[User]:
|
||
"""获取当前用户(可选,未登录返回None)"""
|
||
# 检查特权模式
|
||
if await is_privilege_mode(db):
|
||
# 特权模式下,返回管理员用户
|
||
result = await db.execute(select(User).where(User.role == "admin").limit(1))
|
||
admin_user = result.scalar_one_or_none()
|
||
if admin_user:
|
||
return admin_user
|
||
|
||
if not token:
|
||
return None
|
||
|
||
try:
|
||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
||
username: str = payload.get("sub")
|
||
if username is None:
|
||
return None
|
||
except JWTError:
|
||
return None
|
||
|
||
result = await db.execute(select(User).where(User.username == username))
|
||
user = result.scalar_one_or_none()
|
||
|
||
return user
|
||
|
||
async def get_current_active_user(
|
||
current_user: User = Depends(get_current_user),
|
||
) -> User:
|
||
"""获取当前活跃用户"""
|
||
return current_user
|
||
|
||
async def get_current_admin_user(
|
||
current_user: User = Depends(get_current_user),
|
||
) -> User:
|
||
"""获取当前管理员用户"""
|
||
if current_user.role != "admin":
|
||
raise HTTPException(status_code=403, detail="权限不足")
|
||
return current_user |