$ conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
$ pip install -r requirements.txt
$ wandb login
(This will prompt a link to the wandb auth key. Copy and paste it in the terminal.)
$ wandb sweep translation_sweep.yamlrun the output of the above command which would be of following format:
$ wandb agent <USERNAME/PROJECTNAME/SWEEPID>
$ python main.py
| argument | help |
|---|---|
| --workers | number of data loader workers |
| --batch-size | mini-batch size |
| --learning-rate-weights | base learning rate for weights |
| --learning-rate-biases | base learning rate for biases |
| --weight-decay | weight-decay |
| --lambd | weight on off-diagonal terms |
| --projector | projector MLP |
| --print-freq | print frequency |
| --dmodel | dimension of transformer encoder |
| --nhead | number of heads in transformer |
| --dfeedforward | dimension of feedforward layer in transformer encoder |
| --nlayers | number of layers of transformer encoder |
| --tokenizer | tokenizer |
| --mbert-out-size | Dimenision of mbert output |
| --checkpoint-dir | path to checkpoint directory |