Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions agent/backtest/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Benchmark ticker resolution and fetch for backtest comparison.

Provides a lightweight, zero-dependency way to fetch benchmark reference
data given a set of strategy codes and a data source.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

import pandas as pd

from backtest.loaders.yfinance_loader import DataLoader as YfinanceLoader


# -------------------------------------------------------------------
# Benchmark map: market type → default ticker
# -------------------------------------------------------------------

MARKET_BENCHMARKS: dict[str, Optional[str]] = {
"us_equity": "SPY",
"hk_equity": "HK.03100", # Hang Seng China Enterprises ETF
"a_share": "000300.SH", # CSI 300 (China A-share core index)
"crypto": "BTC-USDT",
"futures": "ES.CME", # E-mini S&P 500 futures
"forex": None, # no universal benchmark
}


@dataclass
class BenchmarkResult:
ticker: str
ret_series: pd.Series # per-bar returns, index = timestamps
total_ret: float # total return over the period


def resolve_benchmark(
strategy_codes: list[str],
source: str,
start_date: str,
end_date: str,
interval: str = "1D",
explicit: Optional[str] = None,
) -> Optional[BenchmarkResult]:
"""Resolve the appropriate benchmark ticker and fetch its return series.

Args:
strategy_codes: Instruments being backtested (used for market inference).
source: Data source name (tushare / yfinance / okx / akshare / ccxt).
start_date: Backtest start date.
end_date: Backtest end date.
interval: Bar interval (1m / 5m / 15m / 30m / 1H / 4H / 1D).
explicit: Override ticker (e.g. "SPY" passed via config).

Returns:
BenchmarkResult with return series and total return, or None if no
benchmark applies (forex, or fetch failure).
"""
ticker = _resolve_ticker(strategy_codes, source, explicit)
if ticker is None:
return None

try:
bench_df = _fetch_benchmark(ticker, start_date, end_date, interval)
except Exception:
return None

if bench_df.empty or "close" not in bench_df.columns:
return None

close = bench_df["close"].dropna()
if len(close) < 2:
return None

ret_series = close.pct_change().fillna(0.0)
total_ret = float((1 + ret_series).prod() - 1)

return BenchmarkResult(ticker=ticker, ret_series=ret_series, total_ret=total_ret)


# -------------------------------------------------------------------
# Internal helpers
# -------------------------------------------------------------------

def _resolve_ticker(
codes: list[str],
source: str,
explicit: Optional[str],
) -> Optional[str]:
"""Pick the benchmark ticker to use."""

if explicit:
return explicit

# Infer market from source + first code pattern
market = _infer_market(codes, source)
ticker = MARKET_BENCHMARKS.get(market)

# yfinance is the universal fallback for benchmark fetch
# but it only works for us_equity / hk_equity market types
if ticker and market not in {"us_equity", "hk_equity"}:
# Only use benchmark if we can actually fetch it
pass

return ticker


def _infer_market(codes: list[str], source: str) -> str:
"""Rough market inference from symbol patterns and source."""
if not codes:
return "us_equity"

first = codes[0].upper()

if source in ("okx", "ccxt") or "-" in first or "/" in first:
return "crypto"
if first.endswith(".US"):
return "us_equity"
if first.endswith(".HK"):
return "hk_equity"
if source in ("tushare", "akshare"):
if first.isdigit() and len(first) == 6:
return "a_share"
if first.startswith(("IF", "IC", "IH", "IM", "T", "TF")):
return "futures"
return "a_share"

return "us_equity"


def _fetch_benchmark(
ticker: str,
start_date: str,
end_date: str,
interval: str,
) -> pd.DataFrame:
"""Fetch benchmark OHLCV data via yfinance (single symbol, no auth)."""
loader = YfinanceLoader()
result = loader.fetch([ticker], start_date, end_date, interval=interval)

if isinstance(result, dict):
df = result.get(ticker)
elif isinstance(result, pd.DataFrame):
df = result
else:
return pd.DataFrame()

if df is None or (isinstance(df, pd.DataFrame) and df.empty):
return pd.DataFrame()

return df
20 changes: 20 additions & 0 deletions agent/backtest/engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,26 @@ def run_backtest(
index=[s.timestamp for s in self.equity_snapshots],
)
bench_ret = ret_df.mean(axis=1) if ret_df.shape[1] > 0 else pd.Series(0.0, index=dates)

# ── External benchmark fetch ──────────────────────────────────────────
bench_ticker = config.get("benchmark")
if bench_ticker and bench_ticker != "auto":
from backtest.benchmark import resolve_benchmark
bench_result = resolve_benchmark(
strategy_codes=codes,
source=config.get("source", "yfinance"),
start_date=config.get("start_date", ""),
end_date=config.get("end_date", ""),
interval=interval,
explicit=bench_ticker,
)
if bench_result is not None:
bench_ret = bench_result.ret_series.reindex(dates).fillna(0.0)
bench_equity = self.initial_capital * (1 + bench_ret).cumprod()
m["benchmark_ticker"] = bench_result.ticker
m["benchmark_return"] = bench_result.total_ret
# ── External benchmark fetch ──────────────────────────────────────────

bench_equity = self.initial_capital * (1 + bench_ret).cumprod()

# 6. Metrics
Expand Down
87 changes: 78 additions & 9 deletions agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,23 +295,80 @@ def on_event(event_type: str, data: Dict[str, Any]) -> None:
return agent.run(user_message=prompt, history=history)


def _build_benchmark_table(m: dict) -> Optional[Table]:
"""Build a benchmark comparison table from metrics dict.

Args:
m: Metrics dictionary (from _read_metrics or result dict).

Returns:
Rich Table, or None if no benchmark data is present.
"""
bench_ticker = m.get("benchmark_ticker")
bench_ret_str = m.get("benchmark_return")
bench_ret_raw = m.get("_benchmark_return_raw")

# Fall back to equity.csv if benchmark cols not in metrics.csv yet
if not bench_ticker:
return None

# Parse benchmark return
if bench_ret_raw is not None:
bench_ret = bench_ret_raw
elif bench_ret_str is not None:
try:
bench_ret = float(bench_ret_str)
except (ValueError, TypeError):
bench_ret = None
else:
bench_ret = None

strategy_ret_str = m.get("total_return")
strategy_ret = float(strategy_ret_str) if strategy_ret_str else None

table = Table(show_header=False, padding=(0, 2))
table.add_column("Label", style="dim", width=20)
table.add_column("Value", style="white no_wrap")

table.add_row("[dim]Benchmark[/dim]", bench_ticker)

if bench_ret is not None:
table.add_row("[dim]Benchmark Return[/dim]", f"{bench_ret * 100:+.2f}%")

if strategy_ret is not None and bench_ret is not None:
excess = strategy_ret - bench_ret
sign = "+" if excess >= 0 else ""
style = "green" if excess >= 0 else "red"
table.add_row(
"[dim]vs Benchmark[/dim]",
f"[{style}]{sign}{excess * 100:+.2f}%[/{style}]",
)

ir_str = m.get("information_ratio")
if ir_str:
table.add_row("[dim]Info Ratio[/dim]", ir_str)

excess_str = m.get("excess_return")
if excess_str and excess_str != "0" and excess_str != "0.0000":
table.add_row("[dim]Excess Return[/dim]", f"{float(excess_str) * 100:+.2f}%")

return table


def _print_result(result: dict, elapsed: float, *, no_rich: bool = False) -> None:
"""Print execution result panel."""
status = result.get("status", "unknown")
ok = status == "success"
style = "green" if ok else "red"

lines = [f"Status: [bold {style}]{status.upper()}[/bold {style}] Time: {elapsed:.1f}s"]

if result.get("run_id"):
lines.append(f"ID: {result['run_id']}")

review = result.get("review")
if review and review.get("overall_score") is not None:
check = "\u2713" if review.get("passed") else "\u2717"
lines.append(f"Review: {review['overall_score']}pts {check}")

run_dir = result.get("run_dir")
m = {}
if run_dir:
m = _read_metrics(Path(run_dir) / "artifacts" / "metrics.csv")
parts = [f"{k}={m[k]}" for k in ("total_return", "sharpe", "max_drawdown", "trade_count") if k in m]
Expand All @@ -323,6 +380,17 @@ def _print_result(result: dict, elapsed: float, *, no_rich: bool = False) -> Non

console.print(Panel("\n".join(lines), border_style=style, title="Result"))

# ── Benchmark comparison panel ─────────────────────────────────────────────
bench_table = _build_benchmark_table(m)
if bench_table:
console.print(Panel(
bench_table,
border_style="cyan",
title="Benchmark Comparison",
padding=(0, 1),
))
# ── Benchmark comparison panel ─────────────────────────────────────────

content = result.get("content", "").strip()
if content:
console.print(f"\n{content}")
Expand Down Expand Up @@ -952,7 +1020,7 @@ def cmd_list(limit: int = 20) -> None:
st = _read_json(d / "state.json").get("status", "?")
m = _read_metrics(d / "artifacts" / "metrics.csv")
c = "green" if st == "success" else "red" if st == "failed" else "dim"
table.add_row(d.name, f"[{c}]{st}[/{c}]", m.get("total_return", ""), m.get("sharpe", ""), (_read_json(d / "req.json").get("prompt") or "")[:40])
table.add_row(d.name, f"[{c}]{st.upper()}[/{c}]", m.get("total_return", ""), m.get("sharpe", ""), (_read_json(d / "req.json").get("prompt") or "")[:40])

console.print(table)

Expand All @@ -970,8 +1038,9 @@ def cmd_show(run_id: str) -> None:

st = state.get("status", "unknown")
c = "green" if st == "success" else "red"
lines = [f"[bold]Status:[/bold] [{c}]{st.upper()}[/{c}]", f"[bold]Prompt:[/bold] {req.get('prompt', '?')}"]

lines = [f"[bold]Status:[/bold] [{c}]{st.upper()}[/{c}]"]
if req.get("prompt"):
lines.append(f"[bold]Prompt:[/bold] {req['prompt'][:500]}{'...' if len(req['prompt']) > 500 else ''}")
if metrics:
lines.append("\n[bold]Metrics:[/bold]")
lines.extend(f" {k}: {v}" for k, v in metrics.items())
Expand Down Expand Up @@ -1702,7 +1771,7 @@ def main(argv: list[str] | None = None) -> int:
if args.command == "list":
return _coerce_exit_code(cmd_list(args.list_limit))
if args.command == "show":
return _coerce_exit_code(cmd_show(args.run_id))
return _coerce_exit_code(cmd_show(args.show))
if args.command == "chat":
return _coerce_exit_code(cmd_interactive(args.chat_max_iter))

Expand Down Expand Up @@ -1741,7 +1810,7 @@ def main(argv: list[str] | None = None) -> int:
if args.chat:
return _coerce_exit_code(cmd_interactive(args.max_iter))
if args.cont:
return cmd_continue(args.cont[0], args.cont[1], args.max_iter, json_mode=args.json, no_rich=args.no_rich)
return _coerce_exit_code(cmd_continue(args.cont[0], args.cont[1], args.max_iter, json_mode=args.json, no_rich=args.no_rich))

# No flags, no subcommand → check if prompt provided, otherwise interactive mode
if args.prompt or args.prompt_file or not sys.stdin.isatty():
Expand Down