Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,24 @@
# CSPN implemented in Pytorch 0.4.1

### Build & Run

Install the following depencies:
```sh
pip install tensorboardx freeze
```


This branch fixed some bugs and added the implementation for reading KITTI dataset.
You need to download the following dataset in HDF5 format:
```sh
mkdir data; cd data
wget http://datasets.lids.mit.edu/sparse-to-dense/data/nyudepthv2.tar.gz
tar -xvf nyudepthv2.tar.gz && rm -f nyudepthv2.tar.gz
wget http://datasets.lids.mit.edu/sparse-to-dense/data/kitti.tar.gz
tar -xvf kitti.tar.gz && rm -f kitti.tar.gz
cd ..
```


### Introduction
This is a PyTorch(0.4.1) implementation of [Depth Estimation via Affinity Learned with Convolutional Spatial Propagation Network](http://arxiv.org/abs/1808.00150). At present, we can provide train script in NYU Depth V2 dataset for depth completion and monocular depth estimation. KITTI will be available soon!
Expand Down
3 changes: 2 additions & 1 deletion dataloaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def create_loader(args, mode='train'):
from dataloaders.nyu_dataloader import create_loader
return create_loader(args, mode=mode)
elif args.dataset.lower() == 'kitti':
return NotImplementedError
from dataloaders.kitti_dataloader import create_loader
return create_loader(args, mode=mode)
else:
return NotImplementedError
62 changes: 51 additions & 11 deletions dataloaders/kitti_dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,60 @@
@Email : wangxin_buaa@163.com
@File : __init__.py.py
"""
from functools import cmp_to_key

x = [[1,2], [2, 1], [3, 4]]

def mycmp(x, y):
if x[1] == y[1]:
return x[0] - y[0]
return x[1] - y[1]
def create_loader(args, mode='train'):
# Data loading code
print('=> creating ', mode, ' loader ...')
import os
from dataloaders.path import Path
root_dir = Path.db_root_dir(args.dataset)

print(x)
x= sorted(x, key=cmp_to_key(mycmp))
print(x)
# sparsifier is a class for generating random sparse depth input from the ground truth
import numpy as np
sparsifier = None
max_depth = args.max_depth if args.max_depth >= 0.0 else np.inf
from dataloaders.nyu_dataloader.dense_to_sparse import UniformSampling
from dataloaders.nyu_dataloader.dense_to_sparse import SimulatedStereo
if args.sparsifier == UniformSampling.name:
sparsifier = UniformSampling(num_samples=args.num_samples, max_depth=max_depth)
elif args.sparsifier == SimulatedStereo.name:
sparsifier = SimulatedStereo(num_samples=args.num_samples, max_depth=max_depth)

y = set(0)
from dataloaders.kitti_dataloader.kitti_dataloader import KITTIDataset

import torch
if mode.lower() == 'train':
traindir = os.path.join(root_dir, 'train')

from queue import PriorityQueue
if os.path.exists(traindir):
print('Train dataset "{}" is existed!'.format(traindir))
else:
print('Train dataset "{}" is not existed!'.format(traindir))
exit(-1)
train_dataset = KITTIDataset(traindir, type='train',
modality=args.modality, sparsifier=sparsifier)
# worker_init_fn ensures different sampling patterns for each data loading thread
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True, sampler=None,
worker_init_fn=lambda work_id: np.random.seed(work_id))

return train_loader

elif mode.lower() == 'val':
valdir = os.path.join(root_dir, 'val')
if os.path.exists(valdir):
print('Val dataset "{}" is existed!'.format(valdir))
else:
print('Val dataset "{}" is not existed!'.format(valdir))
exit(-1)
val_dataset = KITTIDataset(valdir, type='val',
modality=args.modality, sparsifier=sparsifier)
val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True)

return val_loader

else:
raise NotImplementedError
136 changes: 136 additions & 0 deletions dataloaders/kitti_dataloader/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#
# copyright: Fangchang Ma
# https://github.com/fangchangma/sparse-to-dense.pytorch/blob/master/dataloaders/dataloader.py
#
#
import os
import os.path
import numpy as np
import torch.utils.data as data
import h5py
import dataloaders.transforms as transforms

IMG_EXTENSIONS = ['.h5',]

def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def find_classes(dir):
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx

def make_dataset(dir, class_to_idx):
images = []
dir = os.path.expanduser(dir)
for target in sorted(os.listdir(dir)):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
if is_image_file(fname):
path = os.path.join(root, fname)
item = (path, class_to_idx[target])
images.append(item)
return images

def h5_loader(path):
h5f = h5py.File(path, "r")
rgb = np.array(h5f['rgb'])
rgb = np.transpose(rgb, (1, 2, 0))
depth = np.array(h5f['depth'])
return rgb, depth

# def rgb2grayscale(rgb):
# return rgb[:,:,0] * 0.2989 + rgb[:,:,1] * 0.587 + rgb[:,:,2] * 0.114

to_tensor = transforms.ToTensor()

class MyDataloader(data.Dataset):
modality_names = ['rgb', 'rgbd', 'd'] # , 'g', 'gd'
color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4)

def __init__(self, root, type, sparsifier=None, modality='rgb', loader=h5_loader):
classes, class_to_idx = find_classes(root)
imgs = make_dataset(root, class_to_idx)
assert len(imgs)>0, "Found 0 images in subfolders of: " + root + "\n"
print("Found {} images in {} folder.".format(len(imgs), type))
self.root = root
self.imgs = imgs
self.classes = classes
self.class_to_idx = class_to_idx
if type == 'train':
self.transform = self.train_transform
elif type == 'val':
self.transform = self.val_transform
else:
raise (RuntimeError("Invalid dataset type: " + type + "\n"
"Supported dataset types are: train, val"))
self.loader = loader
self.sparsifier = sparsifier

assert (modality in self.modality_names), "Invalid modality type: " + modality + "\n" + \
"Supported dataset types are: " + ''.join(self.modality_names)
self.modality = modality

def train_transform(self, rgb, depth):
raise (RuntimeError("train_transform() is not implemented. "))

def val_transform(rgb, depth):
raise (RuntimeError("val_transform() is not implemented."))

def create_sparse_depth(self, rgb, depth):
if self.sparsifier is None:
return depth
else:
mask_keep = self.sparsifier.dense_to_sparse(rgb, depth)
sparse_depth = np.zeros(depth.shape)
sparse_depth[mask_keep] = depth[mask_keep]
return sparse_depth

def create_rgbd(self, rgb, depth):
sparse_depth = self.create_sparse_depth(rgb, depth)
rgbd = np.append(rgb, np.expand_dims(sparse_depth, axis=2), axis=2)
return rgbd

def __getraw__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (rgb, depth) the raw data.
"""
path, target = self.imgs[index]
rgb, depth = self.loader(path)
return rgb, depth

def __getitem__(self, index):
rgb, depth = self.__getraw__(index)
if self.transform is not None:
rgb_np, depth_np = self.transform(rgb, depth)
else:
raise(RuntimeError("transform not defined"))

# color normalization
# rgb_tensor = normalize_rgb(rgb_tensor)
# rgb_np = normalize_np(rgb_np)

if self.modality == 'rgb':
input_np = rgb_np
elif self.modality == 'rgbd':
input_np = self.create_rgbd(rgb_np, depth_np)
elif self.modality == 'd':
input_np = self.create_sparse_depth(rgb_np, depth_np)

input_tensor = to_tensor(input_np)
while input_tensor.dim() < 3:
input_tensor = input_tensor.unsqueeze(0)
depth_tensor = to_tensor(depth_np)
depth_tensor = depth_tensor.unsqueeze(0)

return input_tensor, depth_tensor

def __len__(self):
return len(self.imgs)
50 changes: 50 additions & 0 deletions dataloaders/kitti_dataloader/kitti_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#
# copyright: Fangchang Ma
# https://github.com/fangchangma/sparse-to-dense.pytorch/blob/master/dataloaders/kitti_dataloader.py
#
#
import numpy as np
import dataloaders.kitti_dataloader.transforms as transforms
from dataloaders.kitti_dataloader.dataloader import MyDataloader

class KITTIDataset(MyDataloader):
def __init__(self, root, type, sparsifier=None, modality='rgb'):
super(KITTIDataset, self).__init__(root, type, sparsifier, modality)
self.output_size = (228, 912)

def train_transform(self, rgb, depth):
s = np.random.uniform(1.0, 1.5) # random scaling
depth_np = depth / s
angle = np.random.uniform(-5.0, 5.0) # random rotation degrees
do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip

# perform 1st step of data augmentation
transform = transforms.Compose([
transforms.Crop(130, 10, 240, 1200),
transforms.Rotate(angle),
transforms.Resize(s),
transforms.CenterCrop(self.output_size),
transforms.HorizontalFlip(do_flip)
])
rgb_np = transform(rgb)
rgb_np = self.color_jitter(rgb_np) # random color jittering
rgb_np = np.asfarray(rgb_np, dtype='float') / 255
# Scipy affine_transform produced RuntimeError when the depth map was
# given as a 'numpy.ndarray'
depth_np = np.asfarray(depth_np, dtype='float32')
depth_np = transform(depth_np)

return rgb_np, depth_np

def val_transform(self, rgb, depth):
depth_np = depth
transform = transforms.Compose([
transforms.Crop(130, 10, 240, 1200),
transforms.CenterCrop(self.output_size),
])
rgb_np = transform(rgb)
rgb_np = np.asfarray(rgb_np, dtype='float') / 255
depth_np = np.asfarray(depth_np, dtype='float32')
depth_np = transform(depth_np)

return rgb_np, depth_np
Loading