Skip to content

Unable to parse huggingface models #887

@JSeam2

Description

@JSeam2

Describe the bug

We are presently unable to parse huggingface model and call gen-settings via the optimum-cli flow

Expected behaviors

We should be able to parse a huggingface model directly.

optimum-cli export onnx -m sshleifer/tiny-gpt2 --optimize O1 --device cpu --opset 14 --sequence_length 64 ./tinygpt2

And then run

ezkl gen-settings -M model.onnx

Steps to reproduce the bug

  1. Download a model from huggingface via the cli call, we download the small tiny-gpt2 model which is fairly small (500KB) and should be usable on most devices running ezkl.
optimum-cli export onnx -m sshleifer/tiny-gpt2 --optimize O1 --device cpu --opset 14 --sequence_length 64 ./tinygpt2

  1. Execute
ezkl gen-settings -M model.onnx

[E] [2024-12-10 07:15:32:345, ezkl] - [graph] [tract] Undetermined symbol in expression:sequence_length
  1. Hardcoding the values on onnx directly resulted in more problems suggesting a greater incompatibility with Tract
    Examples of errors
[E] [2024-12-10 07:15:48:138, ezkl] - [graph] [tract] Can not broadcast 128 against 256
[E] [2024-12-10 07:18:30:675, ezkl] - [graph] [tract] Failed analyse for node #96 "/transformer/h.0/attn/Concat_3" InferenceConcat

Script used to perform surgery on onnx

import onnx
from onnx import helper
import numpy as np

def print_tensor_shapes(model, prefix=""):
    print(f"\n{prefix} Tensor shapes:")
    for input in model.graph.input:
        print(f"Input {input.name}: {[dim.dim_value if hasattr(dim, 'dim_value') else dim.dim_param for dim in input.type.tensor_type.shape.dim]}")
    for output in model.graph.output:
        print(f"Output {output.name}: {[dim.dim_value if hasattr(dim, 'dim_value') else dim.dim_param for dim in output.type.tensor_type.shape.dim]}")

def hardcode_sequence_lengths(model_path, past_sequence_length, sequence_length, batch_size, output_path):
    """
    Modify ONNX model to replace both past_sequence_length and sequence_length with fixed values
    
    Args:
        model_path: Path to input ONNX model
        past_sequence_length: Integer value to replace past_sequence_length
        sequence_length: Integer value to replace sequence_length
        batch_size: Integer value for batch size
        output_path: Path to save modified model
    """
    # Load the model
    model = onnx.load(model_path)
    
    # Print original shapes
    print_tensor_shapes(model, "Before modification")

    # Update input shapes
    for input in model.graph.input:
        tensor_type = input.type.tensor_type
        
        # Handle different input types
        if input.name == 'input_ids':
            tensor_type.shape.dim[0].dim_value = batch_size
            tensor_type.shape.dim[1].dim_value = sequence_length

        elif input.name == 'attention_mask':
            tensor_type.shape.dim[0].dim_value = batch_size
            tensor_type.shape.dim[1].dim_value = sequence_length + past_sequence_length

        elif input.name == 'position_ids':
            tensor_type.shape.dim[0].dim_value = batch_size
            tensor_type.shape.dim[1].dim_value = sequence_length

        elif 'past_key_values' in input.name:
            tensor_type.shape.dim[0].dim_value = batch_size
            # dim[1] is num_heads (2)
            tensor_type.shape.dim[2].dim_value = past_sequence_length
            tensor_type.shape.dim[3].dim_value = 64  # head dimension
    
    # Update output shapes
    for output in model.graph.output:
        tensor_type = output.type.tensor_type

        if output.name == 'logits':
            tensor_type.shape.dim[0].dim_value = batch_size
            tensor_type.shape.dim[1].dim_value = sequence_length
            # dim[2] is vocab_size (50257)

        elif 'present' in output.name:
            tensor_type.shape.dim[0].dim_value = batch_size
            # dim[1] is num_heads (2)
            tensor_type.shape.dim[2].dim_value = sequence_length + past_sequence_length
            tensor_type.shape.dim[3].dim_value = 64  # head dimension

    # Print modified shapes
    print_tensor_shapes(model, "After modification")
    
    # Check model validity
    onnx.checker.check_model(model)
    
    # Save the modified model
    onnx.save(model, output_path)

if __name__ == "__main__":
    sequence_length = 128      # Your desired sequence length
    past_sequence_length = 128 # Your desired past sequence length
    batch_size = 1            # Your desired batch size
    hardcode_sequence_lengths(
        "model.onnx",
        past_sequence_length,
        sequence_length,
        batch_size,
        "surgery.onnx"
    )

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions