309 lines
12 KiB
Python
309 lines
12 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
import re
|
||
import os
|
||
|
||
# 读取文件内容
|
||
def read_file(file_path):
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
return f.read()
|
||
|
||
# 解析表结构,保留原始格式
|
||
def parse_table_structure(create_table_sql):
|
||
"""解析CREATE TABLE语句,返回表结构信息,保留原始格式"""
|
||
table_info = {}
|
||
|
||
# 提取完整表名(包括反引号)
|
||
table_name_pattern = re.compile(r'CREATE TABLE (.*?)\s*\(', re.IGNORECASE)
|
||
table_name_match = table_name_pattern.search(create_table_sql)
|
||
if table_name_match:
|
||
table_info['full_table_name'] = table_name_match.group(1).strip()
|
||
# 提取不带反引号的表名用于比较
|
||
bare_table_name = table_info['full_table_name'].strip('`')
|
||
table_info['table_name'] = bare_table_name
|
||
|
||
# 提取列信息,保留完整行格式
|
||
columns = []
|
||
|
||
# 找到括号内的内容
|
||
start_pos = create_table_sql.find('(') + 1
|
||
end_pos = create_table_sql.rfind(')')
|
||
table_body = create_table_sql[start_pos:end_pos]
|
||
|
||
# 处理每一行
|
||
lines = table_body.split('\n')
|
||
for line in lines:
|
||
line = line.strip()
|
||
# 跳过主键、索引等
|
||
if line.startswith('PRIMARY KEY') or line.startswith('UNIQUE KEY') or line.startswith('KEY') or line.startswith('CONSTRAINT'):
|
||
continue
|
||
if not line or line == ')':
|
||
continue
|
||
|
||
# 提取列名(带反引号或不带)
|
||
# 匹配第一个空格前的内容作为列名
|
||
column_name_end = line.find(' ')
|
||
if column_name_end == -1:
|
||
continue
|
||
column_name = line[:column_name_end]
|
||
|
||
# 提取不带反引号的列名用于比较
|
||
bare_column_name = column_name.strip('`')
|
||
|
||
columns.append({
|
||
'name': bare_column_name,
|
||
'full_name': column_name,
|
||
'full_line': line
|
||
})
|
||
|
||
table_info['columns'] = columns
|
||
table_info['create_sql'] = create_table_sql
|
||
|
||
return table_info
|
||
|
||
# 从niushop_database.sql中提取列定义,用于更新
|
||
|
||
def extract_column_definitions(sql_content):
|
||
"""从SQL文件中提取列定义,用于比较"""
|
||
column_definitions = {}
|
||
|
||
# 先匹配所有CREATE TABLE语句,正确处理表结构
|
||
# 使用更严格的匹配,确保只匹配到真正的表定义
|
||
create_table_pattern = re.compile(r'CREATE TABLE `?(\w+)`?\s*\(([\s\S]*?)\)\s*(?:ENGINE|CHARACTER|COLLATE|COMMENT|;)', re.IGNORECASE | re.DOTALL)
|
||
create_table_matches = create_table_pattern.findall(sql_content)
|
||
|
||
for table_name, table_body in create_table_matches:
|
||
columns = {}
|
||
|
||
# 处理表体中的列定义
|
||
lines = table_body.split('\n')
|
||
for line in lines:
|
||
line = line.strip()
|
||
# 跳过主键、索引等
|
||
if line.startswith('PRIMARY KEY') or line.startswith('UNIQUE KEY') or line.startswith('KEY') or line.startswith('CONSTRAINT'):
|
||
continue
|
||
if not line or line == ')':
|
||
continue
|
||
# 跳过ALTER TABLE语句
|
||
if line.startswith('ALTER TABLE'):
|
||
continue
|
||
# 跳过ENGINE等表选项
|
||
if line.startswith('ENGINE') or line.startswith('CHARACTER') or line.startswith('COLLATE') or line.startswith('COMMENT'):
|
||
continue
|
||
# 跳过CREATE TABLE语句
|
||
if line.startswith('CREATE TABLE'):
|
||
continue
|
||
|
||
# 提取列名
|
||
column_name_end = line.find(' ')
|
||
if column_name_end == -1:
|
||
continue
|
||
column_name = line[:column_name_end].strip('`')
|
||
|
||
# 提取完整列定义
|
||
# 去掉末尾的逗号
|
||
col_def = line
|
||
if col_def.endswith(','):
|
||
col_def = col_def[:-1]
|
||
|
||
columns[column_name] = col_def.strip()
|
||
|
||
column_definitions[table_name] = columns
|
||
|
||
# 处理单独的ALTER TABLE语句,提取添加的列
|
||
alter_table_pattern = re.compile(r'ALTER TABLE `?(\w+)`?\s*ADD COLUMN\s*(.*?);', re.IGNORECASE | re.DOTALL)
|
||
alter_table_matches = alter_table_pattern.findall(sql_content)
|
||
|
||
for table_name, column_def in alter_table_matches:
|
||
if table_name in column_definitions:
|
||
# 提取列名
|
||
column_name_end = column_def.find(' ')
|
||
if column_name_end != -1:
|
||
column_name = column_def[:column_name_end].strip('`')
|
||
# 确保列定义是有效的
|
||
if not column_def.startswith('ALTER TABLE') and not column_def.startswith('ENGINE'):
|
||
column_definitions[table_name][column_name] = column_def.strip()
|
||
|
||
return column_definitions
|
||
|
||
# 从init.sql中提取表结构
|
||
def extract_init_table_structures(sql_content):
|
||
"""从init.sql中提取表结构,保留完整格式"""
|
||
table_structures = {}
|
||
|
||
# 匹配所有CREATE TABLE语句
|
||
create_table_pattern = re.compile(r'(DROP TABLE IF EXISTS `?(\w+)`?;[\s\S]*?CREATE TABLE `?(\w+)`?\s*\(([\s\S]*?)\)\s*ENGINE\s*=[\s\S]*?;)\s*', re.IGNORECASE | re.DOTALL)
|
||
create_table_matches = create_table_pattern.findall(sql_content)
|
||
|
||
for full_sql, drop_table_name, create_table_name, table_body in create_table_matches:
|
||
# 提取列信息
|
||
columns = []
|
||
lines = table_body.split('\n')
|
||
for line in lines:
|
||
line = line.strip()
|
||
# 跳过主键、索引等
|
||
if line.startswith('PRIMARY KEY') or line.startswith('UNIQUE KEY') or line.startswith('KEY') or line.startswith('CONSTRAINT'):
|
||
continue
|
||
if not line or line == ')':
|
||
continue
|
||
|
||
# 提取列名
|
||
column_name_end = line.find(' ')
|
||
if column_name_end == -1:
|
||
continue
|
||
column_name = line[:column_name_end].strip('`')
|
||
|
||
# 提取完整列定义行
|
||
columns.append({
|
||
'name': column_name,
|
||
'full_line': line
|
||
})
|
||
|
||
table_structures[create_table_name] = {
|
||
'full_sql': full_sql,
|
||
'columns': columns
|
||
}
|
||
|
||
return table_structures
|
||
|
||
# 更新init.sql中的表结构
|
||
def update_table_structure(init_table, niushop_columns):
|
||
"""更新表结构,只添加缺失的列,保留原有格式"""
|
||
init_columns_dict = {col['name']: col for col in init_table['columns']}
|
||
|
||
# 检查是否有缺失的列
|
||
missing_columns = []
|
||
for col_name, niushop_col_def in niushop_columns.items():
|
||
if col_name not in init_columns_dict:
|
||
missing_columns.append(niushop_col_def)
|
||
|
||
if not missing_columns:
|
||
return init_table['full_sql'] # 没有变化,返回原SQL
|
||
|
||
# 构建新的CREATE TABLE语句
|
||
create_sql = init_table['full_sql']
|
||
|
||
# 找到最后一个列定义行的位置
|
||
# 查找最后一个列定义行(在PRIMARY KEY之前)
|
||
last_col_pos = create_sql.rfind('PRIMARY KEY')
|
||
if last_col_pos == -1:
|
||
last_col_pos = create_sql.rfind(')')
|
||
|
||
# 查找最后一个列定义行的结束位置
|
||
# 从last_col_pos往前找换行符
|
||
newline_pos = create_sql.rfind('\n', 0, last_col_pos)
|
||
if newline_pos == -1:
|
||
newline_pos = create_sql.rfind('\n')
|
||
|
||
# 提取前面的内容
|
||
prefix = create_sql[:newline_pos].rstrip()
|
||
suffix = create_sql[newline_pos:]
|
||
|
||
# 添加缺失的列
|
||
for col_def in missing_columns:
|
||
# 保持与原有列相同的缩进
|
||
# 从原有列中获取缩进
|
||
if init_table['columns']:
|
||
first_col_line = init_table['columns'][0]['full_line']
|
||
indent = first_col_line[:len(first_col_line) - len(first_col_line.lstrip())]
|
||
else:
|
||
indent = ' '
|
||
|
||
# 添加新列,带缩进和逗号
|
||
prefix += f'\n{indent}{col_def},'
|
||
|
||
# 重新组合CREATE TABLE语句
|
||
new_create_sql = prefix + suffix
|
||
|
||
return new_create_sql
|
||
|
||
# 主函数
|
||
def main():
|
||
# 文件路径
|
||
niushop_file = './docs/db/niushop_database.sql'
|
||
init_file = './docker/mysql/init/init.sql'
|
||
upgrade_file = './upgrade.sql'
|
||
|
||
# 读取文件内容
|
||
print("正在读取文件内容...")
|
||
niushop_content = read_file(niushop_file)
|
||
init_content = read_file(init_file)
|
||
|
||
# 提取niushop_database.sql中的列定义
|
||
print("正在提取niushop_database.sql中的列定义...")
|
||
niushop_columns = extract_column_definitions(niushop_content)
|
||
print(f"从niushop_database.sql提取到 {len(niushop_columns)} 个表的列定义")
|
||
|
||
# 提取init.sql中的表结构
|
||
print("正在提取init.sql中的表结构...")
|
||
init_tables = extract_init_table_structures(init_content)
|
||
print(f"从init.sql提取到 {len(init_tables)} 个表结构")
|
||
|
||
# 比较表结构差异,生成更新
|
||
print("正在比较表结构差异...")
|
||
upgrade_statements = []
|
||
updated_init_content = init_content
|
||
|
||
# 遍历init.sql中的所有表
|
||
for init_table_name, init_table in init_tables.items():
|
||
print(f"正在处理表 {init_table_name}...")
|
||
|
||
# 去掉lucky_前缀,与niushop_database.sql中的表名进行比较
|
||
original_table_name = init_table_name
|
||
if original_table_name.startswith('lucky_'):
|
||
original_table_name = original_table_name[6:] # 去掉lucky_前缀
|
||
|
||
# 检查niushop_database.sql中是否有对应的表
|
||
if original_table_name in niushop_columns:
|
||
print(f"找到对应的表 {original_table_name}...")
|
||
niushop_cols = niushop_columns[original_table_name]
|
||
|
||
# 检查缺失的列
|
||
init_cols_dict = {col['name']: col for col in init_table['columns']}
|
||
missing_columns = []
|
||
for col_name, niushop_col_def in niushop_cols.items():
|
||
if col_name not in init_cols_dict:
|
||
missing_columns.append({
|
||
'name': col_name,
|
||
'definition': niushop_col_def
|
||
})
|
||
|
||
# 如果有缺失的列,生成ALTER语句并更新init.sql
|
||
if missing_columns:
|
||
print(f"表 {init_table_name} 缺少 {len(missing_columns)} 个列...")
|
||
|
||
# 生成ALTER TABLE语句
|
||
for col in missing_columns:
|
||
# 提取列名(带反引号)
|
||
col_name_in_def = col['definition'].split(' ')[0]
|
||
alter_stmt = f"ALTER TABLE `{init_table_name}` ADD COLUMN {col['definition']};"
|
||
upgrade_statements.append(alter_stmt)
|
||
|
||
# 更新init.sql中的表结构
|
||
new_create_sql = update_table_structure(init_table, niushop_cols)
|
||
updated_init_content = updated_init_content.replace(init_table['full_sql'], new_create_sql)
|
||
|
||
# 写入升级脚本
|
||
print("正在写入升级脚本...")
|
||
with open(upgrade_file, 'w', encoding='utf-8') as f:
|
||
f.write("-- 数据库升级脚本\n")
|
||
f.write("-- 生成时间: 自动生成\n")
|
||
f.write("-- 描述: 根据niushop_database.sql更新init.sql的表结构\n\n")
|
||
f.write("USE shop_mallnew;\n\n")
|
||
|
||
for stmt in upgrade_statements:
|
||
f.write(f"{stmt}\n")
|
||
|
||
# 更新init.sql文件
|
||
print("正在更新init.sql文件...")
|
||
with open(init_file, 'w', encoding='utf-8') as f:
|
||
f.write(updated_init_content)
|
||
|
||
print(f"升级脚本已生成,共 {len(upgrade_statements)} 条ALTER语句")
|
||
print(f"升级脚本路径: {upgrade_file}")
|
||
|
||
if __name__ == '__main__':
|
||
main()
|
||
|