diff --git a/rqalpha_mod_incremental/mod.py b/rqalpha_mod_incremental/mod.py index 62432ce..f972a25 100644 --- a/rqalpha_mod_incremental/mod.py +++ b/rqalpha_mod_incremental/mod.py @@ -41,8 +41,15 @@ def start_up(self, env, mod_config): self._recorder = None self._mod_config = mod_config - if not self._mod_config.persist_folder: - return + if mod_config.recorder == "CsvRecorder": + if mod_config.persist_folder is None: + raise RuntimeError(_(u"You need to set persist_folder to use CsvRecorder!")) + elif mod_config.recorder == "MongodbRecorder": + if mod_config.strategy_id is None or mod_config.mongo_url is None or mod_config.mongo_dbname is None: + raise RuntimeError(_(u"MongodbRecorder requires strategy_id, mongo_url and mongo_dbname! " + u"But got {}").format(mod_config)) + else: + raise RuntimeError(_(u"unknown recorder {}").format(mod_config.recorder)) config = self._env.config if not env.data_source: @@ -60,14 +67,10 @@ def _set_env_and_data_source(self): mod_config = self._mod_config system_log.info("use recorder {}", mod_config.recorder) if mod_config.recorder == "CsvRecorder": - if not mod_config.persist_folder: - raise RuntimeError(_(u"You need to set persist_folder to use CsvRecorder")) persist_folder = os.path.join(mod_config.persist_folder, "persist", str(mod_config.strategy_id)) persist_provider = DiskPersistProvider(persist_folder) self._recorder = recorders.CsvRecorder(persist_folder) elif mod_config.recorder == "MongodbRecorder": - if mod_config.strategy_id is None: - raise RuntimeError(_(u"You need to set strategy_id")) persist_provider = persist_providers.MongodbPersistProvider(mod_config.strategy_id, mod_config.mongo_url, mod_config.mongo_dbname) self._recorder = recorders.MongodbRecorder(mod_config.strategy_id, @@ -92,12 +95,14 @@ def _set_env_and_data_source(self): if persist_meta: # 不修改回测开始时间 self._env.config.base.start_date = datetime.datetime.strptime(persist_meta['start_date'], '%Y-%m-%d').date() - event_start_time = datetime.datetime.strptime(persist_meta['last_end_time'], '%Y-%m-%d').date() + datetime.timedelta(days=1) + event_start_time = datetime.datetime.strptime(persist_meta['last_end_time'], + '%Y-%m-%d').date() + datetime.timedelta(days=1) # 代表历史有运行过,根据历史上次运行的end_date下一天设为事件发送的start_time self._meta["origin_start_date"] = persist_meta["origin_start_date"] self._meta["start_date"] = persist_meta["start_date"] if self._meta["last_end_time"] <= persist_meta["last_end_time"]: - raise ValueError('The end_date should after end_date({}) last time'.format(persist_meta["last_end_time"])) + raise ValueError( + 'The end_date should after end_date({}) last time'.format(persist_meta["last_end_time"])) self._last_end_date = datetime.datetime.strptime(persist_meta["last_end_time"], "%Y-%m-%d").date() self._event_start_time = event_start_time self._overwrite_event_data_source_func() @@ -137,8 +142,6 @@ def on_settlement(self, event): return True def tear_down(self, success, exception=None): - if not self._mod_config.persist_folder: - return if exception is None: self._recorder.store_meta(self._meta) self._recorder.flush() diff --git a/rqalpha_mod_incremental/persist_providers.py b/rqalpha_mod_incremental/persist_providers.py index 289ab48..9ee46b0 100644 --- a/rqalpha_mod_incremental/persist_providers.py +++ b/rqalpha_mod_incremental/persist_providers.py @@ -1,23 +1,52 @@ import os import datetime - +import jsonpickle +import pandas as pd +from rqrisk import Risk from rqalpha.interface import AbstractPersistProvider +def get_performance(strategy_id, analysis_data): + daily_returns = analysis_data['portfolio_daily_returns'] + benchmark = analysis_data['benchmark_daily_returns'] + dates = [p['date'] for p in analysis_data['total_portfolios']] + assert len(daily_returns) == len(benchmark) == len(dates), 'unmatched length' + daily_returns = pd.Series(daily_returns, index=dates) + benchmark = pd.Series(benchmark, index=dates) + risk = Risk(daily_returns, benchmark, 0.) + perf = risk.all() + perf['strategy_id'] = strategy_id + perf['update_time'] = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + perf['start_date'] = analysis_data['total_portfolios'][0]['date'].strftime('%Y-%m-%d') + perf['end_date'] = analysis_data['total_portfolios'][-1]['date'].strftime('%Y-%m-%d') + return perf + + class MongodbPersistProvider(AbstractPersistProvider): def __init__(self, strategy_id, mongo_url, mongo_db): import pymongo import gridfs - persist_db = pymongo.MongoClient(mongo_url)[mongo_db] + self.persist_db = pymongo.MongoClient(mongo_url)[mongo_db] self._strategy_id = strategy_id - self._fs = gridfs.GridFS(persist_db) + self._fs = gridfs.GridFS(self.persist_db) def store(self, key, value): update_time = datetime.datetime.now() self._fs.put(value, strategy_id=self._strategy_id, key=key, update_time=update_time) - for grid_out in self._fs.find({"strategy_id": self._strategy_id, "key": key, "update_time": {"$lt": update_time}}): + for grid_out in self._fs.find( + {"strategy_id": self._strategy_id, "key": key, "update_time": {"$lt": update_time}}): self._fs.delete(grid_out._id) + if key == "mod_sys_analyser": + self._store_performance(value) + + def _store_performance(self, analysis_data): + try: + perf = get_performance(self._strategy_id, + jsonpickle.loads(analysis_data.decode("utf-8"))) + self.persist_db['performance'].update({"strategy_id": self._strategy_id}, perf, upsert=True) + except Exception as e: + print(e) def load(self, key, large_file=False): import gridfs @@ -27,6 +56,12 @@ def load(self, key, large_file=False): except gridfs.errors.NoFile: return None + def should_resume(self): + return False + + def should_run_init(self): + return False + class DiskPersistProvider(AbstractPersistProvider): def __init__(self, path="./persist"):