""" User authentication API routes. """ from typing import Optional from ninja import Router, Schema from ninja.errors import HttpError from ninja_jwt.authentication import JWTAuth from ninja_jwt.tokens import RefreshToken from django.conf import settings from django.http import HttpRequest, HttpResponse from urllib.parse import urlparse, urlencode import requests from .models import User from .schemas import UserOut, UserPrivateOut, UserUpdate, TokenOut, OAuthCallbackIn, MessageOut, RegisterIn, LoginIn router = Router() def get_current_user(request: HttpRequest) -> Optional[User]: """Get current authenticated user from request.""" if hasattr(request, 'auth') and request.auth: return request.auth return None def _is_valid_url(value: str) -> bool: if not value: return False parsed = urlparse(value) return parsed.scheme in {"http", "https"} and bool(parsed.netloc) def _require_oauth_config(): if not settings.OAUTH_CLIENT_ID: raise HttpError(500, "OAuth 未配置客户端 ID") if not settings.OAUTH_AUTHORIZE_URL: raise HttpError(500, "OAuth 未配置授权地址") if not settings.OAUTH_TOKEN_URL: raise HttpError(500, "OAuth 未配置令牌地址") if not settings.OAUTH_USERINFO_URL: raise HttpError(500, "OAuth 未配置用户信息地址") if not _is_valid_url(settings.OAUTH_REDIRECT_URI): raise HttpError(500, "OAuth 回调地址无效") @router.get("/me", response=UserPrivateOut, auth=JWTAuth()) def get_me(request): """Get current user information.""" return request.auth @router.patch("/me", response=UserPrivateOut, auth=JWTAuth()) def update_me(request, data: UserUpdate): """Update current user information.""" user = request.auth # 验证邮箱格式 if data.email is not None: validate_email(data.email) if data.name is not None: if len(data.name) > 50: raise HttpError(400, "名称不能超过50个字符") user.name = data.name if data.email is not None: user.email = data.email if data.avatar is not None: user.avatar = data.avatar user.save() return user @router.post("/logout", response=MessageOut, auth=JWTAuth()) def logout(request): """Logout current user (client should discard token).""" # JWT is stateless, so we just return success # Client should remove the token from storage response = HttpResponse() response.delete_cookie('access_token') response.delete_cookie('refresh_token') return MessageOut(message="已退出登录", success=True) class ChangePasswordIn(Schema): """Change password input schema.""" current_password: str new_password: str @router.post("/change-password", response=MessageOut, auth=JWTAuth()) def change_password(request, data: ChangePasswordIn): """Change current user's password.""" user = request.auth # 验证当前密码 if not user.check_password(data.current_password): raise HttpError(400, "当前密码错误") # 验证新密码 if len(data.new_password) < 6: raise HttpError(400, "新密码长度至少6位") if len(data.new_password) > 128: raise HttpError(400, "新密码长度不能超过128位") if data.current_password == data.new_password: raise HttpError(400, "新密码不能与当前密码相同") # 更新密码 user.set_password(data.new_password) user.save() return MessageOut(message="密码已更新", success=True) @router.post("/refresh", response=TokenOut) def refresh_token(request, refresh_token: str): """Refresh access token using refresh token.""" try: refresh = RefreshToken(refresh_token) return TokenOut( access_token=str(refresh.access_token), refresh_token=str(refresh), ) except Exception: raise HttpError(401, "刷新令牌无效或已过期") def validate_password(password: str) -> None: """Validate password strength.""" if len(password) < 6: raise HttpError(400, "密码长度至少6位") if len(password) > 128: raise HttpError(400, "密码长度不能超过128位") def validate_email(email: Optional[str]) -> None: """Validate email format.""" if email: import re email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' if not re.match(email_pattern, email): raise HttpError(400, "邮箱格式不正确") @router.post("/register", response=TokenOut) def register(request, data: RegisterIn): """Register new user with password.""" # 输入验证 validate_password(data.password) # 邮箱必填且格式验证 if not data.email: raise HttpError(400, "邮箱为必填项") validate_email(data.email) # 检查用户名是否已存在 if User.objects.filter(open_id=data.open_id).exists(): raise HttpError(400, "用户名已被使用") # 检查邮箱是否已存在 if User.objects.filter(email=data.email).exists(): raise HttpError(400, "邮箱已被注册") user = User.objects.create_user( open_id=data.open_id, password=data.password, name=data.name or data.open_id, # 默认显示名称为用户名 email=data.email, login_method="password", ) refresh = RefreshToken.for_user(user) return TokenOut( access_token=str(refresh.access_token), refresh_token=str(refresh), ) @router.post("/login", response=TokenOut) def login(request, data: LoginIn): """Login with open_id or email and password.""" from django.db.models import Q # 支持用户名或邮箱登录 try: user = User.objects.get(Q(open_id=data.open_id) | Q(email=data.open_id)) except User.DoesNotExist: raise HttpError(401, "账号或密码错误") except User.MultipleObjectsReturned: # 如果同时匹配多个用户,优先使用open_id匹配 try: user = User.objects.get(open_id=data.open_id) except User.DoesNotExist: raise HttpError(401, "账号或密码错误") if not user.check_password(data.password): raise HttpError(401, "账号或密码错误") refresh = RefreshToken.for_user(user) return TokenOut( access_token=str(refresh.access_token), refresh_token=str(refresh), ) @router.get("/oauth/url") def get_oauth_url(request, redirect_uri: Optional[str] = None): """Get OAuth authorization URL.""" _require_oauth_config() redirect = redirect_uri or settings.OAUTH_REDIRECT_URI if not _is_valid_url(redirect): raise HttpError(400, "回调地址无效") query = urlencode( { "client_id": settings.OAUTH_CLIENT_ID, "redirect_uri": redirect, "response_type": "code", } ) oauth_url = f"{settings.OAUTH_AUTHORIZE_URL}?{query}" return {"url": oauth_url} @router.post("/oauth/callback", response=TokenOut) def oauth_callback(request, data: OAuthCallbackIn): """Handle OAuth callback and create/update user.""" try: _require_oauth_config() # Exchange code for access token token_response = requests.post( settings.OAUTH_TOKEN_URL, data={ "client_id": settings.OAUTH_CLIENT_ID, "client_secret": settings.OAUTH_CLIENT_SECRET, "code": data.code, "grant_type": "authorization_code", "redirect_uri": settings.OAUTH_REDIRECT_URI, } ) if token_response.status_code != 200: raise HttpError(400, "OAuth token exchange failed") oauth_data = token_response.json() # Get user info from OAuth provider user_response = requests.get( settings.OAUTH_USERINFO_URL, headers={"Authorization": f"Bearer {oauth_data['access_token']}"} ) if user_response.status_code != 200: raise HttpError(400, "Failed to get user info") user_info = user_response.json() # Create or update user user, created = User.objects.update_or_create( open_id=user_info.get("sub") or user_info.get("id"), defaults={ "name": user_info.get("name"), "email": user_info.get("email"), "avatar": user_info.get("picture") or user_info.get("avatar"), "login_method": "oauth", } ) # Generate JWT tokens refresh = RefreshToken.for_user(user) return TokenOut( access_token=str(refresh.access_token), refresh_token=str(refresh), ) except Exception: raise HttpError(500, "OAuth 登录失败") # Development endpoint for testing without OAuth @router.post("/dev/login", response=TokenOut) def dev_login(request, open_id: str, name: Optional[str] = None): """Development login endpoint (disable in production).""" if not settings.DEBUG: raise HttpError(403, "Not available in production") user, created = User.objects.get_or_create( open_id=open_id, defaults={ "name": name or f"Dev User {open_id}", "login_method": "dev", } ) refresh = RefreshToken.for_user(user) return TokenOut( access_token=str(refresh.access_token), refresh_token=str(refresh), )