haha
This commit is contained in:
@@ -2,15 +2,17 @@
|
||||
User authentication API routes.
|
||||
"""
|
||||
from typing import Optional
|
||||
from ninja import Router
|
||||
from ninja import Router, Schema
|
||||
from ninja.errors import HttpError
|
||||
from ninja_jwt.authentication import JWTAuth
|
||||
from ninja_jwt.tokens import RefreshToken
|
||||
from django.conf import settings
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from urllib.parse import urlparse, urlencode
|
||||
import requests
|
||||
|
||||
from .models import User
|
||||
from .schemas import UserOut, UserUpdate, TokenOut, OAuthCallbackIn, MessageOut, RegisterIn, LoginIn
|
||||
from .schemas import UserOut, UserPrivateOut, UserUpdate, TokenOut, OAuthCallbackIn, MessageOut, RegisterIn, LoginIn
|
||||
|
||||
router = Router()
|
||||
|
||||
@@ -22,18 +24,44 @@ def get_current_user(request: HttpRequest) -> Optional[User]:
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/me", response=UserOut, auth=JWTAuth())
|
||||
def _is_valid_url(value: str) -> bool:
|
||||
if not value:
|
||||
return False
|
||||
parsed = urlparse(value)
|
||||
return parsed.scheme in {"http", "https"} and bool(parsed.netloc)
|
||||
|
||||
|
||||
def _require_oauth_config():
|
||||
if not settings.OAUTH_CLIENT_ID:
|
||||
raise HttpError(500, "OAuth 未配置客户端 ID")
|
||||
if not settings.OAUTH_AUTHORIZE_URL:
|
||||
raise HttpError(500, "OAuth 未配置授权地址")
|
||||
if not settings.OAUTH_TOKEN_URL:
|
||||
raise HttpError(500, "OAuth 未配置令牌地址")
|
||||
if not settings.OAUTH_USERINFO_URL:
|
||||
raise HttpError(500, "OAuth 未配置用户信息地址")
|
||||
if not _is_valid_url(settings.OAUTH_REDIRECT_URI):
|
||||
raise HttpError(500, "OAuth 回调地址无效")
|
||||
|
||||
|
||||
@router.get("/me", response=UserPrivateOut, auth=JWTAuth())
|
||||
def get_me(request):
|
||||
"""Get current user information."""
|
||||
return request.auth
|
||||
|
||||
|
||||
@router.patch("/me", response=UserOut, auth=JWTAuth())
|
||||
@router.patch("/me", response=UserPrivateOut, auth=JWTAuth())
|
||||
def update_me(request, data: UserUpdate):
|
||||
"""Update current user information."""
|
||||
user = request.auth
|
||||
|
||||
# 验证邮箱格式
|
||||
if data.email is not None:
|
||||
validate_email(data.email)
|
||||
|
||||
if data.name is not None:
|
||||
if len(data.name) > 50:
|
||||
raise HttpError(400, "名称不能超过50个字符")
|
||||
user.name = data.name
|
||||
if data.email is not None:
|
||||
user.email = data.email
|
||||
@@ -55,6 +83,36 @@ def logout(request):
|
||||
return MessageOut(message="已退出登录", success=True)
|
||||
|
||||
|
||||
class ChangePasswordIn(Schema):
|
||||
"""Change password input schema."""
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
||||
|
||||
@router.post("/change-password", response=MessageOut, auth=JWTAuth())
|
||||
def change_password(request, data: ChangePasswordIn):
|
||||
"""Change current user's password."""
|
||||
user = request.auth
|
||||
|
||||
# 验证当前密码
|
||||
if not user.check_password(data.current_password):
|
||||
raise HttpError(400, "当前密码错误")
|
||||
|
||||
# 验证新密码
|
||||
if len(data.new_password) < 6:
|
||||
raise HttpError(400, "新密码长度至少6位")
|
||||
if len(data.new_password) > 128:
|
||||
raise HttpError(400, "新密码长度不能超过128位")
|
||||
if data.current_password == data.new_password:
|
||||
raise HttpError(400, "新密码不能与当前密码相同")
|
||||
|
||||
# 更新密码
|
||||
user.set_password(data.new_password)
|
||||
user.save()
|
||||
|
||||
return MessageOut(message="密码已更新", success=True)
|
||||
|
||||
|
||||
@router.post("/refresh", response=TokenOut)
|
||||
def refresh_token(request, refresh_token: str):
|
||||
"""Refresh access token using refresh token."""
|
||||
@@ -64,19 +122,50 @@ def refresh_token(request, refresh_token: str):
|
||||
access_token=str(refresh.access_token),
|
||||
refresh_token=str(refresh),
|
||||
)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 401
|
||||
except Exception:
|
||||
raise HttpError(401, "刷新令牌无效或已过期")
|
||||
|
||||
|
||||
def validate_password(password: str) -> None:
|
||||
"""Validate password strength."""
|
||||
if len(password) < 6:
|
||||
raise HttpError(400, "密码长度至少6位")
|
||||
if len(password) > 128:
|
||||
raise HttpError(400, "密码长度不能超过128位")
|
||||
|
||||
|
||||
def validate_email(email: Optional[str]) -> None:
|
||||
"""Validate email format."""
|
||||
if email:
|
||||
import re
|
||||
email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
||||
if not re.match(email_pattern, email):
|
||||
raise HttpError(400, "邮箱格式不正确")
|
||||
|
||||
|
||||
@router.post("/register", response=TokenOut)
|
||||
def register(request, data: RegisterIn):
|
||||
"""Register new user with password."""
|
||||
# 输入验证
|
||||
validate_password(data.password)
|
||||
|
||||
# 邮箱必填且格式验证
|
||||
if not data.email:
|
||||
raise HttpError(400, "邮箱为必填项")
|
||||
validate_email(data.email)
|
||||
|
||||
# 检查用户名是否已存在
|
||||
if User.objects.filter(open_id=data.open_id).exists():
|
||||
return {"error": "账号已存在"}, 400
|
||||
raise HttpError(400, "用户名已被使用")
|
||||
|
||||
# 检查邮箱是否已存在
|
||||
if User.objects.filter(email=data.email).exists():
|
||||
raise HttpError(400, "邮箱已被注册")
|
||||
|
||||
user = User.objects.create_user(
|
||||
open_id=data.open_id,
|
||||
password=data.password,
|
||||
name=data.name,
|
||||
name=data.name or data.open_id, # 默认显示名称为用户名
|
||||
email=data.email,
|
||||
login_method="password",
|
||||
)
|
||||
@@ -89,13 +178,23 @@ def register(request, data: RegisterIn):
|
||||
|
||||
@router.post("/login", response=TokenOut)
|
||||
def login(request, data: LoginIn):
|
||||
"""Login with open_id and password."""
|
||||
"""Login with open_id or email and password."""
|
||||
from django.db.models import Q
|
||||
|
||||
# 支持用户名或邮箱登录
|
||||
try:
|
||||
user = User.objects.get(open_id=data.open_id)
|
||||
user = User.objects.get(Q(open_id=data.open_id) | Q(email=data.open_id))
|
||||
except User.DoesNotExist:
|
||||
return {"error": "账号或密码错误"}, 401
|
||||
raise HttpError(401, "账号或密码错误")
|
||||
except User.MultipleObjectsReturned:
|
||||
# 如果同时匹配多个用户,优先使用open_id匹配
|
||||
try:
|
||||
user = User.objects.get(open_id=data.open_id)
|
||||
except User.DoesNotExist:
|
||||
raise HttpError(401, "账号或密码错误")
|
||||
|
||||
if not user.check_password(data.password):
|
||||
return {"error": "账号或密码错误"}, 401
|
||||
raise HttpError(401, "账号或密码错误")
|
||||
refresh = RefreshToken.for_user(user)
|
||||
return TokenOut(
|
||||
access_token=str(refresh.access_token),
|
||||
@@ -106,27 +205,29 @@ def login(request, data: LoginIn):
|
||||
@router.get("/oauth/url")
|
||||
def get_oauth_url(request, redirect_uri: Optional[str] = None):
|
||||
"""Get OAuth authorization URL."""
|
||||
# This would integrate with Manus SDK or other OAuth provider
|
||||
client_id = settings.OAUTH_CLIENT_ID
|
||||
_require_oauth_config()
|
||||
redirect = redirect_uri or settings.OAUTH_REDIRECT_URI
|
||||
|
||||
# Example OAuth URL (adjust based on actual OAuth provider)
|
||||
oauth_url = f"https://oauth.example.com/authorize?client_id={client_id}&redirect_uri={redirect}&response_type=code"
|
||||
|
||||
if not _is_valid_url(redirect):
|
||||
raise HttpError(400, "回调地址无效")
|
||||
query = urlencode(
|
||||
{
|
||||
"client_id": settings.OAUTH_CLIENT_ID,
|
||||
"redirect_uri": redirect,
|
||||
"response_type": "code",
|
||||
}
|
||||
)
|
||||
oauth_url = f"{settings.OAUTH_AUTHORIZE_URL}?{query}"
|
||||
return {"url": oauth_url}
|
||||
|
||||
|
||||
@router.post("/oauth/callback", response=TokenOut)
|
||||
def oauth_callback(request, data: OAuthCallbackIn):
|
||||
"""Handle OAuth callback and create/update user."""
|
||||
# This would exchange the code for tokens with the OAuth provider
|
||||
# and create or update the user in the database
|
||||
|
||||
# Example implementation (adjust based on actual OAuth provider)
|
||||
try:
|
||||
_require_oauth_config()
|
||||
# Exchange code for access token
|
||||
token_response = requests.post(
|
||||
"https://oauth.example.com/token",
|
||||
settings.OAUTH_TOKEN_URL,
|
||||
data={
|
||||
"client_id": settings.OAUTH_CLIENT_ID,
|
||||
"client_secret": settings.OAUTH_CLIENT_SECRET,
|
||||
@@ -137,18 +238,18 @@ def oauth_callback(request, data: OAuthCallbackIn):
|
||||
)
|
||||
|
||||
if token_response.status_code != 200:
|
||||
return {"error": "OAuth token exchange failed"}, 400
|
||||
raise HttpError(400, "OAuth token exchange failed")
|
||||
|
||||
oauth_data = token_response.json()
|
||||
|
||||
# Get user info from OAuth provider
|
||||
user_response = requests.get(
|
||||
"https://oauth.example.com/userinfo",
|
||||
settings.OAUTH_USERINFO_URL,
|
||||
headers={"Authorization": f"Bearer {oauth_data['access_token']}"}
|
||||
)
|
||||
|
||||
if user_response.status_code != 200:
|
||||
return {"error": "Failed to get user info"}, 400
|
||||
raise HttpError(400, "Failed to get user info")
|
||||
|
||||
user_info = user_response.json()
|
||||
|
||||
@@ -171,8 +272,8 @@ def oauth_callback(request, data: OAuthCallbackIn):
|
||||
refresh_token=str(refresh),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 500
|
||||
except Exception:
|
||||
raise HttpError(500, "OAuth 登录失败")
|
||||
|
||||
|
||||
# Development endpoint for testing without OAuth
|
||||
@@ -180,7 +281,7 @@ def oauth_callback(request, data: OAuthCallbackIn):
|
||||
def dev_login(request, open_id: str, name: Optional[str] = None):
|
||||
"""Development login endpoint (disable in production)."""
|
||||
if not settings.DEBUG:
|
||||
return {"error": "Not available in production"}, 403
|
||||
raise HttpError(403, "Not available in production")
|
||||
|
||||
user, created = User.objects.get_or_create(
|
||||
open_id=open_id,
|
||||
|
||||
Reference in New Issue
Block a user