Skip to content

Latest commit

 

History

History
88 lines (54 loc) · 4.14 KB

README.md

File metadata and controls

88 lines (54 loc) · 4.14 KB

Deep Probabilistic Segmentation

PyTorch Lightning Config: Hydra

Description

Development of deep probabilistic binary segmentation models for the segmentation of an object to be tracked in a sequence of monocular RGB images. The aim is not to produce a mask for the current image, as is customary, but a segmentation based on color statistics that can be used on several consecutive images in the sequence (assuming the statistics are almost constant).

More specifically, we designed two solutions:

  • A pixel-wise segmentation model using an MLP ;
  • A line-wise segmentation model using a U-Net.

These two local and lightweight models are conditioned by a global appearance vector predicted by a ResNet model.

Installation

You first need to clone the repository main branch:

git clone https://github.com/TomRavaud/deep_probabilistic_segmentation.git
cd deep_probabilistic_segmentation

Then, you need to install the required dependencies. The project uses conda to manage the dependencies. You can create a new conda environment using the provided environment.yaml file:

conda env create -f environment.yaml

Dataset download

The dataset used to train the model is the Google Scanned Objects split of the MegaPose dataset. You can download the dataset, as well as the 3D models, by running the script data/download_gso_data.sh in the directory in which you want to store the dataset. For example, to download the dataset in the data directory of the repository, run:

cd data
bash download_gso_data.sh

To avoid downloading the whole dataset, this script allows you to specify the number of shards to download. Follow the displayed instructions when running it.

Meshes decimation

It is possible to train a segmentation model without having to render the 3D models, as the MegaPose dataset already provides the segmentation masks. In this case, you can skip this section.

To reduce the resolution of the 3D models, you can use the src/scripts/meshlab/decimate_meshes.py script. This script uses MeshLab from the Python library PyMeshLab and performs a quadric edge collapse decimation on the meshes. It is recommended to use it to make the batch rendering process more efficient, as we are only interested in the projected silhouette of the object in the image. To use the script, make sure you have MeshLab and PyMeshLab installed and run:

python src/scripts/meshlab/decimate_meshes.py --src_object_set_path {path_to_original_models} --dst_object_set_path {path_to_save_decimated_models} --num_faces {target_number_of_faces}

By default, the script will decimate the meshes to 1000 faces.

MobileSAM weights

The model makes use of the MobileSAM pretrained model. You can download the weights from the MobileSAM repository. Once downloaded, set the path to the weights in the configuration file configs/model/default.yaml or as a command line argument.

Contour lines generation

To train our line-wise segmentation model, we need to generate contour lines from the segmentation masks of the objects. To do so, you can run the following script:

python src/scripts/clines_extraction/extract_clines.py

Usage

To activate the conda environment, run:

conda activate deep_probabilistic_segmentation

You can now run scripts located in the src/scripts directory from the root of the repository. For example, to train the model, run:

python src/scripts/train.py

Acknowledgement

Some of the code is borrowed from MegaPose (maintained in happypose) so as to make the dataset handling easier.