Files
ai_api_web/backend/routes/models.py
2026-01-22 18:26:47 +08:00

148 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from flask import Blueprint, request, jsonify
from models import db
from models.model import Model
from sqlalchemy import or_, and_
models_bp = Blueprint('models', __name__)
@models_bp.route('', methods=['GET'])
def get_models():
"""获取模型列表(支持筛选、搜索、分页)"""
# 获取查询参数
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 20, type=int)
provider = request.args.get('provider', '')
tags = request.args.getlist('tags') # 支持多个标签
token_group = request.args.get('token_group', '')
billing_type = request.args.get('billing_type', '')
endpoint_type = request.args.get('endpoint_type', '')
search_query = request.args.get('search', '')
show_recharge_price = request.args.get('show_recharge_price', 'false').lower() == 'true'
show_multiplier = request.args.get('show_multiplier', 'false').lower() == 'true'
# 构建查询
query = Model.query.filter_by(is_active=True)
# 供应商筛选
if provider and provider != 'all':
query = query.filter(Model.provider == provider)
# 标签筛选
if tags and 'all' not in tags:
import json
for tag in tags:
query = query.filter(
or_(
Model.tags.like(f'%"{tag}"%'),
Model.tags.like(f'%{tag}%')
)
)
# 分组筛选
if token_group and token_group != 'all':
import json
query = query.filter(
or_(
Model.available_groups.like(f'%"{token_group}"%'),
Model.available_groups.like(f'%{token_group}%')
)
)
# 计费类型筛选
if billing_type and billing_type != 'all':
query = query.filter(Model.billing_type == billing_type)
# 端点类型筛选
if endpoint_type and endpoint_type != 'all':
query = query.filter(Model.endpoint_type == endpoint_type)
# 搜索
if search_query:
query = query.filter(
or_(
Model.name.like(f'%{search_query}%'),
Model.description.like(f'%{search_query}%')
)
)
# 分页
pagination = query.order_by(Model.created_at.desc()).paginate(
page=page, per_page=per_page, error_out=False
)
# 获取筛选选项(用于前端显示)
all_providers = db.session.query(Model.provider).filter_by(is_active=True).distinct().all()
providers = [p[0] for p in all_providers]
# 获取所有标签
all_tags = set()
all_models = Model.query.filter_by(is_active=True).all()
for model in all_models:
if model.tags:
import json
try:
tags_list = json.loads(model.tags)
all_tags.update(tags_list)
except:
pass
return jsonify({
'models': [model.to_dict(show_recharge_price, show_multiplier) for model in pagination.items],
'pagination': {
'page': page,
'per_page': per_page,
'total': pagination.total,
'pages': pagination.pages
},
'filters': {
'providers': providers,
'tags': sorted(list(all_tags)),
'billing_types': ['pay_as_you_go', 'per_request'],
'endpoint_types': ['anthropic', 'openai']
}
}), 200
@models_bp.route('/<int:model_id>', methods=['GET'])
def get_model(model_id):
"""获取单个模型详情"""
model = Model.query.get_or_404(model_id)
return jsonify({'model': model.to_dict(show_recharge_price=True, show_multiplier=True)}), 200
@models_bp.route('/filters', methods=['GET'])
def get_filters():
"""获取筛选选项"""
# 获取所有分组从Token表
from models.token import Token
all_groups = db.session.query(Token.group).distinct().all()
groups = [g[0] for g in all_groups if g[0]]
# 获取所有供应商
all_providers = db.session.query(Model.provider).filter_by(is_active=True).distinct().all()
providers = [p[0] for p in all_providers]
# 获取所有标签
all_tags = set()
all_models = Model.query.filter_by(is_active=True).all()
for model in all_models:
if model.tags:
import json
try:
tags_list = json.loads(model.tags)
all_tags.update(tags_list)
except:
pass
return jsonify({
'providers': providers,
'tags': sorted(list(all_tags)),
'groups': groups,
'billing_types': [
{'value': 'pay_as_you_go', 'label': '按量计费'},
{'value': 'per_request', 'label': '按次计费'}
],
'endpoint_types': [
{'value': 'anthropic', 'label': 'anthropic'},
{'value': 'openai', 'label': 'openai'}
]
}), 200