diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index 9daba911533..e6bde47c745 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -116,6 +116,7 @@ def create_account_instance( benchmark: Optional[str], account: Union[float, int, dict], pos_type: str = "Position", + freq: str = "day", ) -> Account: """ # TODO: is very strange pass benchmark_config in the account (maybe for report) @@ -161,6 +162,7 @@ def create_account_instance( return Account( init_cash=init_cash, position_dict=position_dict, + freq=freq, pos_type=pos_type, benchmark_config=( {} @@ -183,6 +185,7 @@ def get_strategy_executor( account: Union[float, int, dict] = 1e9, exchange_kwargs: dict = {}, pos_type: str = "Position", + freq: str = "day", ) -> Tuple[BaseStrategy, BaseExecutor]: # NOTE: # - for avoiding recursive import @@ -196,6 +199,7 @@ def get_strategy_executor( benchmark=benchmark, account=account, pos_type=pos_type, + freq=freq, ) exchange_kwargs = copy.copy(exchange_kwargs) @@ -223,6 +227,7 @@ def backtest( account: Union[float, int, dict] = 1e9, exchange_kwargs: dict = {}, pos_type: str = "Position", + freq: str = "day", ) -> Tuple[PORT_METRIC, INDICATOR_METRIC]: """initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and executor in the nested decision execution @@ -272,6 +277,7 @@ def backtest( account, exchange_kwargs, pos_type=pos_type, + freq=freq, ) return backtest_loop(start_time, end_time, trade_strategy, trade_executor) @@ -286,6 +292,7 @@ def collect_data( exchange_kwargs: dict = {}, pos_type: str = "Position", return_value: dict | None = None, + freq: str = "day", ) -> Generator[object, None, None]: """initialize the strategy and executor, then collect the trade decision data for rl training @@ -305,6 +312,7 @@ def collect_data( account, exchange_kwargs, pos_type=pos_type, + freq=freq, ) yield from collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value=return_value) diff --git a/tests/test_backtest_account_freq.py b/tests/test_backtest_account_freq.py new file mode 100644 index 00000000000..1075fbfa090 --- /dev/null +++ b/tests/test_backtest_account_freq.py @@ -0,0 +1,72 @@ +"""Test for Issue #1846: Backtest should thread freq to Account, not hardcode 'day'.""" +import pytest +from unittest.mock import patch +from qlib.backtest import create_account_instance +from qlib.backtest.account import Account + + +class TestBacktestAccountFreq: + """Verify that freq parameter flows from backtest config to Account.""" + + def test_account_direct_freq_day(self): + """Account class should default to freq='day'.""" + account = Account(init_cash=1e6, freq="day", port_metr_enabled=False) + assert account.freq == "day" + + def test_account_direct_freq_custom(self): + """Account class should store custom freq.""" + account = Account(init_cash=1e6, freq="30min", port_metr_enabled=False) + assert account.freq == "30min" + + def test_account_direct_freq_60min(self): + """Verify 60min freq is threaded correctly to Account.""" + account = Account(init_cash=1e6, freq="60min", port_metr_enabled=False) + assert account.freq == "60min" + + @patch("qlib.backtest.Account") + def test_create_account_instance_passes_freq(self, mock_account_cls): + """create_account_instance should forward freq to Account constructor.""" + mock_account_cls.return_value = mock_account_cls + create_account_instance( + start_time="2020-01-01", + end_time="2020-12-31", + benchmark=None, + account=1e6, + freq="60min", + ) + # Verify Account was called with freq="60min" + mock_account_cls.assert_called_once() + call_kwargs = mock_account_cls.call_args + assert call_kwargs.kwargs.get("freq") == "60min" or \ + (len(call_kwargs.args) > 2 and call_kwargs.args[2] == "60min") + + @patch("qlib.backtest.Account") + def test_create_account_instance_default_freq_is_day(self, mock_account_cls): + """create_account_instance without freq should default to 'day'.""" + mock_account_cls.return_value = mock_account_cls + create_account_instance( + start_time="2020-01-01", + end_time="2020-12-31", + benchmark=None, + account=1e6, + ) + call_kwargs = mock_account_cls.call_args + assert call_kwargs.kwargs.get("freq") == "day" + + @patch("qlib.backtest.Account") + def test_create_account_freq_not_hardcoded(self, mock_account_cls): + """Ensure freq='1min' doesn't silently become 'day'.""" + mock_account_cls.return_value = mock_account_cls + create_account_instance( + start_time="2020-01-01", + end_time="2020-12-31", + benchmark=None, + account=1e6, + freq="1min", + ) + call_kwargs = mock_account_cls.call_args + assert call_kwargs.kwargs.get("freq") == "1min" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])