Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 65 additions & 25 deletions mrcnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
from collections import OrderedDict
import multiprocessing
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.compat.v1.disable_eager_execution()
import keras
import keras.backend as K
import keras.layers as KL
import keras.engine as KE
import keras.models as KM

from mrcnn import utils
Expand Down Expand Up @@ -49,7 +49,18 @@ def log(text, array=None):
text += " {}".format(array.dtype)
print(text)

class AnchorsLayer(KL.Layer):
def __init__(self, anchors, name="anchors", **kwargs):
super(AnchorsLayer, self).__init__(name=name, **kwargs)
self.anchors = tf.Variable(anchors)

def call(self, dummy):
return self.anchors

def get_config(self):
config = super(AnchorsLayer, self).get_config()
return config

class BatchNorm(KL.BatchNormalization):
"""Extends the Keras BatchNormalization class to allow a central place
to make changes if needed.
Expand Down Expand Up @@ -252,7 +263,7 @@ def clip_boxes_graph(boxes, window):
return clipped


class ProposalLayer(KE.Layer):
class ProposalLayer(KL.Layer):
"""Receives anchor scores and selects a subset to pass as proposals
to the second stage. Filtering is done based on anchor scores and
non-max suppression to remove overlaps. It also applies bounding
Expand Down Expand Up @@ -338,10 +349,10 @@ def compute_output_shape(self, input_shape):

def log2_graph(x):
"""Implementation of Log2. TF doesn't have a native implementation."""
return tf.log(x) / tf.log(2.0)
return tf.math.log(x) / tf.math.log(2.0)


class PyramidROIAlign(KE.Layer):
class PyramidROIAlign(KL.Layer):
"""Implements ROI Pooling on multiple levels of the feature pyramid.

Params:
Expand Down Expand Up @@ -619,7 +630,7 @@ def detection_targets_graph(proposals, gt_class_ids, gt_boxes, gt_masks, config)
return rois, roi_gt_class_ids, deltas, masks


class DetectionTargetLayer(KE.Layer):
class DetectionTargetLayer(KL.Layer):
"""Subsamples proposals and generates target box refinement, class_ids,
and masks for each.

Expand Down Expand Up @@ -699,7 +710,7 @@ def refine_detections_graph(rois, probs, deltas, window, config):
# Class IDs per ROI
class_ids = tf.argmax(probs, axis=1, output_type=tf.int32)
# Class probability of the top class of each ROI
indices = tf.stack([tf.range(probs.shape[0]), class_ids], axis=1)
indices = tf.stack([tf.range(tf.shape(probs)[0]), class_ids], axis = 1)
class_scores = tf.gather_nd(probs, indices)
# Class-specific bounding box deltas
deltas_specific = tf.gather_nd(deltas, indices)
Expand All @@ -717,7 +728,7 @@ def refine_detections_graph(rois, probs, deltas, window, config):
# Filter out low confidence boxes
if config.DETECTION_MIN_CONFIDENCE:
conf_keep = tf.where(class_scores >= config.DETECTION_MIN_CONFIDENCE)[:, 0]
keep = tf.sets.set_intersection(tf.expand_dims(keep, 0),
keep = tf.compat.v1.sets.set_intersection(tf.expand_dims(keep, 0),
tf.expand_dims(conf_keep, 0))
keep = tf.sparse_tensor_to_dense(keep)[0]

Expand Down Expand Up @@ -755,7 +766,7 @@ def nms_keep_map(class_id):
nms_keep = tf.reshape(nms_keep, [-1])
nms_keep = tf.gather(nms_keep, tf.where(nms_keep > -1)[:, 0])
# 4. Compute intersection between keep and nms_keep
keep = tf.sets.set_intersection(tf.expand_dims(keep, 0),
keep = tf.compat.v1.sets.set_intersection(tf.expand_dims(keep, 0),
tf.expand_dims(nms_keep, 0))
keep = tf.sparse_tensor_to_dense(keep)[0]
# Keep top detections
Expand All @@ -779,7 +790,7 @@ def nms_keep_map(class_id):
return detections


class DetectionLayer(KE.Layer):
class DetectionLayer(KL.Layer):
"""Takes classified proposal boxes and their bounding box deltas and
returns the final detection boxes.

Expand Down Expand Up @@ -948,7 +959,10 @@ def fpn_classifier_graph(rois, feature_maps, image_meta,
name='mrcnn_bbox_fc')(shared)
# Reshape to [batch, num_rois, NUM_CLASSES, (dy, dx, log(dh), log(dw))]
s = K.int_shape(x)
mrcnn_bbox = KL.Reshape((s[1], num_classes, 4), name="mrcnn_bbox")(x)
if s[1]==None:
mrcnn_bbox = KL.Reshape((-1, num_classes, 4), name="mrcnn_bbox")(x)
else:
mrcnn_bbox = KL.Reshape((s[1], num_classes, 4), name="mrcnn_bbox")(x)

return mrcnn_class_logits, mrcnn_probs, mrcnn_bbox

Expand Down Expand Up @@ -1872,8 +1886,9 @@ def build(self, mode, config):
input_gt_boxes = KL.Input(
shape=[None, 4], name="input_gt_boxes", dtype=tf.float32)
# Normalize coordinates
gt_boxes = KL.Lambda(lambda x: norm_boxes_graph(
x, K.shape(input_image)[1:3]))(input_gt_boxes)
# gt_boxes = KL.Lambda(lambda x: norm_boxes_graph(
# x, K.shape(input_image)[1:3]))(input_gt_boxes)
gt_boxes = KL.Lambda(lambda x: norm_boxes_graph2(x))([input_gt_boxes,input_image])
# 3. GT Masks (zero padded)
# [batch, height, width, MAX_GT_INSTANCES]
if config.USE_MINI_MASK:
Expand Down Expand Up @@ -1931,7 +1946,8 @@ def build(self, mode, config):
# TODO: can this be optimized to avoid duplicating the anchors?
anchors = np.broadcast_to(anchors, (config.BATCH_SIZE,) + anchors.shape)
# A hack to get around Keras's bad support for constants
anchors = KL.Lambda(lambda x: tf.Variable(anchors), name="anchors")(input_image)
# anchors = KL.Lambda(lambda x: tf.Variable(anchors), name="anchors")(input_image)
anchors = AnchorsLayer(anchors, name="anchors")(input_image)
else:
anchors = input_anchors

Expand Down Expand Up @@ -2102,7 +2118,7 @@ def load_weights(self, filepath, by_name=False, exclude=None):
# Conditional import to support versions of Keras before 2.2
# TODO: remove in about 6 months (end of 2018)
try:
from keras.engine import saving
from tensorflow.python.keras import saving
except ImportError:
# Keras before 2.2 used the 'topology' namespace.
from keras.engine import topology as saving
Expand All @@ -2127,9 +2143,9 @@ def load_weights(self, filepath, by_name=False, exclude=None):
layers = filter(lambda l: l.name not in exclude, layers)

if by_name:
saving.load_weights_from_hdf5_group_by_name(f, layers)
saving.hdf5_format.load_weights_from_hdf5_group_by_name(f, layers)
else:
saving.load_weights_from_hdf5_group(f, layers)
saving.hdf5_format.load_weights_from_hdf5_group(f, layers)
if hasattr(f, 'close'):
f.close()

Expand All @@ -2155,24 +2171,43 @@ def compile(self, learning_rate, momentum):
metrics. Then calls the Keras compile() function.
"""
# Optimizer object
optimizer = keras.optimizers.SGD(
optimizer = tf.keras.optimizers.SGD(
lr=learning_rate, momentum=momentum,
clipnorm=self.config.GRADIENT_CLIP_NORM)
# Add Losses
# # First, clear previously set losses to avoid duplication
# self.keras_model._losses = []
# self.keras_model._per_input_losses = {}
# First, clear previously set losses to avoid duplication
self.keras_model._losses = []
self.keras_model._per_input_losses = {}
try:
self.keras_model._losses.clear()
except AttributeError:
pass
try:
self.keras_model._per_input_losses.clear()
except AttributeError:
pass
loss_names = [
"rpn_class_loss", "rpn_bbox_loss",
"mrcnn_class_loss", "mrcnn_bbox_loss", "mrcnn_mask_loss"]
# for name in loss_names:
# layer = self.keras_model.get_layer(name)
# if layer.output in self.keras_model.losses:
# continue
# loss = (
# tf.reduce_mean(layer.output, keepdims=True)
# * self.config.LOSS_WEIGHTS.get(name, 1.))
# self.keras_model.add_loss(loss)
existing_layer_names = []
for name in loss_names:
layer = self.keras_model.get_layer(name)
if layer.output in self.keras_model.losses:
if layer is None or name in existing_layer_names:
continue
loss = (
tf.reduce_mean(layer.output, keepdims=True)
* self.config.LOSS_WEIGHTS.get(name, 1.))
existing_layer_names.append(name)
loss = (tf.reduce_mean(layer.output, keepdims=True)
* self.config.LOSS_WEIGHTS.get(name, 1.))
self.keras_model.add_loss(loss)
# loss = tf.reduce_mean(input_tensor=layer.output, keepdims=True)

# Add L2 Regularization
# Skip gamma and beta weights of batch normalization layers.
Expand All @@ -2196,7 +2231,8 @@ def compile(self, learning_rate, momentum):
loss = (
tf.reduce_mean(layer.output, keepdims=True)
* self.config.LOSS_WEIGHTS.get(name, 1.))
self.keras_model.metrics_tensors.append(loss)
# self.keras_model.metrics_tensors.append(loss)
self.keras_model.add_metric(loss, name=name, aggregation='mean')

def set_trainable(self, layer_regex, keras_model=None, indent=0, verbose=1):
"""Sets model layers as trainable if their names match
Expand Down Expand Up @@ -2850,6 +2886,10 @@ def norm_boxes_graph(boxes, shape):
shift = tf.constant([0., 0., 1., 1.])
return tf.divide(boxes - shift, scale)

def norm_boxes_graph2(x):
boxes,tensor_for_shape = x
shape = tf.shape(tensor_for_shape)[1:3]
return norm_boxes_graph(boxes,shape)

def denorm_boxes_graph(boxes, shape):
"""Converts boxes from normalized coordinates to pixel coordinates.
Expand Down
4 changes: 2 additions & 2 deletions mrcnn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ def box_refinement_graph(box, gt_box):

dy = (gt_center_y - center_y) / height
dx = (gt_center_x - center_x) / width
dh = tf.log(gt_height / height)
dw = tf.log(gt_width / width)
dh = tf.math.log(gt_height / height)
dw = tf.math.log(gt_width / width)

result = tf.stack([dy, dx, dh, dw], axis=1)
return result
Expand Down
12 changes: 12 additions & 0 deletions report.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
TensorFlow 2.0 Upgrade Script
-----------------------------
Converted 0 files
Detected 0 issues that require attention
--------------------------------------------------------------------------------
================================================================================
Detailed log follows:

================================================================================
================================================================================
Input tree: 'Mask_RCNN'
================================================================================
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ scipy
Pillow
cython
matplotlib
scikit-image
scikit-image==0.16.2
tensorflow>=1.3.0
keras>=2.0.8
opencv-python
Expand Down
Loading