Files
ai_web/backend/apps/users/api.py
2026-01-28 16:00:56 +08:00

300 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
User authentication API routes.
"""
from typing import Optional
from ninja import Router, Schema
from ninja.errors import HttpError
from ninja_jwt.authentication import JWTAuth
from ninja_jwt.tokens import RefreshToken
from django.conf import settings
from django.http import HttpRequest, HttpResponse
from urllib.parse import urlparse, urlencode
import requests
from .models import User
from .schemas import UserOut, UserPrivateOut, UserUpdate, TokenOut, OAuthCallbackIn, MessageOut, RegisterIn, LoginIn
router = Router()
def get_current_user(request: HttpRequest) -> Optional[User]:
"""Get current authenticated user from request."""
if hasattr(request, 'auth') and request.auth:
return request.auth
return None
def _is_valid_url(value: str) -> bool:
if not value:
return False
parsed = urlparse(value)
return parsed.scheme in {"http", "https"} and bool(parsed.netloc)
def _require_oauth_config():
if not settings.OAUTH_CLIENT_ID:
raise HttpError(500, "OAuth 未配置客户端 ID")
if not settings.OAUTH_AUTHORIZE_URL:
raise HttpError(500, "OAuth 未配置授权地址")
if not settings.OAUTH_TOKEN_URL:
raise HttpError(500, "OAuth 未配置令牌地址")
if not settings.OAUTH_USERINFO_URL:
raise HttpError(500, "OAuth 未配置用户信息地址")
if not _is_valid_url(settings.OAUTH_REDIRECT_URI):
raise HttpError(500, "OAuth 回调地址无效")
@router.get("/me", response=UserPrivateOut, auth=JWTAuth())
def get_me(request):
"""Get current user information."""
return request.auth
@router.patch("/me", response=UserPrivateOut, auth=JWTAuth())
def update_me(request, data: UserUpdate):
"""Update current user information."""
user = request.auth
# 验证邮箱格式
if data.email is not None:
validate_email(data.email)
if data.name is not None:
if len(data.name) > 50:
raise HttpError(400, "名称不能超过50个字符")
user.name = data.name
if data.email is not None:
user.email = data.email
if data.avatar is not None:
user.avatar = data.avatar
user.save()
return user
@router.post("/logout", response=MessageOut, auth=JWTAuth())
def logout(request):
"""Logout current user (client should discard token)."""
# JWT is stateless, so we just return success
# Client should remove the token from storage
response = HttpResponse()
response.delete_cookie('access_token')
response.delete_cookie('refresh_token')
return MessageOut(message="已退出登录", success=True)
class ChangePasswordIn(Schema):
"""Change password input schema."""
current_password: str
new_password: str
@router.post("/change-password", response=MessageOut, auth=JWTAuth())
def change_password(request, data: ChangePasswordIn):
"""Change current user's password."""
user = request.auth
# 验证当前密码
if not user.check_password(data.current_password):
raise HttpError(400, "当前密码错误")
# 验证新密码
if len(data.new_password) < 6:
raise HttpError(400, "新密码长度至少6位")
if len(data.new_password) > 128:
raise HttpError(400, "新密码长度不能超过128位")
if data.current_password == data.new_password:
raise HttpError(400, "新密码不能与当前密码相同")
# 更新密码
user.set_password(data.new_password)
user.save()
return MessageOut(message="密码已更新", success=True)
@router.post("/refresh", response=TokenOut)
def refresh_token(request, refresh_token: str):
"""Refresh access token using refresh token."""
try:
refresh = RefreshToken(refresh_token)
return TokenOut(
access_token=str(refresh.access_token),
refresh_token=str(refresh),
)
except Exception:
raise HttpError(401, "刷新令牌无效或已过期")
def validate_password(password: str) -> None:
"""Validate password strength."""
if len(password) < 6:
raise HttpError(400, "密码长度至少6位")
if len(password) > 128:
raise HttpError(400, "密码长度不能超过128位")
def validate_email(email: Optional[str]) -> None:
"""Validate email format."""
if email:
import re
email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if not re.match(email_pattern, email):
raise HttpError(400, "邮箱格式不正确")
@router.post("/register", response=TokenOut)
def register(request, data: RegisterIn):
"""Register new user with password."""
# 输入验证
validate_password(data.password)
# 邮箱必填且格式验证
if not data.email:
raise HttpError(400, "邮箱为必填项")
validate_email(data.email)
# 检查用户名是否已存在
if User.objects.filter(open_id=data.open_id).exists():
raise HttpError(400, "用户名已被使用")
# 检查邮箱是否已存在
if User.objects.filter(email=data.email).exists():
raise HttpError(400, "邮箱已被注册")
user = User.objects.create_user(
open_id=data.open_id,
password=data.password,
name=data.name or data.open_id, # 默认显示名称为用户名
email=data.email,
login_method="password",
)
refresh = RefreshToken.for_user(user)
return TokenOut(
access_token=str(refresh.access_token),
refresh_token=str(refresh),
)
@router.post("/login", response=TokenOut)
def login(request, data: LoginIn):
"""Login with open_id or email and password."""
from django.db.models import Q
# 支持用户名或邮箱登录
try:
user = User.objects.get(Q(open_id=data.open_id) | Q(email=data.open_id))
except User.DoesNotExist:
raise HttpError(401, "账号或密码错误")
except User.MultipleObjectsReturned:
# 如果同时匹配多个用户优先使用open_id匹配
try:
user = User.objects.get(open_id=data.open_id)
except User.DoesNotExist:
raise HttpError(401, "账号或密码错误")
if not user.check_password(data.password):
raise HttpError(401, "账号或密码错误")
refresh = RefreshToken.for_user(user)
return TokenOut(
access_token=str(refresh.access_token),
refresh_token=str(refresh),
)
@router.get("/oauth/url")
def get_oauth_url(request, redirect_uri: Optional[str] = None):
"""Get OAuth authorization URL."""
_require_oauth_config()
redirect = redirect_uri or settings.OAUTH_REDIRECT_URI
if not _is_valid_url(redirect):
raise HttpError(400, "回调地址无效")
query = urlencode(
{
"client_id": settings.OAUTH_CLIENT_ID,
"redirect_uri": redirect,
"response_type": "code",
}
)
oauth_url = f"{settings.OAUTH_AUTHORIZE_URL}?{query}"
return {"url": oauth_url}
@router.post("/oauth/callback", response=TokenOut)
def oauth_callback(request, data: OAuthCallbackIn):
"""Handle OAuth callback and create/update user."""
try:
_require_oauth_config()
# Exchange code for access token
token_response = requests.post(
settings.OAUTH_TOKEN_URL,
data={
"client_id": settings.OAUTH_CLIENT_ID,
"client_secret": settings.OAUTH_CLIENT_SECRET,
"code": data.code,
"grant_type": "authorization_code",
"redirect_uri": settings.OAUTH_REDIRECT_URI,
}
)
if token_response.status_code != 200:
raise HttpError(400, "OAuth token exchange failed")
oauth_data = token_response.json()
# Get user info from OAuth provider
user_response = requests.get(
settings.OAUTH_USERINFO_URL,
headers={"Authorization": f"Bearer {oauth_data['access_token']}"}
)
if user_response.status_code != 200:
raise HttpError(400, "Failed to get user info")
user_info = user_response.json()
# Create or update user
user, created = User.objects.update_or_create(
open_id=user_info.get("sub") or user_info.get("id"),
defaults={
"name": user_info.get("name"),
"email": user_info.get("email"),
"avatar": user_info.get("picture") or user_info.get("avatar"),
"login_method": "oauth",
}
)
# Generate JWT tokens
refresh = RefreshToken.for_user(user)
return TokenOut(
access_token=str(refresh.access_token),
refresh_token=str(refresh),
)
except Exception:
raise HttpError(500, "OAuth 登录失败")
# Development endpoint for testing without OAuth
@router.post("/dev/login", response=TokenOut)
def dev_login(request, open_id: str, name: Optional[str] = None):
"""Development login endpoint (disable in production)."""
if not settings.DEBUG:
raise HttpError(403, "Not available in production")
user, created = User.objects.get_or_create(
open_id=open_id,
defaults={
"name": name or f"Dev User {open_id}",
"login_method": "dev",
}
)
refresh = RefreshToken.for_user(user)
return TokenOut(
access_token=str(refresh.access_token),
refresh_token=str(refresh),
)