Skip to content

i want to set the gpu fraction,but it failed, gpu is always fully occupied  #170

@zfs1993

Description

@zfs1993

i add these code to set the gpu fraction,
the yolo part (yolo.py)
class YOLO(object):
def init(self):
self.model_path = 'model_data/yolo.h5'
#self.model_path = 'model_data/yolo_tiny.h5'
self.anchors_path = 'model_data/yolo_anchors.txt'
self.classes_path = 'model_data/coco_classes.txt'
self.score = 0.5
self.iou = 0.5
self.class_names = self._get_class()
self.anchors = self._get_anchors()
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.3
config.gpu_options.allow_growth = True
sess=tf.Session(config=config)
self.sess = sess
set_session(sess)
self.model_image_size = (416, 416) # fixed size or (None, None)
self.is_fixed_size = self.model_image_size != (None, None)
self.boxes, self.scores, self.classes = self.generate()

the features part(tools/generate_detections.py)
class ImageEncoder(object):

def __init__(self, checkpoint_filename, input_name="images",
             output_name="features"):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    self.session=tf.Session(config=config)
    #self.session = tf.Session()
    with tf.gfile.GFile(checkpoint_filename, "rb") as file_handle:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(file_handle.read())
    tf.import_graph_def(graph_def, name="net")
    self.input_var = tf.get_default_graph().get_tensor_by_name(
        "net/%s:0" % input_name)
    self.output_var = tf.get_default_graph().get_tensor_by_name(
        "net/%s:0" % output_name)

    assert len(self.output_var.get_shape()) == 2
    assert len(self.input_var.get_shape()) == 4
    self.feature_dim = self.output_var.get_shape().as_list()[-1]
    self.image_shape = self.input_var.get_shape().as_list()[1:]

def __call__(self, data_x, batch_size=32):
    out = np.zeros((len(data_x), self.feature_dim), np.float32)
    _run_in_batches(
        lambda x: self.session.run(self.output_var, feed_dict=x),
        {self.input_var: data_x}, out, batch_size)
    return out

but it failed ,is there anyone who meet the same questions?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions