Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion qlib/data/storage/file_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,29 @@ class FileFeatureStorage(FileStorageMixin, FeatureStorage):
def __init__(self, instrument: str, field: str, freq: str, provider_uri: dict = None, **kwargs):
super(FileFeatureStorage, self).__init__(instrument, field, freq, **kwargs)
self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri)
self.file_name = f"{instrument.lower()}/{field.lower()}.{freq.lower()}.bin"

base_uri = self.dpm.get_data_uri(self.freq) # Assuming dpm is available and get_data_uri provides the base path

# Candidate 1: Original Case (Preferred for correct behavior on Linux)
name_orig = f"{instrument}/{field.lower()}.{freq.lower()}.bin"
if (base_uri / name_orig).exists():
self.file_name = name_orig
return

# Candidate 2: Uppercase (Fix for lowercase input finding uppercase folder on case-sensitive OS)
name_upper = f"{instrument.upper()}/{field.lower()}.{freq.lower()}.bin"
if (base_uri / name_upper).exists():
self.file_name = name_upper
return

# Candidate 3: Lowercase (Backward Compatibility)
name_lower = f"{instrument.lower()}/{field.lower()}.{freq.lower()}.bin"
if (base_uri / name_lower).exists():
self.file_name = name_lower
return

# Default: Original Case (For new files)
self.file_name = name_orig

def clear(self):
with self.uri.open("wb") as _:
Expand Down
2 changes: 2 additions & 0 deletions qlib/workflow/online/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def prepare_tasks(self, cur_time) -> List[dict]:
self.logger.info(
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
)
if self.ta.cal_interval(calendar_latest, max_test[0]) < self.rg.step:
return []
res = []
for rec in latest_records:
task = rec.load_object("task")
Expand Down
118 changes: 118 additions & 0 deletions tests/test_issue_2045.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import unittest
from unittest.mock import MagicMock, patch
import pandas as pd
from qlib.workflow.online.strategy import RollingStrategy

class TestIssue2045(unittest.TestCase):
@classmethod
def setUpClass(cls):
import qlib
from pathlib import Path
import shutil
import os

# Create dummy data for init
cls.dummy_dir = Path("dummy_data_2045")
if cls.dummy_dir.exists():
shutil.rmtree(cls.dummy_dir)
(cls.dummy_dir / "calendars").mkdir(parents=True)
(cls.dummy_dir / "instruments").mkdir(parents=True)
(cls.dummy_dir / "features").mkdir(parents=True)

with open(cls.dummy_dir / "calendars" / "day.txt", "w") as f:
f.write("2020-01-01\n2020-01-02\n")

# Initialize
qlib.init(provider_uri={"day": str(cls.dummy_dir.absolute())})

@classmethod
def tearDownClass(cls):
import shutil
if cls.dummy_dir.exists():
shutil.rmtree(cls.dummy_dir)

def test_prepare_tasks_interval_check(self):
# Mock dependencies
# Create a mock that looks like RollingGen for the instance check
from qlib.workflow.task.gen import RollingGen
mock_rg = MagicMock(spec=RollingGen)
mock_rg.step = 10 # Rolling step is 10
mock_rg.gen_following_tasks.return_value = [{"task": "new_task"}]

# Mock TimeAdjuster to return an interval smaller than step
# We'll patch the class used inside RollingStrategy
with patch("qlib.workflow.online.strategy.TimeAdjuster") as MockTimeAdjuster:
mock_ta_instance = MockTimeAdjuster.return_value
mock_ta_instance.cal_interval.return_value = 5 # Interval (5) < Step (10)

# Mock OnlineToolR (attached to strategy)
# We need to patch where it is imported or used.
# In strategy.py: self.tool = OnlineToolR(self.exp_name)
with patch("qlib.workflow.online.strategy.OnlineToolR") as MockOnlineToolR:
mock_tool_instance = MockOnlineToolR.return_value

# Setup mock recorder
mock_recorder = MagicMock()
# Mock task config structure: task["dataset"]["kwargs"]["segments"]["test"] -> (start, end)
# max_test will be the max of these tuples.
mock_recorder.load_object.return_value = {
"dataset": {
"kwargs": {
"segments": {
"test": (pd.Timestamp("2021-01-01"), pd.Timestamp("2021-01-10"))
}
}
}
}

mock_tool_instance.online_models.return_value = [mock_recorder]

# Instantiate strategy
strategy = RollingStrategy(name_id="test_exp", task_template={}, rolling_gen=mock_rg)

# Replace the internal tool/ta with our detailed mocks if needed,
# but patch should have handled the initialization if we did it right.
# However, RollingStrategy.__init__ calls TimeAdjuster(), so mocking class works.
# Same for OnlineToolR.

# Call prepare_tasks
# cur_time doesn't matter much because we mocked cal_interval,
# BUT transform_end_date is called on it.
cur_time = pd.Timestamp("2021-01-15")

# EXECUTE
tasks = strategy.prepare_tasks(cur_time)

# VERIFY
# Expected behavior (Fix): Should return [] because 5 < 10.
self.assertEqual(len(tasks), 0, "Should NOT generate tasks when interval < step")

def test_prepare_tasks_normal(self):
# Mock dependencies
from qlib.workflow.task.gen import RollingGen
mock_rg = MagicMock(spec=RollingGen)
mock_rg.step = 10
mock_rg.gen_following_tasks.return_value = [{"task": "new_task"}]

with patch("qlib.workflow.online.strategy.TimeAdjuster") as MockTimeAdjuster:
mock_ta_instance = MockTimeAdjuster.return_value
mock_ta_instance.cal_interval.return_value = 15 # Interval (15) > Step (10)

with patch("qlib.workflow.online.strategy.OnlineToolR") as MockOnlineToolR:
mock_tool_instance = MockOnlineToolR.return_value
mock_recorder = MagicMock()
mock_recorder.load_object.return_value = {
"dataset": { "kwargs": { "segments": { "test": (pd.Timestamp("2021-01-01"), pd.Timestamp("2021-01-10")) } } }
}
mock_tool_instance.online_models.return_value = [mock_recorder]

strategy = RollingStrategy(name_id="test_exp", task_template={}, rolling_gen=mock_rg)

cur_time = pd.Timestamp("2021-01-25")
tasks = strategy.prepare_tasks(cur_time)

self.assertEqual(len(tasks), 1, "Should generate tasks when interval > step")


if __name__ == "__main__":
unittest.main()
Loading