Files
shop-platform/update_table_structure.py

451 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import re
import os
# 读取文件内容
def read_file(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
# 解析表结构,保留原始格式
def parse_table_structure(create_table_sql):
"""解析CREATE TABLE语句返回表结构信息保留原始格式"""
table_info = {}
# 提取完整表名(包括反引号)
table_name_pattern = re.compile(r'CREATE TABLE (.*?)\s*\(', re.IGNORECASE)
table_name_match = table_name_pattern.search(create_table_sql)
if table_name_match:
table_info['full_table_name'] = table_name_match.group(1).strip()
# 提取不带反引号的表名用于比较
bare_table_name = table_info['full_table_name'].strip('`')
table_info['table_name'] = bare_table_name
# 提取列信息,保留完整行格式
columns = []
# 找到括号内的内容
start_pos = create_table_sql.find('(') + 1
end_pos = create_table_sql.rfind(')')
table_body = create_table_sql[start_pos:end_pos]
# 处理每一行
lines = table_body.split('\n')
for line in lines:
line = line.strip()
# 跳过主键、索引等
if line.startswith('PRIMARY KEY') or line.startswith('UNIQUE KEY') or line.startswith('KEY') or line.startswith('CONSTRAINT'):
continue
if not line or line == ')':
continue
# 提取列名(带反引号或不带)
# 匹配第一个空格前的内容作为列名
column_name_end = line.find(' ')
if column_name_end == -1:
continue
column_name = line[:column_name_end]
# 提取不带反引号的列名用于比较
bare_column_name = column_name.strip('`')
columns.append({
'name': bare_column_name,
'full_name': column_name,
'full_line': line
})
table_info['columns'] = columns
table_info['create_sql'] = create_table_sql
return table_info
# 从niushop_database.sql中提取列定义用于更新
def extract_column_definitions(sql_content):
"""从SQL文件中提取列定义用于比较"""
column_definitions = {}
# 先匹配所有CREATE TABLE语句正确处理表结构
# 使用更严格的匹配,确保只匹配到真正的表定义
create_table_pattern = re.compile(r'CREATE TABLE `?(\w+)`?\s*\(([\s\S]*?)\)\s*(?:ENGINE|CHARACTER|COLLATE|COMMENT|;)', re.IGNORECASE | re.DOTALL)
create_table_matches = create_table_pattern.findall(sql_content)
for table_name, table_body in create_table_matches:
columns = {}
# 处理表体中的列定义
lines = table_body.split('\n')
for line in lines:
line = line.strip()
# 跳过主键、索引等
if line.startswith('PRIMARY KEY') or line.startswith('UNIQUE KEY') or line.startswith('KEY') or line.startswith('CONSTRAINT'):
continue
if not line or line == ')':
continue
# 跳过ALTER TABLE语句
if line.startswith('ALTER TABLE'):
continue
# 跳过ENGINE等表选项
if line.startswith('ENGINE') or line.startswith('CHARACTER') or line.startswith('COLLATE') or line.startswith('COMMENT'):
continue
# 跳过CREATE TABLE语句
if line.startswith('CREATE TABLE'):
continue
# 提取列名
column_name_end = line.find(' ')
if column_name_end == -1:
continue
column_name = line[:column_name_end].strip('`')
# 提取完整列定义
# 去掉末尾的逗号
col_def = line
if col_def.endswith(','):
col_def = col_def[:-1]
columns[column_name] = col_def.strip()
column_definitions[table_name] = columns
# 处理单独的ALTER TABLE语句提取添加的列
alter_table_pattern = re.compile(r'ALTER TABLE `?(\w+)`?\s*ADD COLUMN\s*(.*?);', re.IGNORECASE | re.DOTALL)
alter_table_matches = alter_table_pattern.findall(sql_content)
for table_name, column_def in alter_table_matches:
if table_name in column_definitions:
# 提取列名
column_name_end = column_def.find(' ')
if column_name_end != -1:
column_name = column_def[:column_name_end].strip('`')
# 确保列定义是有效的
if not column_def.startswith('ALTER TABLE') and not column_def.startswith('ENGINE'):
column_definitions[table_name][column_name] = column_def.strip()
return column_definitions
# 从init.sql中提取表结构
def extract_init_table_structures(sql_content):
"""从init.sql中提取表结构保留完整格式"""
table_structures = {}
# 匹配所有CREATE TABLE语句
create_table_pattern = re.compile(r'(DROP TABLE IF EXISTS `?(\w+)`?;[\s\S]*?CREATE TABLE `?(\w+)`?\s*\(([\s\S]*?)\)\s*ENGINE\s*=[\s\S]*?;)\s*', re.IGNORECASE | re.DOTALL)
create_table_matches = create_table_pattern.findall(sql_content)
for full_sql, drop_table_name, create_table_name, table_body in create_table_matches:
# 提取列信息
columns = []
lines = table_body.split('\n')
for line in lines:
line = line.strip()
# 跳过主键、索引等
if line.startswith('PRIMARY KEY') or line.startswith('UNIQUE KEY') or line.startswith('KEY') or line.startswith('CONSTRAINT'):
continue
if not line or line == ')':
continue
# 提取列名
column_name_end = line.find(' ')
if column_name_end == -1:
continue
column_name = line[:column_name_end].strip('`')
# 提取完整列定义行
columns.append({
'name': column_name,
'full_line': line
})
table_structures[create_table_name] = {
'full_sql': full_sql,
'columns': columns
}
return table_structures
# 更新init.sql中的表结构
def update_table_structure(init_table, niushop_columns):
"""更新表结构,只添加缺失的列,保留原有格式"""
init_columns_dict = {col['name']: col for col in init_table['columns']}
# 检查是否有缺失的列
missing_columns = []
for col_name, niushop_col_def in niushop_columns.items():
if col_name not in init_columns_dict:
missing_columns.append(niushop_col_def)
if not missing_columns:
return init_table['full_sql'] # 没有变化返回原SQL
# 构建新的CREATE TABLE语句
create_sql = init_table['full_sql']
# 找到最后一个列定义行的位置
# 查找最后一个列定义行在PRIMARY KEY之前
last_col_pos = create_sql.rfind('PRIMARY KEY')
if last_col_pos == -1:
last_col_pos = create_sql.rfind(')')
# 查找最后一个列定义行的结束位置
# 从last_col_pos往前找换行符
newline_pos = create_sql.rfind('\n', 0, last_col_pos)
if newline_pos == -1:
newline_pos = create_sql.rfind('\n')
# 提取前面的内容
prefix = create_sql[:newline_pos].rstrip()
suffix = create_sql[newline_pos:]
# 添加缺失的列
for col_def in missing_columns:
# 保持与原有列相同的缩进
# 从原有列中获取缩进
if init_table['columns']:
first_col_line = init_table['columns'][0]['full_line']
indent = first_col_line[:len(first_col_line) - len(first_col_line.lstrip())]
else:
indent = ' '
# 添加新列,带缩进和逗号
prefix += f'\n{indent}{col_def},'
# 重新组合CREATE TABLE语句
new_create_sql = prefix + suffix
return new_create_sql
# 提取INSERT语句
def extract_insert_statements(sql_content):
"""从SQL文件中提取INSERT语句包括完整的列名和值"""
insert_statements = {}
# 匹配所有INSERT语句使用更可靠的方式处理
# 先分割SQL文件为多个语句
statements = sql_content.split(';')
for stmt in statements:
stmt = stmt.strip()
if not stmt:
continue
# 检查是否是INSERT语句
if stmt.upper().startswith('INSERT INTO'):
try:
# 提取表名
table_name_start = stmt.find('INSERT INTO') + len('INSERT INTO')
table_name_end = stmt.find('(')
table_name_part = stmt[table_name_start:table_name_end].strip()
# 去掉反引号
table_name = table_name_part.strip('`')
# 提取列名
columns_start = table_name_end + 1
columns_end = stmt.find(')', columns_start)
columns_str = stmt[columns_start:columns_end].strip()
# 提取VALUES部分
values_start = stmt.find('VALUES') + len('VALUES')
values_part = stmt[values_start:].strip()
# 处理VALUES部分提取所有括号对
values = []
bracket_count = 0
current_value = ''
in_quotes = False
quote_char = ''
for char in values_part:
if char in ('"', "'") and (not current_value or current_value[-1] != '\\'):
if not in_quotes:
in_quotes = True
quote_char = char
elif char == quote_char:
in_quotes = False
if char == '(' and not in_quotes:
bracket_count += 1
current_value = '('
elif char == ')' and not in_quotes:
bracket_count -= 1
current_value += ')'
if bracket_count == 0:
# 去掉括号
value_content = current_value[1:-1].strip()
if value_content:
values.append(value_content)
current_value = ''
elif bracket_count > 0:
current_value += char
if table_name not in insert_statements:
insert_statements[table_name] = {
'columns': columns_str,
'values': []
}
insert_statements[table_name]['values'].extend(values)
except Exception as e:
# 跳过格式不正确的INSERT语句
continue
return insert_statements
# 比较数据差异,生成数据升级语句
def generate_data_upgrade_statements(niushop_inserts, init_inserts, table_mapping):
"""生成数据升级语句"""
upgrade_statements = []
for niushop_table, niushop_insert_info in niushop_inserts.items():
# 检查init.sql中是否有对应的表
if niushop_table in table_mapping:
init_table = table_mapping[niushop_table]
# 获取niushop_database.sql中的列名和值
columns_str = niushop_insert_info.get('columns', '')
niushop_values = niushop_insert_info.get('values', [])
# 获取init.sql中对应表的数据
init_insert_info = init_inserts.get(init_table, {})
init_values = init_insert_info.get('values', [])
# 生成缺失数据的INSERT语句
for niushop_value in niushop_values:
# 过滤掉无效的VALUES如包含列名的VALUES
if 'store_name' in niushop_value and 'site_id' in niushop_value:
continue
# 过滤掉太短的值,可能是解析错误
if len(niushop_value) < 5:
continue
if niushop_value not in init_values:
# 构建完整的INSERT语句包含列名
if columns_str:
insert_stmt = f"INSERT INTO `{init_table}` ({columns_str}) VALUES ({niushop_value});"
else:
# 如果没有列名使用简化的INSERT语句
insert_stmt = f"INSERT INTO `{init_table}` VALUES ({niushop_value});"
upgrade_statements.append(insert_stmt)
return upgrade_statements
# 主函数
def main():
# 文件路径
niushop_file = './docs/db/niushop_database.sql'
init_file = './docker/mysql/init/init.sql'
upgrade_file = './upgrade.sql'
# 读取文件内容
print("正在读取文件内容...")
niushop_content = read_file(niushop_file)
init_content = read_file(init_file)
# 提取niushop_database.sql中的列定义
print("正在提取niushop_database.sql中的列定义...")
niushop_columns = extract_column_definitions(niushop_content)
print(f"从niushop_database.sql提取到 {len(niushop_columns)} 个表的列定义")
# 提取init.sql中的表结构
print("正在提取init.sql中的表结构...")
init_tables = extract_init_table_structures(init_content)
print(f"从init.sql提取到 {len(init_tables)} 个表结构")
# 创建表名映射niushop_table -> init_table
table_mapping = {}
for init_table_name, init_table in init_tables.items():
original_table_name = init_table_name
if original_table_name.startswith('lucky_'):
original_table_name = original_table_name[6:] # 去掉lucky_前缀
table_mapping[original_table_name] = init_table_name
# 比较表结构差异,生成更新
print("正在比较表结构差异...")
structure_upgrade_statements = []
updated_init_content = init_content
# 遍历init.sql中的所有表
for init_table_name, init_table in init_tables.items():
print(f"正在处理表 {init_table_name}...")
# 去掉lucky_前缀与niushop_database.sql中的表名进行比较
original_table_name = init_table_name
if original_table_name.startswith('lucky_'):
original_table_name = original_table_name[6:] # 去掉lucky_前缀
# 检查niushop_database.sql中是否有对应的表
if original_table_name in niushop_columns:
print(f"找到对应的表 {original_table_name}...")
niushop_cols = niushop_columns[original_table_name]
# 检查缺失的列
init_cols_dict = {col['name']: col for col in init_table['columns']}
missing_columns = []
for col_name, niushop_col_def in niushop_cols.items():
if col_name not in init_cols_dict:
missing_columns.append({
'name': col_name,
'definition': niushop_col_def
})
# 如果有缺失的列生成ALTER语句并更新init.sql
if missing_columns:
print(f"{init_table_name} 缺少 {len(missing_columns)} 个列...")
# 生成ALTER TABLE语句
for col in missing_columns:
# 提取列名(带反引号)
col_name_in_def = col['definition'].split(' ')[0]
alter_stmt = f"ALTER TABLE `{init_table_name}` ADD COLUMN {col['definition']};"
structure_upgrade_statements.append(alter_stmt)
# 更新init.sql中的表结构
new_create_sql = update_table_structure(init_table, niushop_cols)
updated_init_content = updated_init_content.replace(init_table['full_sql'], new_create_sql)
# 提取INSERT语句生成数据升级
print("正在提取INSERT语句...")
niushop_inserts = extract_insert_statements(niushop_content)
init_inserts = extract_insert_statements(init_content)
print("正在生成数据升级语句...")
data_upgrade_statements = generate_data_upgrade_statements(niushop_inserts, init_inserts, table_mapping)
# 写入升级脚本
print("正在写入升级脚本...")
with open(upgrade_file, 'w', encoding='utf-8') as f:
f.write("-- 数据库升级脚本\n")
f.write("-- 生成时间: 自动生成\n")
f.write("-- 描述: 根据niushop_database.sql更新init.sql的表结构和数据\n\n")
f.write("USE shop_mallnew;\n\n")
# 写入表结构升级语句
if structure_upgrade_statements:
f.write("-- 表结构升级语句\n")
for stmt in structure_upgrade_statements:
f.write(f"{stmt}\n")
f.write("\n")
# 写入数据升级语句
if data_upgrade_statements:
f.write("-- 数据升级语句\n")
for stmt in data_upgrade_statements:
f.write(f"{stmt}\n")
# 更新init.sql文件
print("正在更新init.sql文件...")
with open(init_file, 'w', encoding='utf-8') as f:
f.write(updated_init_content)
total_statements = len(structure_upgrade_statements) + len(data_upgrade_statements)
print(f"升级脚本已生成,共 {total_statements} 条语句")
print(f"其中表结构升级语句 {len(structure_upgrade_statements)} 条,数据升级语句 {len(data_upgrade_statements)}")
print(f"升级脚本路径: {upgrade_file}")
if __name__ == '__main__':
main()