From ecbab5755acd4484ddd5e487785448edf3c2c8e2 Mon Sep 17 00:00:00 2001 From: warren618 Date: Mon, 23 Mar 2026 02:16:48 +0800 Subject: [PATCH] fix(model): add pretrain toggle to HIST model Added a pretrain parameter (default True for backward compatibility) to control whether pretrained base_model weights are loaded during fit(). When pretrain=False, the model trains from scratch. Fixes microsoft/qlib#2074 --- qlib/contrib/model/pytorch_hist.py | 39 ++++++++++++++++-------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/qlib/contrib/model/pytorch_hist.py b/qlib/contrib/model/pytorch_hist.py index 779cde9c859..de2a4cdc7ad 100644 --- a/qlib/contrib/model/pytorch_hist.py +++ b/qlib/contrib/model/pytorch_hist.py @@ -59,6 +59,7 @@ def __init__( optimizer="adam", GPU=0, seed=None, + pretrain=True, **kwargs, ): # Set logger. @@ -82,6 +83,7 @@ def __init__( self.stock_index = stock_index self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") self.seed = seed + self.pretrain = pretrain self.logger.info( "HIST parameters setting:" @@ -277,24 +279,25 @@ def fit( evals_result["valid"] = [] # load pretrained base_model - if self.base_model == "LSTM": - pretrained_model = LSTMModel() - elif self.base_model == "GRU": - pretrained_model = GRUModel() - else: - raise ValueError("unknown base model name `%s`" % self.base_model) - - if self.model_path is not None: - self.logger.info("Loading pretrained model...") - pretrained_model.load_state_dict(torch.load(self.model_path)) - - model_dict = self.HIST_model.state_dict() - pretrained_dict = { - k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135 - } - model_dict.update(pretrained_dict) - self.HIST_model.load_state_dict(model_dict) - self.logger.info("Loading pretrained model Done...") + if self.pretrain: + if self.base_model == "LSTM": + pretrained_model = LSTMModel() + elif self.base_model == "GRU": + pretrained_model = GRUModel() + else: + raise ValueError("unknown base model name `%s`" % self.base_model) + + if self.model_path is not None: + self.logger.info("Loading pretrained model...") + pretrained_model.load_state_dict(torch.load(self.model_path)) + + model_dict = self.HIST_model.state_dict() + pretrained_dict = { + k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135 + } + model_dict.update(pretrained_dict) + self.HIST_model.load_state_dict(model_dict) + self.logger.info("Loading pretrained model Done...") # train self.logger.info("training...")