This repository contains the results and code for the Multi-Task CrabNet (MT CrabNet). MT CrabNet is a modification of the original Compositionally restricted attention-based Network (CrabNet) model. There are two main modifications:
-
Adding external features (e.g, temperature) as inputs, to one of the layers in the model. There are a total of three different options,
concat_at_input,tile_at_inputandconcat_at_output. This model supports the use of >1 external features. -
Enabling multi-task learning with the use of selector vectors. The selector vectors indicate what target property is being learnt, and works even when there are some missing data. The selector vector can either be added to the start of the attention heads, or one of the layers in the residual block.
The inputs to the models are the compositions and the external features. The outputs are the predicted properties. The model can simultaneously learn more than one property.
We hypothesize that the attention mechanism can implicitly learn composition-specific dopant effects, which traditional composition-based models tend to struggle with. Additionally, multi-task learning of related properties should also improve overall predictive accuracy. To investigate these deductions, we focused on thermoelectric (TE) properties prediction, since doping is one common strategy. The model was trained on an experimental TE dataset, called the Systematically Verified Thermoelectric (sysTEm) Dataset, which is a previous work of ours.
Aside from MT CrabNet, the code for the baseline models (DopNet and Random Forest) used for comparison and the Single-Task Crabnet (ST CrabNet). This model is intended to be used for domains beyond TE materials.
Figure 1: Original CrabNet architecture.
Figure 2: Adding External Features (Absolute Temperature used as an example).
Figure 3: Multi-task CrabNet with selector vectors.
If this repository has been useful for your work, please consider citing the paper and GitHub repo:
Tang LZ, Mohanty T, Baird SG, Ng LWT, Sparks TD. Multi-task Attention for Doped Thermoelectric Properties Prediction. ChemRxiv. 2025; doi:10.26434/chemrxiv-2025-jcgjz
In BibTeX format:
@misc{Tang_MultitaskAttentionDoped_2025,
title = {Multi-Task {{Attention}} for {{Doped Thermoelectric Properties Prediction}}},
author = {Tang, Leng Ze and Mohanty, Trupti and Baird, Sterling G. and Ng, Leonard W. T. and Sparks, Taylor D.},
year = 2025,
month = sep,
publisher = {ChemRxiv},
doi = {10.26434/chemrxiv-2025-jcgjz},
url = {https://chemrxiv.org/engage/chemrxiv/article-details/68b42efda94eede154b776b1},
urldate = {2025-09-03},
archiveprefix = {ChemRxiv},
langid = {english},
keywords = {attention mechanism,doped materials,figure of merit,machine learning,materials informatics,multi-task learning,thermoelectric}
}
You are recommended to clone the entire repository. The local_pkgs folder contains modules for the different models, along with functions for generating figures and formatting data.
conda or venv can be used to install the required dependencies. The CrabNet model that we modified had some known compatibility issues with newer PyTorch versions. Please see the section: Known issues with CrabNet right now, for more information.
environment.yml - contains the dependencies for running a conda environment. This PyTorch version should work for CrabNet. All dependencies are listed in full_dependencies.txt.
requirements.txt - should work with pip install.
-
local_pkgs/
Contains all core Python packages and modules for models, data processing, and utilities.proj_pkg/: Main project code (data handling, preprocessing, ML pipeline, plotting, etc.)dopnet_pkg/: DopNet model code and related utilities.crabnet_pkg/: Modified CrabNet and MT CrabNet code.
-
data/
Contains input datasets and feature files used for training and evaluation.dataset_formatting.ipynbin root contains the code for generating these files from the original data. -
training_models/
Scripts for training models.results/Contains the predictions by the various models.
-
figures/
Output figures found in the work.figures_code/contains the scripts for generating these figures.
The latest CrabNet repository prior to this work had the following issues:
- CrabNet does not work with Python Versions > 3.10 sparks-baird/CrabNet#78
- CrabNet needs
PyTorchversion < 2.0 sparks-baird/CrabNet#70
As we modified the CrabNet code directly, the issues are also carried over to our MT CrabNet and ST CrabNet models.
This repository adapted code from the following works: