From f98bd8074568ebd6e74d164fb09df6ecf99ab986 Mon Sep 17 00:00:00 2001 From: Eden Rochman Date: Fri, 27 Mar 2026 14:01:00 +0000 Subject: [PATCH] Fix record_temp to look up datetime level by position PortAnaRecord._generate and MultiPassPortAnaRecord.random_init used pred.index.get_level_values("datetime") which assumes datetime is a named level that pandas can resolve directly. This can fail when the MultiIndex level ordering differs. Now we first resolve the position with index.names.index("datetime") and pass the integer position to get_level_values(). Fixes #1909 --- qlib/workflow/record_temp.py | 6 ++-- tests/test_record_temp_datetime.py | 56 ++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 tests/test_record_temp_datetime.py diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index ecd58ec2098..afe57a9aad2 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -471,7 +471,8 @@ def _generate(self, **kwargs): setattr(self, k, fill_placeholder(getattr(self, k), placeholder_value)) # if the backtesting time range is not set, it will automatically extract time range from the prediction file - dt_values = pred.index.get_level_values("datetime") + dt_level = pred.index.names.index("datetime") + dt_values = pred.index.get_level_values(dt_level) if self.backtest_config["start_time"] is None: self.backtest_config["start_time"] = dt_values.min() if self.backtest_config["end_time"] is None: @@ -617,7 +618,8 @@ def __init__(self, recorder, pass_num=10, shuffle_init_score=True, **kwargs): def random_init(self): pred_df = self.load("pred.pkl") - all_pred_dates = pred_df.index.get_level_values("datetime") + dt_level = pred_df.index.names.index("datetime") + all_pred_dates = pred_df.index.get_level_values(dt_level) bt_start_date = pd.to_datetime(self.backtest_config.get("start_time")) if bt_start_date is None: first_bt_pred_date = all_pred_dates.min() diff --git a/tests/test_record_temp_datetime.py b/tests/test_record_temp_datetime.py new file mode 100644 index 00000000000..a3d98143c51 --- /dev/null +++ b/tests/test_record_temp_datetime.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for record_temp datetime level fix (Fixes #1909). + +The fix replaces hard-coded get_level_values("datetime") with a +positional lookup via index.names.index("datetime"), so it works +regardless of the position of the datetime level in the MultiIndex. +""" + +import numpy as np +import pandas as pd +import pytest + + +def _make_multiindex(names, n_dates=3, n_instruments=2): + """Helper to build a MultiIndex with the given level names.""" + dates = pd.date_range("2023-01-01", periods=n_dates, freq="D") + instruments = [f"STOCK_{chr(65 + i)}" for i in range(n_instruments)] + arrays = { + "datetime": np.repeat(dates, n_instruments), + "instrument": np.tile(instruments, n_dates), + } + # Build in the order specified by `names` + return pd.MultiIndex.from_arrays([arrays[n] for n in names], names=names) + + +def test_datetime_first_level(): + """Standard order: (datetime, instrument).""" + idx = _make_multiindex(["datetime", "instrument"]) + dt_level = idx.names.index("datetime") + dt_values = idx.get_level_values(dt_level) + assert len(dt_values) == 6 + assert dt_values[0] == pd.Timestamp("2023-01-01") + + +def test_datetime_second_level(): + """Reversed order: (instrument, datetime).""" + idx = _make_multiindex(["instrument", "datetime"]) + dt_level = idx.names.index("datetime") + dt_values = idx.get_level_values(dt_level) + assert len(dt_values) == 6 + assert dt_values[0] == pd.Timestamp("2023-01-01") + + +def test_missing_datetime_raises(): + """If datetime level is absent, index() should raise ValueError.""" + idx = pd.MultiIndex.from_tuples( + [("A", 1), ("B", 2)], names=["instrument", "id"] + ) + with pytest.raises(ValueError): + idx.names.index("datetime") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])