-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
In contrib\model\pytorch_hist.py
Description: Currently, the HIST model implementation in qlib lacks an explicit parameter to control the pre-training loading logic. This makes it difficult for users to skip the pre-training weight-loading phase even when a model path is specified or when they wish to train the architecture from scratch.
Proposed Changes:
init method: Added pretrain (defaulting to True for backward compatibility) to the parameters.
Logic Update: Wrapped the weight-loading and state_dict update logic inside if self.pretrain:.
# In __init__
def __init__(self, ..., pretrain=True, **kwargs):
...
self.pretrain = pretrain
# In weight loading logic
if self.pretrain:
if self.model_path is not None:
self.logger.info("Loading pretrained model...")
pretrained_model.load_state_dict(torch.load(self.model_path))
# ... (state_dict update logic)
self.logger.info("Loading pretrained model Done...")
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working