diff --git a/README.md b/README.md index e95c24d0..0244e504 100644 --- a/README.md +++ b/README.md @@ -164,6 +164,8 @@ The code is developed under the following configurations. - Hardware: >=4 GPUs for training, >=1 GPU for testing (set ```[--gpus GPUS]``` accordingly) - Software: Ubuntu 16.04.3 LTS, ***CUDA>=8.0, Python>=3.5, PyTorch>=0.4.0*** - Dependencies: numpy, scipy, opencv, yacs, tqdm +- gdown https://drive.google.com/file/d/1Il1Pcb13syeHi9LA9KjXz8KqMFN9izgo -O ckpt.zip +- unzip -j ckpt.zip ## Quick start: Test on an image using our trained model 1. Here is a simple demo to do inference on a single image: diff --git a/demo_test.sh b/demo_test.sh index 2af79c0c..1c4dc8df 100755 --- a/demo_test.sh +++ b/demo_test.sh @@ -35,3 +35,5 @@ python3 -u test.py \ TEST.checkpoint epoch_20.pth fi +# MODEL_NAME=ade20k-hrnetv2-c1 +# python3 -u test_cpu.py --imgs ADE_val_00001519.jpg --cfg config/ade20k-hrnetv2.yaml diff --git a/mit_semseg/lib/utils/th.py b/mit_semseg/lib/utils/th.py index ca6ef938..9e952a99 100644 --- a/mit_semseg/lib/utils/th.py +++ b/mit_semseg/lib/utils/th.py @@ -8,17 +8,17 @@ def as_variable(obj): if isinstance(obj, Variable): return obj - if isinstance(obj, collections.Sequence): + if isinstance(obj, collections.abc.Sequence): return [as_variable(v) for v in obj] - elif isinstance(obj, collections.Mapping): + elif isinstance(obj, collections.abc.Mapping): return {k: as_variable(v) for k, v in obj.items()} else: return Variable(obj) def as_numpy(obj): - if isinstance(obj, collections.Sequence): + if isinstance(obj, collections.abc.Sequence): return [as_numpy(v) for v in obj] - elif isinstance(obj, collections.Mapping): + elif isinstance(obj, collections.abc.Mapping): return {k: as_numpy(v) for k, v in obj.items()} elif isinstance(obj, Variable): return obj.data.cpu().numpy() @@ -33,9 +33,9 @@ def mark_volatile(obj): if isinstance(obj, Variable): obj.no_grad = True return obj - elif isinstance(obj, collections.Mapping): + elif isinstance(obj, collections.abc.Mapping): return {k: mark_volatile(o) for k, o in obj.items()} - elif isinstance(obj, collections.Sequence): + elif isinstance(obj, collections.abc.Sequence): return [mark_volatile(o) for o in obj] else: return obj diff --git a/requirements.txt b/requirements.txt index 154a2c40..8e448121 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,16 @@ +--extra-index-url https://download.pytorch.org/whl/cpu numpy scipy -pytorch==0.4.1 -torchvision -opencv3 +torch==1.12.1+cpu +torchvision==0.13.1+cpu +opencv-python yacs tqdm +aiohttp +aiofiles +aiohttp_cors +pillow +gdown +uvicorn[standard] +fastapi +python-multipart diff --git a/test_cpu.py b/test_cpu.py new file mode 100644 index 00000000..f5abb31a --- /dev/null +++ b/test_cpu.py @@ -0,0 +1,191 @@ +# System libs +import os +import argparse +from distutils.version import LooseVersion +# Numerical libs +import numpy as np +import torch +import torch.nn as nn +from scipy.io import loadmat +import csv +# Our libs +from mit_semseg.dataset import TestDataset +from mit_semseg.models import ModelBuilder, SegmentationModule +from mit_semseg.utils import colorEncode, find_recursive, setup_logger +from mit_semseg.lib.nn import user_scattered_collate, async_copy_to +from mit_semseg.lib.utils import as_numpy +from PIL import Image +from tqdm import tqdm +from mit_semseg.config import cfg + +colors = loadmat('data/color150.mat')['colors'] +names = {} +with open('data/object150_info.csv') as f: + reader = csv.reader(f) + next(reader) + for row in reader: + names[int(row[0])] = row[5].split(";")[0] + + +def visualize_result(data, pred, cfg): + (img, info) = data + + # print predictions in descending order + pred = np.int32(pred) + pixs = pred.size + uniques, counts = np.unique(pred, return_counts=True) + print("Predictions in [{}]:".format(info)) + for idx in np.argsort(counts)[::-1]: + name = names[uniques[idx] + 1] + ratio = counts[idx] / pixs * 100 + if ratio > 0.1: + print(" {}: {:.2f}%".format(name, ratio)) + + # colorize prediction + pred_color = colorEncode(pred, colors).astype(np.uint8) + + # aggregate images and save + im_vis = np.concatenate((img, pred_color), axis=1) + + img_name = info.split('/')[-1] + Image.fromarray(im_vis).save( + os.path.join(cfg.TEST.result, img_name.replace('.jpg', '.png'))) + + +def test(segmentation_module, loader): + segmentation_module.eval() + + pbar = tqdm(total=len(loader)) + for batch_data in loader: + # process data + batch_data = batch_data[0] + segSize = (batch_data['img_ori'].shape[0], + batch_data['img_ori'].shape[1]) + img_resized_list = batch_data['img_data'] + + with torch.no_grad(): + scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0], segSize[1]) + + for img in img_resized_list: + feed_dict = batch_data.copy() + feed_dict['img_data'] = img + del feed_dict['img_ori'] + del feed_dict['info'] + + # forward pass + pred_tmp = segmentation_module(feed_dict, segSize=segSize) + scores = scores + pred_tmp / len(cfg.DATASET.imgSizes) + + _, pred = torch.max(scores, dim=1) + pred = as_numpy(pred.squeeze(0)) + + # visualization + visualize_result( + (batch_data['img_ori'], batch_data['info']), + pred, + cfg + ) + + pbar.update(1) + + + +def main(cfg): + # Network Builders + net_encoder = ModelBuilder.build_encoder( + arch=cfg.MODEL.arch_encoder, + fc_dim=cfg.MODEL.fc_dim, + weights=cfg.MODEL.weights_encoder) + net_decoder = ModelBuilder.build_decoder( + arch=cfg.MODEL.arch_decoder, + fc_dim=cfg.MODEL.fc_dim, + num_class=cfg.DATASET.num_class, + weights=cfg.MODEL.weights_decoder, + use_softmax=True) + + crit = nn.NLLLoss(ignore_index=-1) + + segmentation_module = SegmentationModule(net_encoder, net_decoder, crit).cpu() + + # Dataset and Loader + dataset_test = TestDataset( + cfg.list_test, + cfg.DATASET) + loader_test = torch.utils.data.DataLoader( + dataset_test, + batch_size=cfg.TEST.batch_size, + shuffle=False, + collate_fn=user_scattered_collate, + num_workers=5, + drop_last=True) + + # Main loop + test(segmentation_module, loader_test) + + print('Inference done!') + + +if __name__ == '__main__': + assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), \ + 'PyTorch>=0.4.0 is required' + + parser = argparse.ArgumentParser( + description="PyTorch Semantic Segmentation Testing" + ) + parser.add_argument( + "--imgs", + required=True, + type=str, + help="an image path, or a directory name" + ) + parser.add_argument( + "--cfg", + default="config/ade20k-resnet50dilated-ppm_deepsup.yaml", + metavar="FILE", + help="path to config file", + type=str, + ) + parser.add_argument( + "opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + args = parser.parse_args() + + cfg.merge_from_file(args.cfg) + cfg.merge_from_list(args.opts) + # cfg.freeze() + + logger = setup_logger(distributed_rank=0) # TODO + logger.info("Loaded configuration file {}".format(args.cfg)) + logger.info("Running with config:\n{}".format(cfg)) + + cfg.MODEL.arch_encoder = cfg.MODEL.arch_encoder.lower() + cfg.MODEL.arch_decoder = cfg.MODEL.arch_decoder.lower() + + # absolute paths of model weights + cfg.MODEL.weights_encoder = os.path.join( + cfg.DIR, 'encoder_' + cfg.TEST.checkpoint) + cfg.MODEL.weights_decoder = os.path.join( + cfg.DIR, 'decoder_' + cfg.TEST.checkpoint) + + + print(cfg.MODEL.weights_encoder) + print(cfg.MODEL.weights_decoder) + + assert os.path.exists(cfg.MODEL.weights_encoder) and \ + os.path.exists(cfg.MODEL.weights_decoder), "checkpoint does not exist!" + + # generate testing image list + if os.path.isdir(args.imgs): + imgs = find_recursive(args.imgs) + else: + imgs = [args.imgs] + assert len(imgs), "imgs should be a path to image (.jpg) or directory." + cfg.list_test = [{'fpath_img': x} for x in imgs] + + if not os.path.isdir(cfg.TEST.result): + os.makedirs(cfg.TEST.result) + + main(cfg)