diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bd6e8deb..e628ca2a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,14 +10,14 @@ repos: args: ['--maxkb=1000'] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.14 + rev: v0.15.9 hooks: - id: ruff args: [ --fix ] - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.19.1 + rev: v1.20.0 hooks: - id: mypy pass_filenames: false diff --git a/Cargo.lock b/Cargo.lock index f1e4f5bd..51d605a0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -30,7 +30,7 @@ dependencies = [ [[package]] name = "akquant" -version = "0.2.5" +version = "0.2.6" dependencies = [ "anyhow", "chrono", diff --git a/Cargo.toml b/Cargo.toml index 10f1305f..37a8686e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "akquant" -version = "0.2.5" +version = "0.2.6" edition = "2024" description = "High-performance quantitative trading framework based on Rust and Python" license = "MIT" diff --git a/docs/en/advanced/llm.md b/docs/en/advanced/llm.md index 92ee973c..3feabbe6 100644 --- a/docs/en/advanced/llm.md +++ b/docs/en/advanced/llm.md @@ -147,7 +147,9 @@ Use this template if the user needs to generate a machine learning strategy. * **Configuration**: Call `self.model.set_validation(...)` to configure Walk-Forward Validation. This automatically sets up the rolling window and training triggers. * **Feature Engineering**: Implement `prepare_features(self, df, mode)` method. * **Training**: The framework automatically calls `on_train_signal` -> `prepare_features(mode='training')` -> `model.fit()` based on the validation config. - * **Inference**: In `on_bar`, manually call `prepare_features(mode='inference')` and then `model.predict()`. + * **Inference**: In `on_bar`, first check `self.is_model_ready()` and `self.current_validation_window()`, then call `prepare_features(mode='inference')` and `model.predict()`. + * **Lifecycle**: Training happens on the current bar, but the newly trained model activates on the next bar. `test_window` defines the planned OOS range, and `rolling_step=0` falls back to `test_window`. + * **Clone**: The framework calls `model.clone()` for each training window. Custom models should override it if `deepcopy` is unsafe. 3. **Data Handling**: * `prepare_features(df, mode)`: @@ -174,6 +176,7 @@ class MLStrategy(Strategy): self.model.set_validation( method='walk_forward', train_window='200d', # Train on last 200 days data + test_window='30d', # Planned OOS window for the active model rolling_step='30d', # Retrain every 30 days frequency='1d', verbose=True @@ -204,6 +207,10 @@ class MLStrategy(Strategy): def on_bar(self, bar: Bar): # 3. Inference (Real-time) + window = self.current_validation_window() + if window is None or not self.is_model_ready(): + return + # Ensure enough history for feature calculation hist_df = self.get_history_df(30) # Small buffer for features if len(hist_df) < 10: @@ -216,6 +223,9 @@ class MLStrategy(Strategy): try: pred = self.model.predict(X_curr)[0] pos = self.get_position(bar.symbol) + active_start = window['active_start_bar'] + active_end = window['active_end_bar'] + print(f"Window [{active_start}, {active_end}] | pred={pred}") if pred == 1 and pos == 0: self.buy(bar.symbol, 1000) diff --git a/docs/en/advanced/ml.md b/docs/en/advanced/ml.md index 77d2703a..401d44a6 100644 --- a/docs/en/advanced/ml.md +++ b/docs/en/advanced/ml.md @@ -46,6 +46,17 @@ Feature preprocessing (e.g., standardization, normalization) can also introduce * **Isolation**: During Walk-forward training, Pipeline calls `fit` (calculating mean/variance) only on the current training window data, then applies it to the validation set. * **Consistency**: In the inference phase, Pipeline automatically applies the trained statistics without manual user maintenance. +### 6. Model Lifecycle in the Current Compatibility Mode + +The current Walk-forward implementation uses a compatibility-oriented lifecycle: + +* **Training Window**: After the current bar finishes, the framework trains a new model clone on the latest `train_window` bars. +* **Delayed Activation**: The newly trained model does not predict on the current bar. It becomes active on the next bar. +* **Effective Range**: `test_window` defines the intended out-of-sample range for the active model. +* **Rolling Updates**: `rolling_step` controls when the next retraining is triggered. If it is `0`, the framework falls back to `test_window`. +* **Explicit State Checks**: In `on_bar`, prefer `self.is_model_ready()` and `self.current_validation_window()` before calling `self.model.predict(...)`. +* **Model Cloning**: The framework calls `QuantModel.clone()` to create a pending model for each training window. Override it if your custom model cannot be deep-copied safely. + --- ## Complete Runnable Example @@ -83,6 +94,7 @@ class WalkForwardStrategy(Strategy): self.model.set_validation( method='walk_forward', train_window=50, # Train on past 50 bars + test_window=20, # Keep each fitted model active for 20 OOS bars rolling_step=10, # Retrain every 10 bars frequency='1m', # Data frequency incremental=False, # Whether to use incremental learning (Sklearn supports partial_fit) @@ -92,6 +104,8 @@ class WalkForwardStrategy(Strategy): # Ensure history depth covers training window + feature calculation window # Alternatively use self.warmup_period = 60 self.set_history_depth(60) + self._last_logged_window_index = 0 + self._last_logged_pending_activation = 0 def prepare_features(self, df: pd.DataFrame, mode: str = "training") -> Tuple[Any, Any]: """ @@ -131,6 +145,26 @@ class WalkForwardStrategy(Strategy): def on_bar(self, bar): # 3. Real-time Prediction & Trading + validation_window = self.current_validation_window() + if validation_window is None: + return + + pending_activation = validation_window["pending_activation_bar"] + if ( + not self.is_model_ready() + and pending_activation is not None + and pending_activation != self._last_logged_pending_activation + ): + print( + f"Bar {bar.timestamp}: " + f"Pending Window={validation_window['pending_window_index']} " + f"Activation Bar={pending_activation}" + ) + self._last_logged_pending_activation = int(pending_activation) + return + + if not self.is_model_ready(): + return # Get recent history for feature extraction # Note: Need enough history to calculate features (e.g. pct_change(2) needs at least 3 bars) @@ -148,9 +182,24 @@ class WalkForwardStrategy(Strategy): # Get prediction signal (probability) # SklearnAdapter returns probability of Class 1 for binary classification signal = self.model.predict(X_curr)[0] - - # Print signal for observation - # print(f"Time: {bar.timestamp}, Signal: {signal:.4f}") + window_index = int(validation_window["window_index"]) + active_start_bar = validation_window["active_start_bar"] + active_end_bar = validation_window["active_end_bar"] + + if window_index != self._last_logged_window_index: + print( + f"Bar {bar.timestamp}: " + f"Activated Window={window_index} " + f"ActiveRange=[{active_start_bar}, {active_end_bar}]" + ) + self._last_logged_window_index = window_index + + print( + f"Bar {bar.timestamp}: " + f"Window={window_index} " + f"ActiveRange=[{active_start_bar}, {active_end_bar}] " + f"Signal={signal:.4f}" + ) # Combine with risk rules for ordering # Use self.get_position(symbol) to check position @@ -162,7 +211,7 @@ class WalkForwardStrategy(Strategy): self.sell(bar.symbol, pos) except Exception: - # Model might not be initialized or training failed + # Keep the example resilient to inference-time failures pass if __name__ == "__main__": @@ -306,12 +355,24 @@ def set_validation( * `method`: Currently only supports `'walk_forward'`. * `train_window`: Length of training window. Supports `'1y'` (1 year), `'6m'` (6 months), `'50d'` (50 days), or integer (number of bars). -* `test_window`: Length of testing window (not strictly used in current rolling mode, mainly for evaluation configuration). -* `rolling_step`: Rolling step size, i.e., how often to retrain the model. +* `test_window`: Intended out-of-sample window length for the active model. In compatibility mode, the newly trained model activates on the next bar and covers this range by default. +* `rolling_step`: Rolling step size, i.e., how often to retrain the model. If it is `0`, the framework falls back to `test_window`. * `frequency`: Data frequency, used to correctly convert time strings to bar counts (e.g., 1y = 252 bars under '1d'). -* `incremental`: Whether to use incremental learning (continue training based on last model) or retrain from scratch. Default is `False`. +* `incremental`: Whether to use incremental learning (continue training from the last active model) or retrain from scratch. Default is `False`. * `verbose`: Whether to print training logs. Default is `False`. +### `model.clone` + +Create a model copy for a new training window. + +```python +def clone(self) -> QuantModel +``` + +* The default implementation uses `copy.deepcopy`. +* Override this method if your model owns GPU handles, locks, file descriptors, or any state that should not be copied blindly. +* The framework trains a pending model on the current bar and activates it on the next one, so `clone()` is central to window isolation. + ### `strategy.prepare_features` Callback function that must be implemented by the user for feature engineering. @@ -327,3 +388,30 @@ def prepare_features(self, df: pd.DataFrame, mode: str = "training") -> Tuple[An * `mode="training"`: Return `(X, y)`. * `mode="inference"`: Return `X` (usually the last row). * **Note**: This is a pure function and should not rely on external state. + +### `strategy.is_model_ready` + +Check whether an active model is currently available for inference. + +```python +def is_model_ready(self) -> bool +``` + +* `True` means it is safe to call `self.model.predict(...)` on the current bar. +* Before the first training window completes, this typically returns `False`. + +### `strategy.current_validation_window` + +Return the current Walk-forward lifecycle state. + +```python +def current_validation_window(self) -> dict[str, Any] | None +``` + +The returned dictionary may include: + +* `window_index`: Current active window index +* `active_start_bar` / `active_end_bar`: Planned active range of the current model +* `pending_activation_bar`: Bar index where the pending model will become active +* `pending_window_index`: Pending window index +* `next_train_bar`: Next scheduled retraining bar index diff --git a/docs/zh/advanced/llm.md b/docs/zh/advanced/llm.md index 02a489c1..2aa16145 100644 --- a/docs/zh/advanced/llm.md +++ b/docs/zh/advanced/llm.md @@ -155,7 +155,9 @@ class MovingAverageStrategy(Strategy): * **Configuration**: Call `self.model.set_validation(...)` to configure Walk-Forward Validation. This automatically sets up the rolling window and training triggers. * **Feature Engineering**: Implement `prepare_features(self, df, mode)` method. * **Training**: The framework automatically calls `on_train_signal` -> `prepare_features(mode='training')` -> `model.fit()` based on the validation config. - * **Inference**: In `on_bar`, manually call `prepare_features(mode='inference')` and then `model.predict()`. + * **Inference**: In `on_bar`, first check `self.is_model_ready()` and `self.current_validation_window()`, then call `prepare_features(mode='inference')` and `model.predict()`. + * **Lifecycle**: Training happens on the current bar, but the newly trained model activates on the next bar. `test_window` defines the planned OOS range, and `rolling_step=0` falls back to `test_window`. + * **Clone**: The framework calls `model.clone()` for each training window. Custom models should override it if `deepcopy` is unsafe. 3. **Data Handling**: * `prepare_features(df, mode)`: @@ -182,6 +184,7 @@ class MLStrategy(Strategy): self.model.set_validation( method='walk_forward', train_window='200d', # Train on last 200 days data + test_window='30d', # Planned OOS window for the active model rolling_step='30d', # Retrain every 30 days frequency='1d', verbose=True @@ -212,6 +215,10 @@ class MLStrategy(Strategy): def on_bar(self, bar: Bar): # 3. Inference (Real-time) + window = self.current_validation_window() + if window is None or not self.is_model_ready(): + return + # Ensure enough history for feature calculation hist_df = self.get_history_df(30) # Small buffer for features if len(hist_df) < 10: @@ -224,6 +231,9 @@ class MLStrategy(Strategy): try: pred = self.model.predict(X_curr)[0] pos = self.get_position(bar.symbol) + active_start = window['active_start_bar'] + active_end = window['active_end_bar'] + print(f"Window [{active_start}, {active_end}] | pred={pred}") if pred == 1 and pos == 0: self.buy(bar.symbol, 1000) diff --git a/docs/zh/advanced/ml.md b/docs/zh/advanced/ml.md index 715ba06a..1b6ed3a0 100644 --- a/docs/zh/advanced/ml.md +++ b/docs/zh/advanced/ml.md @@ -46,6 +46,17 @@ AKQuant 内置了一个高性能的机器学习训练框架,专为量化交易 * **隔离**: 在 Walk-forward 训练时,Pipeline 只会在当前的训练窗口数据上调用 `fit`(计算均值/方差),然后应用到验证集上。 * **一致性**: 在推理阶段,Pipeline 会自动应用训练好的统计量,无需用户手动维护。 +### 6. 当前兼容模式下的模型生命周期 + +当前版本的 Walk-forward 采用“兼容式生命周期管理”: + +* **训练窗口**: 在当前 bar 完成后,框架使用最近 `train_window` 根数据训练一个新模型副本。 +* **延迟生效**: 新模型不会在当前 bar 立即参与预测,而是从下一根 bar 开始生效。 +* **生效区间**: `test_window` 用于定义该模型计划覆盖的样本外区间。 +* **滚动更新**: `rolling_step` 决定下一次重训触发点;若设置为 `0`,框架会回退使用 `test_window`。 +* **显式状态**: 在 `on_bar` 中建议通过 `self.is_model_ready()` 和 `self.current_validation_window()` 判断模型是否已可用于推理,并读取当前窗口元数据。 +* **模型克隆**: 框架会调用 `QuantModel.clone()` 为每个训练窗口创建待训练模型副本;如果自定义模型不适合 `deepcopy`,请重写该方法。 + --- ## 完整可运行示例 @@ -83,6 +94,7 @@ class WalkForwardStrategy(Strategy): self.model.set_validation( method='walk_forward', train_window=50, # 使用过去 50 个 bar 训练 + test_window=20, # 每个模型默认覆盖 20 个样本外 bar rolling_step=10, # 每 10 个 bar 重训一次 frequency='1m', # 数据频率 incremental=False, # 是否增量训练 (Sklearn 支持 partial_fit) @@ -92,6 +104,8 @@ class WalkForwardStrategy(Strategy): # 确保历史数据长度足够 (训练窗口 + 特征计算所需窗口) # 也可以使用 self.warmup_period = 60 self.set_history_depth(60) + self._last_logged_window_index = 0 + self._last_logged_pending_activation = 0 def prepare_features(self, df: pd.DataFrame, mode: str = "training") -> Tuple[Any, Any]: """ @@ -131,6 +145,26 @@ class WalkForwardStrategy(Strategy): def on_bar(self, bar): # 3. 实时预测与交易 + validation_window = self.current_validation_window() + if validation_window is None: + return + + pending_activation = validation_window["pending_activation_bar"] + if ( + not self.is_model_ready() + and pending_activation is not None + and pending_activation != self._last_logged_pending_activation + ): + print( + f"Bar {bar.timestamp}: " + f"Pending Window={validation_window['pending_window_index']} " + f"Activation Bar={pending_activation}" + ) + self._last_logged_pending_activation = int(pending_activation) + return + + if not self.is_model_ready(): + return # 获取最近的数据进行特征提取 # 注意:需要足够的历史长度来计算特征 (例如 pct_change(2) 需要至少3根bar) @@ -148,9 +182,25 @@ class WalkForwardStrategy(Strategy): # 获取预测信号 (概率) # SklearnAdapter 对于二分类返回 Class 1 的概率 signal = self.model.predict(X_curr)[0] + window_index = int(validation_window["window_index"]) + active_start_bar = validation_window["active_start_bar"] + active_end_bar = validation_window["active_end_bar"] + + if window_index != self._last_logged_window_index: + print( + f"Bar {bar.timestamp}: " + f"Activated Window={window_index} " + f"ActiveRange=[{active_start_bar}, {active_end_bar}]" + ) + self._last_logged_window_index = window_index # 打印信号方便观察 - # print(f"Time: {bar.timestamp}, Signal: {signal:.4f}") + print( + f"Bar {bar.timestamp}: " + f"Window={window_index} " + f"ActiveRange=[{active_start_bar}, {active_end_bar}] " + f"Signal={signal:.4f}" + ) # 结合风控规则下单 # 使用 self.get_position(symbol) 获取持仓 @@ -162,7 +212,7 @@ class WalkForwardStrategy(Strategy): self.sell(bar.symbol, pos) except Exception: - # 模型可能尚未初始化或训练失败 + # 示例中仍然保留保护,避免推理异常中断回测 pass if __name__ == "__main__": @@ -306,12 +356,24 @@ def set_validation( * `method`: 目前仅支持 `'walk_forward'`。 * `train_window`: 训练窗口长度。支持 `'1y'` (1年), `'6m'` (6个月), `'50d'` (50天) 或整数 (Bar数量)。 -* `test_window`: 测试窗口长度 (在当前滚动训练模式下未严格使用,主要用于评估配置)。 -* `rolling_step`: 滚动步长,即每隔多久重训一次模型。 +* `test_window`: 模型计划生效的样本外窗口长度。在兼容模式下,新模型会从下一根 bar 开始生效,并默认覆盖该窗口长度。 +* `rolling_step`: 滚动步长,即每隔多久重训一次模型。若为 `0`,框架会回退使用 `test_window`。 * `frequency`: 数据的频率,用于将时间字符串正确转换为 Bar 数量 (例如 '1d' 下 1y=252 bars)。 -* `incremental`: 是否使用增量学习(在上次训练的基础上继续训练)还是从头重训。默认为 `False`。 +* `incremental`: 是否使用增量学习(在上次活动模型基础上继续训练)还是从头重训。默认为 `False`。 * `verbose`: 是否打印训练日志,默认为 `False`。 +### `model.clone` + +为新的训练窗口创建模型副本。 + +```python +def clone(self) -> QuantModel +``` + +* 默认实现使用 `copy.deepcopy`。 +* 如果你的模型包含 GPU 句柄、线程锁、外部连接或不可拷贝状态,建议显式重写该方法。 +* 框架会在当前 bar 训练待生效模型,并在下一根 bar 激活它,因此 `clone()` 是隔离窗口生命周期的重要组成部分。 + ### `strategy.prepare_features` 用户必须实现的回调函数,用于特征工程。 @@ -327,3 +389,30 @@ def prepare_features(self, df: pd.DataFrame, mode: str = "training") -> Tuple[An * `mode="training"`: 返回 `(X, y)`。 * `mode="inference"`: 返回 `X` (通常是最后一行)。 * **注意**: 这是一个纯函数,不应依赖外部状态。 + +### `strategy.is_model_ready` + +判断当前是否已有可用于推理的活动模型。 + +```python +def is_model_ready(self) -> bool +``` + +* 返回 `True` 表示当前 bar 可以安全调用 `self.model.predict(...)`。 +* 在首个训练窗口完成前,通常会返回 `False`。 + +### `strategy.current_validation_window` + +返回当前 Walk-forward 生命周期状态。 + +```python +def current_validation_window(self) -> dict[str, Any] | None +``` + +返回值可能包含: + +* `window_index`: 当前活动窗口编号 +* `active_start_bar` / `active_end_bar`: 当前活动模型的计划生效区间 +* `pending_activation_bar`: 待生效模型将在第几根 bar 激活 +* `pending_window_index`: 待生效窗口编号 +* `next_train_bar`: 下一次计划重训的 bar 编号 diff --git a/examples/02_parameter_optimization.py b/examples/02_parameter_optimization.py index 3150ad31..9ef7bc92 100644 --- a/examples/02_parameter_optimization.py +++ b/examples/02_parameter_optimization.py @@ -76,18 +76,21 @@ def on_bar(self, bar: Any) -> None: # 1. 生成模拟数据 print("Generating synthetic data...") dates = pd.date_range(start="2023-01-01", periods=500) - # 随机漫步 - close_prices = 100 + np.cumsum(np.random.randn(500)) - df = pd.DataFrame( - { - "open": close_prices, - "high": close_prices + 1, - "low": close_prices - 1, - "close": close_prices, - "volume": 1000, - }, - index=dates, - ) + symbols = ["OPT_A", "OPT_B"] + data_map: dict[str, pd.DataFrame] = {} + for index, symbol in enumerate(symbols): + close_prices = 100 + index * 5 + np.cumsum(np.random.randn(500)) + data_map[symbol] = pd.DataFrame( + { + "open": close_prices, + "high": close_prices + 1, + "low": close_prices - 1, + "close": close_prices, + "volume": 1000, + "symbol": symbol, + }, + index=dates, + ) # 2. 运行优化 print("Starting optimization...") @@ -110,8 +113,8 @@ def on_bar(self, bar: Any) -> None: results = run_grid_search( strategy=SMACrossStrategy, param_grid=param_grid, - data=df, - symbol="TEST", + data=data_map, + symbols=symbols, sort_by="total_return", # 按总收益排序 max_workers=2, ) diff --git a/examples/10_ml_walk_forward.py b/examples/10_ml_walk_forward.py index eaab55c0..fc8967b6 100644 --- a/examples/10_ml_walk_forward.py +++ b/examples/10_ml_walk_forward.py @@ -28,6 +28,7 @@ def __init__(self) -> None: self.model.set_validation( method="walk_forward", train_window=50, # Use last 50 bars for training + test_window=20, # Configure the out-of-sample horizon rolling_step=10, # Retrain every 10 bars frequency="1m", verbose=True, # Print training logs @@ -35,6 +36,8 @@ def __init__(self) -> None: # Ensure we have enough history for features + training self.set_history_depth(60) + self._last_logged_window_index = 0 + self._last_logged_pending_activation = 0 print("WalkForwardStrategy initialized") @@ -87,33 +90,29 @@ def on_bar(self, bar: Bar) -> None: if self.model is None: return - # Check if model is trained - # A simple heuristic: check if we have passed the first training window - if self._bar_count < 50: + validation_window = self.current_validation_window() + if validation_window is None: return - # 3. Real-time Prediction - # Reuse logic: Get recent history -> Extract features - hist_df = self.get_history_df(5) + pending_activation = validation_window["pending_activation_bar"] + if ( + not self.is_model_ready() + and pending_activation is not None + and pending_activation != self._last_logged_pending_activation + ): + pending_window_index = int(validation_window["pending_window_index"]) + print( + f"Bar {bar.timestamp}: " + f"Pending Window={pending_window_index} " + f"Activation Bar={pending_activation}" + ) + self._last_logged_pending_activation = int(pending_activation) + return - # Manually calculate features for the last bar (or reuse prepare_features if - # designed carefully) - # Here we do it manually for efficiency/clarity or reuse prepare_features with - # a trick - - # Reuse attempt: - # prepare_features drops the last row (because of shift(-1)). - # This means we can't use it directly to get the *current* feature vector for - # *next* prediction? - # Wait, to predict t+1, we need features at t. - # prepare_features(df) -> X[t], y[t] (where y[t] is return at t+1). - # We need X[t]. - # But prepare_features does X.iloc[:-1]. It drops X[t] because y[t] is - # unknown! - - # So we need a separate feature extraction or a mode. - # Let's stick to manual extraction for prediction in this demo, - # or implement a flexible prepare_features. + if not self.is_model_ready(): + return + + hist_df = self.get_history_df(5) current_ret1 = (bar.close - hist_df["close"].iloc[-2]) / hist_df["close"].iloc[ -2 @@ -126,13 +125,27 @@ def on_bar(self, bar: Bar) -> None: X_curr = X_curr.fillna(0) try: - # Predict pred_prob = self.model.predict(X_curr) signal = ( pred_prob[0] if isinstance(pred_prob, (list, np.ndarray)) else pred_prob ) - - print(f"Bar {bar.timestamp}: Pred Signal = {signal:.4f}") + window_index = int(validation_window["window_index"]) + active_start_bar = validation_window["active_start_bar"] + active_end_bar = validation_window["active_end_bar"] + if window_index != self._last_logged_window_index: + print( + f"Bar {bar.timestamp}: " + f"Activated Window={window_index} " + f"ActiveRange=[{active_start_bar}, {active_end_bar}]" + ) + self._last_logged_window_index = window_index + + print( + f"Bar {bar.timestamp}: " + f"Window={window_index} " + f"ActiveRange=[{active_start_bar}, {active_end_bar}] " + f"Pred Signal = {signal:.4f}" + ) if signal > 0.55: self.buy(bar.symbol, 100) diff --git a/examples/12_wfo_integrated.py b/examples/12_wfo_integrated.py index cf3c2255..55297db9 100644 --- a/examples/12_wfo_integrated.py +++ b/examples/12_wfo_integrated.py @@ -70,24 +70,24 @@ def param_constraint(params: Dict[str, Any]) -> bool: # 1. 生成模拟数据 (随机游走) np.random.seed(42) dates = pd.date_range(start="2020-01-01", end="2023-12-31", freq="D") - # 生成带趋势的随机游走 - returns = np.random.normal(0.0002, 0.02, len(dates)) # 每日微涨,波动率2% - price = 100 * np.cumprod(1 + returns) - - df = pd.DataFrame( - { - "date": dates, - "open": price, - "high": price * 1.01, - "low": price * 0.99, - "close": price, - "volume": 10000, - "symbol": "DEMO", - } - ) - df.set_index("date", inplace=True) - - print("Data loaded:", df.shape) + data_map: dict[str, pd.DataFrame] = {} + for index, symbol in enumerate(["DEMO_A", "DEMO_B"]): + returns = np.random.normal(0.0002 + index * 0.00005, 0.02, len(dates)) + price = (100 + index * 10) * np.cumprod(1 + returns) + df = pd.DataFrame( + { + "date": dates, + "open": price, + "high": price * 1.01, + "low": price * 0.99, + "close": price, + "volume": 10000, + "symbol": symbol, + } + ) + data_map[symbol] = df.set_index("date") + + print("Data loaded:", {symbol: frame.shape for symbol, frame in data_map.items()}) # 2. 定义参数网格 param_grid = { @@ -103,7 +103,7 @@ def param_constraint(params: Dict[str, Any]) -> bool: wfo_results = run_walk_forward( strategy=DualMovingAverageStrategy, param_grid=param_grid, - data=df, + data=data_map, train_period=250, test_period=60, metric="sharpe_ratio", # 优化目标: 夏普比率 @@ -111,6 +111,7 @@ def param_constraint(params: Dict[str, Any]) -> bool: warmup_calc=warmup_calc, constraint=param_constraint, compounding=False, # 不使用复利拼接 (简单累加盈亏) + symbols=list(data_map.keys()), ) if not wfo_results.empty: diff --git a/examples/textbook/ch11_optimization.py b/examples/textbook/ch11_optimization.py index dfa423f7..2ad3d7e7 100644 --- a/examples/textbook/ch11_optimization.py +++ b/examples/textbook/ch11_optimization.py @@ -133,9 +133,11 @@ def on_bar(self, bar: Bar) -> None: sorted_results = sorted( results, - key=lambda x: x.metrics.sharpe_ratio - if hasattr(x.metrics, "sharpe_ratio") - else x.metrics.get("sharpe_ratio", -999), + key=lambda x: ( + x.metrics.sharpe_ratio + if hasattr(x.metrics, "sharpe_ratio") + else x.metrics.get("sharpe_ratio", -999) + ), reverse=True, ) diff --git a/pyproject.toml b/pyproject.toml index 84553f68..9d550a51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "akquant" -version = "0.2.5" +version = "0.2.6" description = "High-performance quantitative trading framework based on Rust and Python" readme = "README.md" license = {text = "MIT License"} diff --git a/python/akquant/backtest/result.py b/python/akquant/backtest/result.py index 2a7bf942..39bf8557 100644 --- a/python/akquant/backtest/result.py +++ b/python/akquant/backtest/result.py @@ -424,9 +424,11 @@ def metrics_df(self) -> pd.DataFrame: # Margin Level = Equity / Used Margin # Avoid division by zero daily_agg["margin_level"] = daily_agg.apply( - lambda row: row["equity"] / row["margin"] - if row["margin"] > 0 - else float("inf"), + lambda row: ( + row["equity"] / row["margin"] + if row["margin"] > 0 + else float("inf") + ), axis=1, ) # Filter out inf (no margin used) diff --git a/python/akquant/ml/model.py b/python/akquant/ml/model.py index cc2c9560..9243b03c 100644 --- a/python/akquant/ml/model.py +++ b/python/akquant/ml/model.py @@ -70,6 +70,10 @@ def set_validation( verbose=verbose, ) + def clone(self) -> "QuantModel": + """Return a deep-copied model instance for a new validation window.""" + return copy.deepcopy(self) + @abstractmethod def fit(self, X: DataType, y: DataType) -> None: """ diff --git a/python/akquant/optimize.py b/python/akquant/optimize.py index 0764acd9..b70d20a8 100644 --- a/python/akquant/optimize.py +++ b/python/akquant/optimize.py @@ -16,7 +16,17 @@ from datetime import date, datetime, timedelta from datetime import time as datetime_time from logging.handlers import QueueHandler, QueueListener -from typing import Any, Dict, List, Mapping, Optional, Sequence, Type, Union, cast +from typing import ( + Any, + Dict, + List, + Mapping, + Optional, + Sequence, + Type, + Union, + cast, +) import numpy as np import pandas as pd @@ -27,6 +37,7 @@ from .strategy import Strategy _WORKER_LOG_QUEUE: Any = None +OptimizationData = Union[pd.DataFrame, Dict[str, pd.DataFrame]] def _normalize_backtest_symbol_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]: @@ -40,6 +51,170 @@ def _normalize_backtest_symbol_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]: return normalized +def _normalize_symbol_values(symbols: Any) -> list[str]: + """标准化 symbols 参数.""" + if symbols is None: + return [] + if isinstance(symbols, str): + normalized = [symbols] + elif isinstance(symbols, (list, tuple, set)): + normalized = [str(item) for item in symbols] + else: + raise TypeError("symbols must be a string, list, tuple, or set") + + cleaned: list[str] = [] + seen: set[str] = set() + for item in normalized: + value = str(item).strip() + if not value: + raise ValueError("symbols cannot contain empty values") + if value in seen: + continue + seen.add(value) + cleaned.append(value) + return cleaned + + +def _infer_symbols_from_data(data: Any) -> list[str]: + """从优化输入数据中推断 symbols.""" + if isinstance(data, dict): + return [str(symbol).strip() for symbol in data.keys() if str(symbol).strip()] + if isinstance(data, pd.DataFrame) and "symbol" in data.columns: + symbol_series = data["symbol"].dropna().astype(str).str.strip() + return [symbol for symbol in symbol_series.unique().tolist() if symbol] + return [] + + +def _resolve_optimization_backtest_kwargs( + data: Any, + kwargs: Dict[str, Any], +) -> Dict[str, Any]: + """解析优化入口的 symbols 参数并做数据一致性校验.""" + normalized = _normalize_backtest_symbol_kwargs(kwargs) + requested_symbols = _normalize_symbol_values(normalized.get("symbols")) + inferred_symbols = _infer_symbols_from_data(data) + available_symbols = set(inferred_symbols) + + if not requested_symbols: + if inferred_symbols: + normalized["symbols"] = inferred_symbols + return normalized + + if available_symbols: + missing_symbols = [ + symbol for symbol in requested_symbols if symbol not in available_symbols + ] + if missing_symbols: + raise ValueError( + "Requested symbols are not available in optimization data: " + f"{missing_symbols}" + ) + + normalized["symbols"] = requested_symbols + return normalized + + +def _ensure_dataframe_time_index(df: pd.DataFrame) -> pd.DataFrame: + """确保 DataFrame 使用 DatetimeIndex 并按时间排序.""" + prepared = df + if not isinstance(prepared.index, pd.DatetimeIndex): + for column in ["date", "timestamp", "datetime", "Date", "Timestamp"]: + if column in prepared.columns: + prepared = prepared.set_index(column) + break + prepared = prepared.copy() + prepared.index = pd.to_datetime(prepared.index) + elif not prepared.index.is_monotonic_increasing: + prepared = prepared.copy() + + if not prepared.index.is_monotonic_increasing: + prepared = prepared.sort_index() + return cast(pd.DataFrame, prepared) + + +def _filter_optimization_data_by_symbols( + data: OptimizationData, + symbols: Sequence[str], +) -> OptimizationData: + """按 symbols 过滤优化数据.""" + if not symbols: + return data + + symbol_set = set(symbols) + if isinstance(data, pd.DataFrame): + if "symbol" not in data.columns: + return data + filtered = data[data["symbol"].astype(str).isin(symbol_set)] + return cast(pd.DataFrame, filtered.copy()) + + filtered_map: dict[str, pd.DataFrame] = {} + for symbol in symbols: + if symbol in data: + filtered_map[symbol] = data[symbol] + return filtered_map + + +def _prepare_optimization_data(data: OptimizationData) -> OptimizationData: + """标准化优化数据的时间索引.""" + if isinstance(data, pd.DataFrame): + return _ensure_dataframe_time_index(data) + + prepared: dict[str, pd.DataFrame] = {} + for symbol, df in data.items(): + prepared[str(symbol)] = _ensure_dataframe_time_index(df) + return prepared + + +def _build_optimization_timeline(data: OptimizationData) -> pd.DatetimeIndex: + """提取优化切窗使用的统一时间轴.""" + if isinstance(data, pd.DataFrame): + if not isinstance(data.index, pd.DatetimeIndex): + raise TypeError( + "Optimization data must use DatetimeIndex after preparation" + ) + return cast(pd.DatetimeIndex, data.index.unique().sort_values()) + + timeline = pd.DatetimeIndex([]) + for df in data.values(): + if df.empty: + continue + if not isinstance(df.index, pd.DatetimeIndex): + raise TypeError( + "Optimization data must use DatetimeIndex after preparation" + ) + timeline = cast(pd.DatetimeIndex, timeline.union(df.index.unique())) + return cast(pd.DatetimeIndex, timeline.sort_values()) + + +def _slice_dataframe_by_time( + df: pd.DataFrame, + start_time: pd.Timestamp, + end_time: Optional[pd.Timestamp], +) -> pd.DataFrame: + """根据时间窗口切片 DataFrame.""" + mask = df.index >= start_time + if end_time is not None: + mask = mask & (df.index < end_time) + return cast(pd.DataFrame, df.loc[mask].copy()) + + +def _slice_optimization_data( + data: OptimizationData, + start_time: pd.Timestamp, + end_time: Optional[pd.Timestamp], +) -> OptimizationData: + """根据统一时间窗口切片优化数据.""" + if isinstance(data, pd.DataFrame): + return _slice_dataframe_by_time(data, start_time, end_time) + + sliced: dict[str, pd.DataFrame] = {} + for symbol, df in data.items(): + window_df = _slice_dataframe_by_time(df, start_time, end_time) + if not window_df.empty: + sliced[symbol] = window_df + return sliced + + @dataclass class OptimizationResult: """ @@ -381,7 +556,7 @@ def run_grid_search( :param kwargs: 传递给 run_backtest 的其他参数 (symbol, cash, etc.) :return: 优化结果 (DataFrame 或 List[OptimizationResult]) """ - backtest_kwargs = _normalize_backtest_symbol_kwargs(dict(kwargs)) + backtest_kwargs = _resolve_optimization_backtest_kwargs(data, dict(kwargs)) backtest_kwargs.setdefault("strict_strategy_params", True) if ( "execution_mode" in backtest_kwargs @@ -648,7 +823,7 @@ def run_grid_search( def run_walk_forward( strategy: Type[Strategy], param_grid: Mapping[str, Sequence[Any]], - data: pd.DataFrame, + data: OptimizationData, train_period: int, test_period: int, metric: Union[str, List[str]] = "sharpe_ratio", @@ -670,7 +845,7 @@ def run_walk_forward( :param strategy: 策略类 :param param_grid: 参数网格 - :param data: 回测数据 (必须是 DataFrame 且包含 DatetimeIndex) + :param data: 回测数据 (支持 DataFrame 或 Dict[str, DataFrame]) :param train_period: 训练窗口长度 (Bar数量) :param test_period: 测试窗口长度 (Bar数量) :param metric: 优化目标指标 (默认: "sharpe_ratio"),支持多字段排序列表。 @@ -688,11 +863,15 @@ def run_walk_forward( :param kwargs: 透传给 run_grid_search 和 run_backtest 的其他参数 :return: 包含拼接后资金曲线的 DataFrame """ - if not isinstance(data, pd.DataFrame): - raise ValueError("run_walk_forward requires data to be a pandas DataFrame.") - kwargs = _normalize_backtest_symbol_kwargs(kwargs) - - total_len = len(data) + kwargs = _resolve_optimization_backtest_kwargs(data, kwargs) + requested_symbols = _normalize_symbol_values(kwargs.get("symbols")) + prepared_data = _prepare_optimization_data(data) + prepared_data = _filter_optimization_data_by_symbols( + prepared_data, + requested_symbols, + ) + timeline = _build_optimization_timeline(prepared_data) + total_len = len(timeline) if total_len < train_period + test_period: raise ValueError( f"Data length ({total_len}) is too short for " @@ -710,14 +889,28 @@ def run_walk_forward( # 滚动窗口循环 # Step size is test_period for i in range(0, total_len - train_period - test_period + 1, test_period): - # 1. 切分训练数据 (In-Sample) train_start_idx = i train_end_idx = i + train_period - train_data = data.iloc[train_start_idx:train_end_idx] + oos_start_idx = train_end_idx + oos_end_idx = min(oos_start_idx + test_period, total_len) + + train_start_time = timeline[train_start_idx] + train_end_exclusive = ( + timeline[train_end_idx] if train_end_idx < total_len else None + ) + train_end_time = timeline[train_end_idx - 1] + oos_start_time = timeline[oos_start_idx] + oos_end_exclusive = timeline[oos_end_idx] if oos_end_idx < total_len else None + oos_end_time = timeline[oos_end_idx - 1] + train_data = _slice_optimization_data( + prepared_data, + train_start_time, + train_end_exclusive, + ) print( f"\n=== Window {i // test_period + 1}: " - f"Train [{train_data.index[0]} - {train_data.index[-1]}] ===" + f"Train [{train_start_time} - {train_end_time}] ===" ) # 2. 样本内优化 (Optimization) @@ -756,10 +949,6 @@ def run_walk_forward( print(f" Best Params: {best_params} ({metric_str})") - # 3. 切分测试数据 (Out-of-Sample) - oos_start_idx = train_end_idx - oos_end_idx = min(oos_start_idx + test_period, total_len) - # 计算实际需要的预热期 current_warmup = warmup_period if warmup_calc: @@ -768,9 +957,13 @@ def run_walk_forward( except Exception: pass - # 确保预热数据存在 - slice_start = max(0, oos_start_idx - current_warmup) - test_data_with_warmup = data.iloc[slice_start:oos_end_idx] + slice_start_idx = max(0, oos_start_idx - current_warmup) + slice_start_time = timeline[slice_start_idx] + test_data_with_warmup = _slice_optimization_data( + prepared_data, + slice_start_time, + oos_end_exclusive, + ) # 4. 样本外验证 (Backtest) # 使用最佳参数运行回测 @@ -780,10 +973,7 @@ def run_walk_forward( backtest_kwargs["initial_cash"] = initial_cash backtest_kwargs["warmup_period"] = current_warmup - print( - f" Test [{data.index[oos_start_idx]} - {data.index[oos_end_idx - 1]}] " - f"(Warmup: {current_warmup})" - ) + print(f" Test [{oos_start_time} - {oos_end_time}] (Warmup: {current_warmup})") bt_result = run_backtest( strategy=strategy, data=test_data_with_warmup, **backtest_kwargs @@ -792,12 +982,6 @@ def run_walk_forward( # 5. 提取并拼接结果 equity_curve = bt_result.equity_curve - # 截取 OOS 真正的时间段 (去除预热期) - # 使用时间戳过滤 - oos_start_time = data.index[oos_start_idx] - - # 确保 equity_curve 索引是 datetime 且有时区信息 (BacktestResult 已经处理了) - # data.index 通常是 naive 或 aware,需要匹配 if equity_curve.empty: print(" Warning: Empty equity curve in OOS.") continue @@ -870,8 +1054,8 @@ def run_walk_forward( current_capital = adjusted_equity.iloc[-1] # 添加元数据 - segment_df["train_start"] = data.index[train_start_idx] - segment_df["train_end"] = data.index[train_end_idx] + segment_df["train_start"] = train_start_time + segment_df["train_end"] = train_end_time for k, v in best_params.items(): segment_df[k] = v diff --git a/python/akquant/strategy.py b/python/akquant/strategy.py index 1c8b5d76..4d3497ea 100644 --- a/python/akquant/strategy.py +++ b/python/akquant/strategy.py @@ -51,6 +51,8 @@ from .strategy_history import set_rolling_window as _set_rolling_window_impl from .strategy_logging import log as _log_impl from .strategy_ml import auto_configure_model as _auto_configure_model_impl +from .strategy_ml import current_validation_window as _current_validation_window_impl +from .strategy_ml import is_model_ready as _is_model_ready_impl from .strategy_ml import on_train_signal as _on_train_signal_impl from .strategy_order_events import ( check_order_events as _check_order_events_impl, @@ -239,6 +241,17 @@ class Strategy: _bar_count: int _model_configured: bool model: Optional["QuantModel"] + _ml_validation_lifecycle: bool + _ml_model_template: Optional["QuantModel"] + _ml_active_model: Optional["QuantModel"] + _ml_pending_model: Optional["QuantModel"] + _ml_pending_activation_bar: Optional[int] + _ml_active_window_index: int + _ml_active_window_start_bar: Optional[int] + _ml_active_window_end_bar: Optional[int] + _ml_pending_window_index: int + _ml_pending_window_start_bar: Optional[int] + _ml_pending_window_end_bar: Optional[int] _known_orders: Dict[str, Order] _seen_trade_keys: set[Tuple[Any, ...]] _seen_trade_key_order: Deque[Tuple[Any, ...]] @@ -351,6 +364,17 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "Strategy": instance._bar_count = 0 instance._model_configured = False instance._start_initialized = False + instance._ml_validation_lifecycle = False + instance._ml_model_template = None + instance._ml_active_model = None + instance._ml_pending_model = None + instance._ml_pending_activation_bar = None + instance._ml_active_window_index = 0 + instance._ml_active_window_start_bar = None + instance._ml_active_window_end_bar = None + instance._ml_pending_window_index = 0 + instance._ml_pending_window_start_bar = None + instance._ml_pending_window_end_bar = None # 初始化通常在 __init__ 中的属性,允许子类省略 super().__init__() instance.model = None @@ -481,6 +505,28 @@ def __setstate__(self, state: Dict[str, Any]) -> None: self._pending_daily_timers = [] if not hasattr(self, "_instrument_snapshots"): self._instrument_snapshots = {} + if not hasattr(self, "_ml_validation_lifecycle"): + self._ml_validation_lifecycle = False + if not hasattr(self, "_ml_model_template"): + self._ml_model_template = None + if not hasattr(self, "_ml_active_model"): + self._ml_active_model = None + if not hasattr(self, "_ml_pending_model"): + self._ml_pending_model = None + if not hasattr(self, "_ml_pending_activation_bar"): + self._ml_pending_activation_bar = None + if not hasattr(self, "_ml_active_window_index"): + self._ml_active_window_index = 0 + if not hasattr(self, "_ml_active_window_start_bar"): + self._ml_active_window_start_bar = None + if not hasattr(self, "_ml_active_window_end_bar"): + self._ml_active_window_end_bar = None + if not hasattr(self, "_ml_pending_window_index"): + self._ml_pending_window_index = 0 + if not hasattr(self, "_ml_pending_window_start_bar"): + self._ml_pending_window_start_bar = None + if not hasattr(self, "_ml_pending_window_end_bar"): + self._ml_pending_window_end_bar = None _ensure_framework_state_impl(self) @property @@ -932,6 +978,14 @@ def on_train_signal(self, context: Any) -> None: """ _on_train_signal_impl(self, context) + def is_model_ready(self) -> bool: + """返回当前是否已有可用于推理的模型.""" + return _is_model_ready_impl(self) + + def current_validation_window(self) -> Optional[Dict[str, Any]]: + """返回当前 walk-forward 验证窗口状态.""" + return _current_validation_window_impl(self) + def prepare_features( self, df: pd.DataFrame, mode: str = "training" ) -> Tuple[Any, Any]: diff --git a/python/akquant/strategy_events.py b/python/akquant/strategy_events.py index ec776e4b..ad187a1c 100644 --- a/python/akquant/strategy_events.py +++ b/python/akquant/strategy_events.py @@ -13,6 +13,13 @@ mark_portfolio_dirty, register_boundary_timers, ) +from .strategy_ml import ( + activate_pending_model, + begin_training_cycle, + consume_training_trigger, + finalize_training_cycle, + should_trigger_training, +) from .strategy_scheduler import flush_pending_schedules @@ -79,10 +86,17 @@ def on_bar_event(strategy: Any, bar: Bar, ctx: StrategyContext) -> None: if strategy._bar_count < strategy.warmup_period: return - if strategy._rolling_step > 0 and strategy._bar_count % strategy._rolling_step == 0: - call_user_callback(strategy, "on_train_signal", strategy, payload=strategy) + activate_pending_model(strategy) + should_train = should_trigger_training(strategy) call_user_callback(strategy, "on_bar", bar, payload=bar) + if should_train: + consume_training_trigger(strategy) + training_cycle = begin_training_cycle(strategy) + try: + call_user_callback(strategy, "on_train_signal", strategy, payload=strategy) + finally: + finalize_training_cycle(strategy, training_cycle) analyzer_manager = getattr(strategy, "_analyzer_manager", None) if analyzer_manager is not None: try: diff --git a/python/akquant/strategy_ml.py b/python/akquant/strategy_ml.py index 025d9f81..b22dab0a 100644 --- a/python/akquant/strategy_ml.py +++ b/python/akquant/strategy_ml.py @@ -1,39 +1,269 @@ -from typing import Any +from typing import Any, Optional from .utils import parse_duration_to_bars +def _get_validation_model(strategy: Any) -> Any: + """返回验证配置使用的模型模板.""" + template_model = getattr(strategy, "_ml_model_template", None) + if template_model is not None: + return template_model + return getattr(strategy, "model", None) + + +def _get_validation_config(strategy: Any) -> Any: + """返回验证配置对象.""" + model = _get_validation_model(strategy) + return getattr(model, "validation_config", None) + + +def _resolve_validation_windows(strategy: Any) -> tuple[int, int, int]: + """解析模型 walk-forward 配置窗口.""" + validation_config = _get_validation_config(strategy) + if validation_config is None: + return 0, 0, 0 + + train_window = parse_duration_to_bars( + validation_config.train_window, + validation_config.frequency, + ) + test_window = parse_duration_to_bars( + validation_config.test_window, + validation_config.frequency, + ) + rolling_step = parse_duration_to_bars( + validation_config.rolling_step, + validation_config.frequency, + ) + return train_window, test_window, rolling_step + + +def _effective_training_step(test_window: int, rolling_step: int) -> int: + """计算有效训练步长.""" + if rolling_step > 0: + return rolling_step + if test_window > 0: + return test_window + return 0 + + +def _validation_lifecycle_enabled(strategy: Any) -> bool: + """返回是否启用验证窗口生命周期管理.""" + return bool(getattr(strategy, "_ml_validation_lifecycle", False)) + + +def _clone_model_for_training(strategy: Any) -> Any: + """为当前训练窗口构建待训练模型副本.""" + template_model = _get_validation_model(strategy) + if template_model is None: + return None + + validation_config = _get_validation_config(strategy) + active_model = getattr(strategy, "_ml_active_model", None) + if ( + validation_config is not None + and bool(getattr(validation_config, "incremental", False)) + and active_model is not None + ): + return active_model.clone() + return template_model.clone() + + +def activate_pending_model(strategy: Any) -> None: + """在计划生效点激活待生效模型.""" + if not _validation_lifecycle_enabled(strategy): + return + + pending_model = getattr(strategy, "_ml_pending_model", None) + activation_bar = getattr(strategy, "_ml_pending_activation_bar", None) + if pending_model is None or activation_bar is None: + return + if int(strategy._bar_count) < int(activation_bar): + return + + strategy._ml_active_model = pending_model + strategy.model = pending_model + strategy._ml_active_window_index = int( + getattr(strategy, "_ml_pending_window_index", 0) + ) + strategy._ml_active_window_start_bar = getattr( + strategy, + "_ml_pending_window_start_bar", + None, + ) + strategy._ml_active_window_end_bar = getattr( + strategy, + "_ml_pending_window_end_bar", + None, + ) + strategy._ml_pending_model = None + strategy._ml_pending_activation_bar = None + strategy._ml_pending_window_index = 0 + strategy._ml_pending_window_start_bar = None + strategy._ml_pending_window_end_bar = None + + def auto_configure_model(strategy: Any) -> None: """应用模型校验配置(如滚动训练窗口参数).""" if strategy._model_configured: return if strategy.model and strategy.model.validation_config: - cfg = strategy.model.validation_config try: - train_window = parse_duration_to_bars(cfg.train_window, cfg.frequency) - step = parse_duration_to_bars(cfg.rolling_step, cfg.frequency) - strategy.set_rolling_window(train_window, step) + train_window, test_window, rolling_step = _resolve_validation_windows( + strategy + ) + effective_step = _effective_training_step(test_window, rolling_step) + strategy.set_rolling_window(train_window, effective_step) + setattr(strategy, "_ml_validation_lifecycle", True) + setattr(strategy, "_ml_model_template", strategy.model) + setattr(strategy, "_ml_active_model", None) + setattr(strategy, "_ml_pending_model", None) + setattr(strategy, "_ml_pending_activation_bar", None) + setattr(strategy, "_ml_active_window_index", 0) + setattr(strategy, "_ml_active_window_start_bar", None) + setattr(strategy, "_ml_active_window_end_bar", None) + setattr(strategy, "_ml_pending_window_index", 0) + setattr(strategy, "_ml_pending_window_start_bar", None) + setattr(strategy, "_ml_pending_window_end_bar", None) + setattr(strategy, "_rolling_test_window", test_window) + setattr(strategy, "_rolling_last_train_bar", 0) + setattr(strategy, "_rolling_window_index", 0) + setattr(strategy, "_rolling_next_train_bar", max(train_window, 1)) except Exception as e: print(f"Failed to configure model validation: {e}") + else: + setattr(strategy, "_ml_validation_lifecycle", False) strategy._model_configured = True +def should_trigger_training(strategy: Any) -> bool: + """返回当前 bar 是否应触发自动训练.""" + if strategy._rolling_step <= 0: + return False + + validation_config = _get_validation_config(strategy) + if validation_config is None: + return bool(int(strategy._bar_count) % int(strategy._rolling_step) == 0) + + next_train_bar = int( + getattr(strategy, "_rolling_next_train_bar", strategy._rolling_train_window) + ) + return bool( + int(strategy._bar_count) + >= max(next_train_bar, int(strategy._rolling_train_window)) + ) + + +def consume_training_trigger(strategy: Any) -> None: + """消费一次训练触发并推进下一窗口.""" + validation_config = _get_validation_config(strategy) + if validation_config is None or strategy._rolling_step <= 0: + return + + current_bar = int(strategy._bar_count) + next_window_index = int(getattr(strategy, "_rolling_window_index", 0)) + 1 + pending_start_bar = current_bar + 1 + pending_end_bar: Optional[int] + if int(getattr(strategy, "_rolling_test_window", 0)) > 0: + pending_end_bar = pending_start_bar + int(strategy._rolling_test_window) - 1 + else: + pending_end_bar = None + + setattr(strategy, "_rolling_last_train_bar", current_bar) + setattr( + strategy, + "_rolling_next_train_bar", + current_bar + int(strategy._rolling_step), + ) + setattr(strategy, "_rolling_window_index", next_window_index) + setattr(strategy, "_ml_pending_activation_bar", pending_start_bar) + setattr(strategy, "_ml_pending_window_index", next_window_index) + setattr(strategy, "_ml_pending_window_start_bar", pending_start_bar) + setattr(strategy, "_ml_pending_window_end_bar", pending_end_bar) + + +def begin_training_cycle(strategy: Any) -> Optional[tuple[Any, Any]]: + """开始一次训练周期并临时挂载待训练模型.""" + if not _validation_lifecycle_enabled(strategy): + return None + + training_model = _clone_model_for_training(strategy) + if training_model is None: + return None + + previous_public_model = getattr(strategy, "model", None) + strategy.model = training_model + return previous_public_model, training_model + + +def finalize_training_cycle( + strategy: Any, + cycle_state: Optional[tuple[Any, Any]], +) -> None: + """结束训练周期并恢复对外模型引用.""" + if cycle_state is None: + return + + previous_public_model, training_model = cycle_state + strategy._ml_pending_model = training_model + active_model = getattr(strategy, "_ml_active_model", None) + if active_model is not None: + strategy.model = active_model + return + strategy.model = previous_public_model + + +def is_model_ready(strategy: Any) -> bool: + """返回当前是否已有可用于推理的活动模型.""" + if _validation_lifecycle_enabled(strategy): + return getattr(strategy, "_ml_active_model", None) is not None + return getattr(strategy, "model", None) is not None + + +def current_validation_window(strategy: Any) -> Optional[dict[str, Any]]: + """返回当前验证窗口状态.""" + if not _validation_lifecycle_enabled(strategy): + return None + + return { + "is_model_ready": is_model_ready(strategy), + "window_index": int(getattr(strategy, "_ml_active_window_index", 0)), + "train_window": int(getattr(strategy, "_rolling_train_window", 0)), + "test_window": int(getattr(strategy, "_rolling_test_window", 0)), + "rolling_step": int(getattr(strategy, "_rolling_step", 0)), + "active_start_bar": getattr(strategy, "_ml_active_window_start_bar", None), + "active_end_bar": getattr(strategy, "_ml_active_window_end_bar", None), + "pending_activation_bar": getattr(strategy, "_ml_pending_activation_bar", None), + "pending_window_index": int(getattr(strategy, "_ml_pending_window_index", 0)), + "next_train_bar": int(getattr(strategy, "_rolling_next_train_bar", 0)), + } + + def on_train_signal(strategy: Any, context: Any) -> None: """滚动训练信号回调.""" if strategy.model: try: X_df, _ = strategy.get_rolling_data() - if ( - strategy.model.validation_config - and strategy.model.validation_config.verbose - ): + validation_config = _get_validation_config(strategy) + if validation_config and validation_config.verbose: ts_str = "" if strategy.current_bar: ts_str = strategy.format_time(strategy.current_bar.timestamp) - print(f"[{ts_str}] Auto-training triggered | Train Size: {len(X_df)}") + train_window = int(getattr(strategy, "_rolling_train_window", 0)) + test_window = int(getattr(strategy, "_rolling_test_window", 0)) + window_index = int(getattr(strategy, "_ml_pending_window_index", 0)) + activation_bar = getattr(strategy, "_ml_pending_activation_bar", None) + print( + f"[{ts_str}] Auto-training triggered | " + f"Window={window_index} | " + f"Train Size={len(X_df)} | " + f"Train Window={train_window} | " + f"Test Window={test_window} | " + f"Activation Bar={activation_bar}" + ) X, y = strategy.prepare_features(X_df, mode="training") strategy.model.fit(X, y) diff --git a/tests/test_engine.py b/tests/test_engine.py index d25d5872..44047012 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -4415,6 +4415,444 @@ def test_run_grid_search_db_path_serializes_timestamp_metrics( assert isinstance(metrics.get("end_time"), str) +def test_run_grid_search_infers_symbols_from_dict_data() -> None: + """Grid search should infer symbols from dict-form multisymbol data.""" + data = { + "OPT_DICT_A": _build_benchmark_data(n=40, symbol="OPT_DICT_A"), + "OPT_DICT_B": _build_benchmark_data(n=40, symbol="OPT_DICT_B"), + } + + results = akquant.run_grid_search( + strategy=NoopStrategy, + param_grid={"dummy": [1]}, + data=data, + max_workers=1, + return_df=True, + show_progress=False, + ) + + assert isinstance(results, pd.DataFrame) + assert len(results) == 1 + assert float(results.iloc[0]["total_bars"]) > 0.0 + assert pd.isna(results.iloc[0].get("error")) + + +def test_run_grid_search_dict_data_rejects_missing_symbols() -> None: + """Grid search should fail fast when requested symbols are absent.""" + data = { + "OPT_DICT_A": _build_benchmark_data(n=20, symbol="OPT_DICT_A"), + "OPT_DICT_B": _build_benchmark_data(n=20, symbol="OPT_DICT_B"), + } + + with pytest.raises( + ValueError, + match="Requested symbols are not available in optimization data", + ): + _ = akquant.run_grid_search( + strategy=NoopStrategy, + param_grid={"dummy": [1]}, + data=data, + symbols=["OPT_DICT_C"], + max_workers=1, + return_df=True, + show_progress=False, + ) + + +def test_run_walk_forward_supports_multisymbol_dict_data() -> None: + """Walk-forward should slice dict-form multisymbol data by timeline.""" + data = { + "WFO_DICT_A": _build_benchmark_data(n=24, symbol="WFO_DICT_A"), + "WFO_DICT_B": _build_benchmark_data(n=24, symbol="WFO_DICT_B"), + } + + results = akquant.run_walk_forward( + strategy=NoopStrategy, + param_grid={"dummy": [1]}, + data=data, + train_period=10, + test_period=5, + initial_cash=100_000.0, + max_tasks_per_child=1, + show_progress=False, + ) + + assert isinstance(results, pd.DataFrame) + assert not results.empty + assert "equity" in results.columns + assert "train_start" in results.columns + assert "train_end" in results.columns + assert results["train_start"].iloc[0] < results["train_end"].iloc[0] + + +def test_run_walk_forward_multisymbol_dataframe_uses_timestamp_windows() -> None: + """Walk-forward should slice multisymbol DataFrame input by unique timestamps.""" + symbols = ["WFO_DF_A", "WFO_DF_B"] + data = _build_multisymbol_benchmark_data(n_timestamps=16, symbols=symbols) + + results = akquant.run_walk_forward( + strategy=NoopStrategy, + param_grid={"dummy": [1]}, + data=data, + train_period=5, + test_period=3, + initial_cash=100_000.0, + timezone="UTC", + show_progress=False, + ) + + assert isinstance(results, pd.DataFrame) + assert not results.empty + first_train_start = pd.Timestamp(results["train_start"].iloc[0]) + first_train_end = pd.Timestamp(results["train_end"].iloc[0]) + assert first_train_start == pd.Timestamp("2020-01-01 00:00:00", tz="UTC") + assert first_train_end == pd.Timestamp("2020-01-01 00:04:00", tz="UTC") + + +def test_run_walk_forward_filters_warmup_period_from_oos_equity() -> None: + """Walk-forward output should exclude warmup timestamps from returned OOS curve.""" + symbol = "WFO_WARMUP_BOUNDARY" + data = _build_benchmark_data(n=14, symbol=symbol) + + results = akquant.run_walk_forward( + strategy=NoopStrategy, + param_grid={"dummy": [1]}, + data=data, + train_period=6, + test_period=3, + warmup_period=2, + initial_cash=100_000.0, + timezone="UTC", + show_progress=False, + ) + + assert isinstance(results, pd.DataFrame) + assert not results.empty + first_result_time = pd.Timestamp(results.index.min()) + assert first_result_time == pd.Timestamp("2020-01-01 00:06:00", tz="UTC") + + +def test_on_train_signal_runs_after_on_bar_for_trigger_bar() -> None: + """Rolling training should execute after the trigger bar callback.""" + + class RollingOrderProbeStrategy(akquant.Strategy): + """Capture callback ordering for rolling training.""" + + def __init__(self) -> None: + """Initialize rolling callback probe state.""" + super().__init__() + self.set_rolling_window(train_window=4, step=2) + self.warmup_period = 4 + self.events: list[tuple[str, int]] = [] + + def on_bar(self, bar: akquant.Bar) -> None: + """Record bar callback order.""" + self.events.append(("bar", int(bar.close))) + + def on_train_signal(self, context: Any) -> None: + """Record train callback order using the rolling window tail.""" + df, _ = self.get_rolling_data() + closes = df["close"].dropna().astype(int).tolist() + self.events.append(("train", int(closes[-1]))) + + symbol = "ROLLING_ORDER" + data = pd.DataFrame( + { + "timestamp": pd.date_range("2020-01-01", periods=8, freq="min", tz="UTC"), + "open": np.arange(1, 9, dtype=float), + "high": np.arange(1, 9, dtype=float), + "low": np.arange(1, 9, dtype=float), + "close": np.arange(1, 9, dtype=float), + "volume": np.full(8, 1000.0), + "symbol": [symbol] * 8, + } + ) + strategy = RollingOrderProbeStrategy() + + _ = akquant.run_backtest( + data=data, + strategy=strategy, + symbols=[symbol], + history_depth=4, + show_progress=False, + ) + + assert strategy.events == [ + ("bar", 4), + ("train", 4), + ("bar", 5), + ("bar", 6), + ("train", 6), + ("bar", 7), + ("bar", 8), + ("train", 8), + ] + + +def test_ml_validation_training_schedule_uses_relative_rolling_step() -> None: + """ML validation should retrain relative to first eligible train bar.""" + from akquant.ml.model import QuantModel, ValidationConfig + + class RecordingModel(QuantModel): + """Minimal model stub that records fit calls.""" + + fit_sizes: list[int] = [] + + def __init__(self) -> None: + """Initialize validation config and fit recorder.""" + super().__init__() + self.validation_config = ValidationConfig( + train_window=5, + test_window=2, + rolling_step=3, + frequency="1m", + ) + + def clone(self) -> "RecordingModel": + """Clone the test model while preserving validation config.""" + cloned = RecordingModel() + cloned.validation_config = self.validation_config + return cloned + + def fit(self, X: Any, y: Any) -> None: + """Record fit sample size.""" + RecordingModel.fit_sizes.append(int(len(X))) + + def predict(self, X: Any) -> np.ndarray: + """Return a deterministic empty prediction vector.""" + return np.zeros(len(X)) + + def save(self, path: str) -> None: + """Satisfy abstract model API for tests.""" + return + + def load(self, path: str) -> None: + """Satisfy abstract model API for tests.""" + return + + class ValidationScheduleStrategy(akquant.Strategy): + """Capture ML train bars under validation config.""" + + def __init__(self) -> None: + """Initialize model and training recorder.""" + super().__init__() + self.model = RecordingModel() + self.train_bars: list[int] = [] + + def prepare_features( + self, df: pd.DataFrame, mode: str = "training" + ) -> tuple[pd.DataFrame, pd.Series]: + """Return simple close-only features and aligned labels.""" + features = pd.DataFrame({"close": df["close"].fillna(0.0)}) + labels = pd.Series(np.zeros(len(features), dtype=int)) + return features, labels + + def on_bar(self, bar: akquant.Bar) -> None: + """Ignore trading logic for schedule test.""" + return + + def on_train_signal(self, context: Any) -> None: + """Record train bar index and delegate to default fit logic.""" + self.train_bars.append(int(self._bar_count)) + super().on_train_signal(context) + + symbol = "ML_RELATIVE_STEP" + data = _build_benchmark_data(n=12, symbol=symbol) + strategy = ValidationScheduleStrategy() + + _ = akquant.run_backtest( + data=data, + strategy=strategy, + symbols=[symbol], + show_progress=False, + ) + + assert strategy.train_bars == [5, 8, 11] + assert RecordingModel.fit_sizes == [5, 5, 5] + + +def test_ml_validation_uses_test_window_when_rolling_step_is_zero() -> None: + """ML validation should fall back to test_window when rolling_step is zero.""" + from akquant.ml.model import QuantModel, ValidationConfig + + class RecordingModel(QuantModel): + """Minimal model stub that records fit calls.""" + + fit_sizes: list[int] = [] + + def __init__(self) -> None: + """Initialize validation config and fit recorder.""" + super().__init__() + self.validation_config = ValidationConfig( + train_window=4, + test_window=2, + rolling_step=0, + frequency="1m", + ) + + def clone(self) -> "RecordingModel": + """Clone the test model while preserving validation config.""" + cloned = RecordingModel() + cloned.validation_config = self.validation_config + return cloned + + def fit(self, X: Any, y: Any) -> None: + """Record fit sample size.""" + RecordingModel.fit_sizes.append(int(len(X))) + + def predict(self, X: Any) -> np.ndarray: + """Return a deterministic empty prediction vector.""" + return np.zeros(len(X)) + + def save(self, path: str) -> None: + """Satisfy abstract model API for tests.""" + return + + def load(self, path: str) -> None: + """Satisfy abstract model API for tests.""" + return + + class TestWindowFallbackStrategy(akquant.Strategy): + """Capture ML train bars under test_window fallback scheduling.""" + + def __init__(self) -> None: + """Initialize model and training recorder.""" + super().__init__() + self.model = RecordingModel() + self.train_bars: list[int] = [] + + def prepare_features( + self, df: pd.DataFrame, mode: str = "training" + ) -> tuple[pd.DataFrame, pd.Series]: + """Return simple close-only features and aligned labels.""" + features = pd.DataFrame({"close": df["close"].fillna(0.0)}) + labels = pd.Series(np.zeros(len(features), dtype=int)) + return features, labels + + def on_bar(self, bar: akquant.Bar) -> None: + """Ignore trading logic for schedule test.""" + return + + def on_train_signal(self, context: Any) -> None: + """Record train bar index and delegate to default fit logic.""" + self.train_bars.append(int(self._bar_count)) + super().on_train_signal(context) + + symbol = "ML_TEST_WINDOW_STEP" + data = _build_benchmark_data(n=8, symbol=symbol) + strategy = TestWindowFallbackStrategy() + + _ = akquant.run_backtest( + data=data, + strategy=strategy, + symbols=[symbol], + show_progress=False, + ) + + assert strategy.train_bars == [4, 6, 8] + assert RecordingModel.fit_sizes == [4, 4, 4] + + +def test_ml_validation_activates_new_model_on_next_bar() -> None: + """Newly trained validation models should activate on the next bar.""" + from akquant.ml.model import QuantModel, ValidationConfig + + class VersionedModel(QuantModel): + """Model stub that exposes fitted version ids in predictions.""" + + next_version = 1 + + def __init__(self, version: int = 0) -> None: + """Initialize validation config and current version.""" + super().__init__() + self.validation_config = ValidationConfig( + train_window=4, + test_window=2, + rolling_step=2, + frequency="1m", + ) + self.version = version + + def clone(self) -> "VersionedModel": + """Clone the model and preserve the current version state.""" + cloned = VersionedModel(version=self.version) + cloned.validation_config = self.validation_config + return cloned + + def fit(self, X: Any, y: Any) -> None: + """Assign a new model version when a training window completes.""" + self.version = VersionedModel.next_version + VersionedModel.next_version += 1 + + def predict(self, X: Any) -> np.ndarray: + """Return the current model version as prediction.""" + return np.full(len(X), self.version, dtype=float) + + def save(self, path: str) -> None: + """Satisfy abstract model API for tests.""" + return + + def load(self, path: str) -> None: + """Satisfy abstract model API for tests.""" + return + + class LifecycleStrategy(akquant.Strategy): + """Capture active model versions and window metadata per bar.""" + + def __init__(self) -> None: + """Initialize model and lifecycle recorder.""" + super().__init__() + self.model = VersionedModel() + self.events: list[tuple[int, bool, int | None, int | None, int | None]] = [] + + def prepare_features( + self, df: pd.DataFrame, mode: str = "training" + ) -> tuple[pd.DataFrame, pd.Series]: + """Return simple close-only features and aligned labels.""" + features = pd.DataFrame({"close": df["close"].fillna(0.0)}) + labels = pd.Series(np.zeros(len(features), dtype=int)) + return features, labels + + def on_bar(self, bar: akquant.Bar) -> None: + """Record the visible active model state for the current bar.""" + window = self.current_validation_window() + version: int | None = None + if self.is_model_ready() and self.model is not None: + prediction = self.model.predict(pd.DataFrame({"close": [bar.close]})) + version = int(prediction[0]) + self.events.append( + ( + int(self._bar_count), + bool(self.is_model_ready()), + version, + None if window is None else window["active_start_bar"], + None if window is None else window["active_end_bar"], + ) + ) + + symbol = "ML_LIFECYCLE" + data = _build_benchmark_data(n=8, symbol=symbol) + strategy = LifecycleStrategy() + + _ = akquant.run_backtest( + data=data, + strategy=strategy, + symbols=[symbol], + show_progress=False, + ) + + assert strategy.events == [ + (1, False, None, None, None), + (2, False, None, None, None), + (3, False, None, None, None), + (4, False, None, None, None), + (5, True, 1, 5, 6), + (6, True, 1, 5, 6), + (7, True, 2, 7, 8), + (8, True, 2, 7, 8), + ] + + def test_run_backtest_expiry_date_str_is_rejected() -> None: """expiry_date should reject string input."""