Files
shop-platform/update_table_structure.py

309 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import re
import os
# 读取文件内容
def read_file(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
# 解析表结构,保留原始格式
def parse_table_structure(create_table_sql):
"""解析CREATE TABLE语句返回表结构信息保留原始格式"""
table_info = {}
# 提取完整表名(包括反引号)
table_name_pattern = re.compile(r'CREATE TABLE (.*?)\s*\(', re.IGNORECASE)
table_name_match = table_name_pattern.search(create_table_sql)
if table_name_match:
table_info['full_table_name'] = table_name_match.group(1).strip()
# 提取不带反引号的表名用于比较
bare_table_name = table_info['full_table_name'].strip('`')
table_info['table_name'] = bare_table_name
# 提取列信息,保留完整行格式
columns = []
# 找到括号内的内容
start_pos = create_table_sql.find('(') + 1
end_pos = create_table_sql.rfind(')')
table_body = create_table_sql[start_pos:end_pos]
# 处理每一行
lines = table_body.split('\n')
for line in lines:
line = line.strip()
# 跳过主键、索引等
if line.startswith('PRIMARY KEY') or line.startswith('UNIQUE KEY') or line.startswith('KEY') or line.startswith('CONSTRAINT'):
continue
if not line or line == ')':
continue
# 提取列名(带反引号或不带)
# 匹配第一个空格前的内容作为列名
column_name_end = line.find(' ')
if column_name_end == -1:
continue
column_name = line[:column_name_end]
# 提取不带反引号的列名用于比较
bare_column_name = column_name.strip('`')
columns.append({
'name': bare_column_name,
'full_name': column_name,
'full_line': line
})
table_info['columns'] = columns
table_info['create_sql'] = create_table_sql
return table_info
# 从niushop_database.sql中提取列定义用于更新
def extract_column_definitions(sql_content):
"""从SQL文件中提取列定义用于比较"""
column_definitions = {}
# 先匹配所有CREATE TABLE语句正确处理表结构
# 使用更严格的匹配,确保只匹配到真正的表定义
create_table_pattern = re.compile(r'CREATE TABLE `?(\w+)`?\s*\(([\s\S]*?)\)\s*(?:ENGINE|CHARACTER|COLLATE|COMMENT|;)', re.IGNORECASE | re.DOTALL)
create_table_matches = create_table_pattern.findall(sql_content)
for table_name, table_body in create_table_matches:
columns = {}
# 处理表体中的列定义
lines = table_body.split('\n')
for line in lines:
line = line.strip()
# 跳过主键、索引等
if line.startswith('PRIMARY KEY') or line.startswith('UNIQUE KEY') or line.startswith('KEY') or line.startswith('CONSTRAINT'):
continue
if not line or line == ')':
continue
# 跳过ALTER TABLE语句
if line.startswith('ALTER TABLE'):
continue
# 跳过ENGINE等表选项
if line.startswith('ENGINE') or line.startswith('CHARACTER') or line.startswith('COLLATE') or line.startswith('COMMENT'):
continue
# 跳过CREATE TABLE语句
if line.startswith('CREATE TABLE'):
continue
# 提取列名
column_name_end = line.find(' ')
if column_name_end == -1:
continue
column_name = line[:column_name_end].strip('`')
# 提取完整列定义
# 去掉末尾的逗号
col_def = line
if col_def.endswith(','):
col_def = col_def[:-1]
columns[column_name] = col_def.strip()
column_definitions[table_name] = columns
# 处理单独的ALTER TABLE语句提取添加的列
alter_table_pattern = re.compile(r'ALTER TABLE `?(\w+)`?\s*ADD COLUMN\s*(.*?);', re.IGNORECASE | re.DOTALL)
alter_table_matches = alter_table_pattern.findall(sql_content)
for table_name, column_def in alter_table_matches:
if table_name in column_definitions:
# 提取列名
column_name_end = column_def.find(' ')
if column_name_end != -1:
column_name = column_def[:column_name_end].strip('`')
# 确保列定义是有效的
if not column_def.startswith('ALTER TABLE') and not column_def.startswith('ENGINE'):
column_definitions[table_name][column_name] = column_def.strip()
return column_definitions
# 从init.sql中提取表结构
def extract_init_table_structures(sql_content):
"""从init.sql中提取表结构保留完整格式"""
table_structures = {}
# 匹配所有CREATE TABLE语句
create_table_pattern = re.compile(r'(DROP TABLE IF EXISTS `?(\w+)`?;[\s\S]*?CREATE TABLE `?(\w+)`?\s*\(([\s\S]*?)\)\s*ENGINE\s*=[\s\S]*?;)\s*', re.IGNORECASE | re.DOTALL)
create_table_matches = create_table_pattern.findall(sql_content)
for full_sql, drop_table_name, create_table_name, table_body in create_table_matches:
# 提取列信息
columns = []
lines = table_body.split('\n')
for line in lines:
line = line.strip()
# 跳过主键、索引等
if line.startswith('PRIMARY KEY') or line.startswith('UNIQUE KEY') or line.startswith('KEY') or line.startswith('CONSTRAINT'):
continue
if not line or line == ')':
continue
# 提取列名
column_name_end = line.find(' ')
if column_name_end == -1:
continue
column_name = line[:column_name_end].strip('`')
# 提取完整列定义行
columns.append({
'name': column_name,
'full_line': line
})
table_structures[create_table_name] = {
'full_sql': full_sql,
'columns': columns
}
return table_structures
# 更新init.sql中的表结构
def update_table_structure(init_table, niushop_columns):
"""更新表结构,只添加缺失的列,保留原有格式"""
init_columns_dict = {col['name']: col for col in init_table['columns']}
# 检查是否有缺失的列
missing_columns = []
for col_name, niushop_col_def in niushop_columns.items():
if col_name not in init_columns_dict:
missing_columns.append(niushop_col_def)
if not missing_columns:
return init_table['full_sql'] # 没有变化返回原SQL
# 构建新的CREATE TABLE语句
create_sql = init_table['full_sql']
# 找到最后一个列定义行的位置
# 查找最后一个列定义行在PRIMARY KEY之前
last_col_pos = create_sql.rfind('PRIMARY KEY')
if last_col_pos == -1:
last_col_pos = create_sql.rfind(')')
# 查找最后一个列定义行的结束位置
# 从last_col_pos往前找换行符
newline_pos = create_sql.rfind('\n', 0, last_col_pos)
if newline_pos == -1:
newline_pos = create_sql.rfind('\n')
# 提取前面的内容
prefix = create_sql[:newline_pos].rstrip()
suffix = create_sql[newline_pos:]
# 添加缺失的列
for col_def in missing_columns:
# 保持与原有列相同的缩进
# 从原有列中获取缩进
if init_table['columns']:
first_col_line = init_table['columns'][0]['full_line']
indent = first_col_line[:len(first_col_line) - len(first_col_line.lstrip())]
else:
indent = ' '
# 添加新列,带缩进和逗号
prefix += f'\n{indent}{col_def},'
# 重新组合CREATE TABLE语句
new_create_sql = prefix + suffix
return new_create_sql
# 主函数
def main():
# 文件路径
niushop_file = './docs/db/niushop_database.sql'
init_file = './docker/mysql/init/init.sql'
upgrade_file = './upgrade.sql'
# 读取文件内容
print("正在读取文件内容...")
niushop_content = read_file(niushop_file)
init_content = read_file(init_file)
# 提取niushop_database.sql中的列定义
print("正在提取niushop_database.sql中的列定义...")
niushop_columns = extract_column_definitions(niushop_content)
print(f"从niushop_database.sql提取到 {len(niushop_columns)} 个表的列定义")
# 提取init.sql中的表结构
print("正在提取init.sql中的表结构...")
init_tables = extract_init_table_structures(init_content)
print(f"从init.sql提取到 {len(init_tables)} 个表结构")
# 比较表结构差异,生成更新
print("正在比较表结构差异...")
upgrade_statements = []
updated_init_content = init_content
# 遍历init.sql中的所有表
for init_table_name, init_table in init_tables.items():
print(f"正在处理表 {init_table_name}...")
# 去掉lucky_前缀与niushop_database.sql中的表名进行比较
original_table_name = init_table_name
if original_table_name.startswith('lucky_'):
original_table_name = original_table_name[6:] # 去掉lucky_前缀
# 检查niushop_database.sql中是否有对应的表
if original_table_name in niushop_columns:
print(f"找到对应的表 {original_table_name}...")
niushop_cols = niushop_columns[original_table_name]
# 检查缺失的列
init_cols_dict = {col['name']: col for col in init_table['columns']}
missing_columns = []
for col_name, niushop_col_def in niushop_cols.items():
if col_name not in init_cols_dict:
missing_columns.append({
'name': col_name,
'definition': niushop_col_def
})
# 如果有缺失的列生成ALTER语句并更新init.sql
if missing_columns:
print(f"{init_table_name} 缺少 {len(missing_columns)} 个列...")
# 生成ALTER TABLE语句
for col in missing_columns:
# 提取列名(带反引号)
col_name_in_def = col['definition'].split(' ')[0]
alter_stmt = f"ALTER TABLE `{init_table_name}` ADD COLUMN {col['definition']};"
upgrade_statements.append(alter_stmt)
# 更新init.sql中的表结构
new_create_sql = update_table_structure(init_table, niushop_cols)
updated_init_content = updated_init_content.replace(init_table['full_sql'], new_create_sql)
# 写入升级脚本
print("正在写入升级脚本...")
with open(upgrade_file, 'w', encoding='utf-8') as f:
f.write("-- 数据库升级脚本\n")
f.write("-- 生成时间: 自动生成\n")
f.write("-- 描述: 根据niushop_database.sql更新init.sql的表结构\n\n")
f.write("USE shop_mallnew;\n\n")
for stmt in upgrade_statements:
f.write(f"{stmt}\n")
# 更新init.sql文件
print("正在更新init.sql文件...")
with open(init_file, 'w', encoding='utf-8') as f:
f.write(updated_init_content)
print(f"升级脚本已生成,共 {len(upgrade_statements)} 条ALTER语句")
print(f"升级脚本路径: {upgrade_file}")
if __name__ == '__main__':
main()