diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index ecd58ec2098..afe57a9aad2 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -471,7 +471,8 @@ def _generate(self, **kwargs): setattr(self, k, fill_placeholder(getattr(self, k), placeholder_value)) # if the backtesting time range is not set, it will automatically extract time range from the prediction file - dt_values = pred.index.get_level_values("datetime") + dt_level = pred.index.names.index("datetime") + dt_values = pred.index.get_level_values(dt_level) if self.backtest_config["start_time"] is None: self.backtest_config["start_time"] = dt_values.min() if self.backtest_config["end_time"] is None: @@ -617,7 +618,8 @@ def __init__(self, recorder, pass_num=10, shuffle_init_score=True, **kwargs): def random_init(self): pred_df = self.load("pred.pkl") - all_pred_dates = pred_df.index.get_level_values("datetime") + dt_level = pred_df.index.names.index("datetime") + all_pred_dates = pred_df.index.get_level_values(dt_level) bt_start_date = pd.to_datetime(self.backtest_config.get("start_time")) if bt_start_date is None: first_bt_pred_date = all_pred_dates.min() diff --git a/tests/test_record_temp_datetime.py b/tests/test_record_temp_datetime.py new file mode 100644 index 00000000000..a3d98143c51 --- /dev/null +++ b/tests/test_record_temp_datetime.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for record_temp datetime level fix (Fixes #1909). + +The fix replaces hard-coded get_level_values("datetime") with a +positional lookup via index.names.index("datetime"), so it works +regardless of the position of the datetime level in the MultiIndex. +""" + +import numpy as np +import pandas as pd +import pytest + + +def _make_multiindex(names, n_dates=3, n_instruments=2): + """Helper to build a MultiIndex with the given level names.""" + dates = pd.date_range("2023-01-01", periods=n_dates, freq="D") + instruments = [f"STOCK_{chr(65 + i)}" for i in range(n_instruments)] + arrays = { + "datetime": np.repeat(dates, n_instruments), + "instrument": np.tile(instruments, n_dates), + } + # Build in the order specified by `names` + return pd.MultiIndex.from_arrays([arrays[n] for n in names], names=names) + + +def test_datetime_first_level(): + """Standard order: (datetime, instrument).""" + idx = _make_multiindex(["datetime", "instrument"]) + dt_level = idx.names.index("datetime") + dt_values = idx.get_level_values(dt_level) + assert len(dt_values) == 6 + assert dt_values[0] == pd.Timestamp("2023-01-01") + + +def test_datetime_second_level(): + """Reversed order: (instrument, datetime).""" + idx = _make_multiindex(["instrument", "datetime"]) + dt_level = idx.names.index("datetime") + dt_values = idx.get_level_values(dt_level) + assert len(dt_values) == 6 + assert dt_values[0] == pd.Timestamp("2023-01-01") + + +def test_missing_datetime_raises(): + """If datetime level is absent, index() should raise ValueError.""" + idx = pd.MultiIndex.from_tuples( + [("A", 1), ("B", 2)], names=["instrument", "id"] + ) + with pytest.raises(ValueError): + idx.names.index("datetime") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])