1+ import datetime as dt_module
12import os
23import sys
34from functools import cached_property
2324 ExecutionMode ,
2425 Instrument ,
2526 Order ,
26- PerformanceMetrics ,
2727)
2828from .akquant import (
2929 BacktestResult as RustBacktestResult ,
@@ -178,14 +178,27 @@ def orders(self) -> List[Order]:
178178 return []
179179
180180 @property
181- def metrics (self ) -> PerformanceMetrics :
182- """
183- Get performance metrics as a raw object (Raw Access).
184-
185- This is the raw Rust object containing all metrics fields.
186- For a DataFrame view, use `metrics_df`.
187- """
188- return self ._raw .metrics
181+ def metrics (self ) -> Any :
182+ """Get metrics with timezone-aware datetime conversion."""
183+ metrics = self ._raw .metrics
184+
185+ class MetricsWrapper :
186+ def __init__ (self , raw_metrics : Any , timezone : str ) -> None :
187+ self ._raw = raw_metrics
188+ self ._timezone = timezone
189+
190+ def __getattr__ (self , name : str ) -> Any :
191+ val = getattr (self ._raw , name )
192+ if name in ["start_time" , "end_time" ]:
193+ # Convert ns timestamp to datetime
194+ if isinstance (val , int ):
195+ dt = pd .to_datetime (val , unit = "ns" , utc = True ).tz_convert (
196+ self ._timezone
197+ )
198+ return dt
199+ return val
200+
201+ return MetricsWrapper (metrics , self ._timezone )
189202
190203 @property
191204 def positions (self ) -> pd .DataFrame :
@@ -636,6 +649,8 @@ def run_backtest(
636649 warmup_period : int = 0 ,
637650 lot_size : Union [int , Dict [str , int ], None ] = None ,
638651 show_progress : Optional [bool ] = None ,
652+ start_time : Optional [Union [str , Any ]] = None ,
653+ end_time : Optional [Union [str , Any ]] = None ,
639654 config : Optional [BacktestConfig ] = None ,
640655 instruments_config : Optional [
641656 Union [List [InstrumentConfig ], Dict [str , InstrumentConfig ]]
@@ -663,9 +678,30 @@ def run_backtest(
663678 :param lot_size: 最小交易单位。如果是 int,则应用于所有标的;
664679 如果是 Dict[str, int],则按代码匹配;如果不传(None),默认为 1。
665680 :param show_progress: 是否显示进度条 (默认 True)
681+ :param start_time: 回测开始时间 (e.g., "2020-01-01 09:30"). 优先级高于
682+ config.start_time.
683+ :param end_time: 回测结束时间 (e.g., "2020-12-31 15:00"). 优先级高于
684+ config.end_time.
666685 :param config: BacktestConfig 配置对象 (可选)
667686 :param instruments_config: 标的配置列表或字典 (可选)
668687 :return: 回测结果 Result 对象
688+
689+ 配置优先级说明 (Parameter Priority):
690+ ----------------------------------
691+ 本函数参数采用以下优先级顺序解析(由高到低):
692+
693+ 1. **Explicit Arguments (显式参数)**:
694+ 直接传递给 `run_backtest` 的参数优先级最高。
695+ 例如: `run_backtest(..., start_time="2022-01-01")` 会覆盖 Config 中的设置。
696+
697+ 2. **Configuration Objects (配置对象)**:
698+ 如果显式参数为 `None`,则尝试从 `config` (`BacktestConfig`) 及其子配置
699+ (`StrategyConfig`) 中读取。
700+ 例如: `config.start_time` 或 `config.strategy_config.initial_cash`。
701+
702+ 3. **Default Values (默认值)**:
703+ 如果上述两者都未提供,则使用系统默认值。
704+ 例如: `initial_cash` 默认为 1,000,000。
669705 """
670706 # 0. 设置默认值 (如果未传入且未在 Config 中设置)
671707 # 优先级: 参数 > Config > 默认值
@@ -687,10 +723,19 @@ def run_backtest(
687723 # Resolve Commission Rate
688724 if commission_rate is None :
689725 if config and config .strategy_config :
690- commission_rate = config .strategy_config .fee_amount
726+ commission_rate = config .strategy_config .commission_rate
691727 else :
692728 commission_rate = DEFAULT_COMMISSION_RATE
693729
730+ # Resolve Other Fees (if not passed as args, check config)
731+ if config and config .strategy_config :
732+ if stamp_tax_rate == 0.0 :
733+ stamp_tax_rate = config .strategy_config .stamp_tax_rate
734+ if transfer_fee_rate == 0.0 :
735+ transfer_fee_rate = config .strategy_config .transfer_fee_rate
736+ if min_commission == 0.0 :
737+ min_commission = config .strategy_config .min_commission
738+
694739 # Resolve Timezone
695740 if timezone is None :
696741 if config and config .timezone :
@@ -729,11 +774,23 @@ def run_backtest(
729774 )
730775
731776 # 1.5 处理 Config 覆盖 (剩余部分)
732- if config :
733- if config .start_date :
734- kwargs ["start_date" ] = config .start_date
735- if config .end_date :
736- kwargs ["end_date" ] = config .end_date
777+ # Resolve effective start/end time for filtering
778+ # Priority: explicit argument > config
779+
780+ if start_time is None :
781+ if config and config .start_time :
782+ start_time = config .start_time
783+
784+ if end_time is None :
785+ if config and config .end_time :
786+ end_time = config .end_time
787+
788+ # Update kwargs if needed by strategy (optional, can be removed if strategies
789+ # don't need it)
790+ if start_time :
791+ kwargs ["start_time" ] = start_time
792+ if end_time :
793+ kwargs ["end_time" ] = end_time
737794
738795 # 注意: initial_cash, commission_rate, timezone, show_progress, history_depth
739796 # 已经在上方通过优先级逻辑处理过了,这里不需要再覆盖
@@ -844,6 +901,58 @@ def run_backtest(
844901 if data is not None :
845902 # Use provided data
846903 if isinstance (data , pd .DataFrame ):
904+ # Ensure index is datetime
905+ if not isinstance (data .index , pd .DatetimeIndex ):
906+ # Try to find a date column if index is not date
907+ # Common candidates: "date", "timestamp", "datetime"
908+ found_date = False
909+ for col in ["date" , "timestamp" , "datetime" , "Date" , "Timestamp" ]:
910+ if col in data .columns :
911+ data = data .set_index (col )
912+ found_date = True
913+ break
914+
915+ if not found_date :
916+ # try convert index
917+ try :
918+ data .index = pd .to_datetime (data .index )
919+ except Exception :
920+ pass
921+
922+ # Ensure index is pd.Timestamp compatible
923+ # (convert datetime.date to Timestamp)
924+ # This is handled by pd.to_datetime but let's be safe for object index
925+ if data .index .dtype == "object" :
926+ try :
927+ data .index = pd .to_datetime (data .index )
928+ except Exception :
929+ pass
930+
931+ # Filter by date if provided
932+ if start_time :
933+ # Handle potential mismatch between Timestamp and datetime.date
934+ ts_start = pd .Timestamp (start_time )
935+ # If index is date objects, compare with date()
936+ if (
937+ len (data ) > 0
938+ and isinstance (data .index [0 ], (dt_module .date ))
939+ and not isinstance (data .index [0 ], dt_module .datetime )
940+ ):
941+ data = data [data .index >= ts_start .date ()]
942+ else :
943+ data = data [data .index >= ts_start ]
944+
945+ if end_time :
946+ ts_end = pd .Timestamp (end_time )
947+ if (
948+ len (data ) > 0
949+ and isinstance (data .index [0 ], (dt_module .date ))
950+ and not isinstance (data .index [0 ], dt_module .datetime )
951+ ):
952+ data = data [data .index <= ts_end .date ()]
953+ else :
954+ data = data [data .index <= ts_end ]
955+
847956 # Try to infer symbol from DataFrame if not explicitly provided or default
848957 if (not symbols or symbols == ["BENCHMARK" ]) and "symbol" in data .columns :
849958 unique_symbols = data ["symbol" ].unique ()
@@ -872,6 +981,28 @@ def run_backtest(
872981 if filter_symbols and sym not in symbols :
873982 continue
874983
984+ # Ensure index is datetime
985+ if not isinstance (df .index , pd .DatetimeIndex ):
986+ # Try to find a date column if index is not date
987+ found_date = False
988+ for col in ["date" , "timestamp" , "datetime" , "Date" , "Timestamp" ]:
989+ if col in df .columns :
990+ df = df .set_index (col )
991+ found_date = True
992+ break
993+
994+ if not found_date :
995+ try :
996+ df .index = pd .to_datetime (df .index )
997+ except Exception :
998+ pass
999+
1000+ # Filter by date
1001+ if start_time :
1002+ df = df [df .index >= pd .Timestamp (start_time )]
1003+ if end_time :
1004+ df = df [df .index <= pd .Timestamp (end_time )]
1005+
8751006 df_prep = prepare_dataframe (df )
8761007 data_map_for_indicators [sym ] = df_prep
8771008 arrays = df_to_arrays (df_prep , symbol = sym )
@@ -881,6 +1012,15 @@ def run_backtest(
8811012 feed .sort ()
8821013 elif isinstance (data , list ):
8831014 if data :
1015+ # Filter by date
1016+ if start_time :
1017+ # Explicitly convert to int to satisfy mypy
1018+ ts_start : int = int (pd .Timestamp (start_time ).value ) # type: ignore
1019+ data = [b for b in data if b .timestamp >= ts_start ] # type: ignore
1020+ if end_time :
1021+ ts_end : int = int (pd .Timestamp (end_time ).value ) # type: ignore
1022+ data = [b for b in data if b .timestamp <= ts_end ] # type: ignore
1023+
8841024 data .sort (key = lambda b : b .timestamp )
8851025 feed .add_bars (data )
8861026 else :
@@ -889,13 +1029,12 @@ def run_backtest(
8891029 logger .warning ("No symbols specified and no data provided." )
8901030
8911031 catalog = ParquetDataCatalog ()
892- start_date = kwargs .get ("start_date" )
893- end_date = kwargs .get ("end_date" )
1032+ # start_time / end_time already resolved above
8941033
8951034 loaded_count = 0
8961035 for sym in symbols :
8971036 # Try Catalog
898- df = catalog .read (sym , start_date = start_date , end_date = end_date )
1037+ df = catalog .read (sym , start_time = start_time , end_time = end_time )
8991038 if df .empty :
9001039 logger .warning (f"Data not found in catalog for { sym } " )
9011040 continue
@@ -1163,8 +1302,6 @@ def plot_result(
11631302 :param benchmark: 基准收益率序列 (可选, Series with DatetimeIndex)
11641303 """
11651304 try :
1166- from datetime import datetime
1167-
11681305 import matplotlib .dates as mdates
11691306 import matplotlib .pyplot as plt
11701307 from matplotlib .gridspec import GridSpec
@@ -1190,9 +1327,9 @@ def plot_result(
11901327 if first_ts > 1e11 :
11911328 scale = 1e-9
11921329
1193- from datetime import timezone
1194-
11951330 # Use UTC to avoid local timezone issues and align with benchmark data
1331+ from datetime import datetime , timezone
1332+
11961333 times = [
11971334 datetime .fromtimestamp (t * scale , tz = timezone .utc ).replace (tzinfo = None )
11981335 for t , _ in equity_curve
0 commit comments