config_center/app/api/deps.py
2025-03-03 22:28:34 +08:00

129 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from jose import JWTError, jwt
from typing import Optional
from datetime import datetime, timedelta
from app.core.config import settings
from app.models.database import get_db
from app.models.user import User
from app.models.config import Config
# 设置 auto_error=False 使其在没有令牌时不抛出异常
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token", auto_error=False)
# 修改 is_privilege_mode 函数
from app.core.config import settings
async def is_privilege_mode(db: AsyncSession) -> bool:
"""检查是否启用了特权模式"""
# 首先检查环境变量中的特权模式设置
if settings.PRIVILEGE_MODE:
return True
# 如果环境变量中没有启用特权模式,则检查数据库中的配置
try:
result = await db.execute(
select(Config).where(
Config.key == "privilege_mode"
)
)
config = result.scalar_one_or_none()
return config is not None and config.value.lower() == "true"
except Exception as e:
print(f"检查特权模式出错: {e}")
return False
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
"""创建访问令牌"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm="HS256")
return encoded_jwt
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db)
) -> User:
"""获取当前用户(必须已登录)"""
# 检查特权模式
if await is_privilege_mode(db):
# 特权模式下,返回管理员用户
result = await db.execute(select(User).where(User.role == "admin").limit(1))
admin_user = result.scalar_one_or_none()
if admin_user:
return admin_user
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无法验证凭据",
headers={"WWW-Authenticate": "Bearer"},
)
if not token:
raise credentials_exception
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
result = await db.execute(select(User).where(User.username == username))
user = result.scalar_one_or_none()
if user is None:
raise credentials_exception
return user
async def get_current_user_optional(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db)
) -> Optional[User]:
"""获取当前用户可选未登录返回None"""
# 检查特权模式
if await is_privilege_mode(db):
# 特权模式下,返回管理员用户
result = await db.execute(select(User).where(User.role == "admin").limit(1))
admin_user = result.scalar_one_or_none()
if admin_user:
return admin_user
if not token:
return None
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
username: str = payload.get("sub")
if username is None:
return None
except JWTError:
return None
result = await db.execute(select(User).where(User.username == username))
user = result.scalar_one_or_none()
return user
async def get_current_active_user(
current_user: User = Depends(get_current_user),
) -> User:
"""获取当前活跃用户"""
return current_user
async def get_current_admin_user(
current_user: User = Depends(get_current_user),
) -> User:
"""获取当前管理员用户"""
if current_user.role != "admin":
raise HTTPException(status_code=403, detail="权限不足")
return current_user