341 lines
14 KiB
Python
341 lines
14 KiB
Python
"""
|
||
检查数据库结构和模型类是否一致的 Django 管理命令
|
||
"""
|
||
from django.core.management.base import BaseCommand
|
||
from django.db import connection
|
||
from django.apps import apps
|
||
from django.core.management.color import no_style
|
||
from django.db import models
|
||
import sys
|
||
|
||
|
||
class Command(BaseCommand):
|
||
help = '检查数据库结构和模型类是否一致'
|
||
|
||
def add_arguments(self, parser):
|
||
parser.add_argument(
|
||
'--app',
|
||
type=str,
|
||
help='指定要检查的应用名称(如 User, finance, business)',
|
||
)
|
||
parser.add_argument(
|
||
'--model',
|
||
type=str,
|
||
help='指定要检查的模型名称',
|
||
)
|
||
parser.add_argument(
|
||
'--fix',
|
||
action='store_true',
|
||
help='尝试修复不一致的问题(生成迁移文件)',
|
||
)
|
||
|
||
def handle(self, *args, **options):
|
||
self.stdout.write(self.style.SUCCESS('开始检查数据库结构和模型类...\n'))
|
||
|
||
app_name = options.get('app')
|
||
model_name = options.get('model')
|
||
fix_mode = options.get('fix', False)
|
||
|
||
# 获取所有应用
|
||
if app_name:
|
||
apps_to_check = [apps.get_app_config(app_name)]
|
||
else:
|
||
apps_to_check = [apps.get_app_config('User'),
|
||
apps.get_app_config('finance'),
|
||
apps.get_app_config('business')]
|
||
|
||
issues = []
|
||
warnings = []
|
||
|
||
with connection.cursor() as cursor:
|
||
# 获取数据库名(只获取一次,提高性能)
|
||
cursor.execute("SELECT DATABASE()")
|
||
db_name = cursor.fetchone()[0]
|
||
if not db_name:
|
||
self.stdout.write(self.style.ERROR('无法获取数据库名,检查终止'))
|
||
return None
|
||
|
||
for app_config in apps_to_check:
|
||
self.stdout.write(self.style.WARNING(f'\n检查应用: {app_config.name}'))
|
||
self.stdout.write('=' * 80)
|
||
|
||
# 获取应用中的所有模型
|
||
models_to_check = app_config.get_models()
|
||
|
||
if model_name:
|
||
models_to_check = [m for m in models_to_check if m.__name__ == model_name]
|
||
|
||
for model in models_to_check:
|
||
self.stdout.write(f'\n检查模型: {model.__name__}')
|
||
self.stdout.write('-' * 80)
|
||
|
||
# 获取数据库表名
|
||
db_table = model._meta.db_table
|
||
|
||
# 检查表是否存在
|
||
table_exists = self._check_table_exists(cursor, db_name, db_table)
|
||
|
||
if not table_exists:
|
||
issue = {
|
||
'type': 'missing_table',
|
||
'app': app_config.name,
|
||
'model': model.__name__,
|
||
'table': db_table,
|
||
'message': f'表 {db_table} 不存在于数据库中'
|
||
}
|
||
issues.append(issue)
|
||
self.stdout.write(self.style.ERROR(f' ❌ 表 {db_table} 不存在'))
|
||
continue
|
||
|
||
self.stdout.write(self.style.SUCCESS(f' ✅ 表 {db_table} 存在'))
|
||
|
||
# 获取数据库表的字段信息
|
||
db_fields = self._get_db_fields(cursor, db_name, db_table)
|
||
|
||
# 获取模型类的字段信息
|
||
model_fields = self._get_model_fields(model)
|
||
|
||
# 检查字段
|
||
field_issues = self._compare_fields(model, db_table, model_fields, db_fields)
|
||
issues.extend(field_issues)
|
||
|
||
# 检查多对多关系表
|
||
m2m_issues = self._check_m2m_tables(cursor, db_name, model)
|
||
issues.extend(m2m_issues)
|
||
|
||
# 输出结果
|
||
self.stdout.write('\n' + '=' * 80)
|
||
self.stdout.write(self.style.SUCCESS('\n检查完成!\n'))
|
||
|
||
if issues:
|
||
self.stdout.write(self.style.ERROR(f'发现 {len(issues)} 个问题:\n'))
|
||
for i, issue in enumerate(issues, 1):
|
||
self.stdout.write(f'{i}. [{issue["type"]}] {issue["message"]}')
|
||
if 'details' in issue:
|
||
for detail in issue['details']:
|
||
self.stdout.write(f' - {detail}')
|
||
else:
|
||
self.stdout.write(self.style.SUCCESS('✅ 数据库结构和模型类完全一致!'))
|
||
|
||
if warnings:
|
||
self.stdout.write(self.style.WARNING(f'\n警告 ({len(warnings)} 个):\n'))
|
||
for i, warning in enumerate(warnings, 1):
|
||
self.stdout.write(f'{i}. {warning}')
|
||
|
||
# 如果发现问题且开启了修复模式
|
||
if issues and fix_mode:
|
||
self.stdout.write(self.style.WARNING('\n尝试生成迁移文件...'))
|
||
from django.core.management import call_command
|
||
try:
|
||
call_command('makemigrations', verbosity=0)
|
||
self.stdout.write(self.style.SUCCESS('迁移文件已生成,请运行 python manage.py migrate 应用迁移'))
|
||
except Exception as e:
|
||
self.stdout.write(self.style.ERROR(f'生成迁移文件失败: {str(e)}'))
|
||
|
||
# 返回 None 而不是整数,避免 Django 的 execute 方法报错
|
||
return None
|
||
|
||
def _check_table_exists(self, cursor, db_name, table_name):
|
||
"""检查表是否存在"""
|
||
try:
|
||
cursor.execute("""
|
||
SELECT COUNT(*)
|
||
FROM information_schema.tables
|
||
WHERE table_schema = %s
|
||
AND table_name = %s
|
||
""", [db_name, table_name])
|
||
return cursor.fetchone()[0] > 0
|
||
except Exception as e:
|
||
# 静默处理错误,避免输出过多
|
||
return False
|
||
|
||
def _get_db_fields(self, cursor, db_name, table_name):
|
||
"""获取数据库表的字段信息"""
|
||
try:
|
||
cursor.execute("""
|
||
SELECT
|
||
COLUMN_NAME,
|
||
DATA_TYPE,
|
||
CHARACTER_MAXIMUM_LENGTH,
|
||
IS_NULLABLE,
|
||
COLUMN_DEFAULT,
|
||
COLUMN_TYPE
|
||
FROM information_schema.COLUMNS
|
||
WHERE TABLE_SCHEMA = %s
|
||
AND TABLE_NAME = %s
|
||
ORDER BY ORDINAL_POSITION
|
||
""", [db_name, table_name])
|
||
|
||
fields = {}
|
||
for row in cursor.fetchall():
|
||
fields[row[0]] = {
|
||
'type': row[1],
|
||
'max_length': row[2],
|
||
'nullable': row[3] == 'YES',
|
||
'default': row[4],
|
||
'column_type': row[5]
|
||
}
|
||
return fields
|
||
except Exception as e:
|
||
self.stdout.write(self.style.ERROR(f' ❌ 获取表 {table_name} 的字段信息失败: {str(e)}'))
|
||
return {}
|
||
|
||
def _get_model_fields(self, model):
|
||
"""获取模型类的字段信息"""
|
||
fields = {}
|
||
# 使用 get_fields(include_parents=False) 只获取当前模型的字段
|
||
# 使用 get_fields() 会包含反向关系,需要过滤
|
||
for field in model._meta.get_fields(include_parents=False):
|
||
# 跳过多对多关系(单独处理)
|
||
if isinstance(field, models.ManyToManyField):
|
||
continue
|
||
|
||
# 跳过反向关系(related_name 定义的字段,如 approver_teams)
|
||
# auto_created=True 表示是 Django 自动创建的反向关系
|
||
if getattr(field, 'auto_created', False):
|
||
continue
|
||
|
||
# 跳过反向外键关系(如 Department.user_set)
|
||
# 如果字段没有 column 属性,说明不是数据库字段
|
||
if not hasattr(field, 'column'):
|
||
continue
|
||
|
||
# 只处理有 column 属性的字段(数据库字段)
|
||
field_name = field.column
|
||
|
||
fields[field_name] = {
|
||
'field': field,
|
||
'name': field.name,
|
||
'type': type(field).__name__,
|
||
'null': getattr(field, 'null', False),
|
||
'blank': getattr(field, 'blank', False),
|
||
'default': getattr(field, 'default', models.NOT_PROVIDED),
|
||
'max_length': getattr(field, 'max_length', None),
|
||
}
|
||
|
||
return fields
|
||
|
||
def _compare_fields(self, model, table_name, model_fields, db_fields):
|
||
"""比较模型字段和数据库字段"""
|
||
issues = []
|
||
|
||
# 检查模型中的字段是否在数据库中存在
|
||
for field_name, field_info in model_fields.items():
|
||
if field_name not in db_fields:
|
||
issue = {
|
||
'type': 'missing_field',
|
||
'app': model._meta.app_label,
|
||
'model': model.__name__,
|
||
'table': table_name,
|
||
'field': field_name,
|
||
'message': f'模型字段 {field_name} 在数据库表 {table_name} 中不存在'
|
||
}
|
||
issues.append(issue)
|
||
self.stdout.write(self.style.ERROR(f' ❌ 字段 {field_name} 在数据库中不存在'))
|
||
else:
|
||
# 检查字段类型是否匹配
|
||
db_field = db_fields[field_name]
|
||
field_type_issue = self._check_field_type(field_info, db_field, field_name)
|
||
if field_type_issue:
|
||
issues.append(field_type_issue)
|
||
self.stdout.write(self.style.WARNING(f' ⚠️ 字段 {field_name} 类型可能不匹配'))
|
||
else:
|
||
self.stdout.write(self.style.SUCCESS(f' ✅ 字段 {field_name} 匹配'))
|
||
|
||
# 检查数据库中是否有模型中没有的字段(可能是遗留字段)
|
||
model_field_names = set(model_fields.keys())
|
||
db_field_names = set(db_fields.keys())
|
||
extra_db_fields = db_field_names - model_field_names
|
||
|
||
# 排除 id 字段(Django 自动添加)
|
||
extra_db_fields.discard('id')
|
||
|
||
if extra_db_fields:
|
||
issue = {
|
||
'type': 'extra_field',
|
||
'app': model._meta.app_label,
|
||
'model': model.__name__,
|
||
'table': table_name,
|
||
'fields': list(extra_db_fields),
|
||
'message': f'数据库表 {table_name} 中有模型中没有的字段: {", ".join(extra_db_fields)}'
|
||
}
|
||
issues.append(issue)
|
||
self.stdout.write(self.style.WARNING(f' ⚠️ 数据库中有额外字段: {", ".join(extra_db_fields)}'))
|
||
|
||
return issues
|
||
|
||
def _check_field_type(self, model_field_info, db_field_info, field_name):
|
||
"""检查字段类型是否匹配"""
|
||
field = model_field_info['field']
|
||
field_type = model_field_info['type']
|
||
db_type = db_field_info['type'].upper()
|
||
|
||
# 类型映射
|
||
type_mapping = {
|
||
'CharField': 'VARCHAR',
|
||
'TextField': 'TEXT',
|
||
'IntegerField': 'INT',
|
||
'BigIntegerField': 'BIGINT',
|
||
'BooleanField': 'TINYINT',
|
||
'DateField': 'DATE',
|
||
'DateTimeField': 'DATETIME',
|
||
'ForeignKey': 'BIGINT', # 外键在数据库中通常是 BIGINT
|
||
}
|
||
|
||
expected_db_type = type_mapping.get(field_type, None)
|
||
|
||
if expected_db_type and db_type not in expected_db_type:
|
||
# 特殊处理:BooleanField 在 MySQL 中可能是 TINYINT(1)
|
||
if field_type == 'BooleanField' and 'TINYINT' in db_type:
|
||
return None
|
||
|
||
# 特殊处理:CharField 和 VARCHAR
|
||
if field_type == 'CharField' and 'VARCHAR' in db_type:
|
||
return None
|
||
|
||
# 特殊处理:TextField 和 TEXT/LONGTEXT
|
||
if field_type == 'TextField' and 'TEXT' in db_type:
|
||
return None
|
||
|
||
return {
|
||
'type': 'type_mismatch',
|
||
'app': field.model._meta.app_label,
|
||
'model': field.model.__name__,
|
||
'field': field_name,
|
||
'expected': expected_db_type,
|
||
'actual': db_type,
|
||
'message': f'字段 {field_name} 类型不匹配: 模型期望 {expected_db_type}, 数据库是 {db_type}'
|
||
}
|
||
|
||
return None
|
||
|
||
def _check_m2m_tables(self, cursor, db_name, model):
|
||
"""检查多对多关系表"""
|
||
issues = []
|
||
|
||
for field in model._meta.get_fields():
|
||
if isinstance(field, models.ManyToManyField):
|
||
# 获取多对多关系表名
|
||
try:
|
||
m2m_table = field.remote_field.through._meta.db_table
|
||
except AttributeError:
|
||
# 如果是自动创建的多对多关系,跳过
|
||
continue
|
||
|
||
# 检查表是否存在
|
||
if not self._check_table_exists(cursor, db_name, m2m_table):
|
||
issue = {
|
||
'type': 'missing_m2m_table',
|
||
'app': model._meta.app_label,
|
||
'model': model.__name__,
|
||
'table': m2m_table,
|
||
'field': field.name,
|
||
'message': f'多对多关系表 {m2m_table} 不存在'
|
||
}
|
||
issues.append(issue)
|
||
self.stdout.write(self.style.ERROR(f' ❌ 多对多关系表 {m2m_table} 不存在'))
|
||
else:
|
||
self.stdout.write(self.style.SUCCESS(f' ✅ 多对多关系表 {m2m_table} 存在'))
|
||
|
||
return issues
|