Skip to content

TrOCR model implementation attempt #450

@TomGrozev

Description

@TomGrozev

Going to preface this with a disclaimer that my understanding of ML models is pretty limited and this is my first foray into implementing a model in Bumblebee.

I've had a go at implementing the TrOCR model from the python transformers library but what I got does not seem to give me the expected results (i.e. the output is just gibberish). Wondering if there are some bumblebee experts that might be able to point me in the right direction because the more I stare at this the more my eyes are turning square.

Here is my attempt at implementing it based on existing implementations of other models in Bumblebee:

defmodule Bumblebee.Text.TrOCR do
  alias Bumblebee.Shared

  options = [
    vocab_size: [
      default: 50265,
      doc: "vocabulary size of the TrOCR model"
    ],
    hidden_size: [
      default: 1024,
      doc: "the dimensionality of hidden layers"
    ],
    decoder_num_blocks: [
      default: 12,
      doc: "the number of Transformer blocks in the decoder"
    ],
    decoder_num_attention_heads: [
      default: 16,
      doc: "the number of attention heads for each attention layer in the decoder"
    ],
    decoder_intermediate_size: [
      default: 4096,
      docs:
        "the dimensionality of the intermediate layer in the transformer feed-forward network (FFN) in the decoder"
    ],
    activation: [
      default: :gelu,
      doc: "the activation function"
    ],
    max_positions: [
      default: 512,
      doc: """
      the vocabulary size of the position embedding. This corresponds to the maximum sequence
      length that this model can process. Typically this is set to a large value just in case,
      such as 512, 1024 or 2048
      """
    ],
    dropout_rate: [
      default: 0.1,
      doc: "the dropout rate for encoder and decoder"
    ],
    attention_dropout_rate: [
      default: 0.0,
      doc: "the dropout rate for attention weights"
    ],
    activation_dropout_rate: [
      default: 0.0,
      doc: "the dropout rate for activations inside fully connected layers"
    ],
    initializer_scale: [
      default: 0.02,
      doc:
        "the standard deviation of the normal initializer used for initializing kernel parameters"
    ],
    layerdrop: [
      default: 0.0,
      doc: """
      the LayerDrop probability for the decoder. See the 
      [LayerDrop paper](see https://huggingface.co/papers/1909.11556) for more details.
      """
    ],
    use_cache: [
      default: true,
      doc:
        "whether or not the model should return the last key/values attentions (not used by all models)"
    ],
    scale_embedding: [
      default: false,
      doc: "whether to scale embeddings by dividing by the square root of `:hidden_size`"
    ],
    positions_embedding: [
      default: true,
      doc: "whether or not to use learned position embeddings. If not, sinusoidal position embeddings will be used."
    ],
  ] ++ Shared.common_options([:num_labels, :id_to_label])

  @moduledoc """
  TrOCR model family.

  ## Architectures

    * `:base` - plain TrOCR without any head on top

  ## Inputs

    * `"input_ids"` - `{batch_size, sequence_length}`

      Indices of input sequence tokens in the vocabulary.

  ## Global layer options

  #{Shared.global_layer_options_doc([:output_hidden_states])}

  ## Configuration

  #{Shared.options_doc(options)}
  """

  defstruct [architecture: :for_causal_language_modeling] ++ Shared.option_defaults(options)

  @behaviour Bumblebee.ModelSpec
  @behaviour Bumblebee.Configurable
  @behaviour Bumblebee.Text.Generation

  import Bumblebee.Utils.Model, only: [join: 2]

  alias Bumblebee.Layers

  @impl true
  def architectures(), do: [:for_causal_language_modeling]

  @impl true
  def config(spec, opts) do
    spec
    |> Shared.put_config_attrs(opts)
    |> Shared.validate_label_options()
  end

  @impl true
  def input_template(_spec) do
    %{
      "input_ids" => Nx.template({1, 1}, :u32)
    }
  end

  @impl true
  def model(%__MODULE__{architecture: :for_causal_language_modeling} = spec) do
    shape = {nil, nil}
    hidden_shape = Tuple.insert_at(shape, tuple_size(shape), spec.hidden_size)

    decoder_attention_head_mask_shape =
      {spec.decoder_num_blocks, spec.decoder_num_attention_heads}

    inputs =
      Bumblebee.Utils.Model.inputs_to_map([
        Axon.input("input_ids", optional: true, shape: shape),
        Axon.input("attention_mask", optional: true, shape: shape),
        Axon.input("position_ids", optional: true, shape: shape),
        Axon.input("attention_head_mask",
          optional: true,
          shape: decoder_attention_head_mask_shape
        ),
        Axon.input("input_embeddings", optional: true, shape: hidden_shape),
        Axon.input("encoder_hidden_state", optional: true, shape: hidden_shape),
        Axon.input("encoder_attention_mask", optional: true, shape: shape),
        Axon.input("cross_attention_head_mask",
          optional: true,
          shape: decoder_attention_head_mask_shape
        ),
        Axon.input("cache", optional: true)
      ])
    
    embeddings =
      embedder(inputs["input_ids"], inputs["position_ids"], inputs["input_embeddings"], spec,
        name: "embedder"
      )

    outputs =
      decoder(
        embeddings,
        inputs["attention_mask"],
        inputs["attention_head_mask"],
        inputs["encoder_hidden_state"],
        inputs["encoder_attention_mask"],
        inputs["cross_attention_head_mask"],
        inputs["cache"],
        spec,
        name: "decoder"
      )

    logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head")

    Layers.output(%{
      logits: logits,
      hidden_states: outputs.hidden_states,
      attentions: outputs.attentions,
      cross_attentions: outputs.cross_attentions,
      cache: outputs.cache
    })
  end

  defp embedder(input_ids, position_ids, input_embeddings, spec, opts) do
    name = opts[:name]

    input_embeddings =
      Layers.default input_embeddings do
        token_embedding(input_ids, spec, name: join(name, "token_embedding"))
      end

    position_ids =
      Layers.default position_ids do
        Layers.default_position_ids(input_embeddings)
      end

    position_embeddings =
      position_embedding(position_ids, spec, name: join(name, "position_embedding"))

    Axon.add([input_embeddings, position_embeddings])
    |> Axon.layer_norm(epsilon: 1.0e-5, name: join(name, "norm"))
    |> Axon.dropout(rate: spec.dropout_rate)
  end

  defp token_embedding(input_ids, spec, opts) do
    name = opts[:name]

    input_embeddings =
      Axon.embedding(input_ids, spec.vocab_size, spec.hidden_size,
        kernel_initializer: kernel_initializer(spec),
        name: name
      )

    if spec.scale_embedding do
      Axon.nx(input_embeddings, fn x -> Nx.multiply(x, Nx.sqrt(spec.hidden_size)) end)
    else
      input_embeddings
    end
  end

  defp position_embedding(position_ids, spec, opts) do
    name = opts[:name]

    # For TrOCR we need to offset the embeddings
    offset = 2

    position_ids
    |> Axon.add(Axon.constant(Nx.tensor(offset)))
    |> Axon.embedding(spec.max_positions + offset, spec.hidden_size, name: name)
  end

  defp decoder(
         hidden_state,
         attention_mask,
         attention_head_mask,
         encoder_hidden_state,
         encoder_attention_mask,
         cross_attention_head_mask,
         cache,
         spec,
         opts
       ) do
    name = opts[:name]

    Layers.Transformer.blocks(hidden_state,
      attention_mask: attention_mask,
      attention_head_mask: attention_head_mask,
      cross_hidden_state: encoder_hidden_state,
      cross_attention_mask: encoder_attention_mask,
      cross_attention_head_mask: cross_attention_head_mask,
      cache: cache,
      causal: true,
      num_blocks: spec.decoder_num_blocks,
      num_attention_heads: spec.decoder_num_attention_heads,
      hidden_size: spec.hidden_size,
      kernel_initializer: kernel_initializer(spec),
      dropout_rate: spec.dropout_rate,
      attention_dropout_rate: spec.attention_dropout_rate,
      layer_norm: [
        epsilon: 1.0e-5
      ],
      ffn: [
        intermediate_size: spec.decoder_intermediate_size,
        activation: spec.activation
      ],
      name: join(name, "blocks")
    )
  end

  defp language_modeling_head(hidden_state, spec, opts) do
    name = opts[:name]

    # TODO: Tie lm-head to word embedding as a spec option
    Layers.dense_transposed(hidden_state, spec.vocab_size,
      kernel_initializer: kernel_initializer(spec),
      name: join(name, "output")
    )
  end

  defp kernel_initializer(spec) do
    Axon.Initializers.normal(scale: spec.initializer_scale)
  end

  @impl true
  def init_cache(spec, batch_size, max_length, inputs) do
    encoder_sequence_length =
      if encoder_hidden_state = inputs["encoder_hidden_state"] do
        Nx.axis_size(encoder_hidden_state, 1)
      end

    Layers.Decoder.init_cache(batch_size, max_length,
      hidden_size: spec.hidden_size,
      decoder_num_attention_heads: spec.decoder_num_attention_heads,
      encoder_num_attention_heads: spec.decoder_num_attention_heads,
      decoder_num_blocks: spec.decoder_num_blocks,
      encoder_sequence_length: encoder_sequence_length
    )
  end

  @impl true
  def traverse_cache(_spec, cache, fun) do
    Layers.Decoder.traverse_cache(cache, fun)
  end

  defimpl Bumblebee.HuggingFace.Transformers.Config do
    def load(spec, data) do
      import Shared.Converters
      
      opts =
        convert!(data,
          vocab_size: {"vocab_size", number()},
          hidden_size: {"d_model", number()},
          decoder_num_blocks: {"decoder_layers", number()},
          decoder_num_attention_heads: {"decoder_attention_heads", number()},
          decoder_intermediate_size: {"decoder_ffn_dim", number()},
          activation: {"activation_function", activation()},
          max_positions: {"max_position_embeddings", number()},
          dropout_rate: {"dropout", number()},
          attention_dropout_rate: {"attention_dropout", number()},
          activation_dropout_rate: {"activation_dropout", number()},
          initializer_scale: {"init_std", number()},
          layerdrop: {"decoder_layerdrop", number()},
          use_cache: {"use_cache", boolean()},
          scale_embedding: {"scale_embedding", boolean()},
          position_embeddings: {"use_learned_position_embeddings", boolean()}
        ) ++ Shared.common_options_from_transformers(data, spec)

      @for.config(spec, opts)
    end
  end

  defimpl Bumblebee.HuggingFace.Transformers.Model do
    def params_mapping(_spec) do
      %{
        "embedder.token_embedding" => "model.decoder.embed_tokens",
        "embedder.position_embedding" => "model.decoder.embed_positions",
        "embedder.norm" => "model.decoder.layernorm_embedding",
        "decoder.blocks.{n}.self_attention.query" => "model.decoder.layers.{n}.self_attn.q_proj",
        "decoder.blocks.{n}.self_attention.key" => "model.decoder.layers.{n}.self_attn.k_proj",
        "decoder.blocks.{n}.self_attention.value" => "model.decoder.layers.{n}.self_attn.v_proj",
        "decoder.blocks.{n}.self_attention.output" =>
          "model.decoder.layers.{n}.self_attn.out_proj",
        "decoder.blocks.{n}.self_attention_norm" =>
          "model.decoder.layers.{n}.self_attn_layer_norm",
        "decoder.blocks.{n}.cross_attention.query" =>
          "model.decoder.layers.{n}.encoder_attn.q_proj",
        "decoder.blocks.{n}.cross_attention.key" =>
          "model.decoder.layers.{n}.encoder_attn.k_proj",
        "decoder.blocks.{n}.cross_attention.value" =>
          "model.decoder.layers.{n}.encoder_attn.v_proj",
        "decoder.blocks.{n}.cross_attention.output" =>
          "model.decoder.layers.{n}.encoder_attn.out_proj",
        "decoder.blocks.{n}.cross_attention_norm" =>
          "model.decoder.layers.{n}.encoder_attn_layer_norm",
        "decoder.blocks.{n}.ffn.intermediate" => "model.decoder.layers.{n}.fc1",
        "decoder.blocks.{n}.ffn.output" => "model.decoder.layers.{n}.fc2",
        "decoder.blocks.{n}.output_norm" => "model.decoder.layers.{n}.final_layer_norm",
        "language_modeling_head.output" => "lm_head",
        "language_modeling_head.logits_bias" => %{
          "bias" => {[{"model", "final_logits_bias"}], fn [value] -> Nx.squeeze(value) end}
        },
      }
    end
  end
end

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions