Files
ai_web/backend/apps/users/api.py

300 lines
9.4 KiB
Python
Raw Normal View History

2026-01-27 18:15:43 +08:00
"""
User authentication API routes.
"""
from typing import Optional
2026-01-28 16:00:56 +08:00
from ninja import Router, Schema
from ninja.errors import HttpError
2026-01-27 18:15:43 +08:00
from ninja_jwt.authentication import JWTAuth
from ninja_jwt.tokens import RefreshToken
from django.conf import settings
from django.http import HttpRequest, HttpResponse
2026-01-28 16:00:56 +08:00
from urllib.parse import urlparse, urlencode
2026-01-27 18:15:43 +08:00
import requests
from .models import User
2026-01-28 16:00:56 +08:00
from .schemas import UserOut, UserPrivateOut, UserUpdate, TokenOut, OAuthCallbackIn, MessageOut, RegisterIn, LoginIn
2026-01-27 18:15:43 +08:00
router = Router()
def get_current_user(request: HttpRequest) -> Optional[User]:
"""Get current authenticated user from request."""
if hasattr(request, 'auth') and request.auth:
return request.auth
return None
2026-01-28 16:00:56 +08:00
def _is_valid_url(value: str) -> bool:
if not value:
return False
parsed = urlparse(value)
return parsed.scheme in {"http", "https"} and bool(parsed.netloc)
def _require_oauth_config():
if not settings.OAUTH_CLIENT_ID:
raise HttpError(500, "OAuth 未配置客户端 ID")
if not settings.OAUTH_AUTHORIZE_URL:
raise HttpError(500, "OAuth 未配置授权地址")
if not settings.OAUTH_TOKEN_URL:
raise HttpError(500, "OAuth 未配置令牌地址")
if not settings.OAUTH_USERINFO_URL:
raise HttpError(500, "OAuth 未配置用户信息地址")
if not _is_valid_url(settings.OAUTH_REDIRECT_URI):
raise HttpError(500, "OAuth 回调地址无效")
@router.get("/me", response=UserPrivateOut, auth=JWTAuth())
2026-01-27 18:15:43 +08:00
def get_me(request):
"""Get current user information."""
return request.auth
2026-01-28 16:00:56 +08:00
@router.patch("/me", response=UserPrivateOut, auth=JWTAuth())
2026-01-27 18:15:43 +08:00
def update_me(request, data: UserUpdate):
"""Update current user information."""
user = request.auth
2026-01-28 16:00:56 +08:00
# 验证邮箱格式
if data.email is not None:
validate_email(data.email)
2026-01-27 18:15:43 +08:00
if data.name is not None:
2026-01-28 16:00:56 +08:00
if len(data.name) > 50:
raise HttpError(400, "名称不能超过50个字符")
2026-01-27 18:15:43 +08:00
user.name = data.name
if data.email is not None:
user.email = data.email
if data.avatar is not None:
user.avatar = data.avatar
user.save()
return user
@router.post("/logout", response=MessageOut, auth=JWTAuth())
def logout(request):
"""Logout current user (client should discard token)."""
# JWT is stateless, so we just return success
# Client should remove the token from storage
response = HttpResponse()
response.delete_cookie('access_token')
response.delete_cookie('refresh_token')
return MessageOut(message="已退出登录", success=True)
2026-01-28 16:00:56 +08:00
class ChangePasswordIn(Schema):
"""Change password input schema."""
current_password: str
new_password: str
@router.post("/change-password", response=MessageOut, auth=JWTAuth())
def change_password(request, data: ChangePasswordIn):
"""Change current user's password."""
user = request.auth
# 验证当前密码
if not user.check_password(data.current_password):
raise HttpError(400, "当前密码错误")
# 验证新密码
if len(data.new_password) < 6:
raise HttpError(400, "新密码长度至少6位")
if len(data.new_password) > 128:
raise HttpError(400, "新密码长度不能超过128位")
if data.current_password == data.new_password:
raise HttpError(400, "新密码不能与当前密码相同")
# 更新密码
user.set_password(data.new_password)
user.save()
return MessageOut(message="密码已更新", success=True)
2026-01-27 18:15:43 +08:00
@router.post("/refresh", response=TokenOut)
def refresh_token(request, refresh_token: str):
"""Refresh access token using refresh token."""
try:
refresh = RefreshToken(refresh_token)
return TokenOut(
access_token=str(refresh.access_token),
refresh_token=str(refresh),
)
2026-01-28 16:00:56 +08:00
except Exception:
raise HttpError(401, "刷新令牌无效或已过期")
def validate_password(password: str) -> None:
"""Validate password strength."""
if len(password) < 6:
raise HttpError(400, "密码长度至少6位")
if len(password) > 128:
raise HttpError(400, "密码长度不能超过128位")
def validate_email(email: Optional[str]) -> None:
"""Validate email format."""
if email:
import re
email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if not re.match(email_pattern, email):
raise HttpError(400, "邮箱格式不正确")
2026-01-27 18:15:43 +08:00
@router.post("/register", response=TokenOut)
def register(request, data: RegisterIn):
"""Register new user with password."""
2026-01-28 16:00:56 +08:00
# 输入验证
validate_password(data.password)
# 邮箱必填且格式验证
if not data.email:
raise HttpError(400, "邮箱为必填项")
validate_email(data.email)
# 检查用户名是否已存在
2026-01-27 18:15:43 +08:00
if User.objects.filter(open_id=data.open_id).exists():
2026-01-28 16:00:56 +08:00
raise HttpError(400, "用户名已被使用")
# 检查邮箱是否已存在
if User.objects.filter(email=data.email).exists():
raise HttpError(400, "邮箱已被注册")
2026-01-27 18:15:43 +08:00
user = User.objects.create_user(
open_id=data.open_id,
password=data.password,
2026-01-28 16:00:56 +08:00
name=data.name or data.open_id, # 默认显示名称为用户名
2026-01-27 18:15:43 +08:00
email=data.email,
login_method="password",
)
refresh = RefreshToken.for_user(user)
return TokenOut(
access_token=str(refresh.access_token),
refresh_token=str(refresh),
)
@router.post("/login", response=TokenOut)
def login(request, data: LoginIn):
2026-01-28 16:00:56 +08:00
"""Login with open_id or email and password."""
from django.db.models import Q
# 支持用户名或邮箱登录
2026-01-27 18:15:43 +08:00
try:
2026-01-28 16:00:56 +08:00
user = User.objects.get(Q(open_id=data.open_id) | Q(email=data.open_id))
2026-01-27 18:15:43 +08:00
except User.DoesNotExist:
2026-01-28 16:00:56 +08:00
raise HttpError(401, "账号或密码错误")
except User.MultipleObjectsReturned:
# 如果同时匹配多个用户优先使用open_id匹配
try:
user = User.objects.get(open_id=data.open_id)
except User.DoesNotExist:
raise HttpError(401, "账号或密码错误")
2026-01-27 18:15:43 +08:00
if not user.check_password(data.password):
2026-01-28 16:00:56 +08:00
raise HttpError(401, "账号或密码错误")
2026-01-27 18:15:43 +08:00
refresh = RefreshToken.for_user(user)
return TokenOut(
access_token=str(refresh.access_token),
refresh_token=str(refresh),
)
@router.get("/oauth/url")
def get_oauth_url(request, redirect_uri: Optional[str] = None):
"""Get OAuth authorization URL."""
2026-01-28 16:00:56 +08:00
_require_oauth_config()
2026-01-27 18:15:43 +08:00
redirect = redirect_uri or settings.OAUTH_REDIRECT_URI
2026-01-28 16:00:56 +08:00
if not _is_valid_url(redirect):
raise HttpError(400, "回调地址无效")
query = urlencode(
{
"client_id": settings.OAUTH_CLIENT_ID,
"redirect_uri": redirect,
"response_type": "code",
}
)
oauth_url = f"{settings.OAUTH_AUTHORIZE_URL}?{query}"
2026-01-27 18:15:43 +08:00
return {"url": oauth_url}
@router.post("/oauth/callback", response=TokenOut)
def oauth_callback(request, data: OAuthCallbackIn):
"""Handle OAuth callback and create/update user."""
try:
2026-01-28 16:00:56 +08:00
_require_oauth_config()
2026-01-27 18:15:43 +08:00
# Exchange code for access token
token_response = requests.post(
2026-01-28 16:00:56 +08:00
settings.OAUTH_TOKEN_URL,
2026-01-27 18:15:43 +08:00
data={
"client_id": settings.OAUTH_CLIENT_ID,
"client_secret": settings.OAUTH_CLIENT_SECRET,
"code": data.code,
"grant_type": "authorization_code",
"redirect_uri": settings.OAUTH_REDIRECT_URI,
}
)
if token_response.status_code != 200:
2026-01-28 16:00:56 +08:00
raise HttpError(400, "OAuth token exchange failed")
2026-01-27 18:15:43 +08:00
oauth_data = token_response.json()
# Get user info from OAuth provider
user_response = requests.get(
2026-01-28 16:00:56 +08:00
settings.OAUTH_USERINFO_URL,
2026-01-27 18:15:43 +08:00
headers={"Authorization": f"Bearer {oauth_data['access_token']}"}
)
if user_response.status_code != 200:
2026-01-28 16:00:56 +08:00
raise HttpError(400, "Failed to get user info")
2026-01-27 18:15:43 +08:00
user_info = user_response.json()
# Create or update user
user, created = User.objects.update_or_create(
open_id=user_info.get("sub") or user_info.get("id"),
defaults={
"name": user_info.get("name"),
"email": user_info.get("email"),
"avatar": user_info.get("picture") or user_info.get("avatar"),
"login_method": "oauth",
}
)
# Generate JWT tokens
refresh = RefreshToken.for_user(user)
return TokenOut(
access_token=str(refresh.access_token),
refresh_token=str(refresh),
)
2026-01-28 16:00:56 +08:00
except Exception:
raise HttpError(500, "OAuth 登录失败")
2026-01-27 18:15:43 +08:00
# Development endpoint for testing without OAuth
@router.post("/dev/login", response=TokenOut)
def dev_login(request, open_id: str, name: Optional[str] = None):
"""Development login endpoint (disable in production)."""
if not settings.DEBUG:
2026-01-28 16:00:56 +08:00
raise HttpError(403, "Not available in production")
2026-01-27 18:15:43 +08:00
user, created = User.objects.get_or_create(
open_id=open_id,
defaults={
"name": name or f"Dev User {open_id}",
"login_method": "dev",
}
)
refresh = RefreshToken.for_user(user)
return TokenOut(
access_token=str(refresh.access_token),
refresh_token=str(refresh),
)