Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
*.json
__pycache__
glove.p

outputs/
build/
113 changes: 112 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,118 @@

This repository contains the implementation of the paper **Language-Assisted 3D Feature Learning for Semantic Scene Understanding**.

Code is coming soon.
![paper](./docs/paper.png)
## Setup

The code was developed and tested on Ubuntu 18.04, with PyTorch 1.6.0 CUDA 10.2 installed. Please execute the following command to install PyTorch:

```shell
conda create -n lang-3d python=3.8
conda activate lang-3d
conda install pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=10.2 -c pytorch
```

Install the necessary packages listed out in `requirements.txt`:
```shell
pip install -r requirements.txt
```
After all packages are properly installed, please run the following commands to compile the CUDA modules for the PointNet++ backbone:
```shell
cd lib/pointnet2
python setup.py install
```

__Before moving on to the next step, please don't forget to set the project root path to the `CONF.PATH.BASE` in `lib/config.py`.__

## Dataset

### ScanRefer Dataset

We use the data generated by ScanRefer codebase. You can follow the [Data preparation](https://github.com/daveredrum/ScanRefer#data-preparation) guide to preprocess data. Then, put it under the `data` folder.

### Language Parser

Please follow follow the [README](./language_parser/README.md) under the `language_parser` folder. Copy the generated files `ScanRefer_filtered_train_parser.json` and `ScanRefer_filtered_val_parser.json` under the `data` folder.

For convenience, we release language parser results `ScanRefer_filtered_train_parser.json` and `ScanRefer_filtered_val_parser.json` in [Release Pages](https://github.com/Asterisci/Language-Assisted-3D/releases).


### Directory Structure
Finally, the dataset files should be organized as follows.

```
data
├── glove.p
├── scannet
│ ├── batch_load_scannet_data.py
│ ├── load_scannet_data.py
│ ├── meta_data/
│ ├── model_util_scannet.py
│ ├── README.md
│ ├── scannet_data/
│ ├── scannet_utils.py
│ ├── scans/
│ └── visualize.py
├── ScanRefer_filtered_train_parser.json
├── ScanRefer_filtered_val_parser.json
```
## Usage
### Training

To train the ScanRefer model for detection with RGB values:
```shell
python -u -m torch.distributed.launch --nproc_per_node=8 scripts/train.py --use_color --relation_prediction --color_prediction --size_prediction --shape_prediction --no_reference --batch_size 12 --val_step 1 --lr 8e-3 --epoch 60
```

To train the ScanRefer model for detection with multiview values:
```shell
python -u -m torch.distributed.launch --nproc_per_node=8 scripts/train.py --use_multiview --use_normal --relation_prediction --color_prediction --size_prediction --shape_prediction --no_reference --batch_size 12 --val_step 1 --lr 8e-3 --epoch 60
```

To train the ScanRefer model for visual grounding with RGB values:
```shell
python -u -m torch.distributed.launch --nproc_per_node=8 scripts/train.py --use_color --relation_prediction --color_prediction --size_prediction --shape_prediction --batch_size 12 --val_step 1 --lr 8e-3 --epoch 60
```

To train the ScanRefer model for detection with multiview values:
```shell
python -u -m torch.distributed.launch --nproc_per_node=8 scripts/train.py --use_multiview --use_normal --relation_prediction --color_prediction --size_prediction --shape_prediction --batch_size 12 --val_step 1 --lr 8e-3 --epoch 60
```

### Evaluation
To evaluate the trained ScanRefer models for detection, please find the folder under `outputs/` with the current timestamp and run:
```shell
python scripts/eval.py --folder <folder_name> --detection --use_color --no_nms --force --repeat 5
```

To evaluate the trained ScanRefer models for visual grounding, please find the folder under `outputs/` with the current timestamp and run:
```shell
python scripts/eval.py --folder <folder_name> --reference --use_color --no_nms --force --repeat 5
```

### Visualization
To predict the localization results predicted by the trained ScanRefer model in a specific scene, please find the corresponding folder under `outputs/` with the current timestamp and run:
```shell
python scripts/visualize.py --folder <folder_name> --scene_id <scene_id> --use_color
```
Note that the flags must match the ones set before training. The training information is stored in `outputs/<folder_name>/info.json`. The output `.ply` files will be stored under `outputs/<folder_name>/vis/<scene_id>/`

## Cite

If you find our work helpful for your research. Please consider citing our paper.

```
@inproceedings{zhang2022language,
title={Language-Assisted 3D Feature Learning for Semantic Scene Understanding},
author={Zhang, Junbo and Fan, Guofan and Wang, Guanghan and Su, Zhengyuan and Ma, Kaisheng and Yi, Li},
booktitle={AAAI},
year={2023}
}
```

## Acknowledgement

Our code is based on [SceneGraphParser](https://github.com/vacancy/SceneGraphParser) and [ScanRefer](https://github.com/daveredrum/ScanRefer). Thanks to all.

## License

Expand Down
11 changes: 11 additions & 0 deletions data/scannet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# ScanNet Instructions

To acquire the access to ScanNet dataset, Please refer to the [ScanNet project page](https://github.com/ScanNet/ScanNet) and follow the instructions there. You will get a `download-scannet.py` script after your request for the ScanNet dataset is approved. Note that only a subset of ScanNet is needed. Once you get `download-scannet.py`, please use the commands below to download the portion of ScanNet that is necessary for ScanRefer:

```shell
python2 download-scannet.py -o data/scannet --type _vh_clean_2.ply
python2 download-scannet.py -o data/scannet --type .aggregation.json
python2 download-scannet.py -o data/scannet --type _vh_clean_2.0.010000.segs.json
python2 download-scannet.py -o data/scannet --type .txt
```
Roughly 10.6GB free space is needed on your disk.
84 changes: 84 additions & 0 deletions data/scannet/batch_load_scannet_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
Modified from: https://github.com/facebookresearch/votenet/blob/master/scannet/batch_load_scannet_data.py

Batch mode in loading Scannet scenes with vertices and ground truth labels for semantic and instance segmentations

Usage example: python ./batch_load_scannet_data.py
"""

import os
import sys
import datetime
import numpy as np
from load_scannet_data import export
import pdb

SCANNET_DIR = 'scans'
SCAN_NAMES = sorted([line.rstrip() for line in open('meta_data/scannetv2.txt')])
LABEL_MAP_FILE = 'meta_data/scannetv2-labels.combined.tsv'
DONOTCARE_CLASS_IDS = np.array([])
OBJ_CLASS_IDS = np.array([3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]) # exclude wall (1), floor (2), ceiling (22)
MAX_NUM_POINT = 50000
OUTPUT_FOLDER = './scannet_data'

def export_one_scan(scan_name, output_filename_prefix):
mesh_file = os.path.join(SCANNET_DIR, scan_name, scan_name + '_vh_clean_2.ply')
agg_file = os.path.join(SCANNET_DIR, scan_name, scan_name + '.aggregation.json')
seg_file = os.path.join(SCANNET_DIR, scan_name, scan_name + '_vh_clean_2.0.010000.segs.json')
meta_file = os.path.join(SCANNET_DIR, scan_name, scan_name + '.txt') # includes axisAlignment info for the train set scans.
mesh_vertices, aligned_vertices, semantic_labels, instance_labels, instance_bboxes, aligned_instance_bboxes = export(mesh_file, agg_file, seg_file, meta_file, LABEL_MAP_FILE, None)

mask = np.logical_not(np.in1d(semantic_labels, DONOTCARE_CLASS_IDS))
mesh_vertices = mesh_vertices[mask,:]
aligned_vertices = aligned_vertices[mask,:]
semantic_labels = semantic_labels[mask]
instance_labels = instance_labels[mask]

if instance_bboxes.shape[0] > 1:
num_instances = len(np.unique(instance_labels))
print('Num of instances: ', num_instances)

# bbox_mask = np.in1d(instance_bboxes[:,-1], OBJ_CLASS_IDS)
bbox_mask = np.in1d(instance_bboxes[:,-2], OBJ_CLASS_IDS) # match the mesh2cap
instance_bboxes = instance_bboxes[bbox_mask,:]
aligned_instance_bboxes = aligned_instance_bboxes[bbox_mask,:]
print('Num of care instances: ', instance_bboxes.shape[0])
else:
print("No semantic/instance annotation for test scenes")

N = mesh_vertices.shape[0]
if N > MAX_NUM_POINT:
choices = np.random.choice(N, MAX_NUM_POINT, replace=False)
mesh_vertices = mesh_vertices[choices, :]
aligned_vertices = aligned_vertices[choices, :]
semantic_labels = semantic_labels[choices]
instance_labels = instance_labels[choices]

print("Shape of points: {}".format(mesh_vertices.shape))

np.save(output_filename_prefix+'_vert.npy', mesh_vertices)
np.save(output_filename_prefix+'_aligned_vert.npy', aligned_vertices)
np.save(output_filename_prefix+'_sem_label.npy', semantic_labels)
np.save(output_filename_prefix+'_ins_label.npy', instance_labels)
np.save(output_filename_prefix+'_bbox.npy', instance_bboxes)
np.save(output_filename_prefix+'_aligned_bbox.npy', aligned_instance_bboxes)

def batch_export():
if not os.path.exists(OUTPUT_FOLDER):
print('Creating new data folder: {}'.format(OUTPUT_FOLDER))
os.mkdir(OUTPUT_FOLDER)

for scan_name in SCAN_NAMES:
output_filename_prefix = os.path.join(OUTPUT_FOLDER, scan_name)
# if os.path.exists(output_filename_prefix + '_vert.npy'): continue

print('-'*20+'begin')
print(datetime.datetime.now())
print(scan_name)

export_one_scan(scan_name, output_filename_prefix)

print('-'*20+'done')

if __name__=='__main__':
batch_export()
171 changes: 171 additions & 0 deletions data/scannet/load_scannet_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""
Modified from: https://github.com/facebookresearch/votenet/blob/master/scannet/load_scannet_data.py

Load Scannet scenes with vertices and ground truth labels for semantic and instance segmentations
"""

# python imports
import math
import os, sys, argparse
import inspect
import json
import pdb
import numpy as np
import scannet_utils

def read_aggregation(filename):
object_id_to_segs = {}
label_to_segs = {}
with open(filename) as f:
data = json.load(f)
num_objects = len(data['segGroups'])
for i in range(num_objects):
object_id = data['segGroups'][i]['objectId'] + 1 # instance ids should be 1-indexed
label = data['segGroups'][i]['label']
segs = data['segGroups'][i]['segments']
object_id_to_segs[object_id] = segs
if label in label_to_segs:
label_to_segs[label].extend(segs)
else:
label_to_segs[label] = segs
return object_id_to_segs, label_to_segs


def read_segmentation(filename):
seg_to_verts = {}
with open(filename) as f:
data = json.load(f)
num_verts = len(data['segIndices'])
for i in range(num_verts):
seg_id = data['segIndices'][i]
if seg_id in seg_to_verts:
seg_to_verts[seg_id].append(i)
else:
seg_to_verts[seg_id] = [i]
return seg_to_verts, num_verts


def export(mesh_file, agg_file, seg_file, meta_file, label_map_file, output_file=None):
""" points are XYZ RGB (RGB in 0-255),
semantic label as nyu40 ids,
instance label as 1-#instance,
box as (cx,cy,cz,dx,dy,dz,semantic_label)
"""
label_map = scannet_utils.read_label_mapping(label_map_file, label_from='raw_category', label_to='nyu40id')
# mesh_vertices = scannet_utils.read_mesh_vertices_rgb(mesh_file)
mesh_vertices = scannet_utils.read_mesh_vertices_rgb_normal(mesh_file)

# Load scene axis alignment matrix
lines = open(meta_file).readlines()
axis_align_matrix = None
for line in lines:
if 'axisAlignment' in line:
axis_align_matrix = [float(x) for x in line.rstrip().strip('axisAlignment = ').split(' ')]

if axis_align_matrix != None:
axis_align_matrix = np.array(axis_align_matrix).reshape((4,4))
pts = np.ones((mesh_vertices.shape[0], 4))
pts[:,0:3] = mesh_vertices[:,0:3]
pts = np.dot(pts, axis_align_matrix.transpose()) # Nx4
aligned_vertices = np.copy(mesh_vertices)
aligned_vertices[:,0:3] = pts[:,0:3]
else:
print("No axis alignment matrix found")
aligned_vertices = mesh_vertices

# Load semantic and instance labels
if os.path.isfile(agg_file):
object_id_to_segs, label_to_segs = read_aggregation(agg_file)
seg_to_verts, num_verts = read_segmentation(seg_file)

label_ids = np.zeros(shape=(num_verts), dtype=np.uint32) # 0: unannotated
object_id_to_label_id = {}
for label, segs in label_to_segs.items():
label_id = label_map[label]
for seg in segs:
verts = seg_to_verts[seg]
label_ids[verts] = label_id
instance_ids = np.zeros(shape=(num_verts), dtype=np.uint32) # 0: unannotated
num_instances = len(np.unique(list(object_id_to_segs.keys())))
for object_id, segs in object_id_to_segs.items():
for seg in segs:
verts = seg_to_verts[seg]
instance_ids[verts] = object_id
if object_id not in object_id_to_label_id:
object_id_to_label_id[object_id] = label_ids[verts][0]

instance_bboxes = np.zeros((num_instances,8)) # also include object id
aligned_instance_bboxes = np.zeros((num_instances,8)) # also include object id
for obj_id in object_id_to_segs:
label_id = object_id_to_label_id[obj_id]

# bboxes in the original meshes
obj_pc = mesh_vertices[instance_ids==obj_id, 0:3]
if len(obj_pc) == 0: continue
# Compute axis aligned box
# An axis aligned bounding box is parameterized by
# (cx,cy,cz) and (dx,dy,dz) and label id
# where (cx,cy,cz) is the center point of the box,
# dx is the x-axis length of the box.
xmin = np.min(obj_pc[:,0])
ymin = np.min(obj_pc[:,1])
zmin = np.min(obj_pc[:,2])
xmax = np.max(obj_pc[:,0])
ymax = np.max(obj_pc[:,1])
zmax = np.max(obj_pc[:,2])
bbox = np.array([(xmin+xmax)/2, (ymin+ymax)/2, (zmin+zmax)/2, xmax-xmin, ymax-ymin, zmax-zmin, label_id, obj_id-1]) # also include object id
# NOTE: this assumes obj_id is in 1,2,3,.,,,.NUM_INSTANCES
instance_bboxes[obj_id-1,:] = bbox

# bboxes in the aligned meshes
obj_pc = aligned_vertices[instance_ids==obj_id, 0:3]
if len(obj_pc) == 0: continue
# Compute axis aligned box
# An axis aligned bounding box is parameterized by
# (cx,cy,cz) and (dx,dy,dz) and label id
# where (cx,cy,cz) is the center point of the box,
# dx is the x-axis length of the box.
xmin = np.min(obj_pc[:,0])
ymin = np.min(obj_pc[:,1])
zmin = np.min(obj_pc[:,2])
xmax = np.max(obj_pc[:,0])
ymax = np.max(obj_pc[:,1])
zmax = np.max(obj_pc[:,2])
bbox = np.array([(xmin+xmax)/2, (ymin+ymax)/2, (zmin+zmax)/2, xmax-xmin, ymax-ymin, zmax-zmin, label_id, obj_id-1]) # also include object id
# NOTE: this assumes obj_id is in 1,2,3,.,,,.NUM_INSTANCES
aligned_instance_bboxes[obj_id-1,:] = bbox
else:
# use zero as placeholders for the test scene
print("use placeholders")
num_verts = mesh_vertices.shape[0]
label_ids = np.zeros(shape=(num_verts), dtype=np.uint32) # 0: unannotated
instance_ids = np.zeros(shape=(num_verts), dtype=np.uint32) # 0: unannotated
instance_bboxes = np.zeros((1, 8)) # also include object id
aligned_instance_bboxes = np.zeros((1, 8)) # also include object id

if output_file is not None:
np.save(output_file+'_vert.npy', mesh_vertices)
np.save(output_file+'_aligned_vert.npy', aligned_vertices)
np.save(output_file+'_sem_label.npy', label_ids)
np.save(output_file+'_ins_label.npy', instance_ids)
np.save(output_file+'_bbox.npy', instance_bboxes)
np.save(output_file+'_aligned_bbox.npy', instance_bboxes)

return mesh_vertices, aligned_vertices, label_ids, instance_ids, instance_bboxes, aligned_instance_bboxes

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--scan_path', required=True, help='path to scannet scene (e.g., data/ScanNet/v2/scene0000_00')
parser.add_argument('--output_file', required=True, help='output file')
parser.add_argument('--label_map_file', required=True, help='path to scannetv2-labels.combined.tsv')
opt = parser.parse_args()

scan_name = os.path.split(opt.scan_path)[-1]
mesh_file = os.path.join(opt.scan_path, scan_name + '_vh_clean_2.ply')
agg_file = os.path.join(opt.scan_path, scan_name + '.aggregation.json')
seg_file = os.path.join(opt.scan_path, scan_name + '_vh_clean_2.0.010000.segs.json')
meta_file = os.path.join(opt.scan_path, scan_name + '.txt') # includes axisAlignment info for the train set scans.
export(mesh_file, agg_file, seg_file, meta_file, opt.label_map_file, opt.output_file)

if __name__ == '__main__':
main()
Loading