diff --git a/demo.py b/demo.py index a8d1232..0cc1d2f 100644 --- a/demo.py +++ b/demo.py @@ -1,11 +1,12 @@ -import argparse +from fire import Fire +import numpy as np +from PIL import Image +from scipy.misc import imresize import torch import torch.nn.parallel from models import modules, net, resnet, densenet, senet -import numpy as np import loaddata_demo as loaddata -import pdb import matplotlib.image import matplotlib.pyplot as plt @@ -29,23 +30,34 @@ def define_model(is_resnet, is_densenet, is_senet): return model -def main(): +def main(image_path): model = define_model(is_resnet=False, is_densenet=False, is_senet=True) - model = torch.nn.DataParallel(model).cuda() - model.load_state_dict(torch.load('./pretrained_model/model_senet')) + model = torch.nn.DataParallel(model) + if torch.cuda.is_available(): + model = model.cuda() + model.load_state_dict(torch.load( + f='./pretrained_model/model_senet', + map_location=None if torch.cuda.is_available() else 'cpu')) model.eval() - nyu2_loader = loaddata.readNyu2('data/demo/img_nyu2.png') + nyu2_loader = loaddata.readNyu2(image_path) test(nyu2_loader, model) def test(nyu2_loader, model): for i, image in enumerate(nyu2_loader): - image = torch.autograd.Variable(image, volatile=True).cuda() + image = torch.autograd.Variable(image, volatile=True) + if torch.cuda.is_available(): + image = image.cuda() out = model(image) - - matplotlib.image.imsave('data/demo/out.png', out.view(out.size(2),out.size(3)).data.cpu().numpy()) + + out = out.view(out.size(2), out.size(3)).data.cpu().numpy() + input_shape = image.data.cpu().numpy().shape[2:4] + out = imresize(arr=out, size=input_shape) + Image.fromarray(out.astype(np.uint8)).save('data/demo/out.png') + + # matplotlib.image.imsave('data/demo/out.png', out) if __name__ == '__main__': - main() + Fire(main) diff --git a/loaddata.py b/loaddata.py index fe1e701..cea3b52 100644 --- a/loaddata.py +++ b/loaddata.py @@ -1,9 +1,7 @@ import pandas as pd -import numpy as np from torch.utils.data import Dataset, DataLoader from torchvision import transforms, utils from PIL import Image -import random from nyu_transform import * diff --git a/loaddata_demo.py b/loaddata_demo.py index 915802e..50bcbea 100644 --- a/loaddata_demo.py +++ b/loaddata_demo.py @@ -1,9 +1,6 @@ -import pandas as pd -import numpy as np from torch.utils.data import Dataset, DataLoader from torchvision import transforms, utils from PIL import Image -import random from demo_transform import * diff --git a/models/modules.py b/models/modules.py index af5e05b..5340a98 100644 --- a/models/modules.py +++ b/models/modules.py @@ -1,14 +1,7 @@ -from collections import OrderedDict -import math import torch import torch.nn.functional as F import torch.nn as nn -from torch.utils import model_zoo -import copy -import numpy as np -import senet -import resnet -import densenet + class _UpProjection(nn.Sequential): diff --git a/models/net.py b/models/net.py index 81e99de..1b886bb 100644 --- a/models/net.py +++ b/models/net.py @@ -1,17 +1,7 @@ -from collections import OrderedDict -import math import torch -import torch.nn.functional as F import torch.nn as nn -from torch.utils import model_zoo -import copy -import numpy as np -import modules -from torchvision import utils +from models import modules -import senet -import resnet -import densenet class model(nn.Module): def __init__(self, Encoder, num_features, block_channel): diff --git a/nyu_transform.py b/nyu_transform.py index 114744a..50ab760 100644 --- a/nyu_transform.py +++ b/nyu_transform.py @@ -10,8 +10,6 @@ import random import scipy.ndimage as ndimage -import pdb - def _is_pil_image(img): if accimage is not None: diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8151247 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +numpy +matplotlib +torch +torchvision \ No newline at end of file diff --git a/test.py b/test.py index 3c10d1b..d400d46 100644 --- a/test.py +++ b/test.py @@ -1,6 +1,4 @@ -import argparse import torch -import torch.nn as nn import torch.nn.parallel from models import modules, net, resnet, densenet, senet diff --git a/train.py b/train.py index c22c3ac..7bfb832 100644 --- a/train.py +++ b/train.py @@ -7,8 +7,6 @@ import torch.backends.cudnn as cudnn import torch.optim import loaddata -import util -import numpy as np import sobel from models import modules, net, resnet, densenet, senet