Skip to content

Commit eb72c6a

Browse files
committed
[Feature] Add LLaST
1 parent 081c8ca commit eb72c6a

File tree

1 file changed

+294
-0
lines changed

1 file changed

+294
-0
lines changed

xtuner/model/llast.py

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from collections import OrderedDict
3+
4+
import torch
5+
import torch.nn as nn
6+
from mmengine.config import Config, ConfigDict
7+
from mmengine.model import BaseModel
8+
from peft import get_peft_model, prepare_model_for_kbit_training
9+
from transformers import PretrainedConfig, PreTrainedModel
10+
from transformers.activations import ACT2FN
11+
12+
from xtuner.dataset.llast import prepare_inputs_labels_for_llast
13+
from xtuner.registry import BUILDER
14+
from .modules import dispatch_modules
15+
from .utils import (LoadWoInit, find_all_linear_names,
16+
get_peft_model_state_dict, guess_load_checkpoint,
17+
make_inputs_require_grad, traverse_dict)
18+
19+
20+
class AudioProjectorConfig(PretrainedConfig):
21+
model_type = 'projector'
22+
_auto_class = 'AutoConfig'
23+
24+
def __init__(
25+
self,
26+
audio_hidden_size=4096,
27+
llm_hidden_size=4096,
28+
depth=2,
29+
hidden_act='gelu',
30+
bias=True,
31+
**kwargs,
32+
):
33+
self.audio_hidden_size = audio_hidden_size
34+
self.llm_hidden_size = llm_hidden_size
35+
self.depth = depth
36+
self.hidden_act = hidden_act
37+
self.bias = bias
38+
super().__init__(**kwargs)
39+
40+
41+
class AudioEncoder(PreTrainedModel):
42+
_auto_class = 'AutoModel'
43+
config_class = AudioProjectorConfig
44+
base_model_prefix = 'model'
45+
supports_gradient_checkpointing = True
46+
47+
def __init__(self, config: AudioProjectorConfig) -> None:
48+
super().__init__(config)
49+
self.gradient_checkpointing = False
50+
print('*' * 30)
51+
print(config.audio_hidden_size, config.llm_hidden_size)
52+
modules = [nn.Linear(config.audio_hidden_size, config.llm_hidden_size)]
53+
for _ in range(1, config.depth):
54+
modules.append(ACT2FN[config.hidden_act])
55+
modules.append(
56+
nn.Linear(
57+
config.llm_hidden_size,
58+
config.llm_hidden_size,
59+
bias=config.bias))
60+
self.model = nn.Sequential(*modules)
61+
62+
def enable_input_require_grads(self):
63+
64+
def make_inputs_require_grad(module, input, output):
65+
output.requires_grad_(True)
66+
67+
self.model.register_forward_hook(make_inputs_require_grad)
68+
69+
def _set_gradient_checkpointing(self, module, value=False):
70+
if isinstance(module, AudioProjectorConfig):
71+
module.gradient_checkpointing = value
72+
73+
def forward(self, x):
74+
if self.gradient_checkpointing and self.training:
75+
layer_outputs = torch.utils.checkpoint.checkpoint(self.model, x)
76+
else:
77+
layer_outputs = self.model(x)
78+
return layer_outputs
79+
80+
81+
class LLaSTModel(BaseModel):
82+
"""Implementation of LLaST.
83+
84+
Acknowledge: LLaVA: Visual Instruction Tuning
85+
(https://llava-vl.github.io/)
86+
"""
87+
88+
def __init__(
89+
self,
90+
llm,
91+
speech_encoder,
92+
freeze_llm=False,
93+
freeze_speech_encoder=False,
94+
speech_select_layer=-1,
95+
pretrained_pth=None,
96+
projector_depth=2,
97+
llm_lora=None,
98+
speech_encoder_lora=None,
99+
use_activation_checkpointing=True,
100+
):
101+
super().__init__()
102+
self.freeze_llm = freeze_llm
103+
self.freeze_speech_encoder = freeze_speech_encoder
104+
with LoadWoInit():
105+
self.llm = self._build_from_cfg_or_module(llm)
106+
self.speech_encoder = self._build_from_cfg_or_module(
107+
speech_encoder)
108+
109+
self.llm.config.use_cache = False
110+
dispatch_modules(self.llm)
111+
112+
projector_config = AudioProjectorConfig(
113+
audio_hidden_size=self.speech_encoder.config.hidden_size,
114+
llm_hidden_size=self.llm.config.hidden_size,
115+
depth=projector_depth)
116+
self.projector = AudioEncoder(projector_config).to(
117+
self.speech_encoder.dtype)
118+
119+
if self.freeze_llm:
120+
self.llm.requires_grad_(False)
121+
if self.freeze_speech_encoder:
122+
self.speech_encoder.requires_grad_(False)
123+
124+
if use_activation_checkpointing:
125+
# For backward compatibility
126+
if hasattr(self.llm, 'enable_input_require_grads'):
127+
self.llm.enable_input_require_grads()
128+
else:
129+
self.llm.get_input_embeddings().register_forward_hook(
130+
make_inputs_require_grad)
131+
if hasattr(self.speech_encoder, 'enable_input_require_grads'):
132+
self.speech_encoder.enable_input_require_grads()
133+
else:
134+
self.speech_encoder.get_input_embeddings(
135+
).register_forward_hook(make_inputs_require_grad)
136+
self.projector.enable_input_require_grads()
137+
138+
# enable gradient (activation) checkpointing for memory efficiency
139+
self.gradient_checkpointing_enable()
140+
141+
self.use_llm_lora = llm_lora is not None
142+
self.use_speech_encoder_lora = speech_encoder_lora is not None
143+
144+
if self.use_llm_lora:
145+
self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing)
146+
if self.use_speech_encoder_lora:
147+
self._prepare_speech_encoder_for_lora(
148+
speech_encoder_lora, use_activation_checkpointing)
149+
150+
if pretrained_pth is not None:
151+
pretrained_state_dict = guess_load_checkpoint(pretrained_pth)
152+
153+
out_str = self.load_state_dict(pretrained_state_dict, strict=False)
154+
assert len(out_str.unexpected_keys) == 0, out_str.unexpected_keys
155+
print(f'Load pretrained weight from {pretrained_pth}')
156+
157+
self.speech_select_layer = speech_select_layer
158+
159+
self._is_init = True
160+
161+
def _parse_lora_config(self, lora_config):
162+
if isinstance(lora_config, dict) or isinstance(
163+
lora_config, Config) or isinstance(lora_config, ConfigDict):
164+
lora_config = BUILDER.build(lora_config)
165+
return lora_config
166+
167+
def gradient_checkpointing_enable(self):
168+
self.activation_checkpointing_enable()
169+
170+
def activation_checkpointing_enable(self):
171+
self.llm.gradient_checkpointing_enable(
172+
gradient_checkpointing_kwargs={'use_reentrant': False})
173+
self.speech_encoder.gradient_checkpointing_enable(
174+
gradient_checkpointing_kwargs={'use_reentrant': False})
175+
self.projector.gradient_checkpointing_enable(
176+
gradient_checkpointing_kwargs={'use_reentrant': False})
177+
178+
def gradient_checkpointing_disable(self):
179+
self.activation_checkpointing_disable()
180+
181+
def activation_checkpointing_disable(self):
182+
self.llm.gradient_checkpointing_disable()
183+
self.speech_encoder.gradient_checkpointing_disable()
184+
self.projector.gradient_checkpointing_disable()
185+
186+
def init_weights(self):
187+
pass
188+
189+
def state_dict(self, *args, **kwargs):
190+
state_dict = super().state_dict(*args, **kwargs)
191+
to_return = OrderedDict()
192+
# Step 1. speech_encoder
193+
if self.use_speech_encoder_lora:
194+
to_return.update(
195+
get_peft_model_state_dict(
196+
self.speech_encoder, state_dict=state_dict))
197+
elif not self.freeze_speech_encoder:
198+
to_return.update({
199+
k: v
200+
for k, v in state_dict.items() if 'speech_encoder.' in k
201+
})
202+
# Step 2. LLM
203+
if self.use_llm_lora:
204+
to_return.update(
205+
get_peft_model_state_dict(self.llm, state_dict=state_dict))
206+
elif not self.freeze_llm:
207+
to_return.update(
208+
{k: v
209+
for k, v in state_dict.items() if 'llm.' in k})
210+
# Step 3. Projector
211+
to_return.update(
212+
{k: v
213+
for k, v in state_dict.items() if 'projector.' in k})
214+
return to_return
215+
216+
def _build_from_cfg_or_module(self, cfg_or_mod):
217+
if isinstance(cfg_or_mod, nn.Module):
218+
return cfg_or_mod
219+
elif isinstance(cfg_or_mod, dict):
220+
traverse_dict(cfg_or_mod)
221+
return BUILDER.build(cfg_or_mod)
222+
else:
223+
raise NotImplementedError
224+
225+
def _prepare_llm_for_lora(self,
226+
lora_config,
227+
use_activation_checkpointing=True):
228+
lora_config = self._parse_lora_config(lora_config)
229+
self.llm = prepare_model_for_kbit_training(
230+
self.llm, use_activation_checkpointing)
231+
if lora_config.target_modules is None:
232+
modules = find_all_linear_names(self.llm)
233+
lora_config.target_modules = modules
234+
self.llm = get_peft_model(self.llm, lora_config)
235+
236+
def _prepare_speech_encoder_for_lora(self,
237+
lora_config,
238+
use_activation_checkpointing=True):
239+
lora_config = self._parse_lora_config(lora_config)
240+
if lora_config.target_modules is None:
241+
modules = find_all_linear_names(self.speech_encoder)
242+
lora_config.target_modules = modules
243+
self.speech_encoder = get_peft_model(self.speech_encoder, lora_config)
244+
245+
def forward(self, data, data_samples=None, mode='loss'):
246+
if 'audio_tokens' in data:
247+
data['audio_tokens'] = data['audio_tokens'].to(
248+
self.speech_encoder.encoder.conv1.weight.dtype)
249+
batch_size = data['audio_tokens'].shape[0]
250+
decoder_input_ids = torch.tensor([
251+
[1] * batch_size
252+
]) * self.speech_encoder.config.decoder_start_token_id
253+
254+
audio_outputs = self.speech_encoder(
255+
data['audio_tokens'],
256+
decoder_input_ids=decoder_input_ids.to(
257+
data['audio_tokens'].device),
258+
output_hidden_states=True).encoder_last_hidden_state
259+
260+
audio_outputs = audio_outputs[:, :max(data['audio_lens']), :]
261+
audio_tokens = self.projector(audio_outputs)
262+
data['audio_tokens'] = audio_tokens
263+
data = prepare_inputs_labels_for_llast(llm=self.llm, **data)
264+
265+
if mode == 'loss':
266+
return self.compute_loss(data, data_samples)
267+
elif mode == 'predict':
268+
return self.predict(data, data_samples)
269+
elif mode == 'tensor':
270+
return self._forward(data, data_samples)
271+
else:
272+
raise NotImplementedError
273+
274+
def _forward(self, data, data_samples=None):
275+
276+
outputs = self.llm(**data)
277+
278+
return outputs
279+
280+
def predict(self, data, data_samples=None):
281+
outputs = self.llm(**data)
282+
logits_dict = [{'logits': logits} for logits in outputs.logits]
283+
return logits_dict
284+
285+
def compute_loss(self, data, data_samples=None):
286+
outputs = self.llm(**data)
287+
loss_dict = {'loss': outputs.loss}
288+
return loss_dict
289+
290+
def __getattr__(self, name: str):
291+
try:
292+
return super().__getattr__(name)
293+
except AttributeError:
294+
return getattr(self.llm, name)

0 commit comments

Comments
 (0)