-
Notifications
You must be signed in to change notification settings - Fork 234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable llava on torchchat #1183
base: main
Are you sure you want to change the base?
Changes from all commits
f52007e
32d969e
9e4350d
72d7b96
dfe37b8
1834696
8ecc2fa
a70d7b5
937e7ed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -7,10 +7,13 @@ | |||||||||
import os | ||||||||||
import re | ||||||||||
import sys | ||||||||||
import glob | ||||||||||
from pathlib import Path | ||||||||||
from typing import Optional | ||||||||||
from typing import Any, Dict, Optional | ||||||||||
|
||||||||||
import torch | ||||||||||
import safetensors.torch | ||||||||||
import shutil | ||||||||||
|
||||||||||
from torchchat.model import TransformerArgs | ||||||||||
|
||||||||||
|
@@ -21,9 +24,176 @@ | |||||||||
|
||||||||||
from torchchat.model import ModelArgs | ||||||||||
|
||||||||||
def remap_llava_checkpoint(llava_ckpt): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Was this written inhouse? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not pretty following your question. |
||||||||||
def _translate_state_dict_for_vision_model(hf_state_dict) -> Dict[str, Any]: | ||||||||||
translated_state_dict = {} | ||||||||||
hf_weight_prefix = "vision_model." | ||||||||||
name_mapping = { | ||||||||||
f"{hf_weight_prefix}embeddings.class_embedding": "encoder.cls_token_embedding.weight", | ||||||||||
f"{hf_weight_prefix}embeddings.position_embedding.weight": "encoder.token_pos_embedding.positional_embedding", | ||||||||||
f"{hf_weight_prefix}embeddings.patch_embedding.weight": "encoder.conv.weight", | ||||||||||
f"{hf_weight_prefix}pre_layrnorm.weight": "encoder.ln_pre.weight", | ||||||||||
f"{hf_weight_prefix}pre_layrnorm.bias": "encoder.ln_pre.bias", | ||||||||||
f"{hf_weight_prefix}post_layernorm.weight": "encoder.ln_post.weight", | ||||||||||
f"{hf_weight_prefix}post_layernorm.bias": "encoder.ln_post.bias", | ||||||||||
} | ||||||||||
patterns = [ | ||||||||||
( | ||||||||||
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.self_attn\.(k|q|v)_proj\.(weight|bias)", | ||||||||||
lambda match: f"encoder.layers.{match.group(1)}.attn.{match.group(2)}_proj.{match.group(3)}", | ||||||||||
), | ||||||||||
( | ||||||||||
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.(weight|bias)", | ||||||||||
lambda match: f"encoder.layers.{match.group(1)}.attn.output_proj.{match.group(2)}", | ||||||||||
), | ||||||||||
( | ||||||||||
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.mlp\.fc(1|2)\.(weight|bias)", | ||||||||||
lambda match: f"encoder.layers.{match.group(1)}.mlp.w{match.group(2)}.{match.group(3)}", | ||||||||||
), | ||||||||||
( | ||||||||||
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.layer_norm1\.(weight|bias)", | ||||||||||
lambda match: f"encoder.layers.{match.group(1)}.sa_norm.{match.group(2)}", | ||||||||||
), | ||||||||||
( | ||||||||||
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.layer_norm2\.(weight|bias)", | ||||||||||
lambda match: f"encoder.layers.{match.group(1)}.mlp_norm.{match.group(2)}", | ||||||||||
), | ||||||||||
] | ||||||||||
for pattern, replacement in patterns: | ||||||||||
for key in list(hf_state_dict.keys()): | ||||||||||
if re.match(pattern, key): | ||||||||||
new_key = re.sub(pattern, replacement, key) | ||||||||||
name_mapping[key] = new_key | ||||||||||
temp_state_dict = {} | ||||||||||
for k, v in hf_state_dict.items(): | ||||||||||
new_k = name_mapping.get(k, k) | ||||||||||
if "in_proj_weight" in new_k or "in_proj_bias" in new_k: | ||||||||||
if new_k not in temp_state_dict: | ||||||||||
temp_state_dict[new_k] = {"q": None, "k": None, "v": None} | ||||||||||
if "q_proj" in k: | ||||||||||
temp_state_dict[new_k]["q"] = v | ||||||||||
elif "k_proj" in k: | ||||||||||
temp_state_dict[new_k]["k"] = v | ||||||||||
elif "v_proj" in k: | ||||||||||
temp_state_dict[new_k]["v"] = v | ||||||||||
else: | ||||||||||
temp_state_dict[new_k] = v | ||||||||||
for k, v in temp_state_dict.items(): | ||||||||||
if isinstance(v, dict): | ||||||||||
translated_state_dict[k] = torch.cat([v["q"], v["k"], v["v"]], dim=0) | ||||||||||
else: | ||||||||||
translated_state_dict[k] = v | ||||||||||
return translated_state_dict | ||||||||||
|
||||||||||
def _translate_state_dict_for_text_model(hf_state_dict) -> Dict[str, Any]: | ||||||||||
key_map = { | ||||||||||
r"model.layers.([0-9]+).self_attn.q_proj.": r"decoder.layers.\1.attention.wq.", | ||||||||||
r"model.layers.([0-9]+).self_attn.k_proj.": r"decoder.layers.\1.attention.wk.", | ||||||||||
r"model.layers.([0-9]+).self_attn.v_proj.": r"decoder.layers.\1.attention.wv.", | ||||||||||
r"model.layers.([0-9]+).self_attn.o_proj.": r"decoder.layers.\1.attention.wo.", | ||||||||||
r"model.layers.([0-9]+).input_layernorm.": r"decoder.layers.\1.attention_norm.", | ||||||||||
r"model.layers.([0-9]+).mlp.gate_proj.": r"decoder.layers.\1.feed_forward.w1.", | ||||||||||
r"model.layers.([0-9]+).mlp.down_proj.": r"decoder.layers.\1.feed_forward.w2.", | ||||||||||
r"model.layers.([0-9]+).mlp.up_proj.": r"decoder.layers.\1.feed_forward.w3.", | ||||||||||
r"model.layers.([0-9]+).post_attention_layernorm.": r"decoder.layers.\1.ffn_norm.", | ||||||||||
r"model.norm.": r"decoder.norm.", | ||||||||||
# r"model.embed_tokens.": r"tok_embeddings.", # load separately | ||||||||||
r"lm_head.": r"decoder.output.", | ||||||||||
} | ||||||||||
new_state_dict = {} | ||||||||||
def get_new_key(old_key: str) -> str: | ||||||||||
for old_pattern, replacement in key_map.items(): | ||||||||||
if (new_key := re.sub(old_pattern, replacement, old_key)) != old_key: | ||||||||||
return new_key | ||||||||||
return old_key | ||||||||||
for old_key in hf_state_dict.keys(): | ||||||||||
new_key = get_new_key(old_key) | ||||||||||
new_state_dict[new_key] = hf_state_dict[old_key] | ||||||||||
return new_state_dict | ||||||||||
|
||||||||||
def _translate_state_dict_for_mm_projector_model(hf_state_dict) -> Dict[str, Any]: | ||||||||||
new_state_dict = {} | ||||||||||
for old_key in hf_state_dict.keys(): | ||||||||||
new_key = "mm_projector." + old_key | ||||||||||
new_state_dict[new_key] = hf_state_dict[old_key] | ||||||||||
return new_state_dict | ||||||||||
|
||||||||||
def split_checkpoint(llava_ckpt): | ||||||||||
language_model_ckpt = {} | ||||||||||
multi_modal_ckpt = {} | ||||||||||
vision_tower_ckpt = {} | ||||||||||
for key, value in llava_ckpt.items(): | ||||||||||
if key.startswith("language_model"): | ||||||||||
language_model_ckpt[key[len("language_model") + 1:]] = value | ||||||||||
elif key.startswith("multi_modal_projector"): | ||||||||||
multi_modal_ckpt[key[len("multi_modal_projector") + 1:]] = value | ||||||||||
elif key.startswith("vision_tower"): | ||||||||||
vision_tower_ckpt[key[len("vision_tower") + 1:]] = value | ||||||||||
return language_model_ckpt, multi_modal_ckpt, vision_tower_ckpt | ||||||||||
language_model_ckpt, multi_modal_ckpt, vision_tower_ckpt = split_checkpoint(llava_ckpt) | ||||||||||
remapped_state_dict = { | ||||||||||
"tok_embeddings.weight": language_model_ckpt.pop("model.embed_tokens.weight"), | ||||||||||
} | ||||||||||
remapped_state_dict.update(_translate_state_dict_for_text_model(language_model_ckpt)) | ||||||||||
remapped_state_dict.update(_translate_state_dict_for_vision_model(vision_tower_ckpt)) | ||||||||||
remapped_state_dict.update(_translate_state_dict_for_mm_projector_model(multi_modal_ckpt)) | ||||||||||
return remapped_state_dict | ||||||||||
|
||||||||||
|
||||||||||
@torch.inference_mode | ||||||||||
def convert_llava_checkpoint( | ||||||||||
*, | ||||||||||
model_dir: Optional[Path] = None, | ||||||||||
) -> None: | ||||||||||
|
||||||||||
""" | ||||||||||
Process safetensor files from a specific directory structure and save the remapped model. | ||||||||||
|
||||||||||
Args: | ||||||||||
model_dir (str): Base directory containing the model subdirectories. | ||||||||||
""" | ||||||||||
|
||||||||||
def _get_llava_files_with_pattern(pattern): | ||||||||||
pattern = os.path.join(model_dir, f"models--llava-hf--llava-1.5-7b-hf/snapshots/*/{pattern}") | ||||||||||
return glob.glob(pattern) | ||||||||||
|
||||||||||
# get all safetensor files in the model directory | ||||||||||
safetensor_files = _get_llava_files_with_pattern("*.safetensors") | ||||||||||
|
||||||||||
if not safetensor_files: | ||||||||||
raise ValueError("No safetensor files found.") | ||||||||||
|
||||||||||
merged_weights = {} | ||||||||||
|
||||||||||
# Merge safetensor files into a whole | ||||||||||
for file in safetensor_files: | ||||||||||
# Load weights from the current file | ||||||||||
part_weights = safetensors.torch.load_file(file) | ||||||||||
|
||||||||||
# Iterate over each weight in the current file | ||||||||||
for key, value in part_weights.items(): | ||||||||||
if key in merged_weights: | ||||||||||
# If the key already exists, concatenate tensors | ||||||||||
merged_weights[key] = torch.cat((merged_weights[key], value), dim=0) | ||||||||||
else: | ||||||||||
# If the key does not exist, add it to the dictionary | ||||||||||
merged_weights[key] = value | ||||||||||
|
||||||||||
# Remap the checkpoint and save it as pth | ||||||||||
remapped_weights = remap_llava_checkpoint(merged_weights) | ||||||||||
model_path = model_dir / "model.pth" | ||||||||||
torch.save(remapped_weights, model_path) | ||||||||||
|
||||||||||
# copy tokenizer | ||||||||||
tokenizer_files = _get_llava_files_with_pattern("tokenizer.model") | ||||||||||
assert len(tokenizer_files) == 1, "Should get only one tokenizer file, but got {}".format(tokenizer_files) | ||||||||||
|
||||||||||
tokenizer_path = model_dir / "tokenizer.model" | ||||||||||
shutil.copy(tokenizer_files[0], tokenizer_path) | ||||||||||
|
||||||||||
|
||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
@torch.inference_mode() | ||||||||||
def convert_hf_checkpoint( | ||||||||||
def convert_text_only_hf_checkpoint( | ||||||||||
*, | ||||||||||
model_dir: Optional[Path] = None, | ||||||||||
model_name: Optional[str] = None, | ||||||||||
|
@@ -132,6 +302,19 @@ def permute(w, n_heads): | |||||||||
os.remove(file) | ||||||||||
|
||||||||||
|
||||||||||
@torch.inference_mode() | ||||||||||
def convert_hf_checkpoint( | ||||||||||
*, | ||||||||||
model_dir: Optional[Path] = None, | ||||||||||
model_name: Optional[str] = None, | ||||||||||
remove_bin_files: bool = False, | ||||||||||
): | ||||||||||
if "llava" in model_name: | ||||||||||
convert_llava_checkpoint(model_dir=model_dir) | ||||||||||
else: | ||||||||||
convert_text_only_hf_checkpoint(model_dir=model_dir, model_name=model_name, remove_bin_files=remove_bin_files) | ||||||||||
|
||||||||||
|
||||||||||
if __name__ == "__main__": | ||||||||||
import argparse | ||||||||||
|
||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -36,6 +36,7 @@ | |||||
from torchchat.model import Model, ModelType | ||||||
from torchchat.utils.build_utils import device_sync, set_precision | ||||||
from torchchat.utils.device_info import get_device_info | ||||||
from torchchat.utils.preprocessors import llava_image_preprocess | ||||||
|
||||||
# torchtune model definition dependencies | ||||||
from torchtune.data import Message | ||||||
|
@@ -357,8 +358,13 @@ def prefill( | |||||
|
||||||
if batch is not None: | ||||||
# TODO: Verify sequential prefill works with multimodal models | ||||||
logits = model(**batch)[:, -1] | ||||||
return tune_sample(logits, 0, 500) | ||||||
logits = model(**batch) | ||||||
if model.config.model_type == ModelType.Llava: | ||||||
context_len, logits = logits[0], logits[1][:, -1] | ||||||
return context_len, tune_sample(logits, 0, 500) | ||||||
else: | ||||||
logits = logits[:, -1] | ||||||
return tune_sample(logits, 0, 500) | ||||||
elif sequential_prefill: | ||||||
for i in range(width): | ||||||
x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1) | ||||||
|
@@ -622,6 +628,13 @@ def generate( | |||||
sequential_prefill=sequential_prefill, | ||||||
**sampling_kwargs, | ||||||
) | ||||||
|
||||||
# For llava with image input, we need to extract next pos id from prefill result | ||||||
if batch and self.model.config.model_type == ModelType.Llava: | ||||||
context_len, next_token = next_token | ||||||
else: | ||||||
context_len, next_token = T, next_token | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
if is_speculative: | ||||||
self.prefill( | ||||||
draft_model, | ||||||
|
@@ -636,7 +649,7 @@ def generate( | |||||
# max_new_tokens <= 2 means we are effectively not calling decode_n_tokens(). | ||||||
callback(next_token.clone().view(-1), done_generating=max_new_tokens <= 2) | ||||||
|
||||||
input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int) | ||||||
input_pos = torch.tensor([start_pos + context_len], device=device, dtype=torch.int) | ||||||
accept_counts = [0] * ( | ||||||
speculate_k + 1 | ||||||
) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long | ||||||
|
@@ -726,27 +739,56 @@ def chat( | |||||
|
||||||
if generator_args.image_prompts is not None: | ||||||
print("Image prompts", generator_args.image_prompts) | ||||||
|
||||||
# Support for just the first image prompt for now | ||||||
images = [Image.open(generator_args.image_prompts[0])] | ||||||
messages = [ | ||||||
Message( | ||||||
role="user", | ||||||
content=[ | ||||||
{"type": "image", "content": images[0]}, | ||||||
{"type": "text", "content": generator_args.prompt}, | ||||||
], | ||||||
eot=True, | ||||||
), | ||||||
Message(role="assistant", content=""), | ||||||
] | ||||||
|
||||||
transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path)) | ||||||
data = transform({"messages": messages}, inference=True) | ||||||
batch = padded_collate([data], self.builder_args.device) | ||||||
batch.pop("mask") | ||||||
encoded = batch["tokens"] | ||||||
assert len(images) == 1, "Only one image prompt is supported for now" | ||||||
|
||||||
#TODO: updated encoded variable for multi-modality models to include image tokens. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain this to me? |
||||||
if self.model.config.model_type == ModelType.Flamingo: | ||||||
messages = [ | ||||||
Message( | ||||||
role="user", | ||||||
content=[ | ||||||
{"type": "image", "content": images[0]}, | ||||||
{"type": "text", "content": generator_args.prompt}, | ||||||
], | ||||||
eot=True, | ||||||
), | ||||||
Message(role="assistant", content=""), | ||||||
] | ||||||
|
||||||
transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path)) | ||||||
data = transform({"messages": messages}, inference=True) | ||||||
batch = padded_collate([data], self.builder_args.device) | ||||||
batch.pop("mask") | ||||||
encoded = batch["tokens"] | ||||||
elif self.model.config.model_type == ModelType.Llava: | ||||||
#TODO: double check the tokenizer. | ||||||
def find_subtensor(tensor, target): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typehints |
||||||
target_len = len(target) | ||||||
for i in range(len(tensor) - target_len + 1): | ||||||
if torch.all(tensor[i:i+target_len] == target): | ||||||
return i | ||||||
return -1 | ||||||
|
||||||
input_ids = self.encode_tokens(generator_args.prompt, bos=True, device=self.builder_args.device) | ||||||
image_token_indices = self.encode_tokens("<image>", device=self.builder_args.device)[1:] | ||||||
index = find_subtensor(input_ids, image_token_indices) | ||||||
|
||||||
if index == -1: | ||||||
raise ValueError("Image token not found in prompt") | ||||||
|
||||||
batch = { | ||||||
"tokens": input_ids[:index].unsqueeze(0), | ||||||
"encoder_input": llava_image_preprocess(images[0], device=self.builder_args.device, dtype=self.builder_args.precision), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I might be misunderstanding batch, but it looks like the batch variable isn't used? Especially encoder_input There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the whole batch will be forwarded into llava model, so everything here is used by llava's forward function |
||||||
"post_tokens": input_ids[index + len(image_token_indices) :].unsqueeze(0), | ||||||
} | ||||||
|
||||||
# can not get actual encoded image feature before model inference; pseudo one | ||||||
pseudo_vision_encoded = torch.zeros(1, 624).to(device=self.builder_args.device, dtype=self.builder_args.precision) | ||||||
encoded = torch.cat([batch["tokens"].view(1, -1), pseudo_vision_encoded, batch["post_tokens"].view(1, -1)], dim=-1).view(-1) | ||||||
|
||||||
else: | ||||||
encoded = self.encode_tokens( | ||||||
generator_args.prompt, bos=True, device=self.builder_args.device | ||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -14,7 +14,7 @@ | |||||||||||||
|
||||||||||||||
import torchvision | ||||||||||||||
|
||||||||||||||
from typing import Any, Callable, Dict, Optional, Union | ||||||||||||||
from typing import Any, Callable, Dict, List, Optional, Union | ||||||||||||||
from collections.abc import Hashable | ||||||||||||||
|
||||||||||||||
import torch | ||||||||||||||
|
@@ -56,7 +56,6 @@ def identity(**kwargs): | |||||||||||||
return list(kwargs.values())[0] | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
|
||||||||||||||
class MultiModalProjector(nn.Module): | ||||||||||||||
def __init__(self, in_channels: int, out_channels: int, act: nn.Module): | ||||||||||||||
super().__init__() | ||||||||||||||
|
@@ -126,7 +125,10 @@ def forward( | |||||||||||||
dtype=torch.int, | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
return self.decoder(decoder_input, input_pos=input_pos) | ||||||||||||||
return decoder_input.shape[1], self.decoder(decoder_input, input_pos=input_pos) | ||||||||||||||
else: | ||||||||||||||
return self.decoder(decoder_input, input_pos=input_pos) | ||||||||||||||
Comment on lines
+128
to
+130
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure I follow why we need to do this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is for returning the context length for prefilling to set up the pos_id for the generation after prefilling. |
||||||||||||||
|
||||||||||||||
|
||||||||||||||
def setup_caches(self, batch_size, max_seq_len) -> None: | ||||||||||||||
self.decoder.setup_caches(batch_size, max_seq_len) | ||||||||||||||
|
@@ -262,6 +264,7 @@ class TransformerArgs: | |||||||||||||
use_tiktoken: bool = False | ||||||||||||||
max_seq_length: int = 8192 | ||||||||||||||
rope_scaling: Optional[Dict[str, Any]] = None | ||||||||||||||
use_hf_rope: bool = False | ||||||||||||||
# For pipeline parallel | ||||||||||||||
n_stages: int = 1 | ||||||||||||||
stage_idx: int = 0 | ||||||||||||||
|
@@ -413,7 +416,6 @@ def __init__( | |||||||||||||
# print(f"dtype on entry {dtype}") | ||||||||||||||
if not dtype: | ||||||||||||||
dtype = get_precision() | ||||||||||||||
# print(f"dtype on get_prec {dtype}") | ||||||||||||||
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) | ||||||||||||||
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) | ||||||||||||||
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) | ||||||||||||||
|
@@ -553,13 +555,16 @@ def reset_caches(self): | |||||||||||||
|
||||||||||||||
|
||||||||||||||
class LlavaModel(Model): | ||||||||||||||
def __init__(self, config: ModelArgs) -> None: | ||||||||||||||
super().__init__(config) | ||||||||||||||
self.text_transformer_args = self.model.decoder.config | ||||||||||||||
|
||||||||||||||
def forward( | ||||||||||||||
self, | ||||||||||||||
tokens: Tensor, | ||||||||||||||
*, | ||||||||||||||
input_pos: Optional[Tensor] = None, | ||||||||||||||
encoder_input: Optional[Dict[str, Tensor]] = None, | ||||||||||||||
post_tokens: Optional[Tensor] = None, | ||||||||||||||
input_pos: Optional[Tensor] = None, | ||||||||||||||
) -> Tensor: | ||||||||||||||
return self.model(tokens, encoder_input=encoder_input, post_tokens=post_tokens, input_pos=input_pos) | ||||||||||||||
|
||||||||||||||
|
@@ -605,6 +610,13 @@ def __init__(self, config: TransformerArgs) -> None: | |||||||||||||
|
||||||||||||||
self.max_batch_size = -1 | ||||||||||||||
self.max_seq_length = -1 | ||||||||||||||
# For supporting sequence parallel (default is off, thus value of 1) | ||||||||||||||
self.seq_parallel_degree = 1 | ||||||||||||||
if config.use_hf_rope: | ||||||||||||||
self.precompute_freqs_cis = hf_precompute_freqs_cis | ||||||||||||||
else: | ||||||||||||||
self.precompute_freqs_cis = precompute_freqs_cis | ||||||||||||||
Comment on lines
+615
to
+618
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
||||||||||||||
|
||||||||||||||
def setup_caches(self, max_batch_size, max_seq_length, cache_lanes: int = 1): | ||||||||||||||
if ( | ||||||||||||||
|
@@ -623,7 +635,7 @@ def setup_caches(self, max_batch_size, max_seq_length, cache_lanes: int = 1): | |||||||||||||
max_batch_size, max_seq_length, cache_lanes=cache_lanes | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
freqs_cis = precompute_freqs_cis( | ||||||||||||||
freqs_cis = self.precompute_freqs_cis( | ||||||||||||||
self.config.dim // self.config.n_heads, | ||||||||||||||
self.config.block_size * 2, | ||||||||||||||
self.config.rope_base, | ||||||||||||||
|
@@ -657,8 +669,10 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int | |||||||||||||
assert self.freqs_cis is not None, "Caches must be initialized first" | ||||||||||||||
mask = self.causal_mask[None, None, input_pos] | ||||||||||||||
freqs_cis = self.freqs_cis[input_pos] | ||||||||||||||
|
||||||||||||||
if self.tok_embeddings: | ||||||||||||||
x = self.tok_embeddings(x) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
for _, layer in self.layers.items(): | ||||||||||||||
x = layer(x, input_pos, freqs_cis, mask, cache_lane=cache_lane) | ||||||||||||||
|
@@ -715,6 +729,10 @@ def __init__(self, config: TransformerArgs): | |||||||||||||
self.n_local_heads = config.n_local_heads | ||||||||||||||
self.dim = config.dim | ||||||||||||||
self._register_load_state_dict_pre_hook(self.load_hook) | ||||||||||||||
if config.use_hf_rope: | ||||||||||||||
self.apply_rotary_emb = hf_apply_rotary_emb | ||||||||||||||
else: | ||||||||||||||
self.apply_rotary_emb = apply_rotary_emb | ||||||||||||||
Comment on lines
+732
to
+735
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
||||||||||||||
def setup_cache(self, max_batch_size, max_seq_length, cache_lanes: int = 1): | ||||||||||||||
n_local_heads = self.n_local_heads | ||||||||||||||
|
@@ -798,8 +816,8 @@ def forward( | |||||||||||||
# -1 = self.n_local_heads | ||||||||||||||
v = v.view(bsz, seqlen, -1, self.head_dim) | ||||||||||||||
|
||||||||||||||
q = apply_rotary_emb(q, freqs_cis) | ||||||||||||||
k = apply_rotary_emb(k, freqs_cis) | ||||||||||||||
q = self.apply_rotary_emb(q, freqs_cis) | ||||||||||||||
k = self.apply_rotary_emb(k, freqs_cis) | ||||||||||||||
|
||||||||||||||
q, k, v = (x.transpose(1, 2) for x in (q, k, v)) | ||||||||||||||
|
||||||||||||||
|
@@ -919,6 +937,58 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: | |||||||||||||
return x_out2.type_as(x) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Can move Mainly for keeping concepts together There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd like to keep the current structure, with all HF rotary embedding functions grouped together and all previous embedding functions in a separate section. |
||||||||||||||
|
||||||||||||||
|
||||||||||||||
|
||||||||||||||
# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L77 | ||||||||||||||
def hf_precompute_freqs_cis(dim: int, end: int, theta: float, dtype=None, **kwargs): | ||||||||||||||
if not dtype: | ||||||||||||||
dtype = get_precision() | ||||||||||||||
|
||||||||||||||
freqs = 1.0 / ( | ||||||||||||||
theta | ||||||||||||||
** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim) | ||||||||||||||
) | ||||||||||||||
# pyre-ignore Undefined attribute [16]: `float` has no attribute `device`. | ||||||||||||||
t = torch.arange(end, device=freqs.device, dtype=torch.int64).type_as( | ||||||||||||||
freqs # pyre-ignore | ||||||||||||||
) | ||||||||||||||
freqs = torch.outer(t, freqs).float() # pyre-ignore | ||||||||||||||
emb = torch.cat((freqs, freqs), dim=-1) | ||||||||||||||
freqs_cos = torch.cos(emb) | ||||||||||||||
freqs_sin = torch.sin(emb) | ||||||||||||||
return torch.stack((freqs_cos, freqs_sin), dim=-1).to(dtype=dtype) | ||||||||||||||
|
||||||||||||||
# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L135 | ||||||||||||||
def rotate_half(x): | ||||||||||||||
"""Rotates half the hidden dims of the input.""" | ||||||||||||||
x1 = x[..., : x.shape[-1] // 2] | ||||||||||||||
x2 = x[..., x.shape[-1] // 2 :] | ||||||||||||||
return torch.cat((-x2, x1), dim=-1) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def hf_apply_rotary_emb(x, freq_cis, unsqueeze_dim=1, **kwargs): | ||||||||||||||
"""Applies Rotary Position Embedding to the query and key tensors. | ||||||||||||||
Args: | ||||||||||||||
q (`torch.Tensor`): The query tensor. | ||||||||||||||
k (`torch.Tensor`): The key tensor. | ||||||||||||||
cos (`torch.Tensor`): The cosine part of the rotary embedding. | ||||||||||||||
sin (`torch.Tensor`): The sine part of the rotary embedding. | ||||||||||||||
unsqueeze_dim (`int`, *optional*, defaults to 1): | ||||||||||||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and | ||||||||||||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note | ||||||||||||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and | ||||||||||||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes | ||||||||||||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have | ||||||||||||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. | ||||||||||||||
Returns: | ||||||||||||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. | ||||||||||||||
Comment on lines
+969
to
+984
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Outdated comment? |
||||||||||||||
""" | ||||||||||||||
cos = freq_cis[..., 0].unsqueeze(unsqueeze_dim) | ||||||||||||||
sin = freq_cis[..., 1].unsqueeze(unsqueeze_dim) | ||||||||||||||
x_out = (x * cos) + (rotate_half(x) * sin) | ||||||||||||||
return x_out.type_as(x) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||||||||||||
# ExecuTorch model components | ||||||||||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import torch | ||
import torchvision as tv | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lintrunner ordering |
||
from torchvision import transforms as tvT | ||
from PIL import Image | ||
import os | ||
|
||
from typing import List | ||
|
||
|
||
def llava_image_preprocess( | ||
image: Image, | ||
*, | ||
target_h: int = 336, | ||
target_w: int = 336, | ||
rescale_factor: float = 0.00392156862745098, | ||
image_mean: List[float] = [0.48145466, 0.4578275, 0.40821073], | ||
image_std: List[float] = [0.26862954, 0.26130258, 0.27577711], | ||
device: torch.device = torch.device("cpu"), | ||
dtype: torch.dtype = torch.bfloat16, | ||
) -> torch.Tensor: | ||
""" | ||
Preprocess an image by resizing it to fit a target height and width, | ||
padding with median RGB value to make a square, scaling, and normalizing. | ||
Args: | ||
img_address (str): Address of the local image file will be forwarded to the model. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Autogen'd comment? |
||
target_h (int): Target height. | ||
target_w (int): Target width. | ||
rescale_factor (float): Rescaling factor. | ||
image_mean (list): Mean values for normalization. | ||
image_std (list): Standard deviation values for normalization. | ||
Returns: | ||
torch.Tensor: Preprocessed image tensor. | ||
Raises: | ||
FileNotFoundError: If the image file does not exist. | ||
ValueError: If the target height or width is not positive. | ||
""" | ||
|
||
# Check if the target height and width are positive | ||
if target_h <= 0 or target_w <= 0: | ||
raise ValueError("Target height and width must be positive") | ||
|
||
# Convert the image to a tensor | ||
img = tvT.functional.pil_to_tensor(image) | ||
|
||
# Calculate the height and width ratios | ||
ratio_h = img.shape[1] / target_h | ||
ratio_w = img.shape[2] / target_w | ||
|
||
# Resize the image to fit in a target_h x target_w canvas | ||
ratio = max(ratio_h, ratio_w) | ||
output_size = (int(img.shape[1] / ratio), int(img.shape[2] / ratio)) | ||
img = tvT.Resize(size=output_size)(img) | ||
|
||
# Pad the image with median RGB value to make a square | ||
l_pad = (target_w - img.shape[2]) // 2 | ||
t_pad = (target_h - img.shape[1]) // 2 | ||
r_pad = -((target_w - img.shape[2]) // -2) | ||
b_pad = -((target_h - img.shape[1]) // -2) | ||
|
||
torch._check(l_pad >= 0) | ||
torch._check(t_pad >= 0) | ||
torch._check(r_pad >= 0) | ||
torch._check(b_pad >= 0) | ||
|
||
# Pad the image | ||
resized = torch.nn.functional.pad( | ||
img, | ||
(l_pad, r_pad, t_pad, b_pad), | ||
) | ||
|
||
# Scale the image | ||
scaled = resized * rescale_factor | ||
|
||
# Normalize the image | ||
normed = tvT.Normalize(image_mean, image_std)(scaled) | ||
|
||
return normed.unsqueeze(0).to(device).to(dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code comment blocks to help us move things around later