- 
                Notifications
    
You must be signed in to change notification settings  - Fork 348
 
partialconv
        sngyo edited this page Nov 4, 2019 
        ·
        2 revisions
      
    Export script for partialconv
import argparse
import os   
import torch
import torchvision.models as models_baseline # networks with zero padding
import models as models_partial # partial conv based padding 
model_baseline_names = sorted(name for name in models_baseline.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models_baseline.__dict__[name]))
model_partial_names = sorted(name for name in models_partial.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models_partial.__dict__[name]))
model_names = model_baseline_names + model_partial_names
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50',
                    choices=model_names,
                    help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet50)')
def main():
    global args
    args = parser.parse_args()
    
    print("=> using pre-trained model '{}'".format(args.arch))
    if args.arch in models_baseline.__dict__:
        model = models_baseline.__dict__[args.arch](pretrained=True)
    else:
        model = models_partial.__dict__[args.arch](pretrained=True)
    print(model)
    model.eval()
    dummy = torch.autograd.Variable(torch.randn(1, 3, 224, 224))
    out = model(dummy)
    torch.onnx.export(model, dummy, args.arch + '.onnx', verbose=True, opset_version=10)
    print('Export is done')
if __name__ == '__main__':
    main()(c) 2019 ax Inc. & AXELL CORPORATION