The implementation code of paper, based on the starter repo provided by CLVR.
- reward-induced representation model
- initialize a virtual environment
conda create --name myenv python=3.8
conda activate myenv
- configure environment dependencies
pip install -r requirements.txt
- re-implement the reward-induced representation learning model
MODEL = Encoder + MLPs + LSTM + rewards_heads(MLPs)
- re-plicate the experiment to show the representation ability of model
EXP = Encoder(reward-induced) + Detached Decoder
- implement PPO to finished the downstream task (agent following target) with reward-induced representation
- build several representation models as baselines to train downstream task with PPO
(cnn|image-scratch|image-reconstruction|image-reconstruction-finetune|reward-prediction|reward-prediction-fintune|oracle)
- train all of the above representation models and verify the better performance of reward-induced model
-
/presentation: the presentation slides of the implementation task -
/re_implement_paper: the notes about the paper -
/scripts: the shell files to run the tasks -
/sprites_datagen: the dataset -
/sprites_env: the data environment -
/src: the src for README -
/tmp: the visualization results during training process -
/weights: the weights of models -
baseline.py: the baseline models for final training -
general_utils.py: general tool function -
model.py: the pre-trained representation learning models -
ppo_train.py: train the whole task with PPO -
ppo.py: PPO implementation -
pre_train.py: pre-train the representation learning models -
README.md: the project info -
requirement.txt: environmental dependencies
- provide a shell file to each task in
/scripts
- pre-train representation models
pretrain_image_recon_decoder.shpretrain_image_recon_model.shpretrain_reward_pred_model.sh
- train the downsteam task with ppo
ppotrain_cnn.shppotrain_image_rec_finetune.shppotrain_image_rec.shppotrain_image_scratch.shppotrain_oracle.shppotrain_reward_pred_finetune.shppotrain_reward_pred.sh
- change the parameters in each shell file like whether using gpus(
gpus_num) or whether using wandb to record(is_use_wandb)








