#!/usr/bin/env python3 # -*- coding: utf-8 -*- import re import os # 读取文件内容 def read_file(file_path): with open(file_path, 'r', encoding='utf-8') as f: return f.read() # 解析表结构,保留原始格式 def parse_table_structure(create_table_sql): """解析CREATE TABLE语句,返回表结构信息,保留原始格式""" table_info = {} # 提取完整表名(包括反引号) table_name_pattern = re.compile(r'CREATE TABLE (.*?)\s*\(', re.IGNORECASE) table_name_match = table_name_pattern.search(create_table_sql) if table_name_match: table_info['full_table_name'] = table_name_match.group(1).strip() # 提取不带反引号的表名用于比较 bare_table_name = table_info['full_table_name'].strip('`') table_info['table_name'] = bare_table_name # 提取列信息,保留完整行格式 columns = [] # 找到括号内的内容 start_pos = create_table_sql.find('(') + 1 end_pos = create_table_sql.rfind(')') table_body = create_table_sql[start_pos:end_pos] # 处理每一行 lines = table_body.split('\n') for line in lines: line = line.strip() # 跳过主键、索引等 if line.startswith('PRIMARY KEY') or line.startswith('UNIQUE KEY') or line.startswith('KEY') or line.startswith('CONSTRAINT'): continue if not line or line == ')': continue # 提取列名(带反引号或不带) # 匹配第一个空格前的内容作为列名 column_name_end = line.find(' ') if column_name_end == -1: continue column_name = line[:column_name_end] # 提取不带反引号的列名用于比较 bare_column_name = column_name.strip('`') columns.append({ 'name': bare_column_name, 'full_name': column_name, 'full_line': line }) table_info['columns'] = columns table_info['create_sql'] = create_table_sql return table_info # 从niushop_database.sql中提取列定义,用于更新 def extract_column_definitions(sql_content): """从SQL文件中提取列定义,用于比较""" column_definitions = {} # 先匹配所有CREATE TABLE语句,正确处理表结构 # 使用更严格的匹配,确保只匹配到真正的表定义 create_table_pattern = re.compile(r'CREATE TABLE `?(\w+)`?\s*\(([\s\S]*?)\)\s*(?:ENGINE|CHARACTER|COLLATE|COMMENT|;)', re.IGNORECASE | re.DOTALL) create_table_matches = create_table_pattern.findall(sql_content) for table_name, table_body in create_table_matches: columns = {} # 处理表体中的列定义 lines = table_body.split('\n') for line in lines: line = line.strip() # 跳过主键、索引等 if line.startswith('PRIMARY KEY') or line.startswith('UNIQUE KEY') or line.startswith('KEY') or line.startswith('CONSTRAINT'): continue if not line or line == ')': continue # 跳过ALTER TABLE语句 if line.startswith('ALTER TABLE'): continue # 跳过ENGINE等表选项 if line.startswith('ENGINE') or line.startswith('CHARACTER') or line.startswith('COLLATE') or line.startswith('COMMENT'): continue # 跳过CREATE TABLE语句 if line.startswith('CREATE TABLE'): continue # 提取列名 column_name_end = line.find(' ') if column_name_end == -1: continue column_name = line[:column_name_end].strip('`') # 提取完整列定义 # 去掉末尾的逗号 col_def = line if col_def.endswith(','): col_def = col_def[:-1] columns[column_name] = col_def.strip() column_definitions[table_name] = columns # 处理单独的ALTER TABLE语句,提取添加的列 alter_table_pattern = re.compile(r'ALTER TABLE `?(\w+)`?\s*ADD COLUMN\s*(.*?);', re.IGNORECASE | re.DOTALL) alter_table_matches = alter_table_pattern.findall(sql_content) for table_name, column_def in alter_table_matches: if table_name in column_definitions: # 提取列名 column_name_end = column_def.find(' ') if column_name_end != -1: column_name = column_def[:column_name_end].strip('`') # 确保列定义是有效的 if not column_def.startswith('ALTER TABLE') and not column_def.startswith('ENGINE'): column_definitions[table_name][column_name] = column_def.strip() return column_definitions # 从init.sql中提取表结构 def extract_init_table_structures(sql_content): """从init.sql中提取表结构,保留完整格式""" table_structures = {} # 匹配所有CREATE TABLE语句 create_table_pattern = re.compile(r'(DROP TABLE IF EXISTS `?(\w+)`?;[\s\S]*?CREATE TABLE `?(\w+)`?\s*\(([\s\S]*?)\)\s*ENGINE\s*=[\s\S]*?;)\s*', re.IGNORECASE | re.DOTALL) create_table_matches = create_table_pattern.findall(sql_content) for full_sql, drop_table_name, create_table_name, table_body in create_table_matches: # 提取列信息 columns = [] lines = table_body.split('\n') for line in lines: line = line.strip() # 跳过主键、索引等 if line.startswith('PRIMARY KEY') or line.startswith('UNIQUE KEY') or line.startswith('KEY') or line.startswith('CONSTRAINT'): continue if not line or line == ')': continue # 提取列名 column_name_end = line.find(' ') if column_name_end == -1: continue column_name = line[:column_name_end].strip('`') # 提取完整列定义行 columns.append({ 'name': column_name, 'full_line': line }) table_structures[create_table_name] = { 'full_sql': full_sql, 'columns': columns } return table_structures # 更新init.sql中的表结构 def update_table_structure(init_table, niushop_columns): """更新表结构,只添加缺失的列,保留原有格式""" init_columns_dict = {col['name']: col for col in init_table['columns']} # 检查是否有缺失的列 missing_columns = [] for col_name, niushop_col_def in niushop_columns.items(): if col_name not in init_columns_dict: missing_columns.append(niushop_col_def) if not missing_columns: return init_table['full_sql'] # 没有变化,返回原SQL # 构建新的CREATE TABLE语句 create_sql = init_table['full_sql'] # 找到最后一个列定义行的位置 # 查找最后一个列定义行(在PRIMARY KEY之前) last_col_pos = create_sql.rfind('PRIMARY KEY') if last_col_pos == -1: last_col_pos = create_sql.rfind(')') # 查找最后一个列定义行的结束位置 # 从last_col_pos往前找换行符 newline_pos = create_sql.rfind('\n', 0, last_col_pos) if newline_pos == -1: newline_pos = create_sql.rfind('\n') # 提取前面的内容 prefix = create_sql[:newline_pos].rstrip() suffix = create_sql[newline_pos:] # 添加缺失的列 for col_def in missing_columns: # 保持与原有列相同的缩进 # 从原有列中获取缩进 if init_table['columns']: first_col_line = init_table['columns'][0]['full_line'] indent = first_col_line[:len(first_col_line) - len(first_col_line.lstrip())] else: indent = ' ' # 添加新列,带缩进和逗号 prefix += f'\n{indent}{col_def},' # 重新组合CREATE TABLE语句 new_create_sql = prefix + suffix return new_create_sql # 主函数 def main(): # 文件路径 niushop_file = './docs/db/niushop_database.sql' init_file = './docker/mysql/init/init.sql' upgrade_file = './upgrade.sql' # 读取文件内容 print("正在读取文件内容...") niushop_content = read_file(niushop_file) init_content = read_file(init_file) # 提取niushop_database.sql中的列定义 print("正在提取niushop_database.sql中的列定义...") niushop_columns = extract_column_definitions(niushop_content) print(f"从niushop_database.sql提取到 {len(niushop_columns)} 个表的列定义") # 提取init.sql中的表结构 print("正在提取init.sql中的表结构...") init_tables = extract_init_table_structures(init_content) print(f"从init.sql提取到 {len(init_tables)} 个表结构") # 比较表结构差异,生成更新 print("正在比较表结构差异...") upgrade_statements = [] updated_init_content = init_content # 遍历init.sql中的所有表 for init_table_name, init_table in init_tables.items(): print(f"正在处理表 {init_table_name}...") # 去掉lucky_前缀,与niushop_database.sql中的表名进行比较 original_table_name = init_table_name if original_table_name.startswith('lucky_'): original_table_name = original_table_name[6:] # 去掉lucky_前缀 # 检查niushop_database.sql中是否有对应的表 if original_table_name in niushop_columns: print(f"找到对应的表 {original_table_name}...") niushop_cols = niushop_columns[original_table_name] # 检查缺失的列 init_cols_dict = {col['name']: col for col in init_table['columns']} missing_columns = [] for col_name, niushop_col_def in niushop_cols.items(): if col_name not in init_cols_dict: missing_columns.append({ 'name': col_name, 'definition': niushop_col_def }) # 如果有缺失的列,生成ALTER语句并更新init.sql if missing_columns: print(f"表 {init_table_name} 缺少 {len(missing_columns)} 个列...") # 生成ALTER TABLE语句 for col in missing_columns: # 提取列名(带反引号) col_name_in_def = col['definition'].split(' ')[0] alter_stmt = f"ALTER TABLE `{init_table_name}` ADD COLUMN {col['definition']};" upgrade_statements.append(alter_stmt) # 更新init.sql中的表结构 new_create_sql = update_table_structure(init_table, niushop_cols) updated_init_content = updated_init_content.replace(init_table['full_sql'], new_create_sql) # 写入升级脚本 print("正在写入升级脚本...") with open(upgrade_file, 'w', encoding='utf-8') as f: f.write("-- 数据库升级脚本\n") f.write("-- 生成时间: 自动生成\n") f.write("-- 描述: 根据niushop_database.sql更新init.sql的表结构\n\n") f.write("USE shop_mallnew;\n\n") for stmt in upgrade_statements: f.write(f"{stmt}\n") # 更新init.sql文件 print("正在更新init.sql文件...") with open(init_file, 'w', encoding='utf-8') as f: f.write(updated_init_content) print(f"升级脚本已生成,共 {len(upgrade_statements)} 条ALTER语句") print(f"升级脚本路径: {upgrade_file}") if __name__ == '__main__': main()