-
Notifications
You must be signed in to change notification settings - Fork 129
Open
Description
Going to preface this with a disclaimer that my understanding of ML models is pretty limited and this is my first foray into implementing a model in Bumblebee.
I've had a go at implementing the TrOCR model from the python transformers library but what I got does not seem to give me the expected results (i.e. the output is just gibberish). Wondering if there are some bumblebee experts that might be able to point me in the right direction because the more I stare at this the more my eyes are turning square.
Here is my attempt at implementing it based on existing implementations of other models in Bumblebee:
defmodule Bumblebee.Text.TrOCR do
alias Bumblebee.Shared
options = [
vocab_size: [
default: 50265,
doc: "vocabulary size of the TrOCR model"
],
hidden_size: [
default: 1024,
doc: "the dimensionality of hidden layers"
],
decoder_num_blocks: [
default: 12,
doc: "the number of Transformer blocks in the decoder"
],
decoder_num_attention_heads: [
default: 16,
doc: "the number of attention heads for each attention layer in the decoder"
],
decoder_intermediate_size: [
default: 4096,
docs:
"the dimensionality of the intermediate layer in the transformer feed-forward network (FFN) in the decoder"
],
activation: [
default: :gelu,
doc: "the activation function"
],
max_positions: [
default: 512,
doc: """
the vocabulary size of the position embedding. This corresponds to the maximum sequence
length that this model can process. Typically this is set to a large value just in case,
such as 512, 1024 or 2048
"""
],
dropout_rate: [
default: 0.1,
doc: "the dropout rate for encoder and decoder"
],
attention_dropout_rate: [
default: 0.0,
doc: "the dropout rate for attention weights"
],
activation_dropout_rate: [
default: 0.0,
doc: "the dropout rate for activations inside fully connected layers"
],
initializer_scale: [
default: 0.02,
doc:
"the standard deviation of the normal initializer used for initializing kernel parameters"
],
layerdrop: [
default: 0.0,
doc: """
the LayerDrop probability for the decoder. See the
[LayerDrop paper](see https://huggingface.co/papers/1909.11556) for more details.
"""
],
use_cache: [
default: true,
doc:
"whether or not the model should return the last key/values attentions (not used by all models)"
],
scale_embedding: [
default: false,
doc: "whether to scale embeddings by dividing by the square root of `:hidden_size`"
],
positions_embedding: [
default: true,
doc: "whether or not to use learned position embeddings. If not, sinusoidal position embeddings will be used."
],
] ++ Shared.common_options([:num_labels, :id_to_label])
@moduledoc """
TrOCR model family.
## Architectures
* `:base` - plain TrOCR without any head on top
## Inputs
* `"input_ids"` - `{batch_size, sequence_length}`
Indices of input sequence tokens in the vocabulary.
## Global layer options
#{Shared.global_layer_options_doc([:output_hidden_states])}
## Configuration
#{Shared.options_doc(options)}
"""
defstruct [architecture: :for_causal_language_modeling] ++ Shared.option_defaults(options)
@behaviour Bumblebee.ModelSpec
@behaviour Bumblebee.Configurable
@behaviour Bumblebee.Text.Generation
import Bumblebee.Utils.Model, only: [join: 2]
alias Bumblebee.Layers
@impl true
def architectures(), do: [:for_causal_language_modeling]
@impl true
def config(spec, opts) do
spec
|> Shared.put_config_attrs(opts)
|> Shared.validate_label_options()
end
@impl true
def input_template(_spec) do
%{
"input_ids" => Nx.template({1, 1}, :u32)
}
end
@impl true
def model(%__MODULE__{architecture: :for_causal_language_modeling} = spec) do
shape = {nil, nil}
hidden_shape = Tuple.insert_at(shape, tuple_size(shape), spec.hidden_size)
decoder_attention_head_mask_shape =
{spec.decoder_num_blocks, spec.decoder_num_attention_heads}
inputs =
Bumblebee.Utils.Model.inputs_to_map([
Axon.input("input_ids", optional: true, shape: shape),
Axon.input("attention_mask", optional: true, shape: shape),
Axon.input("position_ids", optional: true, shape: shape),
Axon.input("attention_head_mask",
optional: true,
shape: decoder_attention_head_mask_shape
),
Axon.input("input_embeddings", optional: true, shape: hidden_shape),
Axon.input("encoder_hidden_state", optional: true, shape: hidden_shape),
Axon.input("encoder_attention_mask", optional: true, shape: shape),
Axon.input("cross_attention_head_mask",
optional: true,
shape: decoder_attention_head_mask_shape
),
Axon.input("cache", optional: true)
])
embeddings =
embedder(inputs["input_ids"], inputs["position_ids"], inputs["input_embeddings"], spec,
name: "embedder"
)
outputs =
decoder(
embeddings,
inputs["attention_mask"],
inputs["attention_head_mask"],
inputs["encoder_hidden_state"],
inputs["encoder_attention_mask"],
inputs["cross_attention_head_mask"],
inputs["cache"],
spec,
name: "decoder"
)
logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head")
Layers.output(%{
logits: logits,
hidden_states: outputs.hidden_states,
attentions: outputs.attentions,
cross_attentions: outputs.cross_attentions,
cache: outputs.cache
})
end
defp embedder(input_ids, position_ids, input_embeddings, spec, opts) do
name = opts[:name]
input_embeddings =
Layers.default input_embeddings do
token_embedding(input_ids, spec, name: join(name, "token_embedding"))
end
position_ids =
Layers.default position_ids do
Layers.default_position_ids(input_embeddings)
end
position_embeddings =
position_embedding(position_ids, spec, name: join(name, "position_embedding"))
Axon.add([input_embeddings, position_embeddings])
|> Axon.layer_norm(epsilon: 1.0e-5, name: join(name, "norm"))
|> Axon.dropout(rate: spec.dropout_rate)
end
defp token_embedding(input_ids, spec, opts) do
name = opts[:name]
input_embeddings =
Axon.embedding(input_ids, spec.vocab_size, spec.hidden_size,
kernel_initializer: kernel_initializer(spec),
name: name
)
if spec.scale_embedding do
Axon.nx(input_embeddings, fn x -> Nx.multiply(x, Nx.sqrt(spec.hidden_size)) end)
else
input_embeddings
end
end
defp position_embedding(position_ids, spec, opts) do
name = opts[:name]
# For TrOCR we need to offset the embeddings
offset = 2
position_ids
|> Axon.add(Axon.constant(Nx.tensor(offset)))
|> Axon.embedding(spec.max_positions + offset, spec.hidden_size, name: name)
end
defp decoder(
hidden_state,
attention_mask,
attention_head_mask,
encoder_hidden_state,
encoder_attention_mask,
cross_attention_head_mask,
cache,
spec,
opts
) do
name = opts[:name]
Layers.Transformer.blocks(hidden_state,
attention_mask: attention_mask,
attention_head_mask: attention_head_mask,
cross_hidden_state: encoder_hidden_state,
cross_attention_mask: encoder_attention_mask,
cross_attention_head_mask: cross_attention_head_mask,
cache: cache,
causal: true,
num_blocks: spec.decoder_num_blocks,
num_attention_heads: spec.decoder_num_attention_heads,
hidden_size: spec.hidden_size,
kernel_initializer: kernel_initializer(spec),
dropout_rate: spec.dropout_rate,
attention_dropout_rate: spec.attention_dropout_rate,
layer_norm: [
epsilon: 1.0e-5
],
ffn: [
intermediate_size: spec.decoder_intermediate_size,
activation: spec.activation
],
name: join(name, "blocks")
)
end
defp language_modeling_head(hidden_state, spec, opts) do
name = opts[:name]
# TODO: Tie lm-head to word embedding as a spec option
Layers.dense_transposed(hidden_state, spec.vocab_size,
kernel_initializer: kernel_initializer(spec),
name: join(name, "output")
)
end
defp kernel_initializer(spec) do
Axon.Initializers.normal(scale: spec.initializer_scale)
end
@impl true
def init_cache(spec, batch_size, max_length, inputs) do
encoder_sequence_length =
if encoder_hidden_state = inputs["encoder_hidden_state"] do
Nx.axis_size(encoder_hidden_state, 1)
end
Layers.Decoder.init_cache(batch_size, max_length,
hidden_size: spec.hidden_size,
decoder_num_attention_heads: spec.decoder_num_attention_heads,
encoder_num_attention_heads: spec.decoder_num_attention_heads,
decoder_num_blocks: spec.decoder_num_blocks,
encoder_sequence_length: encoder_sequence_length
)
end
@impl true
def traverse_cache(_spec, cache, fun) do
Layers.Decoder.traverse_cache(cache, fun)
end
defimpl Bumblebee.HuggingFace.Transformers.Config do
def load(spec, data) do
import Shared.Converters
opts =
convert!(data,
vocab_size: {"vocab_size", number()},
hidden_size: {"d_model", number()},
decoder_num_blocks: {"decoder_layers", number()},
decoder_num_attention_heads: {"decoder_attention_heads", number()},
decoder_intermediate_size: {"decoder_ffn_dim", number()},
activation: {"activation_function", activation()},
max_positions: {"max_position_embeddings", number()},
dropout_rate: {"dropout", number()},
attention_dropout_rate: {"attention_dropout", number()},
activation_dropout_rate: {"activation_dropout", number()},
initializer_scale: {"init_std", number()},
layerdrop: {"decoder_layerdrop", number()},
use_cache: {"use_cache", boolean()},
scale_embedding: {"scale_embedding", boolean()},
position_embeddings: {"use_learned_position_embeddings", boolean()}
) ++ Shared.common_options_from_transformers(data, spec)
@for.config(spec, opts)
end
end
defimpl Bumblebee.HuggingFace.Transformers.Model do
def params_mapping(_spec) do
%{
"embedder.token_embedding" => "model.decoder.embed_tokens",
"embedder.position_embedding" => "model.decoder.embed_positions",
"embedder.norm" => "model.decoder.layernorm_embedding",
"decoder.blocks.{n}.self_attention.query" => "model.decoder.layers.{n}.self_attn.q_proj",
"decoder.blocks.{n}.self_attention.key" => "model.decoder.layers.{n}.self_attn.k_proj",
"decoder.blocks.{n}.self_attention.value" => "model.decoder.layers.{n}.self_attn.v_proj",
"decoder.blocks.{n}.self_attention.output" =>
"model.decoder.layers.{n}.self_attn.out_proj",
"decoder.blocks.{n}.self_attention_norm" =>
"model.decoder.layers.{n}.self_attn_layer_norm",
"decoder.blocks.{n}.cross_attention.query" =>
"model.decoder.layers.{n}.encoder_attn.q_proj",
"decoder.blocks.{n}.cross_attention.key" =>
"model.decoder.layers.{n}.encoder_attn.k_proj",
"decoder.blocks.{n}.cross_attention.value" =>
"model.decoder.layers.{n}.encoder_attn.v_proj",
"decoder.blocks.{n}.cross_attention.output" =>
"model.decoder.layers.{n}.encoder_attn.out_proj",
"decoder.blocks.{n}.cross_attention_norm" =>
"model.decoder.layers.{n}.encoder_attn_layer_norm",
"decoder.blocks.{n}.ffn.intermediate" => "model.decoder.layers.{n}.fc1",
"decoder.blocks.{n}.ffn.output" => "model.decoder.layers.{n}.fc2",
"decoder.blocks.{n}.output_norm" => "model.decoder.layers.{n}.final_layer_norm",
"language_modeling_head.output" => "lm_head",
"language_modeling_head.logits_bias" => %{
"bias" => {[{"model", "final_logits_bias"}], fn [value] -> Nx.squeeze(value) end}
},
}
end
end
endReactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels