Skip to content

Add transliteration with beam-search #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packaging/PKGBUILD.slimt-git
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Maintainer: jerinphilip<at>live<dot>in
_pkgname=slimt
pkgname=${_pkgname}-git
pkgver=r50.0208079
pkgver=r46.c20abc9
pkgrel=1
epoch=
pkgdesc="Inference frontend for tiny11 models"
Expand Down
3 changes: 2 additions & 1 deletion scripts/t12n.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def load_config(path):
model_nano = Model(nano, package)

data = sys.stdin.read()
responses = service.translate(model_nano, [data], html=False)
lines = data.splitlines()
responses = service.translate(model_nano, lines, html=False)

for response in responses:
print(response.source.text, "->", response.target.text)
4 changes: 4 additions & 0 deletions slimt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ set(SLIMT_PUBLIC_HEADERS
Model.hh
Modules.hh
Response.hh
Search.hh
Shortlist.hh
Splitter.hh
Tensor.hh
TextProcessor.hh
Transformer.hh
Transliterator.hh
Types.hh
Vocabulary.hh
Utils.hh)
Expand Down Expand Up @@ -50,12 +52,14 @@ set(SLIMT_SOURCES
Regex.cc
Request.cc
Response.cc
Search.cc
Shortlist.cc
Splitter.cc
Tensor.cc
TensorOps.cc
TextProcessor.cc
Transformer.cc
Transliterator.cc
Utils.cc
Vocabulary.cc
XHScanner.cc)
Expand Down
9 changes: 7 additions & 2 deletions slimt/Frontend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "slimt/Model.hh"
#include "slimt/Request.hh"
#include "slimt/Response.hh"
#include "slimt/Search.hh"
#include "slimt/TextProcessor.hh"
#include "slimt/Types.hh"
#include "slimt/Utils.hh"
Expand All @@ -43,12 +44,14 @@ void exhaust(const Config &config, const Ptr<Model> &model, Batcher &batcher) {
AverageMeter<float> wps;
AverageMeter<float> occupancy;
Batch batch = batcher.generate();
Greedy greedy(model->transformer(), model->vocabulary(),
model->shortlist_generator());
while (!batch.empty()) {
// convert between batches.
Timer timer;
Input input = convert(batch, model->vocabulary().pad_id(),
config.tgt_length_limit_factor);
Histories histories = model->forward(input);
Histories histories = greedy.generate(input);
batch.complete(histories);
batch = batcher.generate();

Expand Down Expand Up @@ -206,7 +209,9 @@ Async::Async(const Config &config)
// convert between batches.
Input input = convert(batch, model->vocabulary().pad_id(),
config_.tgt_length_limit_factor);
Histories histories = model->forward(input);
Greedy greedy(model->transformer(), model->vocabulary(),
model->shortlist_generator());
Histories histories = greedy.generate(input);
batch.complete(histories);
auto [next_batch, next_model] = batcher_.generate();
batch = std::move(next_batch);
Expand Down
139 changes: 0 additions & 139 deletions slimt/Model.cc
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
#include "slimt/Model.hh"

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "slimt/Aligned.hh"
#include "slimt/Input.hh"
#include "slimt/Io.hh"
#include "slimt/Shortlist.hh"
#include "slimt/Tensor.hh"
#include "slimt/TensorOps.hh"
#include "slimt/Transformer.hh"
#include "slimt/Types.hh"
#include "slimt/Vocabulary.hh"
Expand Down Expand Up @@ -72,137 +64,6 @@ Model::Model(const Config &config, const Package<std::string> &package)
shortlist_generator_(make_shortlist_generator(
view_.shortlist, vocabulary_, vocabulary_)) {}

std::optional<ShortlistGenerator> Model::make_shortlist_generator(
View view, const Vocabulary &source, const Vocabulary &target) {
if (view.data == nullptr || view.size == 0) {
return std::nullopt;
}
return ShortlistGenerator(view, source, target);
}

namespace {
void update_alignment(const std::vector<size_t> &lengths,
const std::vector<bool> &finished, const Tensor &attn,
Alignments &alignments) {
const auto *data = attn.data<float>();
// B x H x 1 (T) x S
size_t batch_size = attn.dim(-4);
size_t num_heads = attn.dim(-3);
size_t slice = attn.dim(-2);
size_t source_length = attn.dim(-1);

// https://github.com/marian-nmt/marian-dev/blob/53b0b0d7c83e71265fee0dd832ab3bcb389c6ec3/src/models/transformer.h#L214-L232
for (size_t id = 0; id < batch_size; id++) {
// Copy the elements into the particular alignment index.
size_t head_id = 0;
if (!finished[id]) {
size_t batch_stride = (num_heads * slice * source_length);
size_t head_stride = (slice * source_length);
const float *alignment = data + id * batch_stride + head_id * head_stride;
size_t length = lengths[id];
Distribution distribution(length);
std::copy(alignment, alignment + length, distribution.data());
alignments[id].push_back(std::move(distribution));
}
}
}
} // namespace

Histories Model::decode(const Tensor &encoder_out, const Input &input) const {
// Prepare a shortlist for the entire input.
size_t batch_size = encoder_out.dim(-3);
size_t source_sequence_length = encoder_out.dim(-2);

std::optional<Words> indices = std::nullopt;
if (shortlist_generator_) {
Shortlist shortlist = shortlist_generator_->generate(input.words());
indices = shortlist.words();
}
// The following can be used to check if shortlist is going wrong.
// std::vector<uint32_t> indices(vocabulary_.size());
// std::iota(indices.begin(), indices.end(), 0);

std::vector<bool> complete(batch_size, false);
uint32_t eos = vocabulary_.eos_id();
auto record = [eos, &complete](Words &step, Sentences &sentences) {
size_t finished = 0;
for (size_t i = 0; i < step.size(); i++) {
if (not complete[i]) {
complete[i] = (step[i] == eos);
sentences[i].push_back(step[i]);
}
finished += static_cast<int>(complete[i]);
}
return sentences.size() - finished;
};

// Initialize a first step.
Sentences sentences(batch_size);
Alignments alignments(sentences.size());

const Decoder &decoder = transformer_.decoder();
Words previous_slice = {};
std::vector<Tensor> states = decoder.start_states(batch_size);
auto [logits, attn] =
decoder.step(encoder_out, input.mask(), states, previous_slice, indices);

if (indices) {
previous_slice =
greedy_sample_from_words(logits, vocabulary_, *indices, batch_size);
} else {
previous_slice = greedy_sample(logits, vocabulary_, batch_size);
}

update_alignment(input.lengths(), complete, attn, alignments);
record(previous_slice, sentences);

size_t remaining = sentences.size();
size_t max_seq_length = input.limit_factor() * source_sequence_length;
for (size_t i = 1; i < max_seq_length && remaining > 0; i++) {
auto [logits, attn] = decoder.step(encoder_out, input.mask(), states,
previous_slice, indices);
if (indices) {
previous_slice =
greedy_sample_from_words(logits, vocabulary_, *indices, batch_size);
} else {
previous_slice = greedy_sample(logits, vocabulary_, batch_size);
}
update_alignment(input.lengths(), complete, attn, alignments);
remaining = record(previous_slice, sentences);
}

Histories histories;
for (size_t i = 0; i < sentences.size(); i++) {
Hypothesis hypothesis{
.target = std::move(sentences[i]), //
.alignment = std::move(alignments[i]) //
};
auto history = std::make_shared<Hypothesis>(std::move(hypothesis));
histories.push_back(std::move(history));
}

return histories;
}

Histories Model::forward(const Input &input) const {
const Tensor &indices = input.indices();
const Tensor &mask = input.mask();

// uint64_t batch_size = indices.dim(-2);
// uint64_t sequence_length = indices.dim(-1);
// uint64_t embed_dim = embedding_.dim(-1);

Tensor word_embedding =
index_select(transformer_.embedding(), indices, "word_embedding");
transform_embedding(word_embedding);

// https://github.com/browsermt/marian-dev/blob/14c9d9b0e732f42674e41ee138571d5a7bf7ad94/src/models/transformer.h#L570
// https://github.com/browsermt/marian-dev/blob/14c9d9b0e732f42674e41ee138571d5a7bf7ad94/src/models/transformer.h#L133
Tensor encoder_out = transformer_.encoder().forward(word_embedding, mask);
Histories histories = decode(encoder_out, input);
return histories;
}

namespace preset {
Model::Config tiny() {
// NOLINTBEGIN
Expand Down
7 changes: 0 additions & 7 deletions slimt/Model.hh
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ class SLIMT_EXPORT Model {
explicit Model(const Config &config, const Package<std::string> &package);
explicit Model(const Config &config, const Package<View> &package);

Histories forward(const Input &input) const;

const Config &config() const { return config_; }
const Vocabulary &vocabulary() const { return vocabulary_; }
const TextProcessor &processor() const { return processor_; }
Expand All @@ -65,11 +63,6 @@ class SLIMT_EXPORT Model {
}

private:
Histories decode(const Tensor &encoder_out, const Input &input) const;

static std::optional<ShortlistGenerator> make_shortlist_generator(
View view, const Vocabulary &source, const Vocabulary &target);

size_t id_;
Config config_;
using Mmap = Package<io::MmapFile>;
Expand Down
Loading
Loading