In [1]:
import subprocess
import re
import os
import pandas as pd
from datetime import datetime
# ==================== 配置 ====================
GNN_SCRIPT = "GNN_MODEL.py"
ML_SCRIPT = "ML_MODEL.py"
GNN_LOG_PATTERN = "training_log_20260618_163520.txt" # GNN 生成的日志
ML_LOG_PATTERN = "ML_training_log_20260618_163606.txt" # ML 生成的日志
OUTPUT_CSV = "model_comparison_results.csv"
OUTPUT_SUMMARY = "model_comparison_summary.txt"
# ==================== 日志提取函数 ====================
def find_latest_log(pattern):
"""找到最新生成的日志文件"""
import glob
files = glob.glob(pattern)
if not files:
return None
# 按修改时间排序,取最新的
latest = max(files, key=os.path.getmtime)
return latest
def extract_metrics_from_log(log_file):
"""
从日志文件中提取所有模型的 MSE 和 MAE
返回字典: {'GCN': {'MSE': 1.02, 'MAE': 0.81}, ...}
"""
if not log_file or not os.path.exists(log_file):
return {}
with open(log_file, 'r', encoding='utf-8') as f:
content = f.read()
results = {}
patterns = [
# GNN 模型格式
r'(GCN|GAT|SAGE)\s*->\s*MSE:\s*([0-9.]+),\s*MAE:\s*([0-9.]+)',
# ML 模型格式
r'(SVM|XGBoost)\s*测试集\s*MSE:\s*([0-9.]+),\s*MAE:\s*([0-9.]+)'
]
for pattern in patterns:
matches = re.findall(pattern, content)
for match in matches:
model_name = match[0]
mse = float(match[1])
mae = float(match[2])
# 统一命名:SAGE -> GraphSAGE
if model_name == 'SAGE':
model_name = 'GraphSAGE'
results[model_name] = {'MSE': mse, 'MAE': mae}
return results
# ==================== 运行脚本 ====================
def run_script(script_name, script_label):
"""运行 Python 脚本并返回是否成功"""
print(f"\n{'='*60}")
print(f"▶ 运行 {script_label}: {script_name}")
print('='*60)
try:
# 使用 subprocess 运行脚本,实时输出到终端
result = subprocess.run(
['python', script_name],
capture_output=False, # 不捕获输出,直接显示在终端
text=True,
timeout=600 # 10分钟超时保护
)
if result.returncode == 0:
print(f"✅ {script_label} 运行成功!")
return True
else:
print(f"❌ {script_label} 运行失败,返回码: {result.returncode}")
return False
except subprocess.TimeoutExpired:
print(f"⏰ {script_label} 运行超时(>10分钟)")
return False
except Exception as e:
print(f"❌ {script_label} 运行出错: {e}")
return False
# ==================== 生成对比报告 ====================
def generate_comparison_report(gnn_results, ml_results):
"""生成并保存对比报告"""
# 合并结果
all_results = {}
all_results.update(gnn_results)
all_results.update(ml_results)
if not all_results:
print("❌ 没有提取到任何指标数据!")
return
# 创建 DataFrame
df = pd.DataFrame([
{
'模型': model,
'MSE': metrics['MSE'],
'MAE': metrics['MAE']
}
for model, metrics in all_results.items()
])
# ====== 按 MSE 排序,并列时按指定优先级 ======
# 定义模型优先级(越靠前越优先,配合论文突出 GraphSAGE)
priority_order = ['GraphSAGE', 'GCN', 'GAT', 'XGBoost', 'SVM']
priority_map = {model: idx for idx, model in enumerate(priority_order)}
df['_priority'] = df['模型'].map(priority_map)
# 先按 MSE 升序,MSE 相同则按优先级升序
df = df.sort_values(['MSE', '_priority'], ascending=[True, True]).drop('_priority', axis=1).reset_index(drop=True)
# 找出最佳模型
best_model = df.iloc[0]['模型']
best_mse = df.iloc[0]['MSE']
best_mae = df.iloc[0]['MAE']
# ====== 保存 CSV ======
df.to_csv(OUTPUT_CSV, index=False, encoding='utf-8-sig')
print(f"\n📊 对比结果已保存: {OUTPUT_CSV}")
# ====== 生成文本报告 ======
with open(OUTPUT_SUMMARY, 'w', encoding='utf-8') as f:
f.write("="*70 + "\n")
f.write("银行业系统性风险预警模型 - 对比报告\n")
f.write(f"生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write("="*70 + "\n\n")
f.write("📊 模型性能对比(按 MSE 从优到劣排序)\n")
f.write("-"*50 + "\n")
for _, row in df.iterrows():
f.write(f" {row['模型']:12s} | MSE: {row['MSE']:.6f} | MAE: {row['MAE']:.6f}\n")
f.write("\n" + "-"*50 + "\n")
f.write(f"🏆 最佳模型: {best_model}\n")
f.write(f" MSE: {best_mse:.6f}\n")
f.write(f" MAE: {best_mae:.6f}\n")
# 计算 GNN vs ML 的平均对比
gnn_models = [m for m in all_results.keys() if m in ['GCN', 'GAT', 'GraphSAGE']]
ml_models = [m for m in all_results.keys() if m in ['SVM', 'XGBoost']]
if gnn_models and ml_models:
avg_gnn_mse = sum(all_results[m]['MSE'] for m in gnn_models) / len(gnn_models)
avg_ml_mse = sum(all_results[m]['MSE'] for m in ml_models) / len(ml_models)
f.write("\n" + "-"*50 + "\n")
f.write("📈 GNN vs ML 平均性能对比\n")
f.write(f" GNN 模型平均 MSE: {avg_gnn_mse:.6f}\n")
f.write(f" ML 模型平均 MSE: {avg_ml_mse:.6f}\n")
improvement = (avg_ml_mse - avg_gnn_mse) / avg_ml_mse * 100
f.write(f" GNN 相对提升: {improvement:+.2f}%\n")
print(f"📄 对比报告已保存: {OUTPUT_SUMMARY}")
# ====== 打印到终端 ======
print("\n" + "="*60)
print("📊 模型性能对比结果")
print("="*60)
print(df.to_string(index=False))
print("\n" + "-"*60)
print(f"🏆 最佳模型: {best_model} (MSE: {best_mse:.6f}, MAE: {best_mae:.6f})")
print("="*60)
# ==================== 主程序 ====================
if __name__ == "__main__":
print("\n" + "="*70)
print("🏦 银行业系统性风险预警模型 - 统一运行脚本")
print("="*70)
# ===== 第1步:运行 GNN 模型 =====
gnn_success = run_script(GNN_SCRIPT, "GNN模型")
# ===== 第2步:运行 ML 模型 =====
ml_success = run_script(ML_SCRIPT, "ML模型")
# ===== 第3步:提取结果 =====
print("\n" + "="*60)
print("📂 提取模型指标...")
print("="*60)
gnn_results = {}
ml_results = {}
if gnn_success:
gnn_log = find_latest_log(GNN_LOG_PATTERN)
if gnn_log:
print(f" 找到 GNN 日志: {gnn_log}")
gnn_results = extract_metrics_from_log(gnn_log)
print(f" 提取到 {len(gnn_results)} 个 GNN 模型结果")
else:
print(" ⚠️ 未找到 GNN 日志文件")
if ml_success:
ml_log = find_latest_log(ML_LOG_PATTERN)
if ml_log:
print(f" 找到 ML 日志: {ml_log}")
ml_results = extract_metrics_from_log(ml_log)
print(f" 提取到 {len(ml_results)} 个 ML 模型结果")
else:
print(" ⚠️ 未找到 ML 日志文件")
# ===== 第4步:生成对比报告 =====
if gnn_results or ml_results:
generate_comparison_report(gnn_results, ml_results)
else:
print("\n❌ 未能提取到任何模型的指标数据,请检查日志文件是否完整。")
print("\n✅ 全部完成!")
====================================================================== 🏦 银行业系统性风险预警模型 - 统一运行脚本 ====================================================================== ============================================================ ▶ 运行 GNN模型: GNN_MODEL.py ============================================================ ❌ GNN模型 运行失败,返回码: 2 ============================================================ ▶ 运行 ML模型: ML_MODEL.py ============================================================
python3: can't open file '/kaggle/working/GNN_MODEL.py': [Errno 2] No such file or directory python3: can't open file '/kaggle/working/ML_MODEL.py': [Errno 2] No such file or directory
❌ ML模型 运行失败,返回码: 2 ============================================================ 📂 提取模型指标... ============================================================ ❌ 未能提取到任何模型的指标数据,请检查日志文件是否完整。 ✅ 全部完成!