Skip to content

Make it possible to load siglip models from local files #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

maxlund
Copy link

@maxlund maxlund commented Apr 5, 2025

Read img and patch size if supplied in model_config arg.

What is the context for the regex parsing of the repo name, the img/patch size isn't always correct in the config.json file I guess? Anyway this small change makes it possible to load a local model while being offline:

local_model_dir_path = "/Users/maxlund/mlx-models/mlx-siglip-large-384"
model, processor = load(
    path_or_hf_repo=local_model_dir_path,
    model_config={"image_size": 384, "patch_size": 16}
)

FWIW the image and patch size seemed to be correct for both mlx-community/siglip-large-patch16-384 and mlx-community/siglip-so400m-patch14-384 via downloaded config.json in the hf repos

Max Lund added 2 commits April 5, 2025 10:17
- read img and patch size if supplied in model_config arg
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if you want to pin this to some version, just thought I'd add it since I can't run the repo without torch installed

@Blaizzy
Copy link
Owner

Blaizzy commented Apr 5, 2025

Hey @maxlund

Thanks for the PR!

The context is that for certain models they don't supply the patch and img size on the config.json, I can only find it in the name.

Besides torch, which I will address today. Are you having trouble with any Siglip model in particular?

@maxlund
Copy link
Author

maxlund commented Apr 5, 2025

Hey no problem, messing around with it now and running into some issues. This seems to work fine and gives me embeddings for both text and images. But I want to extract them in separate steps of my pipeline.

from mlx_embeddings.utils import load, generate
import requests
from PIL import Image

# Load vision model and processor
model, processor = load("mlx-community/siglip-large-patch16-384", {"num_classes": 0})

# Load multiple images
image_urls = [
    "./images/cats.jpg",  # cats
    "./images/desktop_setup.png"  # desktop setup
]
images = [Image.open(requests.get(url, stream=True).raw) if url.startswith("http") else Image.open(url) for url in
          image_urls]

# Text descriptions
texts = ["a photo of cats", "a photo of a desktop setup", "a photo of a person"]

outputs = generate(model, processor, texts=texts, images=images)

This:

outputs = generate(model, processor, texts=texts, images=None)

gives me:

(<class 'AttributeError'>, AttributeError("'SiglipProcessor' object has no attribute 'batch_encode_plus'"), <traceback object at 0x136cc5f00>)

get_text_features and get_image_features using

inputs_text = processor(text=texts, images=None, padding="max_length", return_tensors="pt")
inputs_imgs = processor(text=None, images=images, return_tensors="pt")
input_ids = mx.array(inputs_text.input_ids)
pixel_values = mx.array(inputs_imgs.pixel_values)

but ran into other issues..
just about to have lunch back in a bit and I can give more details

@maxlund
Copy link
Author

maxlund commented Apr 5, 2025

Okay some progress..

import mlx.core as mx
from mlx_embeddings.utils import load, generate
import requests
from PIL import Image

model, processor = load("mlx-community/siglip-large-patch16-384", {"num_classes": 0})
image_urls = [
    "./images/cats.jpg",  # cats
    "./images/desktop_setup.png"  # desktop setup
]
images = [Image.open(requests.get(url, stream=True).raw) if url.startswith("http") else Image.open(url) for url in
          image_urls]

texts = "a sentence"
inputs_text = processor(text=texts, images=None, padding="max_length", return_tensors="pt")
inputs_imgs = processor(text=None, images=images, return_tensors="pt")
input_ids = mx.array(inputs_text.input_ids)
pixel_values = mx.array(inputs_imgs.pixel_values)
print(f"{input_ids.shape=}")
print(f"{pixel_values.shape=}")
try:
    text_embs = model.get_text_features(input_ids=input_ids)
    print(f"{type(text_embs)}")
    print(f"{type(text_embs.shape)}")
    print(text_embs)
except Exception as e:
    print(f"model.get_text_features(input_ids=input_ids) error: {e}")

try:
    img_embs = model.get_image_features(pixel_values=pixel_values)
    print(f"{type(img_embs)}")
    print(f"{type(img_embs.shape)}")
except Exception as e:
    print(f"model.get_image_features(pixel_values=pixel_values) error: {e}")
    
#input_ids.shape=(1, 64)
#pixel_values.shape=(2, 3, 384, 384)
#<class 'mlx.core.array'>
#<class 'tuple'>
#array([[-0.580078, -0.153076, -0.0585327, ..., 0.469727, 0.0390015, 0.192871]], dtype=float16)
#model.get_image_features(pixel_values=pixel_values) error: 'ModelArgs' object has no attribute 'use_return_dict'
    ```
    

@maxlund
Copy link
Author

maxlund commented Apr 5, 2025

 img_embs = model.get_image_features(pixel_values=pixel_values, return_dict=False)

# model.get_image_features(pixel_values=pixel_values) error: [conv] Expect the input channels in the input and weight array to match but got shapes - input: (2,3,384,384) and weight: (1024,16,16,3)

@maxlund
Copy link
Author

maxlund commented Apr 5, 2025

Okay this did the trick I think

    dtype = (
        model.vision_model.vision_model.embeddings.patch_embedding.weight.dtype
    )
    img_embs = model.get_image_features(pixel_values=pixel_values.transpose(0, 2, 3, 1).astype(dtype), return_dict=False)
    print(f"{type(img_embs)=}")
    print(f"{img_embs.shape=}")

# type(img_embs)=<class 'mlx.core.array'>
# img_embs.shape=(2, 1024)

Might be able to get some benchmarks soon if no other road bumps

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants