Skip to content

Commit 2fcece6

Browse files
Merge pull request #51 from akfamily/dev
feat: 统一时间参数命名并增强回测日期过滤
2 parents 60bfb09 + 2e6adb5 commit 2fcece6

File tree

11 files changed

+202
-50
lines changed

11 files changed

+202
-50
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "akquant"
3-
version = "0.1.30"
3+
version = "0.1.31"
44
edition = "2024"
55
description = "High-performance quantitative trading framework based on Rust and Python"
66
license = "MIT"

docs/en/examples.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@ import pandas as pd
196196
import numpy as np
197197

198198
# 1. Prepare data (Mock data)
199-
def create_dummy_data(symbol, start_date, n_bars, price=100.0):
200-
dates = pd.date_range(start_date, periods=n_bars, freq="B")
199+
def create_dummy_data(symbol, start_time, n_bars, price=100.0):
200+
dates = pd.date_range(start_time, periods=n_bars, freq="B")
201201
np.random.seed(42)
202202
changes = np.random.randn(n_bars)
203203
prices = price + np.cumsum(changes)

docs/zh/examples.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@ import pandas as pd
196196
import numpy as np
197197

198198
# 1. 准备数据 (模拟数据)
199-
def create_dummy_data(symbol, start_date, n_bars, price=100.0):
200-
dates = pd.date_range(start_date, periods=n_bars, freq="B")
199+
def create_dummy_data(symbol, start_time, n_bars, price=100.0):
200+
dates = pd.date_range(start_time, periods=n_bars, freq="B")
201201
np.random.seed(42)
202202
changes = np.random.randn(n_bars)
203203
prices = price + np.cumsum(changes)

examples/benchmark_akquant_multi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
# 生成数据 (使用 'B' 代表工作日频率,从 1990 年开始)
4444
# 固定种子以确保结果可复现
4545
df = get_benchmark_data(
46-
DATA_SIZE, symbol, freq="B", start_date="1990-01-01", seed=42 + i
46+
DATA_SIZE, symbol, freq="B", start_time="1990-01-01", seed=42 + i
4747
)
4848

4949
# 重命名列

examples/benchmark_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def get_benchmark_data(
1010
n: int = 200_000,
1111
symbol: str = "BENCHMARK",
1212
freq: str = "min",
13-
start_date: str = "2020-01-01",
13+
start_time: str = "2020-01-01",
1414
seed: Optional[int] = None,
1515
) -> pd.DataFrame:
1616
"""
@@ -21,7 +21,7 @@ def get_benchmark_data(
2121
print(f"Generating {n} rows of dummy data...")
2222
t0 = time.time()
2323

24-
dates = pd.date_range(start=start_date, periods=n, freq=freq, tz="UTC")
24+
dates = pd.date_range(start=start_time, periods=n, freq=freq, tz="UTC")
2525
# Add 15 hours to simulate market close time (15:00:00 UTC)
2626
dates = dates + pd.Timedelta(hours=15)
2727

examples/plot_demo.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,22 @@ def on_bar(self, bar: Bar) -> None:
4949
# Configuration
5050
SYMBOL = "sh600000"
5151
START_DATE = "20120101"
52-
END_DATE = "20231231"
52+
END_DATE = "20261231"
5353
INITIAL_CASH = 100_000.0
5454

5555
df = ak.stock_zh_a_daily(symbol=SYMBOL, start_date=START_DATE, end_date=END_DATE)
5656
df["symbol"] = SYMBOL
5757

58+
# from akquant.config import BacktestConfig, StrategyConfig, RiskConfig
59+
# # 配置风险参数:safety_margin
60+
# risk_config = RiskConfig(safety_margin=0.0001)
61+
# strategy_config = StrategyConfig(risk=risk_config)
62+
# backtest_config = BacktestConfig(
63+
# strategy_config=strategy_config,
64+
# # start_time="20200131",
65+
# # end_time="20260210"
66+
# )
67+
5868
# 2. Run Backtest
5969
print("\nRunning Backtest...")
6070
result = run_backtest(
@@ -63,6 +73,9 @@ def on_bar(self, bar: Bar) -> None:
6373
symbol=SYMBOL,
6474
initial_cash=INITIAL_CASH,
6575
show_progress=True,
76+
# config=backtest_config,
77+
start_time="20160101",
78+
end_time="20201231",
6679
)
6780

6881
# 3. Print Metrics

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "maturin"
44

55
[project]
66
name = "akquant"
7-
version = "0.1.30"
7+
version = "0.1.31"
88
description = "High-performance quantitative trading framework based on Rust and Python"
99
readme = "README.md"
1010
license = {text = "MIT License"}

python/akquant/backtest.py

Lines changed: 159 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import datetime as dt_module
12
import os
23
import sys
34
from functools import cached_property
@@ -23,7 +24,6 @@
2324
ExecutionMode,
2425
Instrument,
2526
Order,
26-
PerformanceMetrics,
2727
)
2828
from .akquant import (
2929
BacktestResult as RustBacktestResult,
@@ -178,14 +178,27 @@ def orders(self) -> List[Order]:
178178
return []
179179

180180
@property
181-
def metrics(self) -> PerformanceMetrics:
182-
"""
183-
Get performance metrics as a raw object (Raw Access).
184-
185-
This is the raw Rust object containing all metrics fields.
186-
For a DataFrame view, use `metrics_df`.
187-
"""
188-
return self._raw.metrics
181+
def metrics(self) -> Any:
182+
"""Get metrics with timezone-aware datetime conversion."""
183+
metrics = self._raw.metrics
184+
185+
class MetricsWrapper:
186+
def __init__(self, raw_metrics: Any, timezone: str) -> None:
187+
self._raw = raw_metrics
188+
self._timezone = timezone
189+
190+
def __getattr__(self, name: str) -> Any:
191+
val = getattr(self._raw, name)
192+
if name in ["start_time", "end_time"]:
193+
# Convert ns timestamp to datetime
194+
if isinstance(val, int):
195+
dt = pd.to_datetime(val, unit="ns", utc=True).tz_convert(
196+
self._timezone
197+
)
198+
return dt
199+
return val
200+
201+
return MetricsWrapper(metrics, self._timezone)
189202

190203
@property
191204
def positions(self) -> pd.DataFrame:
@@ -636,6 +649,8 @@ def run_backtest(
636649
warmup_period: int = 0,
637650
lot_size: Union[int, Dict[str, int], None] = None,
638651
show_progress: Optional[bool] = None,
652+
start_time: Optional[Union[str, Any]] = None,
653+
end_time: Optional[Union[str, Any]] = None,
639654
config: Optional[BacktestConfig] = None,
640655
instruments_config: Optional[
641656
Union[List[InstrumentConfig], Dict[str, InstrumentConfig]]
@@ -663,9 +678,30 @@ def run_backtest(
663678
:param lot_size: 最小交易单位。如果是 int,则应用于所有标的;
664679
如果是 Dict[str, int],则按代码匹配;如果不传(None),默认为 1。
665680
:param show_progress: 是否显示进度条 (默认 True)
681+
:param start_time: 回测开始时间 (e.g., "2020-01-01 09:30"). 优先级高于
682+
config.start_time.
683+
:param end_time: 回测结束时间 (e.g., "2020-12-31 15:00"). 优先级高于
684+
config.end_time.
666685
:param config: BacktestConfig 配置对象 (可选)
667686
:param instruments_config: 标的配置列表或字典 (可选)
668687
:return: 回测结果 Result 对象
688+
689+
配置优先级说明 (Parameter Priority):
690+
----------------------------------
691+
本函数参数采用以下优先级顺序解析(由高到低):
692+
693+
1. **Explicit Arguments (显式参数)**:
694+
直接传递给 `run_backtest` 的参数优先级最高。
695+
例如: `run_backtest(..., start_time="2022-01-01")` 会覆盖 Config 中的设置。
696+
697+
2. **Configuration Objects (配置对象)**:
698+
如果显式参数为 `None`,则尝试从 `config` (`BacktestConfig`) 及其子配置
699+
(`StrategyConfig`) 中读取。
700+
例如: `config.start_time` 或 `config.strategy_config.initial_cash`。
701+
702+
3. **Default Values (默认值)**:
703+
如果上述两者都未提供,则使用系统默认值。
704+
例如: `initial_cash` 默认为 1,000,000。
669705
"""
670706
# 0. 设置默认值 (如果未传入且未在 Config 中设置)
671707
# 优先级: 参数 > Config > 默认值
@@ -687,10 +723,19 @@ def run_backtest(
687723
# Resolve Commission Rate
688724
if commission_rate is None:
689725
if config and config.strategy_config:
690-
commission_rate = config.strategy_config.fee_amount
726+
commission_rate = config.strategy_config.commission_rate
691727
else:
692728
commission_rate = DEFAULT_COMMISSION_RATE
693729

730+
# Resolve Other Fees (if not passed as args, check config)
731+
if config and config.strategy_config:
732+
if stamp_tax_rate == 0.0:
733+
stamp_tax_rate = config.strategy_config.stamp_tax_rate
734+
if transfer_fee_rate == 0.0:
735+
transfer_fee_rate = config.strategy_config.transfer_fee_rate
736+
if min_commission == 0.0:
737+
min_commission = config.strategy_config.min_commission
738+
694739
# Resolve Timezone
695740
if timezone is None:
696741
if config and config.timezone:
@@ -729,11 +774,23 @@ def run_backtest(
729774
)
730775

731776
# 1.5 处理 Config 覆盖 (剩余部分)
732-
if config:
733-
if config.start_date:
734-
kwargs["start_date"] = config.start_date
735-
if config.end_date:
736-
kwargs["end_date"] = config.end_date
777+
# Resolve effective start/end time for filtering
778+
# Priority: explicit argument > config
779+
780+
if start_time is None:
781+
if config and config.start_time:
782+
start_time = config.start_time
783+
784+
if end_time is None:
785+
if config and config.end_time:
786+
end_time = config.end_time
787+
788+
# Update kwargs if needed by strategy (optional, can be removed if strategies
789+
# don't need it)
790+
if start_time:
791+
kwargs["start_time"] = start_time
792+
if end_time:
793+
kwargs["end_time"] = end_time
737794

738795
# 注意: initial_cash, commission_rate, timezone, show_progress, history_depth
739796
# 已经在上方通过优先级逻辑处理过了,这里不需要再覆盖
@@ -844,6 +901,58 @@ def run_backtest(
844901
if data is not None:
845902
# Use provided data
846903
if isinstance(data, pd.DataFrame):
904+
# Ensure index is datetime
905+
if not isinstance(data.index, pd.DatetimeIndex):
906+
# Try to find a date column if index is not date
907+
# Common candidates: "date", "timestamp", "datetime"
908+
found_date = False
909+
for col in ["date", "timestamp", "datetime", "Date", "Timestamp"]:
910+
if col in data.columns:
911+
data = data.set_index(col)
912+
found_date = True
913+
break
914+
915+
if not found_date:
916+
# try convert index
917+
try:
918+
data.index = pd.to_datetime(data.index)
919+
except Exception:
920+
pass
921+
922+
# Ensure index is pd.Timestamp compatible
923+
# (convert datetime.date to Timestamp)
924+
# This is handled by pd.to_datetime but let's be safe for object index
925+
if data.index.dtype == "object":
926+
try:
927+
data.index = pd.to_datetime(data.index)
928+
except Exception:
929+
pass
930+
931+
# Filter by date if provided
932+
if start_time:
933+
# Handle potential mismatch between Timestamp and datetime.date
934+
ts_start = pd.Timestamp(start_time)
935+
# If index is date objects, compare with date()
936+
if (
937+
len(data) > 0
938+
and isinstance(data.index[0], (dt_module.date))
939+
and not isinstance(data.index[0], dt_module.datetime)
940+
):
941+
data = data[data.index >= ts_start.date()]
942+
else:
943+
data = data[data.index >= ts_start]
944+
945+
if end_time:
946+
ts_end = pd.Timestamp(end_time)
947+
if (
948+
len(data) > 0
949+
and isinstance(data.index[0], (dt_module.date))
950+
and not isinstance(data.index[0], dt_module.datetime)
951+
):
952+
data = data[data.index <= ts_end.date()]
953+
else:
954+
data = data[data.index <= ts_end]
955+
847956
# Try to infer symbol from DataFrame if not explicitly provided or default
848957
if (not symbols or symbols == ["BENCHMARK"]) and "symbol" in data.columns:
849958
unique_symbols = data["symbol"].unique()
@@ -872,6 +981,28 @@ def run_backtest(
872981
if filter_symbols and sym not in symbols:
873982
continue
874983

984+
# Ensure index is datetime
985+
if not isinstance(df.index, pd.DatetimeIndex):
986+
# Try to find a date column if index is not date
987+
found_date = False
988+
for col in ["date", "timestamp", "datetime", "Date", "Timestamp"]:
989+
if col in df.columns:
990+
df = df.set_index(col)
991+
found_date = True
992+
break
993+
994+
if not found_date:
995+
try:
996+
df.index = pd.to_datetime(df.index)
997+
except Exception:
998+
pass
999+
1000+
# Filter by date
1001+
if start_time:
1002+
df = df[df.index >= pd.Timestamp(start_time)]
1003+
if end_time:
1004+
df = df[df.index <= pd.Timestamp(end_time)]
1005+
8751006
df_prep = prepare_dataframe(df)
8761007
data_map_for_indicators[sym] = df_prep
8771008
arrays = df_to_arrays(df_prep, symbol=sym)
@@ -881,6 +1012,15 @@ def run_backtest(
8811012
feed.sort()
8821013
elif isinstance(data, list):
8831014
if data:
1015+
# Filter by date
1016+
if start_time:
1017+
# Explicitly convert to int to satisfy mypy
1018+
ts_start: int = int(pd.Timestamp(start_time).value) # type: ignore
1019+
data = [b for b in data if b.timestamp >= ts_start] # type: ignore
1020+
if end_time:
1021+
ts_end: int = int(pd.Timestamp(end_time).value) # type: ignore
1022+
data = [b for b in data if b.timestamp <= ts_end] # type: ignore
1023+
8841024
data.sort(key=lambda b: b.timestamp)
8851025
feed.add_bars(data)
8861026
else:
@@ -889,13 +1029,12 @@ def run_backtest(
8891029
logger.warning("No symbols specified and no data provided.")
8901030

8911031
catalog = ParquetDataCatalog()
892-
start_date = kwargs.get("start_date")
893-
end_date = kwargs.get("end_date")
1032+
# start_time / end_time already resolved above
8941033

8951034
loaded_count = 0
8961035
for sym in symbols:
8971036
# Try Catalog
898-
df = catalog.read(sym, start_date=start_date, end_date=end_date)
1037+
df = catalog.read(sym, start_time=start_time, end_time=end_time)
8991038
if df.empty:
9001039
logger.warning(f"Data not found in catalog for {sym}")
9011040
continue
@@ -1163,8 +1302,6 @@ def plot_result(
11631302
:param benchmark: 基准收益率序列 (可选, Series with DatetimeIndex)
11641303
"""
11651304
try:
1166-
from datetime import datetime
1167-
11681305
import matplotlib.dates as mdates
11691306
import matplotlib.pyplot as plt
11701307
from matplotlib.gridspec import GridSpec
@@ -1190,9 +1327,9 @@ def plot_result(
11901327
if first_ts > 1e11:
11911328
scale = 1e-9
11921329

1193-
from datetime import timezone
1194-
11951330
# Use UTC to avoid local timezone issues and align with benchmark data
1331+
from datetime import datetime, timezone
1332+
11961333
times = [
11971334
datetime.fromtimestamp(t * scale, tz=timezone.utc).replace(tzinfo=None)
11981335
for t, _ in equity_curve

0 commit comments

Comments
 (0)