300 lines
9.4 KiB
Python
300 lines
9.4 KiB
Python
"""
|
||
User authentication API routes.
|
||
"""
|
||
from typing import Optional
|
||
from ninja import Router, Schema
|
||
from ninja.errors import HttpError
|
||
from ninja_jwt.authentication import JWTAuth
|
||
from ninja_jwt.tokens import RefreshToken
|
||
from django.conf import settings
|
||
from django.http import HttpRequest, HttpResponse
|
||
from urllib.parse import urlparse, urlencode
|
||
import requests
|
||
|
||
from .models import User
|
||
from .schemas import UserOut, UserPrivateOut, UserUpdate, TokenOut, OAuthCallbackIn, MessageOut, RegisterIn, LoginIn
|
||
|
||
router = Router()
|
||
|
||
|
||
def get_current_user(request: HttpRequest) -> Optional[User]:
|
||
"""Get current authenticated user from request."""
|
||
if hasattr(request, 'auth') and request.auth:
|
||
return request.auth
|
||
return None
|
||
|
||
|
||
def _is_valid_url(value: str) -> bool:
|
||
if not value:
|
||
return False
|
||
parsed = urlparse(value)
|
||
return parsed.scheme in {"http", "https"} and bool(parsed.netloc)
|
||
|
||
|
||
def _require_oauth_config():
|
||
if not settings.OAUTH_CLIENT_ID:
|
||
raise HttpError(500, "OAuth 未配置客户端 ID")
|
||
if not settings.OAUTH_AUTHORIZE_URL:
|
||
raise HttpError(500, "OAuth 未配置授权地址")
|
||
if not settings.OAUTH_TOKEN_URL:
|
||
raise HttpError(500, "OAuth 未配置令牌地址")
|
||
if not settings.OAUTH_USERINFO_URL:
|
||
raise HttpError(500, "OAuth 未配置用户信息地址")
|
||
if not _is_valid_url(settings.OAUTH_REDIRECT_URI):
|
||
raise HttpError(500, "OAuth 回调地址无效")
|
||
|
||
|
||
@router.get("/me", response=UserPrivateOut, auth=JWTAuth())
|
||
def get_me(request):
|
||
"""Get current user information."""
|
||
return request.auth
|
||
|
||
|
||
@router.patch("/me", response=UserPrivateOut, auth=JWTAuth())
|
||
def update_me(request, data: UserUpdate):
|
||
"""Update current user information."""
|
||
user = request.auth
|
||
|
||
# 验证邮箱格式
|
||
if data.email is not None:
|
||
validate_email(data.email)
|
||
|
||
if data.name is not None:
|
||
if len(data.name) > 50:
|
||
raise HttpError(400, "名称不能超过50个字符")
|
||
user.name = data.name
|
||
if data.email is not None:
|
||
user.email = data.email
|
||
if data.avatar is not None:
|
||
user.avatar = data.avatar
|
||
|
||
user.save()
|
||
return user
|
||
|
||
|
||
@router.post("/logout", response=MessageOut, auth=JWTAuth())
|
||
def logout(request):
|
||
"""Logout current user (client should discard token)."""
|
||
# JWT is stateless, so we just return success
|
||
# Client should remove the token from storage
|
||
response = HttpResponse()
|
||
response.delete_cookie('access_token')
|
||
response.delete_cookie('refresh_token')
|
||
return MessageOut(message="已退出登录", success=True)
|
||
|
||
|
||
class ChangePasswordIn(Schema):
|
||
"""Change password input schema."""
|
||
current_password: str
|
||
new_password: str
|
||
|
||
|
||
@router.post("/change-password", response=MessageOut, auth=JWTAuth())
|
||
def change_password(request, data: ChangePasswordIn):
|
||
"""Change current user's password."""
|
||
user = request.auth
|
||
|
||
# 验证当前密码
|
||
if not user.check_password(data.current_password):
|
||
raise HttpError(400, "当前密码错误")
|
||
|
||
# 验证新密码
|
||
if len(data.new_password) < 6:
|
||
raise HttpError(400, "新密码长度至少6位")
|
||
if len(data.new_password) > 128:
|
||
raise HttpError(400, "新密码长度不能超过128位")
|
||
if data.current_password == data.new_password:
|
||
raise HttpError(400, "新密码不能与当前密码相同")
|
||
|
||
# 更新密码
|
||
user.set_password(data.new_password)
|
||
user.save()
|
||
|
||
return MessageOut(message="密码已更新", success=True)
|
||
|
||
|
||
@router.post("/refresh", response=TokenOut)
|
||
def refresh_token(request, refresh_token: str):
|
||
"""Refresh access token using refresh token."""
|
||
try:
|
||
refresh = RefreshToken(refresh_token)
|
||
return TokenOut(
|
||
access_token=str(refresh.access_token),
|
||
refresh_token=str(refresh),
|
||
)
|
||
except Exception:
|
||
raise HttpError(401, "刷新令牌无效或已过期")
|
||
|
||
|
||
def validate_password(password: str) -> None:
|
||
"""Validate password strength."""
|
||
if len(password) < 6:
|
||
raise HttpError(400, "密码长度至少6位")
|
||
if len(password) > 128:
|
||
raise HttpError(400, "密码长度不能超过128位")
|
||
|
||
|
||
def validate_email(email: Optional[str]) -> None:
|
||
"""Validate email format."""
|
||
if email:
|
||
import re
|
||
email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
||
if not re.match(email_pattern, email):
|
||
raise HttpError(400, "邮箱格式不正确")
|
||
|
||
|
||
@router.post("/register", response=TokenOut)
|
||
def register(request, data: RegisterIn):
|
||
"""Register new user with password."""
|
||
# 输入验证
|
||
validate_password(data.password)
|
||
|
||
# 邮箱必填且格式验证
|
||
if not data.email:
|
||
raise HttpError(400, "邮箱为必填项")
|
||
validate_email(data.email)
|
||
|
||
# 检查用户名是否已存在
|
||
if User.objects.filter(open_id=data.open_id).exists():
|
||
raise HttpError(400, "用户名已被使用")
|
||
|
||
# 检查邮箱是否已存在
|
||
if User.objects.filter(email=data.email).exists():
|
||
raise HttpError(400, "邮箱已被注册")
|
||
|
||
user = User.objects.create_user(
|
||
open_id=data.open_id,
|
||
password=data.password,
|
||
name=data.name or data.open_id, # 默认显示名称为用户名
|
||
email=data.email,
|
||
login_method="password",
|
||
)
|
||
refresh = RefreshToken.for_user(user)
|
||
return TokenOut(
|
||
access_token=str(refresh.access_token),
|
||
refresh_token=str(refresh),
|
||
)
|
||
|
||
|
||
@router.post("/login", response=TokenOut)
|
||
def login(request, data: LoginIn):
|
||
"""Login with open_id or email and password."""
|
||
from django.db.models import Q
|
||
|
||
# 支持用户名或邮箱登录
|
||
try:
|
||
user = User.objects.get(Q(open_id=data.open_id) | Q(email=data.open_id))
|
||
except User.DoesNotExist:
|
||
raise HttpError(401, "账号或密码错误")
|
||
except User.MultipleObjectsReturned:
|
||
# 如果同时匹配多个用户,优先使用open_id匹配
|
||
try:
|
||
user = User.objects.get(open_id=data.open_id)
|
||
except User.DoesNotExist:
|
||
raise HttpError(401, "账号或密码错误")
|
||
|
||
if not user.check_password(data.password):
|
||
raise HttpError(401, "账号或密码错误")
|
||
refresh = RefreshToken.for_user(user)
|
||
return TokenOut(
|
||
access_token=str(refresh.access_token),
|
||
refresh_token=str(refresh),
|
||
)
|
||
|
||
|
||
@router.get("/oauth/url")
|
||
def get_oauth_url(request, redirect_uri: Optional[str] = None):
|
||
"""Get OAuth authorization URL."""
|
||
_require_oauth_config()
|
||
redirect = redirect_uri or settings.OAUTH_REDIRECT_URI
|
||
if not _is_valid_url(redirect):
|
||
raise HttpError(400, "回调地址无效")
|
||
query = urlencode(
|
||
{
|
||
"client_id": settings.OAUTH_CLIENT_ID,
|
||
"redirect_uri": redirect,
|
||
"response_type": "code",
|
||
}
|
||
)
|
||
oauth_url = f"{settings.OAUTH_AUTHORIZE_URL}?{query}"
|
||
return {"url": oauth_url}
|
||
|
||
|
||
@router.post("/oauth/callback", response=TokenOut)
|
||
def oauth_callback(request, data: OAuthCallbackIn):
|
||
"""Handle OAuth callback and create/update user."""
|
||
try:
|
||
_require_oauth_config()
|
||
# Exchange code for access token
|
||
token_response = requests.post(
|
||
settings.OAUTH_TOKEN_URL,
|
||
data={
|
||
"client_id": settings.OAUTH_CLIENT_ID,
|
||
"client_secret": settings.OAUTH_CLIENT_SECRET,
|
||
"code": data.code,
|
||
"grant_type": "authorization_code",
|
||
"redirect_uri": settings.OAUTH_REDIRECT_URI,
|
||
}
|
||
)
|
||
|
||
if token_response.status_code != 200:
|
||
raise HttpError(400, "OAuth token exchange failed")
|
||
|
||
oauth_data = token_response.json()
|
||
|
||
# Get user info from OAuth provider
|
||
user_response = requests.get(
|
||
settings.OAUTH_USERINFO_URL,
|
||
headers={"Authorization": f"Bearer {oauth_data['access_token']}"}
|
||
)
|
||
|
||
if user_response.status_code != 200:
|
||
raise HttpError(400, "Failed to get user info")
|
||
|
||
user_info = user_response.json()
|
||
|
||
# Create or update user
|
||
user, created = User.objects.update_or_create(
|
||
open_id=user_info.get("sub") or user_info.get("id"),
|
||
defaults={
|
||
"name": user_info.get("name"),
|
||
"email": user_info.get("email"),
|
||
"avatar": user_info.get("picture") or user_info.get("avatar"),
|
||
"login_method": "oauth",
|
||
}
|
||
)
|
||
|
||
# Generate JWT tokens
|
||
refresh = RefreshToken.for_user(user)
|
||
|
||
return TokenOut(
|
||
access_token=str(refresh.access_token),
|
||
refresh_token=str(refresh),
|
||
)
|
||
|
||
except Exception:
|
||
raise HttpError(500, "OAuth 登录失败")
|
||
|
||
|
||
# Development endpoint for testing without OAuth
|
||
@router.post("/dev/login", response=TokenOut)
|
||
def dev_login(request, open_id: str, name: Optional[str] = None):
|
||
"""Development login endpoint (disable in production)."""
|
||
if not settings.DEBUG:
|
||
raise HttpError(403, "Not available in production")
|
||
|
||
user, created = User.objects.get_or_create(
|
||
open_id=open_id,
|
||
defaults={
|
||
"name": name or f"Dev User {open_id}",
|
||
"login_method": "dev",
|
||
}
|
||
)
|
||
|
||
refresh = RefreshToken.for_user(user)
|
||
|
||
return TokenOut(
|
||
access_token=str(refresh.access_token),
|
||
refresh_token=str(refresh),
|
||
)
|