diff --git a/examples/CANVAS_extension.jl b/examples/CANVAS_extension.jl index 1da6ecea..d099d5d5 100644 --- a/examples/CANVAS_extension.jl +++ b/examples/CANVAS_extension.jl @@ -8,25 +8,15 @@ There are 4 major changes from the Poledna et al. (2023) to the CANVAS model (Ho 3) Demand pull firm level price and quanitity setting 4) Adaptive learning for the central bank to learn the parameters of the Taylor rule -This script implements changes 3 and 4 by overwriting the methods that govern that behaviour. +This script implements changes 3 and 4 by overwriting the methods that govern that behaviour. To introduce changes 1 and 2 we need dissagregated data at the household and firm level. """ import BeforeIT as Bit -using Plots, Dates, StatsPlots -# get parameters and initial conditions -cal = Bit.ITALY_CALIBRATION -calibration_date = DateTime(2010, 03, 31) - -p, ic = Bit.get_params_and_initial_conditions(cal, calibration_date; scale = 0.001) - -# initialize with historical data only - series will grow dynamically during simulation -Y_EA_series = vec(ic["Y_EA_series"]) -pi_EA_series = vec(ic["pi_EA_series"]) -r_bar_series = vec(ic["r_bar_series"]) - -Bit.@object mutable struct ModelCANVAS(Bit.Model) <: Bit.AbstractModel end +# ===================================================== +# AGENT TYPES +# ===================================================== # define a new central bank for the CANVAS model abstract type AbstractCentralBankCANVAS <: Bit.AbstractCentralBank end @@ -36,7 +26,7 @@ end # define new firms for the CANVAS model abstract type AbstractFirmsCANVAS <: Bit.AbstractFirms end -Bit.@object mutable struct FirmsCANVAS(Firms) <: AbstractFirmsCANVAS end +Bit.@object mutable struct FirmsCANVAS(Bit.Firms) <: AbstractFirmsCANVAS end # define a new rest of the world for the CANVAS model abstract type AbstractRestOfTheWorldCANVAS <: Bit.AbstractRestOfTheWorld end @@ -45,16 +35,44 @@ Bit.@object mutable struct RestOfTheWorldCANVAS(Bit.RestOfTheWorld) <: AbstractR pi_EA_series::Vector{Float64} end -# define new functions for the CANVAS-specific agents -function Bit.firms_expectations_and_decisions(model::ModelCANVAS) - firms = model.firms +# ===================================================== +# ADAPTIVE EXPECTATIONS (Eq. 15a-15b) +# ===================================================== +# γ^e(t) = exp(α^γ · γ(t-1) + β^γ + ε^γ) - 1 +# π^e(t) = exp(α^π · π(t-1) + β^π + ε^π) - 1 +# where parameters are re-estimated every period from the full history + +function Bit.growth_inflation_expectations( + model::Bit.Model{<:Bit.AbstractWorkers, <:Bit.AbstractWorkers, <:AbstractFirmsCANVAS, + <:Bit.AbstractBank, <:Bit.AbstractCentralBank, <:Bit.AbstractGovernment, + <:Bit.AbstractRestOfTheWorld, <:Bit.AbstractAggregates}) + Y = model.agg.Y + pi_ = model.agg.pi_ + T_prime = model.prop.T_prime + t = model.agg.t + + Y_slice = Y[1:(T_prime + t - 1)] + + # Eq. 15a: AR(1) on growth rates γ(t) = Y(t)/Y(t-1) - 1 + gamma_series = Y_slice[2:end] ./ Y_slice[1:end-1] .- 1.0 + gamma_e = Bit.estimate_next_value(gamma_series) + Y_e = Y_slice[end] * (1 + gamma_e) + + # Eq. 15b: AR(1) on inflation π(t) + pi_e = Bit.estimate_next_value(1 .+ pi_[1:(T_prime + t - 1)]) -1 + + return Y_e, gamma_e, pi_e +end + +# ===================================================== +# DEMAND-PULL PRICING (Eq. 17) +# ===================================================== - # unpack non-firm variables +function Bit.firms_expectations_and_decisions(firms::AbstractFirmsCANVAS, model::Bit.AbstractModel) P_bar_g = model.agg.P_bar_g gamma_e = model.agg.gamma_e pi_e = model.agg.pi_e - # individual firm quantity and price adjustments I = length(firms.G_i) gamma_d_i = zeros(I) pi_d_i = zeros(I) @@ -74,7 +92,7 @@ function Bit.firms_expectations_and_decisions(model::ModelCANVAS) pi_d_i[i] = 0 end end - #pi_d_i = min.(pi_d_i, 0.3) # cap the price adjustment to 30%. Otherwise it can reach 200% in some cases + Q_s_i = firms.Q_s_i .* (1 .+ gamma_e) .* (1 .+ gamma_d_i) # cost push inflation @@ -93,9 +111,15 @@ function Bit.firms_expectations_and_decisions(model::ModelCANVAS) return Q_s_i, I_d_i, DM_d_i, N_d_i, Pi_e_i, DL_d_i, K_e_i, L_e_i, new_P_i end -function Bit.central_bank_rate(model::ModelCANVAS) - cb = model.cb - gamma_EA, pi_EA, T_prime, t = model.rotw.gamma_EA, model.rotw.pi_EA, model.prop.T_prime, model.agg.t +# ===================================================== +# ADAPTIVE TAYLOR RULE (Eq. 19) +# ===================================================== + +function Bit.central_bank_rate(cb::AbstractCentralBankCANVAS, model::Bit.AbstractModel) + gamma_EA = model.rotw.gamma_EA + pi_EA = model.rotw.pi_EA + T_prime = model.prop.T_prime + t = model.agg.t a1 = cb.r_bar_series[1:(T_prime + t - 1)] a2 = model.rotw.Y_EA_series[1:(T_prime + t - 1)] @@ -109,60 +133,81 @@ function Bit.central_bank_rate(model::ModelCANVAS) return r_bar end -function Bit.growth_inflation_EA(model::ModelCANVAS) - rotw = model.rotw +# ===================================================== +# EA DYNAMICS — push! to series +# ===================================================== + +function Bit.growth_inflation_EA(rotw::AbstractRestOfTheWorldCANVAS, model::Bit.AbstractModel) epsilon_Y_EA = model.agg.epsilon_Y_EA - Y_EA = exp(rotw.alpha_Y_EA * log(rotw.Y_EA) + rotw.beta_Y_EA + epsilon_Y_EA) # GDP EA - gamma_EA = Y_EA / rotw.Y_EA - 1 # growth EA + Y_EA = exp(rotw.alpha_Y_EA * log(rotw.Y_EA) + rotw.beta_Y_EA + epsilon_Y_EA) + gamma_EA = Y_EA / rotw.Y_EA - 1 epsilon_pi_EA = randn() * rotw.sigma_pi_EA - pi_EA = exp(rotw.alpha_pi_EA * log(1 + rotw.pi_EA) + rotw.beta_pi_EA + epsilon_pi_EA) - 1 # inflation EA + pi_EA = exp(rotw.alpha_pi_EA * log(1 + rotw.pi_EA) + rotw.beta_pi_EA + epsilon_pi_EA) - 1 # push new values to time series push!(rotw.Y_EA_series, Y_EA) push!(rotw.pi_EA_series, pi_EA) return Y_EA, gamma_EA, pi_EA end -# new firms initialisation -firms_st = Bit.Firms(p, ic) -firms = FirmsCANVAS(Bit.fields(firms_st)...) -firms.Q_s_i .= firms.Q_d_i # overwrite to avoid division by zero for new firm price and quantity setting mechanism - -# new central bank initialisation -cb_st = Bit.CentralBank(p, ic) -cb = CentralBankCANVAS(Bit.fields(cb_st)..., r_bar_series) # add new variables to the aggregates - -# new rotw initialisation -rotw_st = Bit.RestOfTheWorld(p, ic) -rotw = RestOfTheWorldCANVAS(Bit.fields(rotw_st)..., Y_EA_series, pi_EA_series) # add new variables to the aggregates - -# standard initialisations: workers, bank, aggregats, government, properties and data -w_act, w_inact = Bit.Workers(p, ic) -bank = Bit.Bank(p, ic) -agg = Bit.Aggregates(p, ic) -gov = Bit.Government(p, ic) -prop = Bit.Properties(p, ic) -data = Bit.Data() - -# define a standard model -model_std = Bit.Model(p, ic) - -# define a CANVAS model -# importantly, initializing with a tuple "((w_act, w_inact, ...))" rathen than with "(w_act, w_inact, ...)" -# will perform extra needed initialization operations internally (for example, updating totals after all agents have been initialized) -model_canvas = ModelCANVAS((w_act, w_inact, firms, bank, cb, gov, rotw, agg, prop, data)) - -# The CANVAS model extension is also included in the BeforeIT package. -# You can instantiate a CANVAS model directly from parameters and initial conditions in a single line of code as -model_canvas_2 = Bit.ModelCANVAS(p, ic) - -# run the model(s) -T = 12 -n_sims = 8 -model_vector_std = Bit.ensemblerun(model_std, T, n_sims) -model_vector_canvas = Bit.ensemblerun(model_canvas, T, n_sims) -model_vector_canvas_2 = Bit.ensemblerun(model_canvas_2, T, n_sims) - -# plot the results -ps = Bit.plot_data_vectors([model_vector_std, model_vector_canvas, model_vector_canvas_2]) -plot(ps..., layout = (3, 3)) +# ===================================================== +# FACTORY: Create CANVAS Model +# ===================================================== + +""" + create_model(p, ic) + +Create a CANVAS model with demand-pull pricing and adaptive Taylor rule. + +Standard factory interface for `save_all_simulations`: + include("examples/CANVAS_extension.jl") + Bit.save_all_simulations(folder; model_factory=create_model, output_suffix="canvas") +""" +function create_model(p, ic) + # initialize series + Y_EA_series = Vector{Float64}(vec(ic["Y_EA_series"])) + pi_EA_series = Vector{Float64}(vec(ic["pi_EA_series"])) + r_bar_series = Vector{Float64}(vec(ic["r_bar_series"])) + + # custom agents + firms_st = Bit.Firms(p, ic) + firms = FirmsCANVAS((getfield(firms_st, x) for x in fieldnames(Bit.Firms))...) + firms.Q_s_i .= firms.Q_d_i + + cb_st = Bit.CentralBank(p, ic) + cb = CentralBankCANVAS((getfield(cb_st, x) for x in fieldnames(Bit.CentralBank))..., r_bar_series) + + rotw_st = Bit.RestOfTheWorld(p, ic) + rotw = RestOfTheWorldCANVAS( + (getfield(rotw_st, x) for x in fieldnames(Bit.RestOfTheWorld))..., + Y_EA_series, pi_EA_series) + + # standard agents + w_act, w_inact = Bit.Workers(p, ic) + bank = Bit.Bank(p, ic) + gov = Bit.Government(p, ic) + agg = Bit.Aggregates(p, ic) + prop = Bit.Properties(p, ic) + data = Bit.Data(p) + + return Bit.Model(w_act, w_inact, firms, bank, cb, gov, rotw, agg, prop, data) +end + +# ===================================================== +# DEMO: Only runs when executed directly +# ===================================================== + +#T = 12 +#cal = Bit.ITALY_CALIBRATION +#calibration_date = DateTime(2010, 03, 31) + +#p, ic = Bit.get_params_and_initial_conditions(cal, calibration_date; scale = 0.001) + +#model_std = Bit.Model(p, ic) +#model_canvas = create_model(p, ic) + +#model_vector_std = Bit.ensemblerun(model_std, T, 8) +#model_vector_canvas = Bit.ensemblerun(model_canvas, T, 8) + +#ps = Bit.plot_data_vectors([model_vector_std, model_vector_canvas]) +#plot(ps..., layout = (3, 3)) diff --git a/examples/GrowthRateAR1_extension.jl b/examples/GrowthRateAR1_extension.jl new file mode 100644 index 00000000..410c7221 --- /dev/null +++ b/examples/GrowthRateAR1_extension.jl @@ -0,0 +1,336 @@ +""" +===================================================================== +GROWTH-RATE AR(1) EXTENSION +===================================================================== + +Implementation of AR(1) on growth rates (instead of log-levels) via method overloading. +This follows the CANVAS_extension.jl pattern for type-based dispatch. + +Standard AR(1) on log-levels (default BeforeIT): + log(Y_t) = alpha * log(Y_{t-1}) + beta + epsilon + => Y_t = exp(alpha * log(Y_{t-1}) + beta + epsilon) + +Growth-rate AR(1) (this extension): + g_t = alpha * g_{t-1} + beta + epsilon + where g_t = (Y_t - Y_{t-1}) / Y_{t-1} + => Y_t = Y_{t-1} * (1 + g_t) + +Applied to GDP expectations (Y_e, gamma_e), inflation (pi_e), C_G, C_E, Y_I. +Y_EA keeps the base log-level AR(1) so that gamma_EA (which feeds the Taylor +rule → Euribor) remains consistent with the base calibration. + + +## Usage: + include("examples/GrowthRateAR1_extension.jl") + model_gr = create_model(p, ic) + Bit.run!(model_gr, 12) +""" + +import BeforeIT as Bit +using Statistics, LinearAlgebra + +# ===================================================== +# ABSTRACT TYPES FOR DISPATCH +# ===================================================== + +abstract type AbstractRestOfTheWorldGR <: Bit.AbstractRestOfTheWorld end +abstract type AbstractGovernmentGR <: Bit.AbstractGovernment end + +# ===================================================== +# EXTENDED STRUCTS WITH LAGGED GROWTH RATES +# ===================================================== + +""" +Extended RestOfTheWorld with lagged growth rates for AR(1) dynamics on C_E and Y_I. +Y_EA keeps the base log-level AR(1) to preserve Taylor rule / Euribor consistency. +""" +Bit.@object mutable struct RestOfTheWorldGR(Bit.RestOfTheWorld) <: AbstractRestOfTheWorldGR + # Growth-rate AR(1) parameters for C_E and Y_I (not Y_EA) + alpha_E_gr::Bit.typeFloat + beta_E_gr::Bit.typeFloat + sigma_E_gr::Bit.typeFloat + alpha_I_gr::Bit.typeFloat + beta_I_gr::Bit.typeFloat + sigma_I_gr::Bit.typeFloat + # Lagged growth rates (levels come from actual model values) + g_prev_C_E::Bit.typeFloat + g_prev_Y_I::Bit.typeFloat +end + +""" +Extended Government with lagged growth rate for AR(1) dynamics. +""" +Bit.@object mutable struct GovernmentGR(Bit.Government) <: AbstractGovernmentGR + # Growth-rate AR(1) parameters + alpha_G_gr::Bit.typeFloat + beta_G_gr::Bit.typeFloat + sigma_G_gr::Bit.typeFloat + # Lagged growth rate (level comes from actual model value) + g_prev_C_G::Bit.typeFloat +end + +# ===================================================== +# HELPER: ESTIMATE AR(1) ON GROWTH RATES +# ===================================================== + +""" + estimate_gr_ar1(series) + +Estimate AR(1) parameters on growth rates of the series. + +Returns (alpha, beta, sigma) where: + g_t = alpha * g_{t-1} + beta + epsilon + epsilon ~ N(0, sigma^2) +""" +function estimate_gr_ar1(series::Vector) + if length(series) < 3 + return 0.0, 0.0, 0.0 + end + + # Compute growth rates + growth_rates = diff(series) ./ series[1:end-1] + + if length(growth_rates) < 2 + return 0.0, mean(growth_rates), std(growth_rates) + end + + # Estimate AR(1) on growth rates + alpha, beta, sigma, _ = Bit.estimate_for_calibration_script(growth_rates) + + return alpha, beta, sigma +end + +""" + last_growth_rate(series) + +Compute the last growth rate from a historical series. +""" +function last_growth_rate(series::Vector) + if length(series) < 2 + return 0.0 + end + return (series[end] - series[end-1]) / series[end-1] +end + +""" + gr_next_value(current_value, g_prev, alpha, beta, sigma) + +Compute next growth rate using AR(1) and return (new_value, new_growth_rate). + +Uses the actual model value for level (not historical series), ensuring correct scale. +""" +function gr_next_value(current_value::Real, g_prev::Real, alpha::Real, beta::Real, sigma::Real) + # AR(1) on growth rate + epsilon = randn() * sigma + g_new = alpha * g_prev + beta + epsilon + + # Compute new level + Y_new = current_value * (1 + g_new) + + return Y_new, g_new +end + +# NOTE: No override for growth_inflation_EA — Y_EA uses base log-level AR(1) +# so that gamma_EA (Taylor rule input → Euribor) stays consistent. + +# ===================================================== +# OVERRIDE: growth_inflation_expectations (GDP on growth rates) +# ===================================================== + +function Bit.growth_inflation_expectations( + model::Bit.Model{<:Bit.AbstractWorkers, <:Bit.AbstractWorkers, <:Bit.AbstractFirms, + <:Bit.AbstractBank, <:Bit.AbstractCentralBank, <:AbstractGovernmentGR, + <:Bit.AbstractRestOfTheWorld, <:Bit.AbstractAggregates}) + Y = model.agg.Y + pi_ = model.agg.pi_ + T_prime = model.prop.T_prime + t = model.agg.t + + Y_slice = Y[1:(T_prime + t - 1)] + + # AR(1) on growth rates γ(t) = Y(t)/Y(t-1) - 1 (instead of log-level AR) + gamma_series = Y_slice[2:end] ./ Y_slice[1:end-1] .- 1.0 + gamma_e = Bit.estimate_next_value(gamma_series) + Y_e = Y_slice[end] * (1 + gamma_e) + + # AR(1) on (1+π) directly — π is already a rate, no log transform + pi_e = Bit.estimate_next_value(1 .+ pi_[1:(T_prime + t - 1)]) - 1 + + return Y_e, gamma_e, pi_e +end + +# ===================================================== +# OVERRIDE: gov_expenditure (C_G, C_d_j) +# ===================================================== + +function Bit.gov_expenditure(gov::AbstractGovernmentGR, model) + # Unpack non-government arguments + c_G_g = model.prop.c_G_g + P_bar_g = model.agg.P_bar_g + pi_e = model.agg.pi_e + + # Compute C_G using growth-rate AR(1) + C_G_new, g_new = gr_next_value( + gov.C_G, # Use actual model value + gov.g_prev_C_G, # Lagged growth rate + gov.alpha_G_gr, + gov.beta_G_gr, + gov.sigma_G_gr + ) + + # Update lagged growth rate for next iteration + gov.g_prev_C_G = g_new + + # Compute local government consumptions (same as base) + J = size(gov.C_d_j, 1) + C_d_j = C_G_new ./ J .* ones(J) .* sum(c_G_g .* P_bar_g) .* (1 + pi_e) + + return C_G_new, C_d_j +end + +# ===================================================== +# OVERRIDE: rotw_import_export (C_E, Y_I, ...) +# ===================================================== + +function Bit.rotw_import_export(rotw::AbstractRestOfTheWorldGR, model) + # Unpack model arguments + c_E_g = model.prop.c_E_g + c_I_g = model.prop.c_I_g + P_bar_g = model.agg.P_bar_g + pi_e = model.agg.pi_e + + L = size(rotw.C_d_l, 1) + + # Compute C_E using growth-rate AR(1) + C_E_new, g_new_E = gr_next_value( + rotw.C_E, # Use actual model value + rotw.g_prev_C_E, # Lagged growth rate + rotw.alpha_E_gr, + rotw.beta_E_gr, + rotw.sigma_E_gr + ) + + # Update lagged growth rate for next iteration + rotw.g_prev_C_E = g_new_E + + # Compute demand for export + C_d_l = C_E_new ./ L .* ones(L) .* sum(c_E_g .* P_bar_g) .* (1 + pi_e) + + # Compute Y_I using growth-rate AR(1) + Y_I_new, g_new_I = gr_next_value( + rotw.Y_I, # Use actual model value + rotw.g_prev_Y_I, # Lagged growth rate + rotw.alpha_I_gr, + rotw.beta_I_gr, + rotw.sigma_I_gr + ) + + # Update lagged growth rate for next iteration + rotw.g_prev_Y_I = g_new_I + + # Compute supply of imports (same as base) + Y_m = c_I_g * Y_I_new + P_m = P_bar_g * (1 + pi_e) + + return C_E_new, Y_I_new, C_d_l, Y_m, P_m +end + +# ===================================================== +# FACTORY: Create Growth-Rate AR(1) Model +# ===================================================== + +""" + create_model(p, ic) + +Create a model using growth-rate AR(1) for C_G, C_E, and Y_I. + +Y_EA keeps the base log-level AR(1) so that gamma_EA (Taylor rule input) +remains consistent with the base calibration, preserving Euribor dynamics. + +The standard BeforeIT model uses log-level AR(1): + log(Y_t) = alpha * log(Y_{t-1}) + beta + epsilon + +This model uses growth-rate AR(1) for selected variables: + g_t = alpha * g_{t-1} + beta + epsilon + Y_t = Y_{t-1} * (1 + g_t) + +AR(1) parameters are re-estimated on growth rates ONCE at model creation. + +This function is the standard factory interface for extensions - when `save_all_simulations` +is called with an extension file, it `include`s that file and calls `create_model(p, ic)`. +""" +function create_model(p, ic) + T_prime = Int(p["T_prime"]) + + # ===================================================== + # 1. Initialize standard agents FIRST (to get correct scales) + # ===================================================== + + w_act, w_inact = Bit.Workers(p, ic) + firms = Bit.Firms(p, ic) + bank = Bit.Bank(p, ic) + cb = Bit.CentralBank(p, ic) + agg = Bit.Aggregates(p, ic) + prop = Bit.Properties(p, ic) + data = Bit.Data(p) + + # Initialize standard RestOfTheWorld and Government to get actual model values + rotw_std = Bit.RestOfTheWorld(p, ic) + gov_std = Bit.Government(p, ic) + + # ===================================================== + # 2. Load historical series and estimate AR(1) on growth rates + # ===================================================== + # Note: Y_EA is NOT included — it keeps base log-level AR(1) + # to preserve Taylor rule / Euribor consistency. + + # C_G, C_E, Y_I series + C_G_series = Vector{Bit.typeFloat}(vec(ic["C_G"]))[1:T_prime] + C_E_series = Vector{Bit.typeFloat}(vec(ic["C_E"]))[1:T_prime] + Y_I_series = Vector{Bit.typeFloat}(vec(ic["Y_I"]))[1:T_prime] + + # Estimate AR(1) parameters on growth rates (scale-invariant) + alpha_G_gr, beta_G_gr, sigma_G_gr = estimate_gr_ar1(C_G_series) + alpha_E_gr, beta_E_gr, sigma_E_gr = estimate_gr_ar1(C_E_series) + alpha_I_gr, beta_I_gr, sigma_I_gr = estimate_gr_ar1(Y_I_series) + + # Compute lagged growth rates from rescaled series + g_prev_C_G = last_growth_rate(C_G_series) + g_prev_C_E = last_growth_rate(C_E_series) + g_prev_Y_I = last_growth_rate(Y_I_series) + + # ===================================================== + # 4. Initialize extended agents with growth-rate params + # ===================================================== + + # Create extended RestOfTheWorld (C_E, Y_I use growth-rate AR; Y_EA uses base) + rotw = RestOfTheWorldGR((getfield(rotw_std, f) for f in fieldnames(Bit.RestOfTheWorld))..., + alpha_E_gr, beta_E_gr, sigma_E_gr, + alpha_I_gr, beta_I_gr, sigma_I_gr, + g_prev_C_E, g_prev_Y_I + ) + + # Create extended Government with growth-rate AR(1) + gov = GovernmentGR((getfield(gov_std, f) for f in fieldnames(Bit.Government))..., + alpha_G_gr, beta_G_gr, sigma_G_gr, + g_prev_C_G + ) + + return Bit.Model(w_act, w_inact, firms, bank, cb, gov, rotw, agg, prop, data) +end + + +#T = 12 +#cal = Bit.ITALY_CALIBRATION +#calibration_date = DateTime(2010, 03, 31) + +#$p, ic = Bit.get_params_and_initial_conditions(cal, calibration_date; scale = 0.001) + +#model_std = Bit.Model(p, ic) +#model_gr = create_model(p, ic) + +#model_vector_std = Bit.ensemblerun(model_std, T, 8) +#model_vector_gr = Bit.ensemblerun(model_gr, T, 8) + +#ps = Bit.plot_data_vectors([model_vector_std, model_vector_gr]) +#plot(ps..., layout = (3, 3)) diff --git a/examples/analysis/figs/aggregate_forecast_results.jl b/examples/analysis/figs/aggregate_forecast_results.jl new file mode 100644 index 00000000..99f4f344 --- /dev/null +++ b/examples/analysis/figs/aggregate_forecast_results.jl @@ -0,0 +1,90 @@ +# Aggregate Forecast Results - Figures (Heatmaps) +# Creates heatmaps for visual comparison of forecast performance across countries + +using CSV, DataFrames, Statistics, Plots +import BeforeIT as Bit + +include(joinpath(@__DIR__, "../tabs/analysis_utils.jl")) + +# ============================================================================= +# MODEL VARIANT CONFIGURATION +# ============================================================================= +# Change this to create heatmaps for different model variants. +# Reads from: data/{country}/analysis/{MODEL_VARIANT}/ +# Writes to: analysis/figs/multicountry/forecast_performance/{MODEL_VARIANT}/heatmaps/ +# +# Options: "base", "growth_rate", "canvas" + +MODEL_VARIANT = "canvas" + +# ============================================================================= +# MAIN SCRIPT +# ============================================================================= + +@info "Creating forecast heatmaps for variant: $(MODEL_VARIANT)..." + +countries = Bit.discover_countries_with_predictions() +@info "Found $(length(countries)) countries" + +output_dir = joinpath("analysis", "figs", "multicountry", "forecast_performance", MODEL_VARIANT, "heatmaps") +mkpath(output_dir) + +for table_type in AGGREGATE_TABLE_TYPES + result = load_country_forecast_data(MODEL_VARIANT, table_type, countries) + result === nothing && continue + + (; valid_countries, variables, has_pvals) = result + matrix, pval_matrix = build_12q_matrices(result) + all(isnan, matrix) && continue + + n_countries = length(valid_countries) + n_variables = length(variables) + + # ── Heatmap ── + clean_vars = [replace(v, " " => "\n") for v in variables] + plot_width = max(600, n_variables * 120) + plot_height = max(400, n_countries * 40) + + is_bias = occursin("bias", table_type) + is_raw_rmse = table_type in ["rmse_abm", "rmse_ar", "rmse_validation_abm", "rmse_validation_var"] + valid_vals = matrix[.!isnan.(matrix)] + + if is_bias || !is_raw_rmse + max_abs = maximum(abs.(valid_vals)) + p = heatmap(matrix, + color=:RdBu, clims=(-max_abs, max_abs), colorbar=false, + size=(plot_width, plot_height), margin=15Plots.mm, + xticks=(1:n_variables, clean_vars), yticks=(1:n_countries, valid_countries), + xrotation=45) + else + p = heatmap(matrix, + color=:YlOrRd, colorbar=false, + size=(plot_width, plot_height), margin=15Plots.mm, + xticks=(1:n_variables, clean_vars), yticks=(1:n_countries, valid_countries), + xrotation=45) + end + + # Annotations with adaptive text color and significance stars + for i in 1:n_countries, j in 1:n_variables + val = matrix[i, j] + isnan(val) && continue + txt = is_bias ? string(round(val, digits=4)) : string(round(val, digits=1)) + if has_pvals && !isnan(pval_matrix[i, j]) + txt *= stars(pval_matrix[i, j]) + end + if is_bias || !is_raw_rmse + txt_color = abs(val) / max_abs > 0.65 ? :white : :black + else + min_val, max_val = extrema(valid_vals) + range = max_val - min_val + normalized = range > 0 ? (val - min_val) / range : 0.0 + txt_color = normalized > 0.7 ? :white : :black + end + annotate!(j, i, text(txt, 6, txt_color, :center)) + end + + savefig(p, joinpath(output_dir, "$(table_type)_heatmap.png")) + @info "✓ $table_type" +end + +@info "Done." diff --git a/examples/analysis/tabs/aggregate_forecast_results.jl b/examples/analysis/tabs/aggregate_forecast_results.jl new file mode 100644 index 00000000..02257bd9 --- /dev/null +++ b/examples/analysis/tabs/aggregate_forecast_results.jl @@ -0,0 +1,105 @@ +# Aggregate Forecast Results - Tables (CSV and LaTeX) +# Aggregates per-country forecast results into cross-country summaries + +using CSV, DataFrames, Statistics, Printf +import BeforeIT as Bit + +include(joinpath(@__DIR__, "analysis_utils.jl")) + +# ============================================================================= +# MODEL VARIANT CONFIGURATION +# ============================================================================= +# Change this to aggregate results for different model variants. +# Reads from: data/{country}/analysis/{MODEL_VARIANT}/ +# Writes to: analysis/tabs/multicountry_forecast_results/{MODEL_VARIANT}/ +# +# Options: "base", "growth_rate", "canvas" + +MODEL_VARIANT = "canvas" + +# ============================================================================= +# MAIN SCRIPT +# ============================================================================= + +@info "Aggregating forecast results for variant: $(MODEL_VARIANT)..." + +countries = Bit.discover_countries_with_predictions() +@info "Found $(length(countries)) countries" + +output_base = joinpath("analysis", "tabs", "multicountry_forecast_results", MODEL_VARIANT) +mkpath(joinpath(output_base, "countries")) +mkpath(joinpath(output_base, "latex")) + +for table_type in AGGREGATE_TABLE_TYPES + result = load_country_forecast_data(MODEL_VARIANT, table_type, countries) + result === nothing && (@warn "No data for $table_type"; continue) + + (; country_data, valid_countries, variables, has_pvals) = result + matrix, pval_matrix = build_12q_matrices(result) + all(isnan, matrix) && continue + + n_countries = length(valid_countries) + n_variables = length(variables) + is_bias = startswith(table_type, "bias_") + + # ── Country x Variable CSV (12q horizon with significance stars) ── + if has_pvals + digits = is_bias ? 4 : 1 + str_matrix = Matrix{String}(undef, n_countries, n_variables) + for i in 1:n_countries, j in 1:n_variables + if isnan(matrix[i, j]) + str_matrix[i, j] = "" + else + val_str = string(round(matrix[i, j], digits=digits)) + star_str = isnan(pval_matrix[i, j]) ? "" : stars(pval_matrix[i, j]) + str_matrix[i, j] = val_str * star_str + end + end + summary_df = DataFrame(str_matrix, variables) + else + summary_df = DataFrame(matrix, variables) + end + summary_df.Country = valid_countries + select!(summary_df, :Country, Not(:Country)) + CSV.write(joinpath(output_base, "countries", "$(table_type)_countries.csv"), summary_df) + + # ── Horizon x Variable LaTeX table (cross-country mean +/- SE) ── + n_horizons = length(FORECAST_HORIZONS) + mean_vals = fill(NaN, n_horizons, n_variables) + std_errs = fill(NaN, n_horizons, n_variables) + + for (i, h) in enumerate(FORECAST_HORIZONS), (j, var) in enumerate(variables) + vals = Float64[] + for country in valid_countries + df = country_data[country] + var in names(df) || continue + rows = df[df.Horizon .== "$(h)q", :] + if nrow(rows) == 1 && !ismissing(rows[1, var]) && !isnan(rows[1, var]) + push!(vals, rows[1, var]) + end + end + if !isempty(vals) + mean_vals[i, j] = mean(vals) + std_errs[i, j] = std(vals) / sqrt(length(vals)) + end + end + + open(joinpath(output_base, "latex", "cross_country_$(table_type).tex"), "w") do f + println(f, "% Cross-country aggregated table for $table_type") + for (i, h) in enumerate(FORECAST_HORIZONS) + row = ["$(h)q"] + for j in 1:n_variables + if !isnan(mean_vals[i,j]) && !isnan(std_errs[i,j]) + push!(row, @sprintf("%.1f(%.1f)", mean_vals[i,j], std_errs[i,j])) + else + push!(row, "N/A") + end + end + println(f, join(row, " & "), i < n_horizons ? " \\\\ " : "") + end + end + + @info "✓ $table_type: $(n_countries) countries" +end + +@info "Done." diff --git a/examples/analysis/tabs/analysis_utils.jl b/examples/analysis/tabs/analysis_utils.jl index 27053c8e..a95d15f0 100644 --- a/examples/analysis/tabs/analysis_utils.jl +++ b/examples/analysis/tabs/analysis_utils.jl @@ -1,8 +1,7 @@ -function latexTableContent( - input_data::Matrix{String}, tableRowLabels::Vector{String}, - dataFormat::String, tableColumnAlignment, tableBorders::Bool, booktabs::Bool, - makeCompleteLatexDocument::Bool - ) + +function latexTableContent(input_data::Matrix{String}, tableRowLabels::Vector{String}, + dataFormat::String, tableColumnAlignment, tableBorders::Bool, booktabs::Bool, + makeCompleteLatexDocument::Bool) nrows, ncols = size(input_data) latex = [] @@ -44,33 +43,83 @@ function stars(p_value) end end -nanmean(x) = mean(filter(!isnan, x)) -nanmean(x, y) = mapslices(nanmean, x; dims = y) +nanmean(x) = mean(filter(!isnan,x)) +nanmean(x,y) = mapslices(nanmean,x; dims = y) function calculate_forecast_errors(forecast, actual) error = forecast - actual - rmse = dropdims(100 * sqrt.(nanmean(error .^ 2, 1)), dims = 1) - bias = dropdims(nanmean(error, 1), dims = 1) + rmse = dropdims(100 * sqrt.(nanmean(error.^2, 1)), dims=1) + bias = dropdims(nanmean(error, 1), dims=1) return rmse, bias, error end -function write_latex_table(filename, country, input_data_S, horizons) +function write_latex_table(filename, country, input_data_S, horizons; model_variant::String="base") tableRowLabels = ["$(i)q" for i in horizons] dataFormat, tableColumnAlignment = "%.2f", "r" tableBorders, booktabs, makeCompleteLatexDocument = false, false, false latex = latexTableContent(input_data_S, tableRowLabels, dataFormat, tableColumnAlignment, tableBorders, booktabs, makeCompleteLatexDocument) - return open("data/$(country)/analysis/$(filename)", "w") do fid + analysis_dir = "data/$(country)/analysis/$(model_variant)" + mkpath(analysis_dir) + open("$(analysis_dir)/$(filename)", "w") do fid for line in latex write(fid, line * "\n") end end end +function write_csv_table(filename, country, input_data, horizons, number_variables; model_variant::String="base") + """ + Write table data to CSV format for easy programmatic access. + + # Arguments + - `filename`: CSV filename (e.g., "rmse_var.csv") + - `country`: Country code + - `input_data`: Matrix of numerical values (horizons × variables) + - `horizons`: Vector of forecast horizons + - `number_variables`: Number of variables (5 for base, 8 for validation) + - `model_variant`: Subfolder for different model variants (default: "base") + """ + + # Variable names based on number of variables + if number_variables == 5 + variable_names = ["Real GDP", "GDP Deflator Growth", "Real Consumption", "Real Investment", "Euribor"] + elseif number_variables == 8 + variable_names = ["Real GDP", "GDP Deflator Growth", "Real Gov Consumption", "Real Exports", "Real Imports", "Real GDP (EA)", "GDP Deflator Growth (EA)", "Euribor"] + else + # Fallback for other cases + variable_names = ["Var$i" for i in 1:number_variables] + end + + # Create DataFrame with proper structure + df = DataFrame() + + # Add horizon column + df.Horizon = ["$(h)q" for h in horizons] + + # Add variable columns + for (j, var_name) in enumerate(variable_names) + if j <= size(input_data, 2) + df[!, var_name] = input_data[:, j] + else + df[!, var_name] = fill(NaN, length(horizons)) + end + end + + # Write CSV file + analysis_dir = "data/$(country)/analysis/$(model_variant)" + mkpath(analysis_dir) + csv_path = "$(analysis_dir)/$(filename)" + CSV.write(csv_path, df) + + @info "Saved CSV table: $csv_path" +end + function generate_dm_test_comparison(error1, error2, rmse1, rmse2, horizons, number_variables) - input_data = -round.(100 * (rmse1 .- rmse2) ./ rmse2, digits = 1) + input_data = -round.(100 * (rmse1 .- rmse2) ./ rmse2, digits=1) input_data_S = fill("", size(input_data)) + pval_matrix = fill(NaN, size(input_data)) for j in 1:length(horizons) h = horizons[j] @@ -78,14 +127,15 @@ function generate_dm_test_comparison(error1, error2, rmse1, rmse2, horizons, num dm_error1 = view(error1, :, j, l)[map(!, isnan.(view(error1, :, j, l)))] dm_error2 = view(error2, :, j, l)[map(!, isnan.(view(error2, :, j, l)))] _, p_value = Bit.dmtest_modified(dm_error2, dm_error1, h) - input_data_S[j, l] = string(input_data[j, l]) * "(" * string(round(p_value, digits = 2)) * ", " * string(stars(p_value)) * ")" + input_data_S[j, l] = string(input_data[j, l]) * "(" * string(round(p_value, digits=2)) * ", " * string(stars(p_value)) * ")" + pval_matrix[j, l] = p_value end end - return input_data_S + return input_data_S, pval_matrix end function generate_mz_test_bias(error, actual, bias, horizons, number_variables) - input_data = round.(bias, digits = 4) + input_data = round.(bias, digits=4) input_data_S = fill("", size(input_data)) for j in 1:length(horizons) @@ -94,49 +144,257 @@ function generate_mz_test_bias(error, actual, bias, horizons, number_variables) mz_forecast = (view(error, :, j, l) + view(actual, :, j, l))[map(!, isnan.(view(error, :, j, l) + view(actual, :, j, l)))] mz_actual = view(actual, :, j, l)[map(!, isnan.(view(actual, :, j, l)))] _, _, p_value = Bit.mztest(mz_actual, mz_forecast) - input_data_S[j, l] = string(input_data[j, l]) * " (" * string(round(p_value, digits = 3)) * ", " * stars(p_value) * ")" + input_data_S[j, l] = string(input_data[j, l]) * " (" * string(round(p_value, digits=3)) * ", " * stars(p_value) * ")" end end return input_data_S end -function create_bias_rmse_tables_abm(forecast, actual, horizons, type, number_variables) +function generate_bias_ttest(error, bias, horizons, number_variables) + input_data = round.(bias, digits=4) + input_data_S = fill("", size(input_data)) + pval_matrix = fill(NaN, size(input_data)) + + for j in 1:length(horizons) + h = horizons[j] + for l in 1:number_variables + e = view(error, :, j, l)[map(!, isnan.(view(error, :, j, l)))] + _, p_value = Bit.bias_ttest(e, h) + input_data_S[j, l] = string(input_data[j, l]) * " (" * string(round(p_value, digits=3)) * ", " * stars(p_value) * ")" + pval_matrix[j, l] = p_value + end + end + return input_data_S, pval_matrix +end + +function create_bias_rmse_tables_abm(forecast, actual, horizons, type, number_variables, country; model_variant::String="base") type_prefix = type == "validation" ? "validation_" : "" + comparison_model = type == "validation" ? "var" : "ar" rmse_abm, bias_abm, error_abm = calculate_forecast_errors(forecast, actual) - forecast_var = load("data/$(country)/analysis/forecast_$(type_prefix)var.jld2")["forecast"] - rmse_var, _, error_var = calculate_forecast_errors(forecast_var, actual) + # 1. Save ABSOLUTE RMSE + rmse_abs_numeric = round.(rmse_abm, digits=2) + write_csv_table("rmse_$(type_prefix)abm.csv", country, rmse_abs_numeric, horizons, number_variables; model_variant=model_variant) + write_latex_table("rmse_$(type_prefix)abm.tex", country, string.(rmse_abs_numeric), horizons; model_variant=model_variant) + + # 2. Save RELATIVE TO AR/VAR benchmark + base_analysis_dir = "data/$(country)/analysis/base" + forecast_benchmark = load("$(base_analysis_dir)/forecast_$(type_prefix)$(comparison_model).jld2")["forecast"] + rmse_benchmark, _, error_benchmark = calculate_forecast_errors(forecast_benchmark, actual) + + rmse_vs_benchmark_latex, pval_vs_benchmark = generate_dm_test_comparison(error_abm, error_benchmark, rmse_abm, rmse_benchmark, horizons, number_variables) + write_latex_table("rmse_$(type_prefix)abm_vs_$(comparison_model).tex", country, rmse_vs_benchmark_latex, horizons; model_variant=model_variant) - rmse_comparison_data = generate_dm_test_comparison(error_abm, error_var, rmse_abm, rmse_var, horizons, number_variables) - write_latex_table("rmse_$(type_prefix)abm.tex", country, rmse_comparison_data, horizons) + rmse_vs_benchmark_numeric = -round.(100 * (rmse_abm .- rmse_benchmark) ./ rmse_benchmark, digits=1) + write_csv_table("rmse_$(type_prefix)abm_vs_$(comparison_model).csv", country, rmse_vs_benchmark_numeric, horizons, number_variables; model_variant=model_variant) + write_csv_table("pval_$(type_prefix)abm_vs_$(comparison_model).csv", country, pval_vs_benchmark, horizons, number_variables; model_variant=model_variant) - bias_data = generate_mz_test_bias(error_abm, actual, bias_abm, horizons, number_variables) - write_latex_table("bias_$(type_prefix)abm.tex", country, bias_data, horizons) + # 3. Save VARIANT VS BASE ABM (only for non-base variants) + if model_variant != "base" + forecast_base_abm = load("$(base_analysis_dir)/forecast_$(type_prefix)abm.jld2")["forecast"] + rmse_base_abm, _, error_base_abm = calculate_forecast_errors(forecast_base_abm, actual) + + rmse_vs_base_latex, pval_vs_base = generate_dm_test_comparison(error_abm, error_base_abm, rmse_abm, rmse_base_abm, horizons, number_variables) + write_latex_table("rmse_$(type_prefix)abm_vs_base.tex", country, rmse_vs_base_latex, horizons; model_variant=model_variant) + + rmse_vs_base_numeric = -round.(100 * (rmse_abm .- rmse_base_abm) ./ rmse_base_abm, digits=1) + write_csv_table("rmse_$(type_prefix)abm_vs_base.csv", country, rmse_vs_base_numeric, horizons, number_variables; model_variant=model_variant) + write_csv_table("pval_$(type_prefix)abm_vs_base.csv", country, pval_vs_base, horizons, number_variables; model_variant=model_variant) + end + + # 4. Save bias with t-test p-values (HAC standard errors) + bias_data_latex, pval_bias = generate_bias_ttest(error_abm, bias_abm, horizons, number_variables) + write_latex_table("bias_$(type_prefix)abm.tex", country, bias_data_latex, horizons; model_variant=model_variant) + + bias_data_numeric = round.(bias_abm, digits=4) + write_csv_table("bias_$(type_prefix)abm.csv", country, bias_data_numeric, horizons, number_variables; model_variant=model_variant) + write_csv_table("pval_bias_$(type_prefix)abm.csv", country, pval_bias, horizons, number_variables; model_variant=model_variant) return nothing end -function create_bias_rmse_tables_var(forecast, actual, horizons, type, number_variables, k) - type_prefix = type == "validation" ? "validation_" : "" +# ============================================================================= +# SHARED CONSTANTS FOR AGGREGATE ANALYSIS +# ============================================================================= + +const AGGREGATE_TABLE_TYPES = [ + "rmse_abm", "bias_abm", + "rmse_validation_abm", "bias_validation_abm", + "rmse_abm_vs_ar", "rmse_ar", "bias_ar", + "rmse_validation_abm_vs_var", "rmse_validation_var", "bias_validation_var", + "rmse_abm_vs_base", + "rmse_validation_abm_vs_base", +] + +const AGG_BASE_VARIABLES = ["Real GDP", "GDP Deflator Growth", "Real Consumption", "Real Investment", "Euribor"] +const AGG_VALIDATION_VARIABLES = ["Real GDP", "GDP Deflator Growth", "Real Gov Consumption", "Real Exports", + "Real Imports", "Real GDP (EA)", "GDP Deflator Growth (EA)", "Euribor"] + +const FORECAST_HORIZONS = [1, 2, 4, 8, 12] - return if k == 1 - save("data/$(country)/analysis/forecast_$(type_prefix)var.jld2", "forecast", forecast) +get_variables_for_table(t) = occursin("validation", t) ? AGG_VALIDATION_VARIABLES : AGG_BASE_VARIABLES + +const VARIANT_CONFIG = Dict( + "base" => (prediction_folder="abm_predictions", extension_file=nothing), + "growth_rate" => (prediction_folder="abm_predictions_growth_rate", extension_file="../../GrowthRateAR1_extension.jl"), + "canvas" => (prediction_folder="abm_predictions_canvas", extension_file="../../CANVAS_extension.jl"), +) + +# ============================================================================= +# SHARED DATA LOADING FOR AGGREGATE ANALYSIS +# ============================================================================= + +""" + load_country_forecast_data(model_variant, table_type, countries) + +Load per-country CSVs and optional p-value CSVs for a given table type. +Returns a NamedTuple (country_data, pval_data, valid_countries, variables, has_pvals) +or `nothing` if no data found. +""" +function load_country_forecast_data(model_variant, table_type, countries) + variables = get_variables_for_table(table_type) + + country_data = Dict{String, DataFrame}() + for country in countries + csv_file = joinpath("data", country, "analysis", model_variant, "$(table_type).csv") + isfile(csv_file) || continue + try + df = CSV.read(csv_file, DataFrame) + if startswith(table_type, "rmse_") && !occursin("vs_", table_type) + numeric_cols = [c for c in names(df) if c != "Horizon" && eltype(df[!, c]) <: Number] + if any(any(x -> !ismissing(x) && !isnan(x) && x < 0, df[!, c]) for c in numeric_cols) + @warn "Skipping $csv_file: negative values in absolute RMSE" + continue + end + end + country_data[country] = df + catch e + @warn "Failed to load $csv_file" exception=(e, catch_backtrace()) + end + end + + isempty(country_data) && return nothing + valid_countries = sort(collect(keys(country_data))) + + is_comparison = occursin("_vs_", table_type) + is_bias = startswith(table_type, "bias_") + has_pvals = is_comparison || is_bias + + pval_data = Dict{String, DataFrame}() + if has_pvals + pval_name = is_comparison ? replace(table_type, "rmse_" => "pval_") : "pval_" * table_type + for country in valid_countries + pval_file = joinpath("data", country, "analysis", model_variant, "$(pval_name).csv") + isfile(pval_file) || continue + try + pval_data[country] = CSV.read(pval_file, DataFrame) + catch e + @warn "Failed to load $pval_file" exception=(e, catch_backtrace()) + end + end + end + + return (; country_data, pval_data, valid_countries, variables, has_pvals) +end + +""" + build_12q_matrices(result) + +Build country x variable matrices from 12q horizon data. +Takes the output of `load_country_forecast_data`. +Returns (matrix, pval_matrix). +""" +function build_12q_matrices(result) + (; country_data, pval_data, valid_countries, variables, has_pvals) = result + n_countries = length(valid_countries) + n_variables = length(variables) + + matrix = fill(NaN, n_countries, n_variables) + pval_matrix = fill(NaN, n_countries, n_variables) + + for (i, country) in enumerate(valid_countries) + df = country_data[country] + row_12q = df[df.Horizon .== "12q", :] + nrow(row_12q) == 1 || continue + for (j, var) in enumerate(variables) + if var in names(df) && !ismissing(row_12q[1, var]) && !isnan(row_12q[1, var]) + matrix[i, j] = row_12q[1, var] + end + end + if has_pvals && haskey(pval_data, country) + pdf = pval_data[country] + prow = pdf[pdf.Horizon .== "12q", :] + nrow(prow) == 1 || continue + for (j, var) in enumerate(variables) + if var in names(pdf) && !ismissing(prow[1, var]) && !isnan(prow[1, var]) + pval_matrix[i, j] = prow[1, var] + end + end + end + end + + return matrix, pval_matrix +end + +""" + discover_countries_with_variant_predictions(prediction_folder) + +Find country codes that have JLD2 prediction files in the given folder. +""" +function discover_countries_with_variant_predictions(prediction_folder::String) + countries = String[] + isdir("data") || return countries + for entry in readdir("data") + preds_dir = joinpath("data", entry, prediction_folder) + if isdir(preds_dir) && any(endswith(".jld2"), readdir(preds_dir)) + push!(countries, entry) + end + end + return sort(countries) +end + +# ============================================================================= +# PER-COUNTRY TABLE GENERATION +# ============================================================================= + +function create_bias_rmse_tables_var(forecast, actual, horizons, forecast_type, model_type, number_variables, k, country; model_variant::String="base") + type_prefix = forecast_type == "validation" ? "validation_" : "" + analysis_dir = "data/$(country)/analysis/$(model_variant)" + mkpath(analysis_dir) + + if k == 1 + save("$(analysis_dir)/forecast_$(type_prefix)$(model_type).jld2", "forecast", forecast) rmse_var, bias_var, error_var = calculate_forecast_errors(forecast, actual) - input_data_rmse = round.(rmse_var, digits = 2) - write_latex_table("rmse_$(type_prefix)var.tex", country, string.(input_data_rmse), horizons) + # Save RMSE data + input_data_rmse = round.(rmse_var, digits=2) + write_latex_table("rmse_$(type_prefix)$(model_type).tex", country, string.(input_data_rmse), horizons; model_variant=model_variant) + write_csv_table("rmse_$(type_prefix)$(model_type).csv", country, input_data_rmse, horizons, number_variables; model_variant=model_variant) + + # Save bias data with t-test p-values (HAC standard errors) + bias_data_latex, pval_bias = generate_bias_ttest(error_var, bias_var, horizons, number_variables) + write_latex_table("bias_$(type_prefix)$(model_type).tex", country, bias_data_latex, horizons; model_variant=model_variant) + + # For CSV, save just the numerical bias values (without p-values and stars) + bias_data_numeric = round.(bias_var, digits=4) + write_csv_table("bias_$(type_prefix)$(model_type).csv", country, bias_data_numeric, horizons, number_variables; model_variant=model_variant) + write_csv_table("pval_bias_$(type_prefix)$(model_type).csv", country, pval_bias, horizons, number_variables; model_variant=model_variant) - bias_data = generate_mz_test_bias(error_var, actual, bias_var, horizons, number_variables) - write_latex_table("bias_$(type_prefix)var.tex", country, bias_data, horizons) else - save("data/$(country)/analysis/forecast_$(type_prefix)var_$(k).jld2", "forecast", forecast) + save("$(analysis_dir)/forecast_$(type_prefix)$(model_type)_$(k).jld2", "forecast", forecast) rmse_var_k, _, error_var_k = calculate_forecast_errors(forecast, actual) - forecast_base_var = load("data/$(country)/analysis/forecast_$(type_prefix)var.jld2")["forecast"] + forecast_base_var = load("$(analysis_dir)/forecast_$(type_prefix)$(model_type).jld2")["forecast"] rmse_base_var, _, error_base_var = calculate_forecast_errors(forecast_base_var, actual) - rmse_comparison_data = generate_dm_test_comparison(error_var_k, error_base_var, rmse_var_k, rmse_base_var, horizons, number_variables) - write_latex_table("rmse_$(type_prefix)var_$(k).tex", country, rmse_comparison_data, horizons) + # Generate comparison data for LaTeX (with p-values and stars) + rmse_comparison_data_latex, pval_comparison = generate_dm_test_comparison(error_var_k, error_base_var, rmse_var_k, rmse_base_var, horizons, number_variables) + write_latex_table("rmse_$(type_prefix)$(model_type)_$(k).tex", country, rmse_comparison_data_latex, horizons; model_variant=model_variant) + + # For CSV, save just the numerical comparison values (percentage improvement) + rmse_comparison_data_numeric = -round.(100 * (rmse_var_k .- rmse_base_var) ./ rmse_base_var, digits=1) + write_csv_table("rmse_$(type_prefix)$(model_type)_$(k).csv", country, rmse_comparison_data_numeric, horizons, number_variables; model_variant=model_variant) + write_csv_table("pval_$(type_prefix)$(model_type)_$(k).csv", country, pval_comparison, horizons, number_variables; model_variant=model_variant) end -end +end \ No newline at end of file diff --git a/examples/analysis/tabs/simulate_and_create_tables_multicountry.jl b/examples/analysis/tabs/simulate_and_create_tables_multicountry.jl new file mode 100644 index 00000000..26f04673 --- /dev/null +++ b/examples/analysis/tabs/simulate_and_create_tables_multicountry.jl @@ -0,0 +1,162 @@ +# Multi-Country Simulation and Table Creation +# Runs the prediction pipeline for all countries with calibration data +# +# Supports base model and extension variants (e.g., CANVAS, GrowthRateAR1). +# Set MODEL_VARIANT below — PREDICTION_FOLDER and EXTENSION_FILE are derived automatically. + +import BeforeIT as Bit +using Dates, JLD2, CSV, DataFrames, Statistics + +include(joinpath(@__DIR__, "analysis_utils.jl")) + +# ============================================================================= +# CONFIGURATION +# ============================================================================= + +RUN_SIMULATION = true # Run simulations for all countries +RUN_ANALYSIS = true # Generate error tables + +T = 12 # Forecast horizon (quarters) +N_SIMS = 100 # Number of simulations per quarter + +QUARTERS = DateTime(2010, 03, 31):Dates.Month(3):DateTime(2019, 12, 31) +HORIZONS = [1, 2, 4, 8, 12] + +# ============================================================================= +# MODEL VARIANT CONFIGURATION +# ============================================================================= +# Change MODEL_VARIANT to run for different model variants. +# PREDICTION_FOLDER and EXTENSION_FILE are derived from VARIANT_CONFIG. +# Results will be saved to: data/{country}/analysis/{MODEL_VARIANT}/ +# +# Options: "base", "growth_rate", "canvas" + +MODEL_VARIANT = "canvas" + +_config = VARIANT_CONFIG[MODEL_VARIANT] +PREDICTION_FOLDER = _config.prediction_folder +EXTENSION_FILE = _config.extension_file + +# ============================================================================= +# EXTENSION INCLUDE (must be at top level for method dispatch) +# ============================================================================= + +if EXTENSION_FILE !== nothing + include(joinpath(@__DIR__, EXTENSION_FILE)) + @info "Loaded extension: $EXTENSION_FILE (variant: $MODEL_VARIANT)" +end + +# ============================================================================= +# SIMULATION PHASE +# ============================================================================= + +if RUN_SIMULATION + if MODEL_VARIANT != "base" && !@isdefined(create_model) + error("Extension variant '$MODEL_VARIANT' requires create_model() factory. " * + "Check that EXTENSION_FILE '$EXTENSION_FILE' defines it.") + end + + @info "Starting simulations for all countries (variant: $MODEL_VARIANT)..." + + for country in Bit.discover_countries_with_calibration() + folder = "data/$(country)" + + # For non-base variants, check if simulations already exist + if MODEL_VARIANT != "base" + sim_folder = joinpath(folder, "simulations_$(MODEL_VARIANT)") + if isdir(sim_folder) + existing_files = filter(f -> endswith(f, ".jld2"), readdir(sim_folder)) + if length(existing_files) >= 40 + @info "Skipping $country: simulations already exist ($(length(existing_files)) files)" + continue + end + end + end + + @info "Processing $country" + + try + if MODEL_VARIANT == "base" + # Base model: standard pipeline + calibration = Bit.load_calibration_data(country) + Bit.save_all_simulations(folder; T = T, n_sims = N_SIMS) + Bit.save_all_predictions_from_sims(folder, calibration.data) + else + # Extension variant: use model_factory and suffixed folders + Bit.save_all_simulations(folder; + T = T, + n_sims = N_SIMS, + model_factory = create_model, + output_suffix = MODEL_VARIANT + ) + + # Convert simulations to predictions + calibration = Bit.load_calibration_data(country) + mkpath(joinpath(folder, PREDICTION_FOLDER)) + Bit.save_all_predictions_from_sims(folder, calibration.data; + simulation_suffix = "simulations_$(MODEL_VARIANT)", + prediction_suffix = PREDICTION_FOLDER + ) + end + + @info "Completed $country" + catch e + @error "Failed $country: $e" + for (exc, bt) in Base.catch_stack() + showerror(stderr, exc, bt) + println(stderr) + end + end + end +end + +# ============================================================================= +# ANALYSIS PHASE +# ============================================================================= + +if RUN_ANALYSIS + @info "Generating error tables (variant: $MODEL_VARIANT)..." + + # Load analysis functions (analysis_utils.jl already included at top level) + include(joinpath(@__DIR__, "error_table_ar.jl")) + include(joinpath(@__DIR__, "error_table_abm.jl")) + include(joinpath(@__DIR__, "error_table_validation_var.jl")) + include(joinpath(@__DIR__, "error_table_validation_abm.jl")) + + # Discover countries based on variant + analysis_countries = if MODEL_VARIANT == "base" + Bit.discover_countries_with_predictions() + else + discover_countries_with_variant_predictions(PREDICTION_FOLDER) + end + + for country in analysis_countries + @info "Generating tables for $country (variant: $(MODEL_VARIANT))" + + try + calibration = Bit.load_calibration_data(country) + data = calibration.data + ea = calibration.ea + + # Ensure analysis folder exists + mkpath(joinpath("data", country, "analysis", MODEL_VARIANT)) + + # AR/VAR benchmarks (use same variant folder for consistency) + error_table_ar(country, ea, data, QUARTERS, HORIZONS; + model_variant=MODEL_VARIANT) + error_table_validation_var(country, ea, data, QUARTERS, HORIZONS; model_variant=MODEL_VARIANT) + + # ABM predictions (load from prediction_folder, save to variant folder) + error_table_abm(country, ea, data, QUARTERS, HORIZONS; + model_variant=MODEL_VARIANT, prediction_folder=PREDICTION_FOLDER) + error_table_validation_abm(country, ea, data, QUARTERS, HORIZONS; + model_variant=MODEL_VARIANT, prediction_folder=PREDICTION_FOLDER) + + @info "Completed $country" + catch e + @error "Failed $country" exception=(e, catch_backtrace()) + end + end +end + +@info "Done."