From 1447bb751c3053893ee323c35681a1d46e9592a1 Mon Sep 17 00:00:00 2001
From: zhaoqibin
Date: Wed, 24 Jan 2024 14:33:15 +0800
Subject: [PATCH] feat: add load_from_weight_dict interface
issue: https://github.com/ModelTC/lightllm/issues/277
---
lightllm/common/basemodel/basemodel.py | 13 +++++++++++--
lightllm/models/llama/model.py | 5 +++--
2 files changed, 14 insertions(+), 4 deletions(-)
diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py
index 61a04a752..d6e9ff98b 100644
--- a/lightllm/common/basemodel/basemodel.py
+++ b/lightllm/common/basemodel/basemodel.py
@@ -75,6 +75,14 @@ def _verify_params(self):
assert self.load_way == "HF", "only support HF format weights"
assert self.config["num_key_value_heads"] % self.world_size_ == 0
return
+
+ def load_weights_from_dict(self, weight_dict):
+ load_hf_weights(
+ "fp16",
+ weight_dir=self.weight_dir_,
+ pre_post_layer=self.pre_post_weight,
+ transformer_layer_list=self.trans_layers_weight,
+ weight_dict=weight_dict)
def _init_weights(self):
self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode)
@@ -88,8 +96,9 @@ def _init_weights(self):
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
weight_dict=self.weight_dict)
- self.pre_post_weight.verify_load()
- [weight.verify_load() for weight in self.trans_layers_weight]
+ if not self.weight_dict == {}:
+ self.pre_post_weight.verify_load()
+ [weight.verify_load() for weight in self.trans_layers_weight]
return
def _init_mem_manager(self):
diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py
index 29583a54c..837a0f9a9 100644
--- a/lightllm/models/llama/model.py
+++ b/lightllm/models/llama/model.py
@@ -95,8 +95,9 @@ def _init_weights(self):
weight_dict=self.weight_dict,
prefix='model.layers.',
num_layer=self.config["n_layer"])
- self.pre_post_weight.verify_load()
- [weight.verify_load() for weight in self.trans_layers_weight]
+ if not self.weight_dict == {}:
+ self.pre_post_weight.verify_load()
+ [weight.verify_load() for weight in self.trans_layers_weight]
return
def _init_to_get_rotary(self, default_base=10000):