Skip to content
4 changes: 2 additions & 2 deletions example/tensorflow/code_template/tensorflow_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# Read the protobuf text and build a tf.GraphDef
with open(model_file_name, 'r') as model_file:
model_protobuf = text_format.Parse(model_file.read(),
tf.GraphDef())
tf.MetaGraphDef())

# Import the GraphDef built above into the default graph
tf.import_graph_def(model_protobuf)
tf.train.import_meta_graph(model_protobuf)

# You can now add operations on top of the imported graph
13 changes: 8 additions & 5 deletions ide/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,19 +311,22 @@ def isProcessPossible(layerId):
json_str = json_str.strip("'<>() ").replace('\'', '\"')
lrnLayer = imp.load_source('LRN', BASE_DIR + '/keras_app/custom_layers/lrn.py')

# clear clutter from previous graph built by keras to avoid duplicates
K.clear_session()

model = model_from_json(json_str, {'LRN': lrnLayer.LRN})

sess = K.get_session()
tf.train.write_graph(sess.graph.as_graph_def(add_shapes=True), output_fld,
output_file + '.pbtxt', as_text=True)
tf.train.export_meta_graph(
Comment thread
abhigyan7 marked this conversation as resolved.
os.path.join(output_fld, output_file + '.meta'),
as_text=True)

Channel(reply_channel).send({
'text': json.dumps({
'result': 'success',
'action': 'ExportNet',
'id': 'randomId',
'name': randomId + '.pbtxt',
'url': '/media/' + randomId + '.pbtxt',
'name': randomId + '.meta',
'url': '/media/' + randomId + '.meta',
'customLayers': custom_layers_response
})
})
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_app/views/export_graphdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def export_to_tensorflow(request):
randomId = response['randomId']
customLayers = response['customLayers']
os.chdir(BASE_DIR + '/tensorflow_app/views/')
os.system('KERAS_BACKEND=tensorflow python json2pbtxt.py -input_file ' +
os.system('KERAS_BACKEND=tensorflow python json2meta.py -input_file ' +
randomId + '.json -output_file ' + randomId)
return JsonResponse({'result': 'success',
'id': randomId,
'name': randomId + '.pbtxt',
'url': '/media/' + randomId + '.pbtxt',
'name': randomId + '.meta',
'url': '/media/' + randomId + '.meta',
'customLayers': customLayers})
46 changes: 42 additions & 4 deletions tensorflow_app/views/import_graphdef.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from django.views.decorators.csrf import csrf_exempt
from django.http import JsonResponse
import math
Expand Down Expand Up @@ -125,6 +124,46 @@ def get_padding(node, layer, session, input_layer_name, input_layer_dim):
return int(pad_h), int(pad_w)


def get_graph_def_from(model_protobuf):
"""
Parses and returns a GraphDef from input protobuf.

Args:
model_protobuf: a binary or text protobuf message.

Returns:
a tf.GraphDef object with the GraphDef from model_protobuf

Raises:
ValueError: if a GraphDef cannot be parsed from model_protobuf
"""
try:
meta_graph_def = text_format.Merge(model_protobuf, tf.MetaGraphDef())
graph_def = meta_graph_def.graph_def
return graph_def
except (text_format.ParseError, UnicodeDecodeError):
# not a valid text metagraphdef
pass
try:
graph_def = text_format.Merge(model_protobuf, tf.GraphDef())
return graph_def
except (text_format.ParseError, UnicodeDecodeError):
pass
try:
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_protobuf)
return graph_def
except Exception:
pass
try:
meta_graph_def = tf.MetaGraphDef()
meta_graph_def.ParseFromString(model_protobuf)
return meta_graph_def.graph_def
except Exception:
pass
raise ValueError('Invalid model protobuf')


@csrf_exempt
def import_graph_def(request):
if request.method == 'POST':
Expand All @@ -151,15 +190,14 @@ def import_graph_def(request):
return JsonResponse({'result': 'error', 'error': 'No GraphDef model found'})

tf.reset_default_graph()
graph_def = graph_pb2.GraphDef()
d = {}
order = []
input_layer_name = ''
input_layer_dim = []

try:
text_format.Merge(config, graph_def)
except Exception:
graph_def = get_graph_def_from(config)
except ValueError:
return JsonResponse({'result': 'error', 'error': 'Invalid GraphDef'})

tf.import_graph_def(graph_def, name='')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
parser.add_argument('-input_file', action="store",
dest='input_file', type=str, default='model.json')
parser.add_argument('-output_file', action="store",
dest='output_file', type=str, default='model.pbtxt')
dest='output_file', type=str, default='model.meta')
args = parser.parse_args()
input_file = args.input_file
output_file = args.output_file
Expand All @@ -30,6 +30,6 @@
lrn = imp.load_source('LRN', BASE_DIR + '/keras_app/custom_layers/lrn.py')
model = model_from_json(json_str, {'LRN': lrn.LRN})

sess = K.get_session()
tf.train.write_graph(sess.graph.as_graph_def(add_shapes=True), output_fld,
output_file + '.pbtxt', as_text=True)
tf.train.export_meta_graph(
os.path.join(output_fld, output_file + '.meta'),
as_text=True)
28 changes: 28 additions & 0 deletions tests/unit/tensorflow_app/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,31 @@ def test_custom_lrn_tf_import(self):
response = self.client.post(reverse('tf-import'), {'file': model_file})
response = json.loads(response.content)
self.assertEqual(response['result'], 'success')


class ExportMetaGraphTest(unittest.TestCase):
def setUp(self):
self.client = Client()

def test_tf_export(self):
model_file = open(os.path.join(settings.BASE_DIR, 'example/keras',
'AlexNet.json'), 'r')
response = self.client.post(reverse('keras-import'), {'file': model_file})
response = json.loads(response.content)
net = get_shapes(response['net'])
response = self.client.post(reverse('tf-export'), {'net': json.dumps(net),
'net_name': ''})
response = json.loads(response.content)
self.assertEqual(response['result'], 'success')


class ImportMetaGraphTest(unittest.TestCase):
def setUp(self):
self.client = Client()

def test_tf_import(self):
model_file = open(os.path.join(settings.BASE_DIR, 'tests/unit/tensorflow_app',
'vgg16_import_test.meta'), 'r')
response = self.client.post(reverse('tf-import'), {'file': model_file})
response = json.loads(response.content)
self.assertEqual(response['result'], 'success')
Loading