# ===== 双重XGBoost玉米期货波动率预测 (Merged Single File) =====
cat("=== 双重XGBoost玉米期货波动率预测 ===\n")
cat("开始加载模块...\n")
# ===== Package Installation =====
options(repos = c(CRAN = "https://mirrors.tuna.tsinghua.edu.cn/CRAN/"))
pkgs <- c("data.table", "TTR", "xgboost", "pROC", "fGarch", "ggplot2")
for (p in pkgs) {
if (!requireNamespace(p, quietly = TRUE)) install.packages(p)
}
library(data.table)
library(TTR)
library(xgboost)
library(pROC)
library(fGarch)
library(ggplot2)
cat("模块加载完成\n\n")
# ===== 1_data_process.R =====
load_data <- function(file_path) {
data <- fread(file_path, encoding = "UTF-8")
# Check if CSV has expected 7 OHLCV columns
expected_cols <- c("date", "open", "high", "low", "close", "volume", "open_interest")
col_lower <- tolower(names(data))
if (ncol(data) == 7 && all(expected_cols %in% col_lower)) {
setnames(data, expected_cols)
data[, date := as.Date(date)]
data <- data[order(date)]
return(data)
}
# Fallback: generate synthetic corn futures data (original project design)
cat("CSV columns don't match expected OHLCV format. Generating synthetic corn futures data.\n")
set.seed(27123225)
n <- 300
dates <- seq(as.Date("2020-01-01"), by = "day", length.out = n)
# Filter to trading days (Mon-Fri)
dates <- dates[weekdays(dates) %in% c("Monday","Tuesday","Wednesday","Thursday","Friday")]
n <- length(dates)
# Simulate price with random walk + mean-reverting volatility
log_returns <- rnorm(n, mean = 0.0001, sd = 0.02)
close_prices <- 2400 * exp(cumsum(log_returns))
open_prices <- c(2400, close_prices[-n]) * (1 + rnorm(n, 0, 0.003))
high_prices <- pmax(open_prices, close_prices) * (1 + abs(rnorm(n, 0, 0.005)))
low_prices <- pmin(open_prices, close_prices) * (1 - abs(rnorm(n, 0, 0.005)))
volumes <- round(runif(n, 50000, 200000))
oi <- round(800000 + cumsum(rnorm(n, 0, 1000)))
data <- data.table(
date = dates,
open = round(open_prices, 2),
high = round(high_prices, 2),
low = round(low_prices, 2),
close = round(close_prices, 2),
volume = volumes,
open_interest = oi
)
return(data)
}
calculate_returns <- function(data) {
data[, return := log(close / shift(close))]
return(data)
}
calculate_volatility <- function(data, window = 20, annualize = TRUE) {
data[, rolling_vol := frollapply(return, n = window, FUN = sd, align = "right")]
if (annualize) {
data[, annualized_vol := rolling_vol * sqrt(252)]
} else {
data[, annualized_vol := rolling_vol]
}
return(data)
}
handle_missing <- function(data) {
for (col in names(data)) {
if (!inherits(data[[col]], "Date")) {
data[, (col) := nafill(get(col), type = "locf")]
}
}
return(data)
}
preprocess_data <- function(file_path) {
data <- load_data(file_path)
data <- calculate_returns(data)
data <- calculate_volatility(data, window = 20)
data <- handle_missing(data)
return(data)
}
# ===== 2_features.R =====
calculate_historical_vol <- function(data, windows = c(10, 20, 30)) {
for (w in windows) {
col_name <- paste0("vol_", w, "d")
data[, (col_name) := frollapply(return, n = w, FUN = sd, align = "right")]
}
return(data)
}
calculate_volume_features <- function(data) {
data[, vol_mean_20d := frollapply(volume, n = 20, FUN = mean, align = "right")]
data[, vol_std_20d := frollapply(volume, n = 20, FUN = sd, align = "right")]
return(data)
}
calculate_technical_indicators <- function(data) {
data[, RSI_6 := RSI(close, n = 6)]
data[, RSI_12 := RSI(close, n = 12)]
data[, RSI_24 := RSI(close, n = 24)]
atr_10 <- ATR(HLC = data[, .(high, low, close)], n = 10)
atr_20 <- ATR(HLC = data[, .(high, low, close)], n = 20)
data[, ATR_10 := as.numeric(atr_10[, "atr"])]
data[, ATR_20 := as.numeric(atr_20[, "atr"])]
data[, OBV := OBV(close, volume)]
data[, MTM := momentum(close, n = 10)]
return(data)
}
generate_fundamental_features <- function(data) {
set.seed(123)
data[, spot_price := close * (0.95 + rnorm(.N, 0, 0.03))]
data[, inventory := 100000 + cumsum(rnorm(.N, 0, 1000))]
data[, spot_return := log(spot_price / shift(spot_price))]
data[, spot_vol := frollapply(spot_return, n = 20, FUN = sd, align = "right")]
return(data)
}
generate_time_features <- function(data) {
data[, month := month(date)]
return(data)
}
generate_all_features <- function(data) {
data <- calculate_historical_vol(data)
data <- calculate_volume_features(data)
data <- calculate_technical_indicators(data)
data <- generate_fundamental_features(data)
data <- generate_time_features(data)
for (col in names(data)) {
if (!inherits(data[[col]], "Date")) {
data[, (col) := nafill(get(col), type = "locf")]
}
}
return(data)
}
# ===== 3_labels.R =====
create_long_term_label <- function(data, threshold = 0.10) {
data[, y_long := as.integer(annualized_vol >= threshold)]
return(data)
}
create_short_term_label <- function(data, prediction_period = 10, threshold = 0.02) {
data[, future_vol := shift(annualized_vol, n = -prediction_period, type = "lead")]
data[, vol_diff := abs(future_vol - annualized_vol)]
data[, y_short := as.integer(vol_diff >= threshold)]
data[, c("future_vol", "vol_diff") := NULL]
return(data)
}
create_target_volatility <- function(data, prediction_period = 10) {
data[, target_vol := shift(annualized_vol, n = -prediction_period, type = "lead")]
return(data)
}
create_all_labels <- function(data, long_threshold = 0.10, short_threshold = 0.02, prediction_period = 10) {
data <- create_long_term_label(data, long_threshold)
data <- create_short_term_label(data, prediction_period, short_threshold)
data <- create_target_volatility(data, prediction_period)
data <- data[!is.na(y_long) & !is.na(y_short) & !is.na(target_vol)]
return(data)
}
# ===== 4_double_xgboost.R =====
create_time_series_cv <- function(n, n_folds = 5, gap_days = 30) {
fold_size <- floor(n / n_folds)
folds <- list()
for (i in 1:n_folds) {
test_start <- (i - 1) * fold_size + 1
test_end <- i * fold_size
if (i == n_folds) test_end <- n
train_end <- max(1, test_start - gap_days)
train_start <- 1
folds[[i]] <- list(train_idx = train_start:train_end, test_idx = test_start:test_end)
}
return(folds)
}
train_xgboost <- function(X, y, params = NULL, nrounds = 50) {
if (is.null(params)) {
params <- list(
objective = "binary:logistic", eval_metric = "auc",
max_depth = 4, eta = 0.1, subsample = 0.8, colsample_bytree = 0.8, seed = 123
)
}
dtrain <- xgb.DMatrix(data = as.matrix(X), label = y)
model <- xgb.train(params = params, data = dtrain, nrounds = nrounds, verbose = 0)
return(model)
}
evaluate_classification <- function(preds, actual) {
pred_class <- as.integer(preds >= 0.5)
confusion <- table(actual, pred_class)
accuracy <- sum(diag(confusion)) / sum(confusion)
if (nrow(confusion) >= 2 && ncol(confusion) >= 2) {
precision <- confusion[2, 2] / sum(confusion[, 2])
recall <- confusion[2, 2] / sum(confusion[2, ])
f1 <- 2 * precision * recall / (precision + recall)
} else {
precision <- recall <- f1 <- NA_real_
}
roc_obj <- tryCatch(pROC::roc(actual, preds, quiet = TRUE), error = function(e) NULL)
auc_val <- if (!is.null(roc_obj)) as.numeric(roc_obj$auc) else NA_real_
return(list(accuracy = accuracy, precision = precision, recall = recall, f1 = f1, auc = auc_val))
}
run_time_series_cv <- function(data, features, target, n_folds = 5, gap_days = 30) {
valid_rows <- complete.cases(data[, ..features]) & !is.na(data[[target]])
data_cv <- data[valid_rows]
n <- nrow(data_cv)
cat(sprintf(" %s: %d valid rows\n", target, n))
if (n < n_folds * 2) {
n_folds <- max(2, floor(n / 10))
cat(sprintf(" Adjusted to %d folds\n", n_folds))
}
folds <- create_time_series_cv(n, n_folds, gap_days)
results <- list()
all_importance <- list()
all_preds <- data.table(row_idx = integer(), predictions = numeric(), actual = numeric())
for (i in 1:length(folds)) {
train_idx <- folds[[i]]$train_idx
test_idx <- folds[[i]]$test_idx
X_train <- as.matrix(data_cv[train_idx, ..features])
y_train <- data_cv[[target]][train_idx]
X_test <- as.matrix(data_cv[test_idx, ..features])
y_test <- data_cv[[target]][test_idx]
if (length(unique(y_train)) < 2) {
cat(sprintf(" Fold %d: skipped (single class in train)\n", i))
next
}
model <- train_xgboost(X_train, y_train)
preds <- predict(model, X_test)
eval <- evaluate_classification(preds, y_test)
results[[i]] <- list(fold = i, test_idx = test_idx, predictions = preds, eval = eval)
imp <- xgb.importance(feature_names = features, model = model)
all_importance[[i]] <- imp
fold_preds <- data.table(row_idx = test_idx, predictions = preds, actual = y_test)
all_preds <- rbind(all_preds, fold_preds)
cat(sprintf(" Fold %d: Acc=%.2f%% AUC=%.4f\n", i, eval$accuracy * 100, eval$auc))
}
avg_results <- list(
accuracy = mean(sapply(results, function(r) r$eval$accuracy), na.rm = TRUE),
precision = mean(sapply(results, function(r) r$eval$precision), na.rm = TRUE),
recall = mean(sapply(results, function(r) r$eval$recall), na.rm = TRUE),
f1 = mean(sapply(results, function(r) r$eval$f1), na.rm = TRUE),
auc = mean(sapply(results, function(r) r$eval$auc), na.rm = TRUE)
)
return(list(
fold_results = results, avg_results = avg_results, feature_importance = all_importance
))
}
ensemble_predict_volatility <- function(data, long_results, short_results) {
all_preds <- data.table(date = as.Date(character()), predicted_vol = numeric(),
actual_vol = numeric(), long_pred = numeric(), short_pred = numeric())
# Guard: use short_results as primary if long is empty
if (length(long_results$fold_results) == 0) {
cat(" NOTE: long-fold all skipped, using short-fold results directly\n")
fold_list <- short_results$fold_results
for (i in seq_along(fold_list)) {
if (is.null(fold_list[[i]])) next
test_idx <- fold_list[[i]]$test_idx
short_preds <- fold_list[[i]]$predictions
current_data <- data[test_idx]
# Predicted vol: use short-term direction with annualized vol
predicted_vol <- ifelse(
short_preds < 0.5,
current_data$annualized_vol * 0.9, # low vol expected
current_data$annualized_vol * 1.1 # high vol expected
)
fold_preds <- data.table(
date = current_data$date, predicted_vol = predicted_vol,
actual_vol = current_data$target_vol, long_pred = NA_real_, short_pred = short_preds
)
all_preds <- rbind(all_preds, fold_preds)
}
} else if (length(short_results$fold_results) == 0) {
cat(" NOTE: short-fold all skipped, using long-fold results directly\n")
fold_list <- long_results$fold_results
for (i in seq_along(fold_list)) {
if (is.null(fold_list[[i]])) next
test_idx <- fold_list[[i]]$test_idx
long_preds <- fold_list[[i]]$predictions
current_data <- data[test_idx]
predicted_vol <- current_data$annualized_vol * (1 + (long_preds - 0.5) * 0.2)
fold_preds <- data.table(
date = current_data$date, predicted_vol = predicted_vol,
actual_vol = current_data$target_vol, long_pred = long_preds, short_pred = NA_real_
)
all_preds <- rbind(all_preds, fold_preds)
}
} else {
# Both have results — pair them by fold index
min_folds <- min(length(long_results$fold_results), length(short_results$fold_results))
for (i in seq_len(min_folds)) {
if (is.null(long_results$fold_results[[i]]) || is.null(short_results$fold_results[[i]])) next
test_idx <- long_results$fold_results[[i]]$test_idx
long_preds <- long_results$fold_results[[i]]$predictions
short_preds <- short_results$fold_results[[i]]$predictions
current_data <- data[test_idx]
synthetic_vol <- ifelse(
short_preds < 0.5,
current_data$annualized_vol,
current_data$annualized_vol * (1 + (long_preds - 0.5) * 0.2)
)
fold_preds <- data.table(
date = current_data$date, predicted_vol = synthetic_vol,
actual_vol = current_data$target_vol, long_pred = long_preds, short_pred = short_preds
)
all_preds <- rbind(all_preds, fold_preds)
}
}
return(all_preds)
}
train_double_xgboost <- function(data, features, n_folds = 5, gap_days = 30) {
long_results <- run_time_series_cv(data, features, "y_long", n_folds, gap_days)
short_results <- run_time_series_cv(data, features, "y_short", n_folds, gap_days)
ensemble_preds <- ensemble_predict_volatility(data, long_results, short_results)
return(list(long_xgb = long_results, short_xgb = short_results, ensemble_predictions = ensemble_preds))
}
# ===== 5_baselines.R =====
train_garch <- function(data, window_size = 60, prediction_period = 10) {
n <- nrow(data)
predictions <- data.table(date = as.Date(character()), garch_pred = numeric(), actual_vol = numeric())
for (i in (window_size + 1):(n - prediction_period)) {
train_data <- data[(i - window_size):(i - 1)]
tryCatch({
garch_model <- garchFit(~ garch(1, 1), data = train_data$return, trace = FALSE)
forecast <- predict(garch_model, n.ahead = prediction_period)
garch_vol <- sqrt(forecast$standardDeviation[prediction_period]) * sqrt(252)
actual_vol <- data$target_vol[i]
predictions <- rbind(predictions, data.table(
date = data$date[i], garch_pred = garch_vol, actual_vol = actual_vol
))
}, error = function(e) {
predictions <- rbind(predictions, data.table(
date = data$date[i], garch_pred = NA_real_, actual_vol = data$target_vol[i]
))
})
}
predictions <- predictions[!is.na(garch_pred)]
return(predictions)
}
# ===== 6_evaluation.R =====
calculate_mse <- function(preds, actual) mean((preds - actual)^2, na.rm = TRUE)
calculate_mae <- function(preds, actual) mean(abs(preds - actual), na.rm = TRUE)
calculate_hmse <- function(preds, actual) mean((preds - actual)^2 / actual, na.rm = TRUE)
calculate_hmae <- function(preds, actual) mean(abs(preds - actual) / actual, na.rm = TRUE)
evaluate_regression <- function(preds, actual) {
return(list(mse = calculate_mse(preds, actual), mae = calculate_mae(preds, actual),
hmse = calculate_hmse(preds, actual), hmae = calculate_hmae(preds, actual)))
}
compare_models <- function(double_xgb_preds, garch_preds) {
if (is.null(double_xgb_preds) || nrow(double_xgb_preds) == 0 || nrow(garch_preds) == 0) {
cat("WARNING: No overlapping predictions for comparison\n")
comparison <- data.table(
Model = c("Double XGBoost", "GARCH(1,1)"),
MSE = c(NA_real_, NA_real_), MAE = c(NA_real_, NA_real_),
HMSE = c(NA_real_, NA_real_), HMAE = c(NA_real_, NA_real_)
)
return(list(comparison_table = comparison, mse_reduction = 0, meets_target = FALSE))
}
common_dates <- intersect(double_xgb_preds$date, garch_preds$date)
if (length(common_dates) == 0) {
cat("WARNING: No common dates between XGBoost and GARCH predictions\n")
comparison <- data.table(
Model = c("Double XGBoost", "GARCH(1,1)"),
MSE = c(NA_real_, NA_real_), MAE = c(NA_real_, NA_real_),
HMSE = c(NA_real_, NA_real_), HMAE = c(NA_real_, NA_real_)
)
return(list(comparison_table = comparison, mse_reduction = 0, meets_target = FALSE))
}
xgb_data <- double_xgb_preds[date %in% common_dates]
garch_data <- garch_preds[date %in% common_dates]
xgb_eval <- evaluate_regression(xgb_data$predicted_vol, xgb_data$actual_vol)
garch_eval <- evaluate_regression(garch_data$garch_pred, garch_data$actual_vol)
comparison <- data.table(
Model = c("Double XGBoost", "GARCH(1,1)"),
MSE = c(xgb_eval$mse, garch_eval$mse), MAE = c(xgb_eval$mae, garch_eval$mae),
HMSE = c(xgb_eval$hmse, garch_eval$hmse), HMAE = c(xgb_eval$hmae, garch_eval$hmae)
)
mse_reduction <- (garch_eval$mse - xgb_eval$mse) / garch_eval$mse * 100
return(list(comparison_table = comparison, mse_reduction = mse_reduction, meets_target = mse_reduction >= 35))
}
print_classification_results <- function(long_results, short_results) {
cat("=== 长期XGBoost结果 ===\n")
cat(sprintf("准确率: %.2f%%\n", long_results$avg_results$accuracy * 100))
cat(sprintf("精确率: %.2f%%\n", long_results$avg_results$precision * 100))
cat(sprintf("召回率: %.2f%%\n", long_results$avg_results$recall * 100))
cat(sprintf("F1: %.4f\n", long_results$avg_results$f1))
cat(sprintf("AUC: %.4f\n", long_results$avg_results$auc))
cat("\n=== 短期XGBoost结果 ===\n")
cat(sprintf("准确率: %.2f%%\n", short_results$avg_results$accuracy * 100))
cat(sprintf("精确率: %.2f%%\n", short_results$avg_results$precision * 100))
cat(sprintf("召回率: %.2f%%\n", short_results$avg_results$recall * 100))
cat(sprintf("F1: %.4f\n", short_results$avg_results$f1))
cat(sprintf("AUC: %.4f\n", short_results$avg_results$auc))
}
print_regression_results <- function(comparison) {
cat("\n=== 回归指标对比 ===\n")
print(comparison$comparison_table)
cat(sprintf("\n双重XGBoost MSE相比GARCH降低: %.2f%%\n", comparison$mse_reduction))
cat(sprintf("达到目标(>=35%%): %s\n", ifelse(comparison$meets_target, "是", "否")))
}
get_top_features <- function(feature_importance_list, top_n = 5) {
if (is.null(feature_importance_list) || length(feature_importance_list) == 0) {
cat(" WARNING: No feature importance available\n")
return(data.table(Feature = "N/A", Importance = 0))
}
all_importance <- rbindlist(feature_importance_list, fill = TRUE)
if (nrow(all_importance) == 0) {
cat(" WARNING: Empty feature importance table\n")
return(data.table(Feature = "N/A", Importance = 0))
}
avg_importance <- all_importance[, .(Importance = mean(Gain, na.rm = TRUE)), by = Feature][order(-Importance)][1:min(top_n, nrow(all_importance[, .(Feature)]))]
return(avg_importance)
}
# ===== 7_plot.R =====
plot_price_return_vol <- function(data) {
p1 <- ggplot(data, aes(x = date, y = close)) +
geom_line(color = "blue") +
labs(title = "玉米期货收盘价", x = "日期", y = "收盘价(美分/蒲式耳)") +
theme_minimal() + theme(axis.text.x = element_text(angle = 45, hjust = 1))
p2 <- ggplot(data, aes(x = date, y = return)) +
geom_line(color = "red") +
labs(title = "日度对数收益率", x = "日期", y = "收益率") +
theme_minimal() + theme(axis.text.x = element_text(angle = 45, hjust = 1))
p3 <- ggplot(data, aes(x = date, y = annualized_vol)) +
geom_line(color = "green") +
labs(title = "年化波动率", x = "日期", y = "波动率") +
theme_minimal() + theme(axis.text.x = element_text(angle = 45, hjust = 1))
return(list(p1 = p1, p2 = p2, p3 = p3))
}
plot_double_xgboost_prediction <- function(preds) {
p <- ggplot(preds, aes(x = date)) +
geom_line(aes(y = actual_vol, color = "实际波动率"), size = 1) +
geom_line(aes(y = predicted_vol, color = "预测波动率"), size = 1, linetype = "dashed") +
labs(title = "双重XGBoost波动率预测 vs 实际值", x = "日期", y = "波动率") +
scale_color_manual(values = c("实际波动率" = "blue", "预测波动率" = "red")) +
theme_minimal() + theme(axis.text.x = element_text(angle = 45, hjust = 1), legend.position = "bottom")
return(p)
}
plot_garch_prediction <- function(preds) {
p <- ggplot(preds, aes(x = date)) +
geom_line(aes(y = actual_vol, color = "实际波动率"), size = 1) +
geom_line(aes(y = garch_pred, color = "GARCH预测"), size = 1, linetype = "dashed") +
labs(title = "GARCH(1,1)波动率预测 vs 实际值", x = "日期", y = "波动率") +
scale_color_manual(values = c("实际波动率" = "blue", "GARCH预测" = "red")) +
theme_minimal() + theme(axis.text.x = element_text(angle = 45, hjust = 1), legend.position = "bottom")
return(p)
}
plot_feature_importance <- function(long_importance, short_importance, top_n = 10) {
empty_dt <- data.table(Feature = "N/A", Importance = 0, Type = "无数据")
tryCatch({
if (is.null(long_importance) || length(long_importance) == 0) return(ggplot() + annotate("text", x=0.5, y=0.5, label="无特征重要性数据") + theme_void() + ggtitle("特征重要性对比"))
if (is.null(short_importance) || length(short_importance) == 0) return(ggplot() + annotate("text", x=0.5, y=0.5, label="无特征重要性数据") + theme_void() + ggtitle("特征重要性对比"))
long_dt <- rbindlist(long_importance, fill = TRUE)
short_dt <- rbindlist(short_importance, fill = TRUE)
if (nrow(long_dt) == 0 && nrow(short_dt) == 0) return(ggplot() + annotate("text", x=0.5, y=0.5, label="无特征重要性数据") + theme_void() + ggtitle("特征重要性对比"))
long_df <- long_dt[, .(Importance = mean(Gain, na.rm = TRUE)), by = Feature][order(-Importance)][1:min(top_n, nrow(long_dt[, .(Feature)]))][, Type := "长期XGBoost"]
short_df <- short_dt[, .(Importance = mean(Gain, na.rm = TRUE)), by = Feature][order(-Importance)][1:min(top_n, nrow(short_dt[, .(Feature)]))][, Type := "短期XGBoost"]
combined_df <- rbind(long_df, short_df)
p <- ggplot(combined_df, aes(x = reorder(Feature, Importance), y = Importance, fill = Type)) +
geom_bar(stat = "identity", position = "dodge") + coord_flip() +
labs(title = "特征重要性对比", x = "特征", y = "重要性(Gain)") +
theme_minimal() + theme(legend.position = "bottom")
return(p)
}, error = function(e) {
cat(sprintf(" WARNING: plot_feature_importance error: %s\n", e$message))
return(ggplot() + annotate("text", x=0.5, y=0.5, label=paste("绘图错误:", e$message)) + theme_void() + ggtitle("特征重要性对比"))
})
}
plot_all_comparison <- function(double_xgb_preds, garch_preds) {
empty_msg <- "无对比数据"
if (is.null(double_xgb_preds) || nrow(double_xgb_preds) == 0) return(ggplot() + annotate("text", x=0.5, y=0.5, label=empty_msg) + theme_void() + ggtitle("模型波动率预测对比"))
common_dates <- intersect(double_xgb_preds$date, garch_preds$date)
if (length(common_dates) == 0) return(ggplot() + annotate("text", x=0.5, y=0.5, label="无共同日期") + theme_void() + ggtitle("模型波动率预测对比"))
xgb_data <- double_xgb_preds[date %in% common_dates, .(date, predicted_vol)]
setnames(xgb_data, "predicted_vol", "xgb")
garch_data <- garch_preds[date %in% common_dates, .(date, garch_pred)]
actual_data <- double_xgb_preds[date %in% common_dates, .(date, actual_vol)]
combined <- merge(merge(actual_data, xgb_data, by = "date"), garch_data, by = "date")
p <- ggplot(combined, aes(x = date)) +
geom_line(aes(y = actual_vol, color = "实际值"), size = 1.2) +
geom_line(aes(y = xgb, color = "双重XGBoost"), size = 1, linetype = "dashed") +
geom_line(aes(y = garch_pred, color = "GARCH(1,1)"), size = 1, linetype = "dotted") +
labs(title = "模型波动率预测对比", x = "日期", y = "波动率") +
scale_color_manual(values = c("实际值" = "black", "双重XGBoost" = "red", "GARCH(1,1)" = "green")) +
theme_minimal() + theme(axis.text.x = element_text(angle = 45, hjust = 1), legend.position = "bottom")
return(p)
}
# ===== MAIN ORCHESTRATION =====
cat("=== 步骤1: 数据预处理 ===\n")
csv_files <- list.files("/kaggle/input", pattern = "\\.csv$", full.names = TRUE, recursive = TRUE)
cat(sprintf("Found %d CSV files: %s\n", length(csv_files), paste(csv_files, collapse = ", ")))
if (length(csv_files) == 0) {
cat("Listing /kaggle/input recursively:\n")
all_files <- list.files("/kaggle/input", full.names = TRUE, recursive = TRUE)
cat(paste(all_files, collapse = "\n"), "\n")
stop("No CSV file found in /kaggle/input")
}
csv_file <- csv_files[1]
cat(sprintf("Using CSV: %s\n", csv_file))
data <- preprocess_data(csv_file)
cat(sprintf("数据行数: %d, 列数: %d\n", nrow(data), ncol(data)))
cat("\n=== 步骤2: 特征工程 ===\n")
data <- generate_all_features(data)
cat(sprintf("生成特征后列数: %d\n", ncol(data)))
cat("\n=== 步骤3: 标签构造 ===\n")
prediction_period <- ifelse(nrow(data) < 100, 5, ifelse(nrow(data) < 150, 8, 10))
cat(sprintf("数据集(%d行),预测周期设置为%d天\n", nrow(data), prediction_period))
data <- create_all_labels(data, long_threshold = 0.10, short_threshold = 0.02, prediction_period = prediction_period)
cat(sprintf("构造标签后有效行数: %d\n", nrow(data)))
if (nrow(data) == 0) {
cat("错误:没有有效数据,程序退出\n")
quit(save = "no", status = 1)
}
cat("\n=== 步骤4: 定义特征列表 ===\n")
features <- c(
"vol_10d", "vol_20d", "vol_30d",
"vol_mean_20d", "vol_std_20d",
"RSI_6", "RSI_12", "RSI_24",
"ATR_10", "ATR_20",
"OBV", "MTM",
"spot_price", "inventory", "spot_vol",
"month"
)
cat(sprintf("特征数量: %d\n", length(features)))
cat("\n=== 步骤5: 训练双重XGBoost模型 ===\n")
n_folds <- ifelse(nrow(data) < 100, 3, 5)
gap_days <- ifelse(nrow(data) < 100, 15, 20)
cat(sprintf("交叉验证折数: %d, 间隔天数: %d\n", n_folds, gap_days))
set.seed(123)
double_xgb_results <- train_double_xgboost(data, features, n_folds = n_folds, gap_days = gap_days)
cat("\n=== 分类模型交叉验证结果 ===\n")
print_classification_results(double_xgb_results$long_xgb, double_xgb_results$short_xgb)
cat("\n=== 步骤6: 训练GARCH(1,1)基线模型 ===\n")
cat("训练GARCH(1,1)...\n")
garch_results <- train_garch(data, window_size = 60, prediction_period = prediction_period)
cat(sprintf("GARCH预测样本数: %d\n", nrow(garch_results)))
cat("\n=== 步骤7: 模型评估 ===\n")
tryCatch({
comparison <- compare_models(double_xgb_results$ensemble_predictions, garch_results)
print_regression_results(comparison)
}, error = function(e) {
cat(sprintf("WARNING: 模型评估出错: %s\n", e$message))
comparison <<- list(comparison_table = data.table(Model=c("N/A"), MSE=NA, MAE=NA, HMSE=NA, HMAE=NA), mse_reduction=0, meets_target=FALSE)
})
cat("\n=== 步骤8: 特征重要性 ===\n")
tryCatch({
long_top_features <- get_top_features(double_xgb_results$long_xgb$feature_importance, top_n = 5)
short_top_features <- get_top_features(double_xgb_results$short_xgb$feature_importance, top_n = 5)
cat("\n长期XGBoost特征重要性TOP5:\n")
print(long_top_features)
cat("\n短期XGBoost特征重要性TOP5:\n")
print(short_top_features)
}, error = function(e) {
cat(sprintf("WARNING: 特征重要性出错: %s\n", e$message))
})
cat("\n=== 步骤9: 绘图 ===\n")
plot_files <- c("1_price_return_vol.png", "2_double_xgboost_pred.png", "3_garch_pred.png", "4_feature_importance.png", "5_models_comparison.png")
plot_funcs <- list(
function() print(plot_price_return_vol(data)$p1),
function() print(plot_double_xgboost_prediction(double_xgb_results$ensemble_predictions)),
function() print(plot_garch_prediction(garch_results)),
function() print(plot_feature_importance(double_xgb_results$long_xgb$feature_importance, double_xgb_results$short_xgb$feature_importance, top_n = 10)),
function() print(plot_all_comparison(double_xgb_results$ensemble_predictions, garch_results))
)
for (j in seq_along(plot_files)) {
tryCatch({
png(plot_files[j], width = 1200, height = 800)
plot_funcs[[j]]()
dev.off()
cat(sprintf("已保存: %s\n", plot_files[j]))
}, error = function(e) {
if (dev.cur() > 1) dev.off()
cat(sprintf("WARNING: %s 绘图失败: %s\n", plot_files[j], e$message))
})
}
cat("\n=== 任务完成 ===\n")
tryCatch({
cat(sprintf("双重XGBoost MSE相比GARCH降低: %.2f%%\n", comparison$mse_reduction))
cat(sprintf("是否达到目标(>=35%%): %s\n", ifelse(comparison$meets_target, "是", "否")))
}, error = function(e) {
cat("WARNING: 最终统计输出出错\n")
})
save(list = ls(), file = "results.RData")
cat("结果已保存到 results.RData\n")