A transformer neural network for a gesture keyboard that transduces curves swiped across a keyboard into word candidates
Contribution:
- A new method for constructing swipe point embeddings (SPE) that outperforms existing ones. It leverages a weighted sum of all keyboard key embeddings, resulting in a notable perfomance boost: 0.67% increase in Swipe MRR and 0.73% in accuracy compared to SPE construction methods described in literature
Other highlights:
- Enhanced Inference with Custom Beam Search: a modified beam search is implemented that masks out logits corresponding to impossible (according to dictionary) token continuations given an already generated prefix. It is faster and more accurate than a standard beam search
This repository used to contain my Yandex Cup 2023 solution (7th place), but after many improvements, it has become a standalone project
Try out a live demo with a trained model from the competition through this web app
Note
If the website is not available, you can run the demo yourself by following the instructions in the web app's GitHub repository.
Note
The website may take a minute to load, as it is not yet fully optimized. If you encounter a "Something went wrong" page, try refreshing the page. This usually resolves the issue.
Note
It is not guaranteed that the model used in the demo is up-to-date with the latest improvements in this repository.
There is an Android library that aims to help to integrate models from this repository into android keyboards
The library expects that the model is exported via executorch (as in executorch_investigation branch of this repository)
Access a brief research report here, which includes:
- Overview of existing research
- Description of the developed method for constructing swipe point embeddings
- Comparative analysis and results
For in-depth insights, you can refer to my master's thesis (in Russian)
Install the dependencies:
pip install -r requirements/requirements.txt- The code has been tested with python 3.10, 3.11 and 3.12
To acquire and prepare the Yandex Cup dataset, follow the steps below:
cd src
bash ./data_obtaining_and_preprocessing/obtain_and_prepare_data.shNote
The pipeline takes approximately 6 hours to complete on the tested machine.
If you prefer to skip the lengthy preprocessing steps, you can directly download the preprocessed dataset:
cd src
python ./data_obtaining_and_preprocessing/download_dataset_preprocessed.pyTransducing swipes to a list of words involves multiple components
- SwipeFeatureExtractor instance
- neural network architecture
- swipe point embedder
- subword embedder
- encoder
- decoder
- model weights
- decoding algorithm
A SwipeFeatureExtractor is any callable that takes three integer 1d tensors (x, y, t) representing raw coordinates and time in milliseconds and returns a list of tensors that are inputs of a certain SwipePointEmbedder.
Current implementations of this protocol:
TrajectoryFeatureExtractor: Extracts trajectory features such as x, y, dt and coordinate derivatives.CoordinateFunctionFeatureExtractor: An adapter to make callables that accepttorch.stack(x, y)satisfy theSwipeFeatureExtractorinterface. Example of these coordinate feature extractors:DistanceGetter- for each swipe point returns the distance to the key centersNearestKeyGetter- for each swipe point returns the id of the nearest key centerKeyWeightsGetter- for each swipe point returns the weights (importance) of the key by applying a function to theDistanceGetteroutput
MultiFeatureExtractor: Combines multiple feature extractors into one.
SwipeFeatureExtractors are used as a dataset transformation step to extract relevant features from the raw swipe data before feeding it into the model.
After collating the dataset, the format becomes (packed_model_in, dec_out), where packed_model_in is (encoder_input, decoder_input, swipe_pad_mask, word_pad_mask). packed_model_in is passed to the model via unpacking (model(*packed_model_in)).
encoder_inputis a list of tensors (padded features from aSwipeFeatureExtractor)decoder_inputanddecoder_outputaretokenized_target_word[1:]andtokenized_target_word[:-1]correspondingly.
All current models are instances of model.EncoderDecoderTransformerLike and consist of the following components:
- Swipe point embedder
- Subword token embedder (currently char-level)
- Encoder
- Decoder
A WordGenerator receives the encoded swipe features for a swipe and outputs a sorted list of scored word candidates (list of tuples (word: str, score: float)).
A WordGenerator stores:
- A model (
EncoderDecoderTransformerLike) that processes the encoded swipe features - A subword_tokenizer (
CharLevelTokenizerv2) that converts characters to tokens and vice versa - A logit processor (
LogitProcessor) that manipulates the model's output logits. CurrentlyVocabularyLogitProcessoris used to apply vocabulary-based masking and make it impossible for the model to generate the tokens outside the vocabulary - Hyperparameters specific to a particular word generator
Currently, word generators accept non batched swipe features (process one swipe at a time).
The Dataset class expects a jsonl file with the following structure:
[
{
"word":"на",
"curve":{
"x":[567,567,507, ...],
"y":[66,66,101, ...],
"t":[0,3,24, ...],
"grid_name":"your_keyboard_layout_name"}
},
...
]You also need to add your keyboard layout to grid_name_to_grid.json and provide a tokenizer config (see the example in tokenizers\keyboard\ru.json)
You may want to clean the data from outliers and errors using src\data_obtaining_and_preprocessing\filter_dataset.py
Use train.py with a train config. Example:
python -m src.train --train_config configs/train/train_traj_and_nearest.jsonYou can also use as train_for_kaggle.ipynb jupyter notebook (for example if you want to do the training in kaggle).
You may want to extract model states from checkpoints using the provided ckpt_to_pt.py script.
python -m src.utils.ckpt_to_pt --ckpt-path checkpoints --out-path model_statesword_generation_demo.ipynb serves as an example on how to predict via a trained model.
predict.py is used to obtain word candidates for a whole dataset and pickle them
predict.py usage example:
python src/predict.py --config configs/prediction/prediction_conf__traj_and_nearest.json --num-workers 4python -m src.predict_all_epochs --config configs/prediction/prediction_conf__traj_and_nearest.json --num-workers 4Tip
On some systems you may find that multiprocessing with num_workers > 0 is slower than num_workers = 0. Try both options to see which one works better for you.
python -m src.evaluate --config configs/config_evaluation.jsonPlot metrics from a CSV file obtained during evaluation (evaluate.py):
python -m src.plot_metrics --csv results/evaluation_results.csv --metrics accuracy mmr --output_dir results/plots --colors_config configs/experiment_colors.jsonPlot metrics from TensorBoard logs obtained during training (train.py):
python -m src.plot_tb_metrics --tb_logdir_root lightning_logs --output_dir results/plots/tb --colors_config configs/experiment_colors.jsonA WIP documentation can be found here. It doesn't contain much information yet, will be updated. Please refer to docstrings in the code for now
See refactoring plan

