Skip to content

Add Molmo (7B-D, 7B-O, 70B) #33962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 203 commits into
base: main
Choose a base branch
from
Open

Add Molmo (7B-D, 7B-O, 70B) #33962

wants to merge 203 commits into from

Conversation

molbap
Copy link
Contributor

@molbap molbap commented Oct 4, 2024

What does this PR do?

As mentioned in issue #33710 , this is a draft to add support for Molmo natively in transformers.
It is also using the new modular framework introduced in #33248 .

Molmo has several existing variants:

  • MolmoE, a mixture of experts multimodal model, which is not covered in this PR but will be in a follow-up one.
  • Molmo-7B-D, based on Qwen2 + CLIP.
  • Molmo-7B-O, based on a yet to be released Olmo model, and CLIP.
  • Molmo-70B, a scaled up version.

The last three models share the same modeling, and thus will be covered by this PR.

Relative to the modular framework:

Choose a base model that's as close as possible from the one you're porting.

In my case, I'm using Llava as a reference. The differences I identify at a glance are the 2d pooling,

Figure out the differences.

Some differences will be a complete modification of the original module, in that case, all have to be redefined.

class MolmoMultiModalProjector(LlavaMultiModalProjector):
    def __init__(self, config: MolmoConfig):
        super().__init__()
        self.linear_1 = nn.Linear(
            config.vision_config.hidden_size,
            config.text_config.intermediate_size // 2,
            bias=False,
            )
        self.linear_2 = nn.Linear(
            config.text_config.intermediate_size // 2,
            config.text_config.hidden_size,
            bias=False,
            )
        self.linear_3 = nn.Linear(
            config.vision_config.hidden_size,
            config.text_config.intermediate_size // 2,
            bias=False,
            )
    
    def forward(self, image_features):
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        intermediate_states = self.linear_3(image_features)
        hidden_states = self.linear_2(hidden_states, intermediate_states)
        return hidden_states

Some differences will be very tiny. For instance, some layers might be the same, but initialized with a different configuration key.
For instance, the position embeddings are slightly different.

class MolmoVisionEmbeddings(CLIPVisionEmbeddings):
    def __init__(self, config):
        super().__init__()
        self.position_embedding = nn.Embedding(config.num_image_positions, config.hidden_size)

Preserving inheritance across model components renames.

For instance, the code above will trigger

python utils/modular_model_converter.py --files_to_parse src/transformers/models/molmo/modular_molmo.py  --old_model_name="Llava" --new_model_name="Molmo"

> ValueError: Unable to find dependencies for CLIPVisionEmbeddings in transformers.models.clip.modeling_clip. Here are the dependencies found: {'molmo_loss': {'contrastive_loss'}, 'MOLMOVisionModelOutput': {'ModelOutput'}, 'MOLMOTextModelOutput': {'ModelOutput'}, 'MOLMOOutput': {'Mod
elOutput'}, 'MOLMOVisionEmbeddings': {'nn.Module'},

Because the supported pattern is currently searching for a caps-based model name. However, using modular is very promising and makes for a much smaller modeling file to review.

I'll write down hurdles encountered here for future reference so that adding multimodal models to transformers ends up being a breeze.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow looks super nice! Will finish #33859 asap to let you continue!

@molbap
Copy link
Contributor Author

molbap commented Oct 8, 2024

Still seeing some duplicate imports in the modeling code:

from ...modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
)
from .configuration_molmo import MolmoConfig


if is_flash_attn_2_available():
    from ...modeling_flash_attention_utils import _flash_attention_forward


from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
from ...utils import (
    ModelOutput,
    is_flash_attn_2_available,
    torch_int,
)
from .configuration_molmo import MOLMOConfig, MOLMOVisionConfig

One quick&dirty solution would be to do a pass on the imports once the transformer in modular has finished, so that imports from various modules get merged and normalized to the most likely - but there's also some capitalized (wrong) model names that remain as well, strangely, like MOLMOEncoder where we should get MolmoEncoder

class MolmoVisionTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size
        self.embeddings = MolmoVisionEmbeddings(config)
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
        self.encoder = MOLMOEncoder(config)  #  wut 
        self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=True)

getting there however!

@ArthurZucker
Copy link
Collaborator

Do you need a review? 🤗

@d-rau
Copy link

d-rau commented Oct 15, 2024

Maybe a bit pre-mature but when using the script to convert the model to hf I got missmatch issues here:

q_proj, k_proj, v_proj = torch.split(fused_qkv, fused_dims, 0)

@htahboub
Copy link
Contributor

Hi @molbap, was just wondering if you had an ETA on this? Great work here by the way!

@qubvel qubvel removed request for muellerzr and qubvel May 5, 2025 18:32
@molbap molbap requested a review from ArthurZucker May 12, 2025 16:39
@molbap molbap requested a review from zucchini-nlp June 17, 2025 08:38
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on it! Left a few tiny comments here and there, overall looks good to me

Comment on lines +1344 to +1349
valid_positions = image_token_indices_flat >= 0
valid_indices = image_token_indices_flat[valid_positions].long()
valid_features = image_features_flat[valid_positions.to(image_features_flat.device)]
valid_batch_indices = valid_batch_indices_expanded[
valid_positions.to(valid_batch_indices_expanded.device)
].long()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember anymore why we needed this hehe. Is it possible for us to hide it somewhere in processing so that the model does simple embeds.masked_scatter(ids == image_id, image_features)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha, I don't remember either 😆 I'm sure it's doable yes!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants