Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable llava on torchchat #1183

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 185 additions & 2 deletions torchchat/cli/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
@@ -7,10 +7,13 @@
import os
import re
import sys
import glob
from pathlib import Path
from typing import Optional
from typing import Any, Dict, Optional

import torch
import safetensors.torch
import shutil

from torchchat.model import TransformerArgs

@@ -21,9 +24,176 @@

from torchchat.model import ModelArgs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
Llava Conversion Code
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code comment blocks to help us move things around later

def remap_llava_checkpoint(llava_ckpt):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this written inhouse?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not pretty following your question.
This function is consumed by convert_llava_checkpoint to get remapped checkpoint.
I made this as an individual function to simply the logic

def _translate_state_dict_for_vision_model(hf_state_dict) -> Dict[str, Any]:
translated_state_dict = {}
hf_weight_prefix = "vision_model."
name_mapping = {
f"{hf_weight_prefix}embeddings.class_embedding": "encoder.cls_token_embedding.weight",
f"{hf_weight_prefix}embeddings.position_embedding.weight": "encoder.token_pos_embedding.positional_embedding",
f"{hf_weight_prefix}embeddings.patch_embedding.weight": "encoder.conv.weight",
f"{hf_weight_prefix}pre_layrnorm.weight": "encoder.ln_pre.weight",
f"{hf_weight_prefix}pre_layrnorm.bias": "encoder.ln_pre.bias",
f"{hf_weight_prefix}post_layernorm.weight": "encoder.ln_post.weight",
f"{hf_weight_prefix}post_layernorm.bias": "encoder.ln_post.bias",
}
patterns = [
(
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.self_attn\.(k|q|v)_proj\.(weight|bias)",
lambda match: f"encoder.layers.{match.group(1)}.attn.{match.group(2)}_proj.{match.group(3)}",
),
(
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.(weight|bias)",
lambda match: f"encoder.layers.{match.group(1)}.attn.output_proj.{match.group(2)}",
),
(
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.mlp\.fc(1|2)\.(weight|bias)",
lambda match: f"encoder.layers.{match.group(1)}.mlp.w{match.group(2)}.{match.group(3)}",
),
(
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.layer_norm1\.(weight|bias)",
lambda match: f"encoder.layers.{match.group(1)}.sa_norm.{match.group(2)}",
),
(
rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.layer_norm2\.(weight|bias)",
lambda match: f"encoder.layers.{match.group(1)}.mlp_norm.{match.group(2)}",
),
]
for pattern, replacement in patterns:
for key in list(hf_state_dict.keys()):
if re.match(pattern, key):
new_key = re.sub(pattern, replacement, key)
name_mapping[key] = new_key
temp_state_dict = {}
for k, v in hf_state_dict.items():
new_k = name_mapping.get(k, k)
if "in_proj_weight" in new_k or "in_proj_bias" in new_k:
if new_k not in temp_state_dict:
temp_state_dict[new_k] = {"q": None, "k": None, "v": None}
if "q_proj" in k:
temp_state_dict[new_k]["q"] = v
elif "k_proj" in k:
temp_state_dict[new_k]["k"] = v
elif "v_proj" in k:
temp_state_dict[new_k]["v"] = v
else:
temp_state_dict[new_k] = v
for k, v in temp_state_dict.items():
if isinstance(v, dict):
translated_state_dict[k] = torch.cat([v["q"], v["k"], v["v"]], dim=0)
else:
translated_state_dict[k] = v
return translated_state_dict

def _translate_state_dict_for_text_model(hf_state_dict) -> Dict[str, Any]:
key_map = {
r"model.layers.([0-9]+).self_attn.q_proj.": r"decoder.layers.\1.attention.wq.",
r"model.layers.([0-9]+).self_attn.k_proj.": r"decoder.layers.\1.attention.wk.",
r"model.layers.([0-9]+).self_attn.v_proj.": r"decoder.layers.\1.attention.wv.",
r"model.layers.([0-9]+).self_attn.o_proj.": r"decoder.layers.\1.attention.wo.",
r"model.layers.([0-9]+).input_layernorm.": r"decoder.layers.\1.attention_norm.",
r"model.layers.([0-9]+).mlp.gate_proj.": r"decoder.layers.\1.feed_forward.w1.",
r"model.layers.([0-9]+).mlp.down_proj.": r"decoder.layers.\1.feed_forward.w2.",
r"model.layers.([0-9]+).mlp.up_proj.": r"decoder.layers.\1.feed_forward.w3.",
r"model.layers.([0-9]+).post_attention_layernorm.": r"decoder.layers.\1.ffn_norm.",
r"model.norm.": r"decoder.norm.",
# r"model.embed_tokens.": r"tok_embeddings.", # load separately
r"lm_head.": r"decoder.output.",
}
new_state_dict = {}
def get_new_key(old_key: str) -> str:
for old_pattern, replacement in key_map.items():
if (new_key := re.sub(old_pattern, replacement, old_key)) != old_key:
return new_key
return old_key
for old_key in hf_state_dict.keys():
new_key = get_new_key(old_key)
new_state_dict[new_key] = hf_state_dict[old_key]
return new_state_dict

def _translate_state_dict_for_mm_projector_model(hf_state_dict) -> Dict[str, Any]:
new_state_dict = {}
for old_key in hf_state_dict.keys():
new_key = "mm_projector." + old_key
new_state_dict[new_key] = hf_state_dict[old_key]
return new_state_dict

def split_checkpoint(llava_ckpt):
language_model_ckpt = {}
multi_modal_ckpt = {}
vision_tower_ckpt = {}
for key, value in llava_ckpt.items():
if key.startswith("language_model"):
language_model_ckpt[key[len("language_model") + 1:]] = value
elif key.startswith("multi_modal_projector"):
multi_modal_ckpt[key[len("multi_modal_projector") + 1:]] = value
elif key.startswith("vision_tower"):
vision_tower_ckpt[key[len("vision_tower") + 1:]] = value
return language_model_ckpt, multi_modal_ckpt, vision_tower_ckpt
language_model_ckpt, multi_modal_ckpt, vision_tower_ckpt = split_checkpoint(llava_ckpt)
remapped_state_dict = {
"tok_embeddings.weight": language_model_ckpt.pop("model.embed_tokens.weight"),
}
remapped_state_dict.update(_translate_state_dict_for_text_model(language_model_ckpt))
remapped_state_dict.update(_translate_state_dict_for_vision_model(vision_tower_ckpt))
remapped_state_dict.update(_translate_state_dict_for_mm_projector_model(multi_modal_ckpt))
return remapped_state_dict


@torch.inference_mode
def convert_llava_checkpoint(
*,
model_dir: Optional[Path] = None,
) -> None:

"""
Process safetensor files from a specific directory structure and save the remapped model.

Args:
model_dir (str): Base directory containing the model subdirectories.
"""

def _get_llava_files_with_pattern(pattern):
pattern = os.path.join(model_dir, f"models--llava-hf--llava-1.5-7b-hf/snapshots/*/{pattern}")
return glob.glob(pattern)

# get all safetensor files in the model directory
safetensor_files = _get_llava_files_with_pattern("*.safetensors")

if not safetensor_files:
raise ValueError("No safetensor files found.")

merged_weights = {}

# Merge safetensor files into a whole
for file in safetensor_files:
# Load weights from the current file
part_weights = safetensors.torch.load_file(file)

# Iterate over each weight in the current file
for key, value in part_weights.items():
if key in merged_weights:
# If the key already exists, concatenate tensors
merged_weights[key] = torch.cat((merged_weights[key], value), dim=0)
else:
# If the key does not exist, add it to the dictionary
merged_weights[key] = value

# Remap the checkpoint and save it as pth
remapped_weights = remap_llava_checkpoint(merged_weights)
model_path = model_dir / "model.pth"
torch.save(remapped_weights, model_path)

# copy tokenizer
tokenizer_files = _get_llava_files_with_pattern("tokenizer.model")
assert len(tokenizer_files) == 1, "Should get only one tokenizer file, but got {}".format(tokenizer_files)

tokenizer_path = model_dir / "tokenizer.model"
shutil.copy(tokenizer_files[0], tokenizer_path)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
Text-Only Conversion Code
"""

@torch.inference_mode()
def convert_hf_checkpoint(
def convert_text_only_hf_checkpoint(
*,
model_dir: Optional[Path] = None,
model_name: Optional[str] = None,
@@ -132,6 +302,19 @@ def permute(w, n_heads):
os.remove(file)


@torch.inference_mode()
def convert_hf_checkpoint(
*,
model_dir: Optional[Path] = None,
model_name: Optional[str] = None,
remove_bin_files: bool = False,
):
if "llava" in model_name:
convert_llava_checkpoint(model_dir=model_dir)
else:
convert_text_only_hf_checkpoint(model_dir=model_dir, model_name=model_name, remove_bin_files=remove_bin_files)


if __name__ == "__main__":
import argparse

3 changes: 2 additions & 1 deletion torchchat/cli/download.py
Original file line number Diff line number Diff line change
@@ -33,8 +33,9 @@ def _download_hf_snapshot(
local_dir=artifact_dir,
local_dir_use_symlinks=False,
token=hf_token,
ignore_patterns="*safetensors*",
ignore_patterns=None if "llava" in model_config.name else "*safetensors*",
)

except HTTPError as e:
if e.response.status_code == 401: # Missing HuggingFace CLI login.
print(
82 changes: 62 additions & 20 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
@@ -36,6 +36,7 @@
from torchchat.model import Model, ModelType
from torchchat.utils.build_utils import device_sync, set_precision
from torchchat.utils.device_info import get_device_info
from torchchat.utils.preprocessors import llava_image_preprocess

# torchtune model definition dependencies
from torchtune.data import Message
@@ -357,8 +358,13 @@ def prefill(

if batch is not None:
# TODO: Verify sequential prefill works with multimodal models
logits = model(**batch)[:, -1]
return tune_sample(logits, 0, 500)
logits = model(**batch)
if model.config.model_type == ModelType.Llava:
context_len, logits = logits[0], logits[1][:, -1]
return context_len, tune_sample(logits, 0, 500)
else:
logits = logits[:, -1]
return tune_sample(logits, 0, 500)
elif sequential_prefill:
for i in range(width):
x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1)
@@ -622,6 +628,13 @@ def generate(
sequential_prefill=sequential_prefill,
**sampling_kwargs,
)

# For llava with image input, we need to extract next pos id from prefill result
if batch and self.model.config.model_type == ModelType.Llava:
context_len, next_token = next_token
else:
context_len, next_token = T, next_token
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
context_len, next_token = T, next_token
context_len = T


if is_speculative:
self.prefill(
draft_model,
@@ -636,7 +649,7 @@ def generate(
# max_new_tokens <= 2 means we are effectively not calling decode_n_tokens().
callback(next_token.clone().view(-1), done_generating=max_new_tokens <= 2)

input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int)
input_pos = torch.tensor([start_pos + context_len], device=device, dtype=torch.int)
accept_counts = [0] * (
speculate_k + 1
) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
@@ -726,27 +739,56 @@ def chat(

if generator_args.image_prompts is not None:
print("Image prompts", generator_args.image_prompts)

# Support for just the first image prompt for now
images = [Image.open(generator_args.image_prompts[0])]
messages = [
Message(
role="user",
content=[
{"type": "image", "content": images[0]},
{"type": "text", "content": generator_args.prompt},
],
eot=True,
),
Message(role="assistant", content=""),
]

transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path))
data = transform({"messages": messages}, inference=True)
batch = padded_collate([data], self.builder_args.device)
batch.pop("mask")
encoded = batch["tokens"]
assert len(images) == 1, "Only one image prompt is supported for now"

#TODO: updated encoded variable for multi-modality models to include image tokens.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this to me?

if self.model.config.model_type == ModelType.Flamingo:
messages = [
Message(
role="user",
content=[
{"type": "image", "content": images[0]},
{"type": "text", "content": generator_args.prompt},
],
eot=True,
),
Message(role="assistant", content=""),
]

transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path))
data = transform({"messages": messages}, inference=True)
batch = padded_collate([data], self.builder_args.device)
batch.pop("mask")
encoded = batch["tokens"]
elif self.model.config.model_type == ModelType.Llava:
#TODO: double check the tokenizer.
def find_subtensor(tensor, target):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typehints

target_len = len(target)
for i in range(len(tensor) - target_len + 1):
if torch.all(tensor[i:i+target_len] == target):
return i
return -1

input_ids = self.encode_tokens(generator_args.prompt, bos=True, device=self.builder_args.device)
image_token_indices = self.encode_tokens("<image>", device=self.builder_args.device)[1:]
index = find_subtensor(input_ids, image_token_indices)

if index == -1:
raise ValueError("Image token not found in prompt")

batch = {
"tokens": input_ids[:index].unsqueeze(0),
"encoder_input": llava_image_preprocess(images[0], device=self.builder_args.device, dtype=self.builder_args.precision),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be misunderstanding batch, but it looks like the batch variable isn't used?
Just the entries of batch

Especially encoder_input

Copy link
Contributor Author

@Gasoonjia Gasoonjia Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the whole batch will be forwarded into llava model, so everything here is used by llava's forward function

"post_tokens": input_ids[index + len(image_token_indices) :].unsqueeze(0),
}

# can not get actual encoded image feature before model inference; pseudo one
pseudo_vision_encoded = torch.zeros(1, 624).to(device=self.builder_args.device, dtype=self.builder_args.precision)
encoded = torch.cat([batch["tokens"].view(1, -1), pseudo_vision_encoded, batch["post_tokens"].view(1, -1)], dim=-1).view(-1)

else:
encoded = self.encode_tokens(
generator_args.prompt, bos=True, device=self.builder_args.device
88 changes: 79 additions & 9 deletions torchchat/model.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@

import torchvision

from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
from collections.abc import Hashable

import torch
@@ -56,7 +56,6 @@ def identity(**kwargs):
return list(kwargs.values())[0]



class MultiModalProjector(nn.Module):
def __init__(self, in_channels: int, out_channels: int, act: nn.Module):
super().__init__()
@@ -126,7 +125,10 @@ def forward(
dtype=torch.int,
)

return self.decoder(decoder_input, input_pos=input_pos)
return decoder_input.shape[1], self.decoder(decoder_input, input_pos=input_pos)
else:
return self.decoder(decoder_input, input_pos=input_pos)
Comment on lines +128 to +130
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow why we need to do this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for returning the context length for prefilling to set up the pos_id for the generation after prefilling.
Context is we need to set the correct pos_id for generation after prefilling, so that we need to know how many tokens have been consumed by text decoder during prefilling stage. However, that's impossible to get the number by only counting the number of text tokens forwarded to model like what we do for text-only models since we also insert the image tokens in the middle, and that is generated by image head during the model inference.
Therefore here we asked model to return the size of decoder input, and use that to calculate pos_id for following generation,



def setup_caches(self, batch_size, max_seq_len) -> None:
self.decoder.setup_caches(batch_size, max_seq_len)
@@ -262,6 +264,7 @@ class TransformerArgs:
use_tiktoken: bool = False
max_seq_length: int = 8192
rope_scaling: Optional[Dict[str, Any]] = None
use_hf_rope: bool = False
# For pipeline parallel
n_stages: int = 1
stage_idx: int = 0
@@ -413,7 +416,6 @@ def __init__(
# print(f"dtype on entry {dtype}")
if not dtype:
dtype = get_precision()
# print(f"dtype on get_prec {dtype}")
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
@@ -553,13 +555,16 @@ def reset_caches(self):


class LlavaModel(Model):
def __init__(self, config: ModelArgs) -> None:
super().__init__(config)
self.text_transformer_args = self.model.decoder.config

def forward(
self,
tokens: Tensor,
*,
input_pos: Optional[Tensor] = None,
encoder_input: Optional[Dict[str, Tensor]] = None,
post_tokens: Optional[Tensor] = None,
input_pos: Optional[Tensor] = None,
) -> Tensor:
return self.model(tokens, encoder_input=encoder_input, post_tokens=post_tokens, input_pos=input_pos)

@@ -605,6 +610,13 @@ def __init__(self, config: TransformerArgs) -> None:

self.max_batch_size = -1
self.max_seq_length = -1
# For supporting sequence parallel (default is off, thus value of 1)
self.seq_parallel_degree = 1
if config.use_hf_rope:
self.precompute_freqs_cis = hf_precompute_freqs_cis
else:
self.precompute_freqs_cis = precompute_freqs_cis
Comment on lines +615 to +618
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if config.use_hf_rope:
self.precompute_freqs_cis = hf_precompute_freqs_cis
else:
self.precompute_freqs_cis = precompute_freqs_cis
self.precompute_freqs_cis = hf_precompute_freqs_cis if config.use_hf_rope else precompute_freqs_cis



def setup_caches(self, max_batch_size, max_seq_length, cache_lanes: int = 1):
if (
@@ -623,7 +635,7 @@ def setup_caches(self, max_batch_size, max_seq_length, cache_lanes: int = 1):
max_batch_size, max_seq_length, cache_lanes=cache_lanes
)

freqs_cis = precompute_freqs_cis(
freqs_cis = self.precompute_freqs_cis(
self.config.dim // self.config.n_heads,
self.config.block_size * 2,
self.config.rope_base,
@@ -657,8 +669,10 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int
assert self.freqs_cis is not None, "Caches must be initialized first"
mask = self.causal_mask[None, None, input_pos]
freqs_cis = self.freqs_cis[input_pos]

if self.tok_embeddings:
x = self.tok_embeddings(x)


for _, layer in self.layers.items():
x = layer(x, input_pos, freqs_cis, mask, cache_lane=cache_lane)
@@ -715,6 +729,10 @@ def __init__(self, config: TransformerArgs):
self.n_local_heads = config.n_local_heads
self.dim = config.dim
self._register_load_state_dict_pre_hook(self.load_hook)
if config.use_hf_rope:
self.apply_rotary_emb = hf_apply_rotary_emb
else:
self.apply_rotary_emb = apply_rotary_emb
Comment on lines +732 to +735
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if config.use_hf_rope:
self.apply_rotary_emb = hf_apply_rotary_emb
else:
self.apply_rotary_emb = apply_rotary_emb
self.apply_rotary_emb = hf_apply_rotary_emb if config.use_hf_rope else apply_rotary_emb


def setup_cache(self, max_batch_size, max_seq_length, cache_lanes: int = 1):
n_local_heads = self.n_local_heads
@@ -798,8 +816,8 @@ def forward(
# -1 = self.n_local_heads
v = v.view(bsz, seqlen, -1, self.head_dim)

q = apply_rotary_emb(q, freqs_cis)
k = apply_rotary_emb(k, freqs_cis)
q = self.apply_rotary_emb(q, freqs_cis)
k = self.apply_rotary_emb(k, freqs_cis)

q, k, v = (x.transpose(1, 2) for x in (q, k, v))

@@ -919,6 +937,58 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
return x_out2.type_as(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can move apply_rotary_emb so that it is sequentially after hf_apply_rotary_emb?

Mainly for keeping concepts together

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to keep the current structure, with all HF rotary embedding functions grouped together and all previous embedding functions in a separate section.




# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L77
def hf_precompute_freqs_cis(dim: int, end: int, theta: float, dtype=None, **kwargs):
if not dtype:
dtype = get_precision()

freqs = 1.0 / (
theta
** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)
)
# pyre-ignore Undefined attribute [16]: `float` has no attribute `device`.
t = torch.arange(end, device=freqs.device, dtype=torch.int64).type_as(
freqs # pyre-ignore
)
freqs = torch.outer(t, freqs).float() # pyre-ignore
emb = torch.cat((freqs, freqs), dim=-1)
freqs_cos = torch.cos(emb)
freqs_sin = torch.sin(emb)
return torch.stack((freqs_cos, freqs_sin), dim=-1).to(dtype=dtype)

# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L135
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def hf_apply_rotary_emb(x, freq_cis, unsqueeze_dim=1, **kwargs):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
Comment on lines +969 to +984
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outdated comment?

"""
cos = freq_cis[..., 0].unsqueeze(unsqueeze_dim)
sin = freq_cis[..., 1].unsqueeze(unsqueeze_dim)
x_out = (x * cos) + (rotate_half(x) * sin)
return x_out.type_as(x)


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ExecuTorch model components
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2 changes: 1 addition & 1 deletion torchchat/model_config/model_config.py
Original file line number Diff line number Diff line change
@@ -86,6 +86,6 @@ def resolve_model_config(model: str) -> ModelConfig:
model = model_aliases[model]

if model not in model_configs:
raise ValueError(f"Unknown model '{model}'.")
raise ValueError(f"Unknown model '{model}'. Supported models: {model_configs.keys()}")

return model_configs[model]
6 changes: 6 additions & 0 deletions torchchat/model_config/models.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
{
"llava-hf/llava-1.5-7b-hf": {
"aliases": ["llava-1.5"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "llava-hf/llava-1.5-7b-hf",
"transformer_params_key": "llava-1.5"
},
"meta-llama/Llama-2-7b-hf": {
"aliases": ["llama2-base", "llama2-7b"],
"distribution_channel": "HuggingFaceSnapshot",
4 changes: 2 additions & 2 deletions torchchat/model_params/llava-1.5.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
{
"model_type": "llava",
"use_tiktoken": true,
"encoder": {
"tile_size": 336,
"patch_size": 14,
@@ -20,6 +19,7 @@
"n_heads": 32,
"dim": 4096,
"vocab_size": 32064,
"max_seq_length": 768
"max_seq_length": 768,
"use_hf_rope": true
}
}
80 changes: 80 additions & 0 deletions torchchat/utils/preprocessors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
import torchvision as tv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lintrunner ordering

from torchvision import transforms as tvT
from PIL import Image
import os

from typing import List


def llava_image_preprocess(
image: Image,
*,
target_h: int = 336,
target_w: int = 336,
rescale_factor: float = 0.00392156862745098,
image_mean: List[float] = [0.48145466, 0.4578275, 0.40821073],
image_std: List[float] = [0.26862954, 0.26130258, 0.27577711],
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
"""
Preprocess an image by resizing it to fit a target height and width,
padding with median RGB value to make a square, scaling, and normalizing.
Args:
img_address (str): Address of the local image file will be forwarded to the model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Autogen'd comment?

target_h (int): Target height.
target_w (int): Target width.
rescale_factor (float): Rescaling factor.
image_mean (list): Mean values for normalization.
image_std (list): Standard deviation values for normalization.
Returns:
torch.Tensor: Preprocessed image tensor.
Raises:
FileNotFoundError: If the image file does not exist.
ValueError: If the target height or width is not positive.
"""

# Check if the target height and width are positive
if target_h <= 0 or target_w <= 0:
raise ValueError("Target height and width must be positive")

# Convert the image to a tensor
img = tvT.functional.pil_to_tensor(image)

# Calculate the height and width ratios
ratio_h = img.shape[1] / target_h
ratio_w = img.shape[2] / target_w

# Resize the image to fit in a target_h x target_w canvas
ratio = max(ratio_h, ratio_w)
output_size = (int(img.shape[1] / ratio), int(img.shape[2] / ratio))
img = tvT.Resize(size=output_size)(img)

# Pad the image with median RGB value to make a square
l_pad = (target_w - img.shape[2]) // 2
t_pad = (target_h - img.shape[1]) // 2
r_pad = -((target_w - img.shape[2]) // -2)
b_pad = -((target_h - img.shape[1]) // -2)

torch._check(l_pad >= 0)
torch._check(t_pad >= 0)
torch._check(r_pad >= 0)
torch._check(b_pad >= 0)

# Pad the image
resized = torch.nn.functional.pad(
img,
(l_pad, r_pad, t_pad, b_pad),
)

# Scale the image
scaled = resized * rescale_factor

# Normalize the image
normed = tvT.Normalize(image_mean, image_std)(scaled)

return normed.unsqueeze(0).to(device).to(dtype)